# Imports and Hyperparameters

In [1]:
import pickle
import random
import numpy as np
import os
import os.path as osp
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from SGN.model import SGN
from SGN.data import NTUDataLoaders, AverageMeter
from SGN.util import make_dir, get_num_classes
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score

In [19]:
# Hyper Parameters
dataset = 'NTU'
device = torch.device('cuda:0')
# device = torch.device('cpu')
seg = 20
lr = 5e-5
epochs = 500
utility_classes = 120
privacy_classes = 106
validation_acc_freq = 10 #-1 to disable
encoded_channels = 16

# Data

In [4]:
# load data
with open('ntu/X.pkl', 'rb') as f:
    X = pickle.load(f)

# clean data
to_del = []
for file in X:
    if type(X[file]) == list:
        to_del.append(file)
print('to delete', len(to_del))
for file in to_del:
    del X[file]

# pad or trim data to 75 frames. when padding, repeat the last frame
# input is of shape (frames, 75)
T = 75
for file in X:
    if X[file].shape[0] < T:
        X[file] = np.pad(X[file], ((0, T - X[file].shape[0]), (0, 0)), mode='edge')
    elif X[file].shape[0] > T:
        X[file] = X[file][:T, :]

# convert to tensor
for file in X:
    X[file] = torch.tensor(X[file]).float()

to delete 28814


In [7]:
a = {}
p = {}
for file in X:
    if file[16:20] not in a:
        a[file[16:20]] = {}
    if file[8:12] not in a[file[16:20]]:
        a[file[16:20]][file[8:12]] = []
    a[file[16:20]][file[8:12]].append(file)
    
    if file[8:12] not in p:
        p[file[8:12]] = set()
    p[file[8:12]].add(file[16:20])

In [8]:
def gen_samples(samples):
    x, y = [], []
    for _ in range(samples):
        # sample two random p
        p1, p2 = random.sample(list(p.keys()), 2)
        # find overlapping a
        a1 = p[p1]
        a2 = p[p2]
        a12 = a1.intersection(a2)
        if len(a12) == 0:
            continue
        # sample two random a
        a1, a2 = random.sample(list(a12), 2)
        # sample x and y
        x1 = random.sample(a[a1][p1], 1)[0]
        x2 = random.sample(a[a2][p2], 1)[0]
        y1 = random.sample(a[a1][p2], 1)[0]
        y2 = random.sample(a[a2][p1], 1)[0]
        x.append([x1, x2])
        y.append([y1, y2])
    return x, y

batch_size = 32
train_x, train_y = gen_samples(30000)
val_x, val_y = gen_samples(10000)

In [9]:
class Data(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __getitem__(self, index):
        actors = [float(self.X[index][0][9:12]), float(self.X[index][1][9:12])]
        actions = [float(self.X[index][0][17:20]), float(self.X[index][1][17:20])]
        return X[self.X[index][0]], X[self.X[index][1]], X[self.y[index][0]],  X[self.y[index][1]], actors, actions
    
    def __len__(self):
        return len(self.X)

train_data = Data(train_x, train_y)
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data = Data(val_x, val_y)
val_dl = DataLoader(val_data, batch_size=batch_size, shuffle=True)

# Model