In [2]:
import pickle
import random
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

# Data

In [3]:
with open('ntu/X.pkl', 'rb') as f:
    X = pickle.load(f)

In [4]:
# clean data
to_del = []
for file in X:
    if type(X[file]) == list:
        to_del.append(file)
print('to delete', len(to_del))
for file in to_del:
    del X[file]

to delete 28814


In [5]:
# pad or trim data to 75 frames. when padding, repeat the last frame
# input is of shape (frames, 75)
T = 75
for file in X:
    if X[file].shape[0] < T:
        X[file] = np.pad(X[file], ((0, T - X[file].shape[0]), (0, 0)), mode='edge')
    elif X[file].shape[0] > T:
        X[file] = X[file][:T, :]

In [6]:
for file in X:
    X[file] = torch.tensor(X[file])

In [7]:
a = {}
p = {}
for file in X:
    if file[16:20] not in a:
        a[file[16:20]] = {}
    if file[8:12] not in a[file[16:20]]:
        a[file[16:20]][file[8:12]] = []
    a[file[16:20]][file[8:12]].append(file)
    
    if file[8:12] not in p:
        p[file[8:12]] = set()
    p[file[8:12]].add(file[16:20])

In [158]:
def gen_samples(samples):
    x, y = [], []
    for _ in range(samples):
        # sample two random p
        p1, p2 = random.sample(list(p.keys()), 2)
        # find overlapping a
        a1 = p[p1]
        a2 = p[p2]
        a12 = a1.intersection(a2)
        if len(a12) == 0:
            continue
        # sample two random a
        a1, a2 = random.sample(list(a12), 2)
        # sample x and y
        x1 = random.sample(a[a1][p1], 1)[0]
        x2 = random.sample(a[a2][p2], 1)[0]
        y1 = random.sample(a[a1][p2], 1)[0]
        y2 = random.sample(a[a2][p1], 1)[0]
        x.append([x1, x2])
        y.append([y1, y2])
    return x, y

batch_size = 16
train_x, train_y = gen_samples(1000)
val_x, val_y = gen_samples(200)

