In [1]:
# Copyright (c) Microsoft Corporation. All rights reserved.
# Licensed under the MIT License.
import argparse
import time
import shutil
import os
os.environ["CUDA_VISIBLE_DEVICES"] = '0,1,2,3'
import os.path as osp
import csv
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import torchvision.models as models
from transform_cnn import VA
from data_cnn import NTUDataLoaders, AverageMeter,  make_dir, get_cases, get_num_classes
import matplotlib
import matplotlib.pyplot as plt

In [2]:
args = argparse.ArgumentParser(description='View adaptive')
args.add_argument('--model', type=str, default='VA',
                  help='the neural network to use')
args.add_argument('--dataset', type=str, default='NTU',
                  help='select dataset to evlulate')
args.add_argument('--max_epoches', type=int, default=100,
                  help='start number of epochs to run')
args.add_argument('--lr', type=float, default=0.0001,
                  help='initial learning rate')
args.add_argument('--lr_factor', type=float, default=0.1,
                  help='the ratio to reduce lr on each step')
args.add_argument('--optimizer', type=str, default='Adam',
                  help='the optimizer type')
args.add_argument('--print_freq', '-p', type=int, default=20,
                  help='print frequency (default: 20)')
args.add_argument('-b', '--batch_size', type=int, default=1,
                  help='mini-batch size (default: 256)')
args.add_argument('--num_classes', type=int, default=60,
                  help='the number of classes')
args.add_argument('--case', type=int, default=1,
                  help='select which case')
args.add_argument('--aug', type=int, default=1,
                  help='data augmentation')
args.add_argument('--workers', type=int, default=8,
                  help='number of data loading workers')
args.add_argument('--monitor', type=str, default='val_acc',
                  help='quantity to monitor (default: val_acc)')
args.add_argument('--train', type=int, default=1,
                  help='train or test')
args = args.parse_args(args=[])

In [3]:
num_classes = get_num_classes(args.dataset)
num_classes

60

In [4]:
num_classes = get_num_classes(args.dataset)
# if args.model[0:2] == 'VA':
#     model = VA(num_classes)
# else:
#     model = models.resnet50(pretrained=True)
#     num_ftrs = model.fc.in_features
#     model.fc = nn.Linear(num_ftrs, num_classes)

# model = model.cuda()

# # define loss function (criterion) and optimizer
# criterion = nn.CrossEntropyLoss().cuda()
# optimizer = optim.Adam(model.parameters(), lr=args.lr)

# if args.monitor == 'val_acc':
#     mode = 'max'
#     monitor_op = np.greater
#     best = -np.Inf
#     str_op = 'improve'
# elif args.monitor == 'val_loss':
#     mode = 'min'
#     monitor_op = np.less
#     best = np.Inf
#     str_op = 'reduce'
# if args.dataset=='NTU' or args.dataset == 'PKU':
#     scheduler = ReduceLROnPlateau(optimizer, mode=mode, factor=args.lr_factor,
#                               patience=2, cooldown=2, verbose=True)
# else:
#     scheduler = ReduceLROnPlateau(optimizer, mode=mode, factor=args.lr_factor,
#                                   patience=5, cooldown=3, verbose=True)

# Data loading
ntu_loaders = NTUDataLoaders(args.dataset, args.case, args.aug)
train_loader = ntu_loaders.get_train_loader(args.batch_size, args.workers)
val_loader = ntu_loaders.get_val_loader(args.batch_size, args.workers)
train_size = ntu_loaders.get_train_size()
val_size = ntu_loaders.get_val_size()
print('Train on %d samples, validate on %d samples' %
      (train_size, val_size))



Train on 35763 samples, validate on 20815 samples


In [5]:
device = torch.device('cuda')
#action=[40,41,42,43,44,45,46,47,48]

for i, (inputs, maxmin, target) in enumerate(train_loader):
    
    inputs = inputs.to(device)
    target = target.to(device) 

    act= int(target)
    #if act in action:
    filename_format = "image-{0}-{1}.png"
    #print(act)

    base_dir = "/home/fatema/Documents/action_rec/project/img_ske_all/train"


    image_path_dir = os.path.join(base_dir, str(act))
    if not os.path.exists(image_path_dir):
        os.mkdir(image_path_dir)
    file_count = str(len(os.listdir(image_path_dir))+1)
    image_save_path = os.path.join(image_path_dir, filename_format.format(str(act), file_count))

    #print(inputs.shape)
    #print(type(inputs)) 
    #print(labels)
    data=torch.squeeze(inputs, 0)
    #print(data.shape)
    data=data[0]
    #print(data.shape)
    data=data.cpu().numpy()    

    #print(data.shape)

    matplotlib.image.imsave(image_save_path, data)


      

In [6]:
device = torch.device('cuda')
#action=list(range(59))
for i, (inputs, maxmin, target) in enumerate(val_loader):
    
    inputs = inputs.to(device)
    target = target.to(device) 

    act= int(target)
    #if act in action:
    filename_format = "image-{0}-{1}.png"
    #print(act)

    base_dir = "/home/fatema/Documents/action_rec/project/img_ske_all/val"


    image_path_dir = os.path.join(base_dir, str(act))
    if not os.path.exists(image_path_dir):
        os.mkdir(image_path_dir)
    file_count = str(len(os.listdir(image_path_dir))+1)
    image_save_path = os.path.join(image_path_dir, filename_format.format(str(act), file_count))

    #print(inputs.shape)
    #print(type(inputs)) 
    #print(labels)
    data=torch.squeeze(inputs, 0)
    #print(data.shape)
    data=data[0]
    #print(data.shape)
    data=data.cpu().numpy()    

    #print(data.shape)

    matplotlib.image.imsave(image_save_path, data)


      

In [7]:
val_y=[]
train_y=[]

action=[40,41,42,43,44,45,46,47,48]
for i, (inputs, maxmin, target) in enumerate(train_loader):
        c=int(target)
        if c in action:
            train_y.append(c)

print(len(train_y))
           
for i, (inputs, maxmin, target) in enumerate(val_loader):
        c=int(target)
        if c in action:
            val_y.append(c)

print(len(val_y))
                     

5397
3125


In [8]:
# data=torch.squeeze(inputs, 0)
# print(data.shape)
# data=data[0]
# print(data.shape)
# data=data.numpy()


In [9]:
# from matplotlib import pyplot as plt
# plt.imshow(x.astype('uint8'))
# plt.show()