In [1]:
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Data

In [2]:
with open('ntu/X.pkl', 'rb') as f:
    X = pickle.load(f)

In [3]:
# clean data
to_del = []
for file in X:
    if type(X[file]) == list:
        to_del.append(file)
print('to delete', len(to_del))
for file in to_del:
    del X[file]

to delete 28814


In [4]:
# pad or trim data to 75 frames. when padding, repeat the last frame
# input is of shape (frames, 75)
T = 75
for file in X:
    if X[file].shape[0] < T:
        X[file] = np.pad(X[file], ((0, T - X[file].shape[0]), (0, 0)), mode='edge')
    elif X[file].shape[0] > T:
        X[file] = X[file][:T, :]

In [5]:
for file in X:
    X[file] = torch.tensor(X[file])

In [6]:
a = {}
p = {}
for file in X:
    if file[16:20] not in a:
        a[file[16:20]] = {}
    if file[8:12] not in a[file[16:20]]:
        a[file[16:20]][file[8:12]] = []
    a[file[16:20]][file[8:12]].append(file)
    
    if file[8:12] not in p:
        p[file[8:12]] = set()
    p[file[8:12]].add(file[16:20])

In [7]:
samples = 500
batch_size = 16
train_x = []
train_y = []

for _ in range(samples):
    # sample two random p
    p1, p2 = random.sample(list(p.keys()), 2)
    # find overlapping a
    a1 = p[p1]
    a2 = p[p2]
    a12 = a1.intersection(a2)
    if len(a12) == 0:
        continue
    # sample two random a
    a1, a2 = random.sample(list(a12), 2)
    # sample x and y
    x1 = random.sample(a[a1][p1], 1)[0]
    x2 = random.sample(a[a2][p2], 1)[0]
    y1 = random.sample(a[a1][p2], 1)[0]
    y2 = random.sample(a[a2][p1], 1)[0]
    train_x.append([x1, x2])
    train_y.append([y1, y2])

In [8]:
class Data(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __getitem__(self, index):
        actors = [self.X[index][0][8:12], self.X[index][1][8:12]]
        actions = [self.X[index][0][16:20], self.X[index][1][16:20]]
        return X[self.X[index][0]], X[self.X[index][1]], X[self.y[index][0]],  X[self.y[index][1]], actors, actions
    
    def __len__(self):
        return len(self.X)

In [9]:
data = Data(train_x, train_y)
dl = DataLoader(data, batch_size=batch_size, shuffle=True)

# Model

In [10]:
dataset = 'NTU'
seg = 75
lr = 0.1
epochs = 100

In [55]:
from SGN.model import SGN

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.enc1 = nn.Conv2d(in_channels=1, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.enc2 = nn.Conv2d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.enc3 = nn.Conv2d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=1)

        self.pool = nn.MaxPool2d(kernel_size=2, stride=2)

        self.relu = nn.ReLU()

    def forward(self, x):
        x = x.unsqueeze(1)
        x = self.relu(self.enc1(x))
        x = self.pool(x)
        x = self.relu(self.enc2(x))
        x = self.pool(x)
        x = self.relu(self.enc3(x))
        x = self.pool(x)
        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.dec1 = nn.ConvTranspose2d(in_channels=16, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.dec2 = nn.ConvTranspose2d(in_channels=32, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.dec3 = nn.ConvTranspose2d(in_channels=64, out_channels=1, kernel_size=3, stride=1, padding=1)

        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.up75 = nn.Upsample(size=(75,75), mode='nearest')

        self.relu = nn.ReLU()
        self.sigmoid = nn.Sigmoid()

    def forward(self, x):
        x = self.relu(self.dec1(x))
        x = self.up(x)
        x = self.relu(self.dec2(x))
        x = self.up(x)
        x = self.sigmoid(self.dec3(x))
        x = self.up75(x)
        x = x.squeeze(1)
        return x

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.static_encoder = Encoder()
        self.dynamic_encoder = Encoder()
        self.decoder = Decoder()

        # Adversarial Models
        # self.utility = SGN(120, dataset, seg, batch_size, 1) # Action Recognition
        # self.privacy = SGN(106, dataset, seg, batch_size, 1) # Re-Identification

    def cross(self, x1, x2):
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)
        
        out1 = self.decoder(torch.cat((d1, s1), dim=1))
        out2 = self.decoder(torch.cat((d2, s2), dim=1))
        out12 = self.decoder(torch.cat((d1, s2), dim=1))
        out21 = self.decoder(torch.cat((d2, s1), dim=1))

        return out1, out2, out12, out21
    
    def loss(self, x1, x2, y1, y2, actors, actions):
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)

        out1 = self.decoder(torch.cat((d1, s1), dim=1))
        out2 = self.decoder(torch.cat((d2, s2), dim=1))
        out12 = self.decoder(torch.cat((d1, s2), dim=1))
        out21 = self.decoder(torch.cat((d2, s1), dim=1))
        
        # reconstruction loss
        rec_loss = self.reconstruction_loss(x1, out1) + self.reconstruction_loss(x2, out2)
        # print('Reconstruction Loss: ', rec_loss.item())
        
        # cross reconstruction loss
        cross_loss = self.cross_reconstruction_loss(x1, x2, out12, out21)
        # print('Cross Reconstruction Loss: ', cross_loss.item())
        
        # end effector loss
        # triplet loss
        # sgn latent privacy loss (adversarial)
        # sgn latent utility loss (adversarial)

        return rec_loss + cross_loss

    def reconstruction_loss(self, x, y):
        return torch.square(torch.norm(x - y, dim=1)).mean()
    
    def cross_reconstruction_loss(self, x1, x2, y1, y2):
        return torch.square(torch.norm(x1 - y2, dim=1)).mean() + torch.square(torch.norm(x2 - y1, dim=1)).mean()

    def forward(self, x):
        dyn = self.dynamic_encoder(x)
        sta = self.static_encoder(x)
        x = self.decoder(torch.cat((dyn, sta), dim=1))
        return x

In [56]:
model = AutoEncoder()#.cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

In [57]:
for epoch in range(epochs):
    losses = []
    for (x1, x2, y1, y2, actors, actions) in dl:
        # x1, x2, y1, y2 = x1.float().cuda(), x2.float().cuda(), y1.float().cuda(), y2.float().cuda()
        x1, x2, y1, y2 = x1.float(), x2.float(), y1.float(), y2.float()
        loss = model.loss(x1, x2, y1, y2, actors, actions)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        losses.append(loss.item())
    print(f'Epoch {epoch+1}/{epochs} Loss: {np.mean(losses)}')

torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])
torch.Size([16, 8, 9, 9])
torch.Size([16, 75, 75])


KeyboardInterrupt: 