In [159]:
class Data(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __getitem__(self, index):
        actors = [self.X[index][0][8:12], self.X[index][1][8:12]]
        actions = [self.X[index][0][16:20], self.X[index][1][16:20]]
        return X[self.X[index][0]], X[self.X[index][1]], X[self.y[index][0]],  X[self.y[index][1]], actors, actions
    
    def __len__(self):
        return len(self.X)

In [160]:
train_data = Data(train_x, train_y)
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data = Data(val_x, val_y)
val_dl = DataLoader(val_data, batch_size=batch_size, shuffle=True)

# Model

In [39]:
dataset = 'NTU'
seg = 75
lr = 1e-4
epochs = 100

In [155]:
from SGN.model import SGN

class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.enc1 = nn.Conv1d(in_channels=75, out_channels=96, kernel_size=3, stride=1, padding=1)
        self.enc2 = nn.Conv1d(in_channels=96, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.enc3 = nn.Conv1d(in_channels=128, out_channels=160, kernel_size=3, stride=1, padding=1)
        self.enc4 = nn.Conv1d(in_channels=160, out_channels=192, kernel_size=3, stride=1, padding=1)
        self.ref1 = nn.ReflectionPad1d(3)
        self.ref2 = nn.ReflectionPad1d(3)
        self.ref3 = nn.ReflectionPad1d(3)
        self.ref4 = nn.ReflectionPad1d(3)

        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

        self.acti = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.ref1(x)
        x = self.acti(self.enc1(x))
        x = self.pool(x)
        
        x = self.ref2(x)
        x = self.acti(self.enc2(x))
        x = self.pool(x)
        
        x = self.ref3(x)
        x = self.acti(self.enc3(x))
        x = self.pool(x)

        x = self.ref4(x)
        x = self.acti(self.enc4(x))
        x = self.pool(x)

        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.dec1 = nn.ConvTranspose1d(in_channels=384, out_channels=160, kernel_size=3, stride=1, padding=1)
        self.dec2 = nn.ConvTranspose1d(in_channels=160, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.dec3 = nn.ConvTranspose1d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=1)
        self.dec4 = nn.ConvTranspose1d(in_channels=96, out_channels=75, kernel_size=3, stride=1, padding=1)

        self.ref1 = nn.ReflectionPad1d(3)
        self.ref2 = nn.ReflectionPad1d(3)
        self.ref3 = nn.ReflectionPad1d(3)
        self.ref4 = nn.ReflectionPad1d(3)
 
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.up75 = nn.Upsample(size=75, mode='nearest') 

        self.acti = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.ref1(x)
        x = self.acti(self.dec1(x))
        x = self.up(x)

        x = self.ref2(x)
        x = self.acti(self.dec2(x))
        x = self.up(x)

        x = self.ref3(x)
        x = self.acti(self.dec3(x))
        x = self.up(x)

        x = self.ref4(x)
        x = self.acti(self.dec4(x))
        x = self.up75(x)
        return x

class AutoEncoder(nn.Module):
    def __init__(self):
        super(AutoEncoder, self).__init__()

        self.static_encoder = Encoder()
        self.dynamic_encoder = Encoder()
        self.decoder = Decoder()

        # Adversarial Models
        # self.utility = SGN(120, dataset, seg, batch_size, 1) # Action Recognition
        # self.privacy = SGN(106, dataset, seg, batch_size, 1) # Re-Identification

        self.end_effectors_ = [ # from ntu, (joint idx-1, kinematic chain length)
            (19, 5), # left foot
            (15, 5), # right foot
            (23, 8), # left hand tip
            (24, 8), # left thumb
            (21, 8), # right hand tip
            (22, 8), # right thumb
            (3, 5)   # head
        ]

        self.end_effectors = torch.tensor([19, 15, 23, 24, 21, 22, 3]).cuda() * 3
        self.chain_lengths = torch.tensor([5, 5, 8, 8, 8, 8, 5]).cuda()
        self.weights = torch.tensor([1./5, 1./5, 1./8, 1./8, 1./8, 1./8, 1./5]).cuda()

    def cross(self, x1, x2):
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)
        
        out1 = self.decoder(torch.cat((d1, s1), dim=1))
        out2 = self.decoder(torch.cat((d2, s2), dim=1))
        out12 = self.decoder(torch.cat((d1, s2), dim=1))
        out21 = self.decoder(torch.cat((d2, s1), dim=1))

        return out1, out2, out12, out21
    
    def loss(self, x1, x2, y1, y2, actors, actions):
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)

        out1 = self.decoder(torch.cat((d1, s1), dim=1))
        out2 = self.decoder(torch.cat((d2, s2), dim=1))
        out12 = self.decoder(torch.cat((d1, s2), dim=1))
        out21 = self.decoder(torch.cat((d2, s1), dim=1))
        
        # reconstruction loss
        rec_loss = self.reconstruction_loss(x1, out1) + self.reconstruction_loss(x2, out2)
        # print('Reconstruction Loss: ', rec_loss.item())
        
        # cross reconstruction loss
        cross_loss = self.cross_reconstruction_loss(x1, x2, out12, out21)
        # print('Cross Reconstruction Loss: ', cross_loss.item())
        
        # end effector loss
        end_effector_loss = self.end_effector_loss(out1, x1) + self.end_effector_loss(out2, x2) + self.end_effector_loss(out12, x2) + self.end_effector_loss(out21, x1)
        # print('End Effector Loss: ', end_effector_loss.item())

        # triplet loss
        # sgn latent privacy loss (adversarial)
        # sgn latent utility loss (adversarial)

        return rec_loss + cross_loss + end_effector_loss

    def reconstruction_loss(self, x, y):
        # return torch.square(torch.norm(x - y, dim=1)).mean()
        return F.mse_loss(x, y)
    
    def cross_reconstruction_loss(self, x1, x2, y1, y2):
        return F.mse_loss(x1, y2) + F.mse_loss(x2, y1)
        # return torch.square(torch.norm(x1 - y2, dim=1)).mean() + torch.square(torch.norm(x2 - y1, dim=1)).mean()
    
    def end_effector_loss(self, x, y):
        # slice to get the end effector joints
        x_ee = x[:, :, self.end_effectors.unsqueeze(-1) + torch.arange(3).cuda()] 
        y_ee = y[:, :, self.end_effectors.unsqueeze(-1) + torch.arange(3).cuda()]

        # calculate velocities
        x_vel = torch.norm(x_ee[:, 1:] - x_ee[:, :-1], dim=-1) / self.chain_lengths.unsqueeze(0)
        y_vel = torch.norm(y_ee[:, 1:] - y_ee[:, :-1], dim=-1) / self.chain_lengths.unsqueeze(0)
        
        # compute mse loss for each joint
        losses = F.mse_loss(x_vel, y_vel, reduction='none')

        # take sum over end effectors
        loss = losses.sum(dim=1)

        # take mean over batch
        loss = loss.mean()
        
        return loss

    def forward(self, x):
        dyn = self.dynamic_encoder(x)   
        sta = self.static_encoder(x)
        x = self.decoder(torch.cat((dyn, sta), dim=1))
        return x

In [156]:
model = AutoEncoder().cuda()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [161]:
for epoch in range(epochs):
    losses = []
    for (x1, x2, y1, y2, actors, actions) in train_dl:
        # Move tensors to the configured device
        x1, x2, y1, y2 = x1.float().cuda(), x2.float().cuda(), y1.float().cuda(), y2.float().cuda()
        
        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        loss = model.loss(x1, x2, y1, y2, actors, actions)

        # Backward and optimize
        loss.backward()
        optimizer.step()

        # Track the loss
        losses.append(loss.item())

    # Decay learning rate
    scheduler.step()

    # Print the training loss
    print(f'Epoch {epoch+1}/{epochs} Loss: {np.mean(losses)}')

    # Validation
    with torch.no_grad():
        val_losses = []
        for (x1, x2, y1, y2, actors, actions) in val_dl:
            x1, x2, y1, y2 = x1.float().cuda(), x2.float().cuda(), y1.float().cuda(), y2.float().cuda()
            loss = model.loss(x1, x2, y1, y2, actors, actions)
            val_losses.append(loss.item())
        print(f'Epoch {epoch+1}/{epochs} Validation Loss: {np.mean(val_losses)}')

Epoch 1/100 Loss: 0.1145985740072587
Epoch 1/100 Validation Loss: 0.0939013184979558
Epoch 2/100 Loss: 0.09222336814684026
Epoch 2/100 Validation Loss: 0.10371694155037403
Epoch 3/100 Loss: 0.0923240180401241
Epoch 3/100 Validation Loss: 0.08904875256121159
Epoch 4/100 Loss: 0.08755642274285064
Epoch 4/100 Validation Loss: 0.09087808709591627
Epoch 5/100 Loss: 0.08544388161424328
Epoch 5/100 Validation Loss: 0.0916650490835309
Epoch 6/100 Loss: 0.08607951924204826
Epoch 6/100 Validation Loss: 0.14908325299620628
Epoch 7/100 Loss: 0.11639690508737284
Epoch 7/100 Validation Loss: 0.09761720057576895
Epoch 8/100 Loss: 0.08506756626507815
Epoch 8/100 Validation Loss: 0.08533747680485249
Epoch 9/100 Loss: 0.08162144505802323
Epoch 9/100 Validation Loss: 0.08846228383481503
Epoch 10/100 Loss: 0.08067770378992838
Epoch 10/100 Validation Loss: 0.08669823966920376
Epoch 11/100 Loss: 0.0803046391948181
Epoch 11/100 Validation Loss: 0.08417336083948612
Epoch 12/100 Loss: 0.07762494562741588
Epoch