# Imports and Hyperparameters

In [1]:
import pickle
import random
import numpy as np
import os
import os.path as osp
import time
from collections import defaultdict
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from SGN.model import SGN
from SGN.data import NTUDataLoaders, AverageMeter
from SGN.util import make_dir, get_num_classes
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score

In [2]:
# Hyper Parameters
dataset = 'NTU'
device = torch.device('cuda:0')
# device = torch.device('cpu')
seg = 20
lr = 5e-5
epochs = 500
utility_classes = 120
privacy_classes = 106
validation_acc_freq = 3 #-1 to disable
encoded_channels = 32
batch_size = 32

# Data

In [3]:
# load data
with open('ntu/X.pkl', 'rb') as f:
    X = pickle.load(f)

# clean data
to_del = []
for file in X:
    if type(X[file]) == list:
        to_del.append(file)
print('to delete', len(to_del))
for file in to_del:
    del X[file]

# pad or trim data to 75 frames. when padding, repeat the last frame
# input is of shape (frames, 75)
T = 75
for file in X:
    if X[file].shape[0] < T:
        X[file] = np.pad(X[file], ((0, T - X[file].shape[0]), (0, 0)), mode='edge')
    elif X[file].shape[0] > T:
        X[file] = X[file][:T, :]

# convert to tensor
for file in X:
    X[file] = torch.tensor(X[file]).float()

to delete 28814


In [4]:
def parse_file_name(file_name):
    """Parses the filename into a dictionary of parts."""
    parts = file_name.split('C')
    S = parts[0][1:]
    parts = parts[1].split('P')
    C = parts[0]
    parts = parts[1].split('R')
    P = parts[0]
    parts = parts[1].split('A')
    R = parts[0]
    parts = parts[1].split('.')
    A = parts[0]
    return {'S': S, 'C': C, 'P': P, 'R': R, 'A': A}

def organize_data(data):
    organized_data = defaultdict(list)
    for file_name, content in data.items():
        parts = parse_file_name(file_name)
        organized_data[(parts['C'], parts['R'])].append((parts['P'], parts['A'], content))
    return organized_data

def sample_data(organized_data):
    # Pick a random (C, R) pair
    cr_pair = random.choice(list(organized_data.keys()))

    # Get all (P, A, content) tuples for this (C, R) pair
    pa_list = organized_data[cr_pair]

    # Pick 2 unique P values and 2 unique A values
    random.shuffle(pa_list)
    unique_p = set()
    unique_a = set()
    for p, a, content in pa_list:
        if len(unique_p) < 2:
            unique_p.add(p)
        if len(unique_a) < 2:
            unique_a.add(a)
        if len(unique_p) == 2 and len(unique_a) == 2:
            break

    if len(unique_p) < 2 or len(unique_a) < 2:
        raise Exception(f'Not enough unique P or A values for (C, R) pair {cr_pair}')

    # Form all four (P, A) pairs and get the corresponding content
    sampled_data = []
    for p in unique_p:
        for a in unique_a:
            for pa_content in pa_list:
                if pa_content[0] == p and pa_content[1] == a:
                    sampled_data.append(pa_content)
                    break

    return sampled_data

def gen_samples(samples, data):
    d = []
    for _ in range(samples):
        failed = 0
        d_ = []
        while len(d_) != 4:
            d_ = sample_data(data)
            failed += 1
            if failed > 100:
                print('failed to sample data')
                break
        d.append(d_)
    return np.array(d)

organized_data = organize_data(X)

In [5]:
class Data(Dataset):
    def __init__(self, sampled_data):
        self.data = sampled_data

    def __getitem__(self, index):
        return self.data[index][0][2], self.data[index][1][2], self.data[index][2][2], self.data[index][3][2], [float(self.data[index][0][0]), float(self.data[index][3][0])], [float(self.data[index][0][1]), float(self.data[index][1][1])]
    
    def __len__(self):
        return len(self.data)

train_data = gen_samples(5000, organized_data)
val_data = gen_samples(1000, organized_data)
train_dataset = Data(train_data)
val_dataset = Data(val_data)
train_dl = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_dl = DataLoader(val_dataset, batch_size=batch_size, shuffle=True)

  return np.array(d)


# Model

## Adversary

In [6]:
# Input is size of latent space
class Adversary_Emb(nn.Module):
    def __init__(self, num_classes):
        super(Adversary_Emb, self).__init__()
        self.conv = nn.ConvTranspose1d(encoded_channels, 64, 3, 1, 1)
        self.ref = nn.ReflectionPad1d(3)
        self.pool = nn.MaxPool1d(2, 2)
        self.fc1 = nn.Linear(640, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.ref(x)
        x = self.conv(x)
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x
    
class Discriminator(nn.Module): # 1 = real, 0 = fake
    def __init__(self):
        super(Discriminator, self).__init__()

        self.enc1 = nn.Conv1d(in_channels=75, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.enc2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.enc3 = nn.Conv1d(in_channels=32, out_channels=16, kernel_size=3, stride=1, padding=1)
        self.enc4 = nn.Conv1d(in_channels=16, out_channels=8, kernel_size=3, stride=1, padding=1)
        self.ref1 = nn.ReflectionPad1d(3)
        self.ref2 = nn.ReflectionPad1d(3)
        self.ref3 = nn.ReflectionPad1d(3)
        self.ref4 = nn.ReflectionPad1d(3)
        self.fc1 = nn.Linear(80, 32)
        self.fc2 = nn.Linear(32, 1)

        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

        self.acti = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.ref1(x)
        x = self.acti(self.enc1(x))
        x = self.pool(x)
        
        x = self.ref2(x)
        x = self.acti(self.enc2(x))
        x = self.pool(x)
        
        x = self.ref3(x)
        x = self.acti(self.enc3(x))
        x = self.pool(x)

        x = self.ref4(x)
        x = self.acti(self.enc4(x))
        x = self.pool(x)

        #flatten
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = torch.sigmoid(self.fc2(x))

        return x

## Motion Retargeting

In [14]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.enc1 = nn.Conv1d(in_channels=75, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.enc2 = nn.Conv1d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1)
        self.enc3 = nn.Conv1d(in_channels=32, out_channels=encoded_channels, kernel_size=3, stride=1, padding=1)
        # self.enc4 = nn.Conv1d(in_channels=160, out_channels=192, kernel_size=3, stride=1, padding=1)
        self.ref1 = nn.ReflectionPad1d(3)
        self.ref2 = nn.ReflectionPad1d(3)
        self.ref3 = nn.ReflectionPad1d(3)
        # self.ref4 = nn.ReflectionPad1d(3)

        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

        self.acti = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.ref1(x)
        x = self.acti(self.enc1(x))
        x = self.pool(x)
        
        x = self.ref2(x)
        x = self.acti(self.enc2(x))
        x = self.pool(x)
        
        x = self.ref3(x)
        x = self.acti(self.enc3(x))
        x = self.pool(x)

        # x = self.ref4(x)
        # x = self.acti(self.enc4(x))
        # x = self.pool(x)

        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.dec1 = nn.ConvTranspose1d(in_channels=encoded_channels*2, out_channels=48, kernel_size=3, stride=1, padding=1)
        self.dec2 = nn.ConvTranspose1d(in_channels=48, out_channels=64, kernel_size=3, stride=1, padding=1)
        self.dec3 = nn.ConvTranspose1d(in_channels=64, out_channels=75, kernel_size=3, stride=1, padding=1)
        # self.dec4 = nn.ConvTranspose1d(in_channels=96, out_channels=75, kernel_size=3, stride=1, padding=1)

        self.ref1 = nn.ReflectionPad1d(3)
        self.ref2 = nn.ReflectionPad1d(3)
        self.ref3 = nn.ReflectionPad1d(3)
        # self.ref4 = nn.ReflectionPad1d(3)
 
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.up75 = nn.Upsample(size=75, mode='nearest') 

        self.acti = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.ref1(x)
        x = self.acti(self.dec1(x))
        x = self.up(x)

        x = self.ref2(x)
        x = self.acti(self.dec2(x))
        x = self.up(x)

        x = self.ref3(x)
        x = self.acti(self.dec3(x))
        # x = self.up(x)

        # x = self.ref4(x)
        # x = self.acti(self.dec4(x))
        x = self.up75(x)
        return x

class AutoEncoder(nn.Module):
    def __init__(self, adv_lr=1e-3, use_adv=True):
        super(AutoEncoder, self).__init__()

        # AutoEncoder Models
        self.static_encoder = Encoder()
        self.dynamic_encoder = Encoder()
        self.decoder = Decoder()

        # Adversarial Models
        self.use_adv = use_adv
        if use_adv:
            self.priv_adv = Adversary_Emb(privacy_classes)
            self.util_adv = Adversary_Emb(utility_classes)
            self.discriminator = Discriminator()

            self.priv_optim = torch.optim.Adam(self.priv_adv.parameters(), lr=adv_lr)
            self.util_optim = torch.optim.Adam(self.util_adv.parameters(), lr=adv_lr)
            self.discriminator_optim = torch.optim.Adam(self.discriminator.parameters(), lr=adv_lr)

            # Freeze Adversarial Models
            self.priv_adv.eval()
            self.util_adv.eval()
            self.discriminator.eval()

        # Loss Functions
        self.triplet_loss = nn.TripletMarginLoss()
        self.bce_loss = nn.BCELoss()
        
        # Info for loss functions
        self.end_effectors = torch.tensor([19, 15, 23, 24, 21, 22, 3]).to(device) * 3
        self.chain_lengths = torch.tensor([5, 5, 8, 8, 8, 8, 5]).to(device)

        # Lambdas for discounted loss
        self.lambda_rec = 3
        self.lambda_cross = 1
        self.lambda_ee = 5
        self.lambda_trip = 0.1
        self.lambda_latent = 1
        self.lambda_adv_util = 5
        self.lambda_adv_priv = 5
        self.lambda_adv_disc = 10

        # Loss Toggles
        self.use_rec_loss = True
        self.use_cross_loss = True
        self.use_ee_loss = True
        self.use_trip_loss = True
        self.use_latent_consistency = True

    def cross(self, x1, x2):
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)
        
        x1_hat = self.decoder(torch.cat((d1, s1), dim=1))
        x2_hat = self.decoder(torch.cat((d2, s2), dim=1))
        y1_hat = self.decoder(torch.cat((d1, s2), dim=1))
        y2_hat = self.decoder(torch.cat((d2, s1), dim=1))

        return x1_hat, x2_hat, y1_hat, y2_hat
    
    def eval(self, x1, x2):
        dynamic = self.dynamic_encoder(x1)
        static = self.static_encoder(x2)
        return self.decoder(torch.cat((dynamic, static), dim=1))
    
    def loss(self, x1, x2, y1, y2, actors, actions, cross = True, reconstruction = True, emb_adv = True, discrim_adv = True):
        d1 = self.dynamic_encoder(x1) # A1
        d2 = self.dynamic_encoder(x2) # A2
        s1 = self.static_encoder(x1) # P1
        s2 = self.static_encoder(x2) # P2

        x1_hat = self.decoder(torch.cat((d1, s1), dim=1)) # P1, A1
        x2_hat = self.decoder(torch.cat((d2, s2), dim=1)) # P2, A2
        y1_hat = self.decoder(torch.cat((d1, s2), dim=1)) # P2, A1
        y2_hat = self.decoder(torch.cat((d2, s1), dim=1)) # P1, A2

        d12 = self.dynamic_encoder(y1) # A1
        d21 = self.dynamic_encoder(y2) # A2
        s12 = self.static_encoder(y1) # P2
        s21 = self.static_encoder(y2) # P1

        x1_hat_ = self.decoder(torch.cat((d12, s21), dim=1)) # P1, A1
        x2_hat_ = self.decoder(torch.cat((d21, s12), dim=1)) # P2, A2
        y1_hat_ = self.decoder(torch.cat((d12, s12), dim=1)) # P2, A1
        y2_hat_ = self.decoder(torch.cat((d21, s21), dim=1)) # P1, A2
        
        # initialize all losses to 0 tensor
        rec_loss = torch.zeros(1).to(device)
        cross_loss = torch.zeros(1).to(device)
        end_effector_loss = torch.zeros(1).to(device)
        triplet_loss = torch.zeros(1).to(device)
        latent_consistency_loss = torch.zeros(1).to(device)
        privacy_loss = torch.zeros(1).to(device)
        privacy_loss_dyn = torch.zeros(1).to(device)
        privacy_loss_stat = torch.zeros(1).to(device)
        privacy_acc_dyn = torch.zeros(1).to(device)
        privacy_acc_stat = torch.zeros(1).to(device)
        utility_loss = torch.zeros(1).to(device)
        utility_loss_dyn = torch.zeros(1).to(device)
        utility_loss_stat = torch.zeros(1).to(device)
        utility_acc_dyn = torch.zeros(1).to(device)
        utility_acc_stat = torch.zeros(1).to(device)
        discriminator_loss = torch.zeros(1).to(device)
        discriminator_acc = torch.zeros(1).to(device)
                        
        # reconstruction loss
        if self.use_rec_loss and reconstruction:
            rec_loss = self.reconstruction_loss(x1, x1_hat) + self.reconstruction_loss(x2, x2_hat) + self.reconstruction_loss(y1, y1_hat_) + self.reconstruction_loss(y2, y2_hat_)
            # print('Reconstruction Loss: ', rec_loss.item())
        
        # cross reconstruction loss
        if self.use_cross_loss and cross:
            cross_loss = self.cross_loss(y1, y2, y1_hat, y2_hat) + self.cross_loss(x1, x2, x1_hat_, x2_hat_)
            # print('Cross Reconstruction Loss: ', cross_loss.item())
        
        # end effector loss
        if self.use_ee_loss:
            if reconstruction:
                end_effector_loss += self.end_effector_loss(x1_hat, x1) + self.end_effector_loss(x2_hat, x2)
            if cross:
                end_effector_loss += self.end_effector_loss(y1_hat, y1) + self.end_effector_loss(y2_hat, y2)
            # print('End Effector Loss: ', end_effector_loss.item())

        # triplet loss
        if self.use_trip_loss:
            triplet_loss = self.triplet_loss(d12, d1, d2) + self.triplet_loss(d21, d2, d1) + self.triplet_loss(s12, s1, s2) + self.triplet_loss(s21, s2, s1) 
            # print('Triplet Loss: ', triplet_loss.item())

        # latent consistency loss
        if self.use_latent_consistency:
            latent_consistency_loss = self.latent_consistency_loss(d1, d12) + self.latent_consistency_loss(d2, d21) + self.latent_consistency_loss(s1, s21) + self.latent_consistency_loss(s2, s12)
            # print('Latent Consistency Loss: ', latent_consistency_loss.item())

        # adversarial loss
        if self.use_adv and emb_adv:
            # sgn latent privacy loss (adversarial)
            adv_priv_y1, adv_priv_y2 = actors[0] - 1, actors[1] - 1
            adv_priv_y1, adv_priv_y2 = torch.eye(privacy_classes)[adv_priv_y1.long()].to(device), torch.eye(privacy_classes)[adv_priv_y2.long()].to(device)
            privacy_loss_dyn = -(self.adv_loss(self.priv_adv, d1, adv_priv_y1) + self.adv_loss(self.priv_adv, d2, adv_priv_y2))
            privacy_loss_stat = (self.adv_loss(self.priv_adv, s1, adv_priv_y1) + self.adv_loss(self.priv_adv, s2, adv_priv_y2))
            privacy_loss = privacy_loss_dyn + privacy_loss_stat
            # print('Privacy Loss Dynamic: ', privacy_loss_dyn.item(), '\tPrivacy Loss Static: ', privacy_loss_stat.item(), '\tPrivacy Loss: ', privacy_loss.item())

            privacy_acc_dyn = (self.adv_accuracy(self.priv_adv, d1, adv_priv_y1) + self.adv_accuracy(self.priv_adv, d2, adv_priv_y2)) / 2
            privacy_acc_stat = (self.adv_accuracy(self.priv_adv, s1, adv_priv_y1) + self.adv_accuracy(self.priv_adv, s2, adv_priv_y2)) / 2
            # print('Privacy Accuracy Dynamic: ', privacy_acc_dyn.item(), '\tPrivacy Accuracy Static: ', privacy_acc_stat.item())
            
            # sgn latent utility loss (adversarial)
            adv_util_y1, adv_util_y2 = actions[0] - 1, actions[1] - 1
            adv_util_y1, adv_util_y2 = torch.eye(utility_classes)[adv_util_y1.long()].to(device), torch.eye(utility_classes)[adv_util_y2.long()].to(device)
            utility_loss_dyn = (self.adv_loss(self.util_adv, d1, adv_util_y1) + self.adv_loss(self.util_adv, d2, adv_util_y2))
            utility_loss_stat = -(self.adv_loss(self.util_adv, s1, adv_util_y1) + self.adv_loss(self.util_adv, s2, adv_util_y2))
            utility_loss = utility_loss_dyn + utility_loss_stat
            # print('Utility Loss Dynamic: ', utility_loss_dyn.item(), '\tUtility Loss Static: ', utility_loss_stat.item(), '\tUtility Loss: ', utility_loss.item())

            utility_acc_dyn = (self.adv_accuracy(self.util_adv, d1, adv_util_y1) + self.adv_accuracy(self.util_adv, d2, adv_util_y2)) / 2
            utility_acc_stat = (self.adv_accuracy(self.util_adv, s1, adv_util_y1) + self.adv_accuracy(self.util_adv, s2, adv_util_y2)) / 2
            # print('Utility Accuracy Dynamic: ', utility_acc_dyn.item(), '\tUtility Accuracy Static: ', utility_acc_stat.item())

        if self.use_adv and discrim_adv:
            # discrimnator (adversarial)
            discrim_out_fake = self.discriminator(torch.cat((x1_hat, x2_hat, y1_hat, y2_hat, x1_hat_, x2_hat_, y1_hat_, y2_hat_)))
            discriminator_loss = self.bce_loss(discrim_out_fake, torch.ones_like(discrim_out_fake))
            discriminator_acc = torch.sum(torch.round(discrim_out_fake) == 0).float() / (8 * batch_size)
            # print('Discriminator Loss: ', discriminator_loss.item(), '\tDiscriminator Accuracy: ', discriminator_acc.item())

        losses = {
            'rec_loss': rec_loss.item(),
            'cross_loss': cross_loss.item(),
            'end_effector_loss': end_effector_loss.item(),
            'triplet_loss': triplet_loss.item(),
            'latent_consistency_loss': latent_consistency_loss.item(),
            'privacy_loss': privacy_loss.item(),
            'privacy_loss_dyn': privacy_loss_dyn.item(),
            'privacy_loss_stat': privacy_loss_stat.item(),
            'privacy_acc_dyn': privacy_acc_dyn.item(),
            'privacy_acc_stat': privacy_acc_stat.item(),
            'utility_loss': utility_loss.item(),
            'utility_loss_dyn': utility_loss_dyn.item(),
            'utility_loss_stat': utility_loss_stat.item(),
            'utility_acc_dyn': utility_acc_dyn.item(),
            'utility_acc_stat': utility_acc_stat.item(),
            'discriminator_loss': discriminator_loss.item(),
            'discriminator_acc': discriminator_acc.item()
        }

        return rec_loss * self.lambda_rec \
                + cross_loss * self.lambda_cross \
                + end_effector_loss * self.lambda_ee \
                + triplet_loss * self.lambda_trip \
                + latent_consistency_loss * self.lambda_latent \
                + privacy_loss * self.lambda_adv_priv \
                + utility_loss * self.lambda_adv_util \
                + discriminator_loss * self.lambda_adv_disc, \
                x1_hat, x2_hat, y1_hat, y2_hat, losses

    def reconstruction_loss(self, x, y):
        # return torch.square(torch.norm(x - y, dim=1)).mean()
        return F.mse_loss(x, y)
    
    def cross_loss(self, x1, x2, y1, y2):
        # return torch.square(torch.norm(x1 - y2, dim=1)).mean() + torch.square(torch.norm(x2 - y1, dim=1)).mean()
        return F.mse_loss(x1, y1) + F.mse_loss(x2, y2)
    
    def latent_consistency_loss(self, x, y):
        return F.mse_loss(x, y)
    
    def end_effector_loss(self, x, y):
        # slice to get the end effector joints
        x_ee = x[:, :, self.end_effectors.unsqueeze(-1) + torch.arange(3).to(device)] 
        y_ee = y[:, :, self.end_effectors.unsqueeze(-1) + torch.arange(3).to(device)]

        # calculate velocities
        x_vel = torch.norm(x_ee[:, 1:] - x_ee[:, :-1], dim=-1) / self.chain_lengths.unsqueeze(0)
        y_vel = torch.norm(y_ee[:, 1:] - y_ee[:, :-1], dim=-1) / self.chain_lengths.unsqueeze(0)
        
        # compute mse loss for each joint
        losses = F.mse_loss(x_vel, y_vel, reduction='none')

        # take sum over end effectors
        loss = losses.sum(dim=1)

        # take mean over batch
        loss = loss.mean()
        
        return loss
    
    def adv_loss(self, model, x, y):
        return F.cross_entropy(model(x), y)
    
    def adv_accuracy(self, model, x, y):
        return (model(x).argmax(dim=1) == y.argmax(dim=1)).float().mean()

    def train_adv(self, x1, x2, y1, y2, actors, actions, train_emb = True, train_discrim = True):
        if not self.use_adv: return 0,0
        # freeze encoders/decoder
        self.dynamic_encoder.eval()
        self.static_encoder.eval()
        self.decoder.eval()

        # unfreeze adversaries
        self.priv_adv.train()
        self.util_adv.train()
        self.discriminator.train()

        # zero out gradients
        self.priv_optim.zero_grad()
        self.util_optim.zero_grad()
        self.discriminator_optim.zero_grad()

        # encode
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        d3 = self.dynamic_encoder(y1)
        d4 = self.dynamic_encoder(y2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)
        s3 = self.static_encoder(y1)
        s4 = self.static_encoder(y2)

        # decode
        x1_hat = self.decoder(torch.cat((d1, s1), dim=1))
        x2_hat = self.decoder(torch.cat((d2, s2), dim=1))
        y1_hat = self.decoder(torch.cat((d3, s3), dim=1))
        y2_hat = self.decoder(torch.cat((d4, s4), dim=1))

        # instantiate losses
        priv_loss = torch.zeros(1).to(device)
        util_loss = torch.zeros(1).to(device)
        discriminator_loss = torch.zeros(1).to(device)

        if train_emb:
            # train privacy adversary
            p1, p2 = actors[0] - 1, actors[1] - 1
            p1, p2 = torch.eye(privacy_classes)[p1.long()].to(device), torch.eye(privacy_classes)[p2.long()].to(device)
            priv_loss = F.cross_entropy(self.priv_adv(s1), p1) + F.cross_entropy(self.priv_adv(s2), p2) + F.cross_entropy(self.priv_adv(s3), p2) + F.cross_entropy(self.priv_adv(s4), p1)
            priv_loss.backward(retain_graph=True)
            self.priv_optim.step()
            
            # train utility adversary
            a1, a2 = actions[0] - 1, actions[1] - 1
            a1, a2 = torch.eye(utility_classes)[a1.long()].to(device), torch.eye(utility_classes)[a2.long()].to(device)
            util_loss = F.cross_entropy(self.util_adv(d1), a1) + F.cross_entropy(self.util_adv(d2), a2) + F.cross_entropy(self.util_adv(d3), a1) + F.cross_entropy(self.util_adv(d4), a2)
            util_loss.backward(retain_graph=True)
            self.util_optim.step()

        if train_discrim:
            # train discriminator
            output_real = self.discriminator(torch.cat((x1, x2, y1, y2)))
            output_fake = self.discriminator(torch.cat((x1_hat, x2_hat, y1_hat, y2_hat)))
            discriminator_loss = self.bce_loss(output_real, torch.ones_like(output_real)) + self.bce_loss(output_fake, torch.zeros_like(output_fake))
            discriminator_loss.backward()
            self.discriminator_optim.step()

        # unfreeze encoders/decoder
        self.dynamic_encoder.train()
        self.static_encoder.train()
        self.decoder.train()

        # freeze adversaries
        self.priv_adv.eval()
        self.util_adv.eval()
        self.discriminator.eval()

        return priv_loss.item(), util_loss.item(), discriminator_loss.item()

    def forward(self, x):
        dyn = self.dynamic_encoder(x)   
        sta = self.static_encoder(x)
        x = self.decoder(torch.cat((dyn, sta), dim=1))
        return x

## Utility/Privacy Evaluation

In [16]:
def test(test_loader, model):
    acces = AverageMeter()
    # load learnt model that obtained best performance on validation set
    model.eval()

    label_output = list()
    pred_output = list()

    for i, t in enumerate(test_loader):
        inputs = t[0]
        target = t[1]
        with torch.no_grad():
            output = model(inputs.cuda())
            output = output.view(
                (-1, inputs.size(0)//target.size(0), output.size(1)))
            output = output.mean(1)

        label_output.append(target.cpu().numpy())
        pred_output.append(output.cpu().numpy())

        acc = accuracy(output.data, target.cuda())
        acces.update(acc[0], inputs.size(0))

    label_output = np.concatenate(label_output, axis=0)
    pred_output = np.concatenate(pred_output, axis=0)

    label_index = np.argmax(label_output, axis=1)
    pred_index = np.argmax(pred_output, axis=1)

    f1 = f1_score(label_index, pred_index, average='macro', zero_division=0)
    precision = precision_score(label_index, pred_index, average='macro', zero_division=0)
    recall = recall_score(label_index, pred_index, average='macro', zero_division=0)

    return acces.avg, f1, precision, recall
    
def accuracy(output, target):
    batch_size = target.size(0)
    _, pred = output.topk(1, 1, True, True)
    pred = pred.t()
    target = torch.argmax(target, dim=1)  # Add this line to convert one-hot targets to class indices
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    correct = correct.view(-1).float().sum(0, keepdim=True)
    return correct.mul_(100.0 / batch_size)
    
def sgn_eval(train_x, train_y, test_x, test_y, val_x, val_y, case, model):
    # Data loading
    ntu_loaders = NTUDataLoaders(dataset, case, seg=20, train_X=train_x, train_Y=train_y, test_X=test_x, test_Y=test_y, val_X=val_x, val_Y=val_y, aug=0)
    test_loader = ntu_loaders.get_test_loader(batch_size, 16)

    # Test
    return test(test_loader, model)

## Instantiate Models

In [17]:
model = AutoEncoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [18]:
load_model = False
if load_model:
    model.load_state_dict(torch.load('pretrained/MR.pt'))

In [19]:
load_util = True
train_util = False
sgn_ar = SGN(utility_classes, None, seg, batch_size, 0).to(device)
sgn_priv = SGN(privacy_classes, None, seg, batch_size, 0).to(device)

if load_util:
    sgn_priv.load_state_dict(torch.load('SGN/pretrained/privacy.pt')['state_dict'])
    sgn_ar.load_state_dict(torch.load('SGN/pretrained/action.pt')['state_dict'])

# Training

## Train Motion Retargeting

In [20]:
sgn_train_x, sgn_train_y, sgn_val_x, sgn_val_y = np.zeros((batch_size, 300, 150)), np.zeros((batch_size, 1)), np.zeros((batch_size, 300, 150)), np.zeros((batch_size, 1))

best_loss = float('inf')

def train(epoch, train_ae = True, train_cross = True, train_discrim = True, train_emb_adv = True, run_eval = True, use_adv_loss = True, run_sgn_eval = False, save = True):
    global best_loss
    # Assertions
    assert train_ae or train_cross or train_discrim or train_emb_adv, "At least one of the training objectives must be True"
    assert not (run_sgn_eval and not run_eval), "If run_sgn_eval is True, then run_eval must be True"
    
    # Store eval values for validation
    eval_X_known, eval_Y_known, eval_X_rec, eval_Y_rec, eval_X, eval_Y = [], [], [], [], [], []

    # Losses for printing
    losses = []
    rec_loss, cross_loss, end_effector_loss, triplet_loss, latent_consistency_loss, privacy_loss, privacy_loss_dyn, privacy_loss_stat, privacy_acc_dyn, privacy_acc_stat, priv_training_loss, utility_loss, utility_loss_dyn, utility_loss_stat, utility_acc_dyn, utility_acc_stat, util_training_loss, discriminator_loss, discriminator_train_losses, discriminator_training_acc = [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []

    for (x1, x2, y1, y2, actors, actions) in train_dl:
        # Move tensors to the configured device
        x1, x2, y1, y2 = x1.float().to(device), x2.float().to(device), y1.float().to(device), y2.float().to(device)
        
        if train_discrim or train_emb_adv:
            # Train the discriminator
            priv_train_loss, util_train_loss, discriminator_train_loss = model.train_adv(x1, x2, y1, y2, actors, actions, train_emb=train_emb_adv, train_discrim=train_discrim)
            
            # Track the loss
            priv_training_loss.append(priv_train_loss)
            util_training_loss.append(util_train_loss)
            discriminator_train_losses.append(discriminator_train_loss)

        # Zero the gradients
        optimizer.zero_grad()

        # Train the autoencoder/cross reconstruction
        if train_ae or train_cross:
            # Forward pass
            loss, _, _, _, _, losses_ = model.loss(x1, x2, y1, y2, actors, actions, cross=train_cross, reconstruction=train_ae, emb_adv=(train_emb_adv and use_adv_loss), discrim_adv=(train_discrim and use_adv_loss))

            # Backward and optimize
            loss.backward()
            optimizer.step()

            # Track the loss
            losses.append(loss.item())
            rec_loss.append(losses_['rec_loss'])
            cross_loss.append(losses_['cross_loss'])
            end_effector_loss.append(losses_['end_effector_loss'])
            latent_consistency_loss.append(losses_['latent_consistency_loss'])
            triplet_loss.append(losses_['triplet_loss'])
            privacy_loss.append(losses_['privacy_loss'])
            privacy_loss_dyn.append(losses_['privacy_loss_dyn'])
            privacy_loss_stat.append(losses_['privacy_loss_stat'])
            privacy_acc_dyn.append(losses_['privacy_acc_dyn'])
            privacy_acc_stat.append(losses_['privacy_acc_stat'])
            utility_loss.append(losses_['utility_loss'])
            utility_loss_dyn.append(losses_['utility_loss_dyn'])
            utility_loss_stat.append(losses_['utility_loss_stat'])
            utility_acc_dyn.append(losses_['utility_acc_dyn'])
            utility_acc_stat.append(losses_['utility_acc_stat'])
            discriminator_loss.append(losses_['discriminator_loss'])
            discriminator_training_acc.append(losses_['discriminator_acc'])
        
    # Decay learning rate
    scheduler.step()

    # Validation
    if run_eval:
        with torch.no_grad():
            val_losses = []
            val_rec_loss, val_cross_loss, val_end_effector_loss, val_triplet_loss, val_latent_consistency_loss, val_privacy_loss, val_privacy_loss_dyn, val_privacy_loss_stat, val_privacy_acc_dyn, val_privacy_acc_stat, val_utility_loss, val_utility_loss_dyn, val_utility_loss_stat, val_utility_acc_dyn, val_utility_acc_stat, val_discriminator_loss, val_discriminator_acc = [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], [], []
            
            for (x1, x2, y1, y2, actors, actions) in val_dl:
                x1, x2, y1, y2 = x1.float().to(device), x2.float().to(device), y1.float().to(device), y2.float().to(device)
                
                loss, x1_hat, x2_hat, y1_hat, y2_hat, losses_ = model.loss(x1, x2, y1, y2, actors, actions, cross=train_cross, reconstruction=train_ae, emb_adv=(train_emb_adv and use_adv_loss), discrim_adv=(train_discrim and use_adv_loss))
                val_losses.append(loss.item())
                val_rec_loss.append(losses_['rec_loss'])
                val_cross_loss.append(losses_['cross_loss'])
                val_end_effector_loss.append(losses_['end_effector_loss'])
                val_triplet_loss.append(losses_['triplet_loss'])
                val_latent_consistency_loss.append(losses_['latent_consistency_loss'])
                val_privacy_loss.append(losses_['privacy_loss'])
                val_privacy_loss_dyn.append(losses_['privacy_loss_dyn'])
                val_privacy_loss_stat.append(losses_['privacy_loss_stat'])
                val_privacy_acc_dyn.append(losses_['privacy_acc_dyn'])
                val_privacy_acc_stat.append(losses_['privacy_acc_stat'])
                val_utility_loss.append(losses_['utility_loss'])
                val_utility_loss_dyn.append(losses_['utility_loss_dyn'])
                val_utility_loss_stat.append(losses_['utility_loss_stat'])
                val_utility_acc_dyn.append(losses_['utility_acc_dyn'])
                val_utility_acc_stat.append(losses_['utility_acc_stat'])
                val_discriminator_loss.append(losses_['discriminator_loss'])
                val_discriminator_acc.append(losses_['discriminator_acc'])

                if run_sgn_eval:
                    eval_X_known.append(x1.cpu().numpy())
                    eval_X_known.append(x2.cpu().numpy())
                    eval_X_known.append(y1.cpu().numpy())
                    eval_X_known.append(y2.cpu().numpy())

                    eval_Y_known.append(actions[0].cpu().numpy())
                    eval_Y_known.append(actions[1].cpu().numpy())
                    eval_Y_known.append(actions[0].cpu().numpy())
                    eval_Y_known.append(actions[1].cpu().numpy())

                    eval_X_rec.append(x1_hat.cpu().numpy())
                    eval_X_rec.append(x2_hat.cpu().numpy())
                    eval_X.append(y1_hat.cpu().numpy())
                    eval_X.append(y2_hat.cpu().numpy())

                    eval_Y_rec.append(actions[0].cpu().numpy())
                    eval_Y_rec.append(actions[1].cpu().numpy())
                    eval_Y.append(actions[0].cpu().numpy())
                    eval_Y.append(actions[1].cpu().numpy())

    # Print loss/accuracy
    print(f'--------------------\nEpoch {epoch+1}/{epochs}')
    
    if train_ae or train_cross:
        print(f'Training Loss:\t\t\t{np.mean(losses)}\nValidation Loss:\t\t{np.mean(val_losses)}')
        print('\nTraining Losses:')
        print(f'Reconstruction Loss:\t\t{np.mean(rec_loss)}\nCross Reconstruction Loss:\t{np.mean(cross_loss)}\nEnd Effector Loss:\t\t{np.mean(end_effector_loss)}\nTriplet Loss:\t\t\t{np.mean(triplet_loss)}\nLatent Consistency Loss:\t{np.mean(latent_consistency_loss)}')
        print(f'Privacy Loss:\t\t\t{np.mean(privacy_loss)}\nPrivacy Loss Dyn:\t\t{np.mean(privacy_loss_dyn)}\nPrivacy Loss Stat:\t\t{np.mean(privacy_loss_stat)}')
        print(f'Utility Loss:\t\t\t{np.mean(utility_loss)}\nUtility Loss Dyn:\t\t{np.mean(utility_loss_dyn)}\nUtility Loss Stat:\t\t{np.mean(utility_loss_stat)}')
        print(f'Discriminator Loss:\t\t{np.mean(discriminator_loss)}')

    if run_eval:
        print('\nValidation Losses:')
        print(f'Val Reconstruction Loss:\t{np.mean(val_rec_loss)}\nVal Cross Reconstruction Loss:\t{np.mean(val_cross_loss)}\nVal End Effector Loss:\t\t{np.mean(val_end_effector_loss)}\nVal Triplet Loss:\t\t{np.mean(val_triplet_loss)}\nVal Latent Consistency Loss:\t{np.mean(val_latent_consistency_loss)}')
        print(f'Val Privacy Loss:\t\t{np.mean(val_privacy_loss)}\nVal Privacy Loss Dyn:\t\t{np.mean(val_privacy_loss_dyn)}\nVal Privacy Loss Stat:\t\t{np.mean(val_privacy_loss_stat)}')
        print(f'Val Utility Loss:\t\t{np.mean(val_utility_loss)}\nVal Utility Loss Dyn:\t\t{np.mean(val_utility_loss_dyn)}\nVal Utility Loss Stat:\t\t{np.mean(val_utility_loss_stat)}')
        print(f'Val Discriminator Loss:\t\t{np.mean(val_discriminator_loss)}')
    
    if train_emb_adv or train_discrim:
        print('\nAdversary Losses')
        if train_emb_adv:
            print(f'Privacy Training Loss:\t\t{np.mean(priv_training_loss)}\nUtility Training Loss:\t\t{np.mean(util_training_loss)}\nDiscriminator Training Loss:\t{np.mean(discriminator_train_losses)}')
            if train_ae or train_cross:
                print(f'Privacy Acc Dyn:\t\t{np.mean(privacy_acc_dyn)}\nPrivacy Acc Stat:\t\t{np.mean(privacy_acc_stat)}')
                print(f'Utility Acc Dyn:\t\t{np.mean(utility_acc_dyn)}\nUtility Acc Stat:\t\t{np.mean(utility_acc_stat)}')
            print(f'Val Privacy Acc Dyn:\t\t{np.mean(val_privacy_acc_dyn)}\nVal Privacy Acc Stat:\t\t{np.mean(val_privacy_acc_stat)}')
            print(f'Val Utility Acc Dyn:\t\t{np.mean(val_utility_acc_dyn)}\nVal Utility Acc Stat:\t\t{np.mean(val_utility_acc_stat)}')
    
    if train_ae or train_cross: print(f'Discriminator Acc:\t\t{np.mean(discriminator_training_acc)}')
    if run_eval: print(f'Val Discriminator Acc:\t\t{np.mean(val_discriminator_acc)}')

    # Save model
    if save and np.mean(val_losses) < best_loss:
        best_loss = np.mean(val_losses)
        torch.save(model.state_dict(), 'pretrained/MR.pt')

    # Test Accuracy
    # TODO: Add re-identification model to eval
    if run_sgn_eval and run_eval:
        eval_X_known = np.concatenate(eval_X_known)
        eval_X_rec = np.concatenate(eval_X_rec)
        eval_X = np.concatenate(eval_X)
        
        eval_X_known = np.pad(eval_X_known, ((0,0), (0,225), (0,75)), 'constant')
        eval_X_rec = np.pad(eval_X_rec, ((0,0), (0,225), (0,75)), 'constant')
        eval_X = np.pad(eval_X, ((0,0), (0,225), (0,75)), 'constant')

        eval_Y = np.concatenate(eval_Y) - 1
        eval_Y = np.eye(utility_classes)[eval_Y.astype(int)]
        eval_Y_known = np.concatenate(eval_Y_known) - 1
        eval_Y_known = np.eye(utility_classes)[eval_Y_known.astype(int)]
        eval_Y_rec = np.concatenate(eval_Y_rec) - 1
        eval_Y_rec = np.eye(utility_classes)[eval_Y_rec.astype(int)]

        #const
        acc_known, f1_known, prec_known, recall_known = sgn_eval(sgn_train_x, sgn_train_y, eval_X_known, eval_Y_known, sgn_val_x, sgn_val_y, 1, sgn_ar)
        #rec
        acc_rec, f1_rec, prec_rec, recall_rec = sgn_eval(sgn_train_x, sgn_train_y, eval_X_rec, eval_Y_rec, sgn_val_x, sgn_val_y, 1, sgn_ar)
        #cross
        acc_cross, f1_cross, prec_cross, recall_cross = sgn_eval(sgn_train_x, sgn_train_y, eval_X, eval_Y, sgn_val_x, sgn_val_y, 1, sgn_ar)
        print(f'\nConstant Accuracy:\t{acc_known}\nRec Accuracy:\t\t{acc_rec}\nCross Accuracy:\t\t{acc_cross}\n')
        print(f'Constant F1:\t\t{f1_known}\nRec F1:\t\t\t{f1_rec}\nCross F1:\t\t{f1_cross}\n')
        print(f'Constant Precision:\t{prec_known}\nRec Precision:\t\t{prec_rec}\nCross Precision:\t{prec_cross}\n')
        print(f'Constant Recall:\t{recall_known}\nRec Recall:\t\t{recall_rec}\nCross Recall:\t\t{recall_cross}\n')
    else: print('\n')

In [22]:
training_stages = [
    {'epochs': 1000, 'ae': True, 'cross': False, 'discrim': False, 'emb': True, 'eval': True, 'use_adv_loss': False, 'sgn_eval': False, 'save': False},
    {'epochs': 200, 'ae': False, 'cross': False, 'discrim': True, 'emb': False, 'eval': True, 'use_adv_loss': False, 'sgn_eval': True, 'save': False},
    {'epochs': 500, 'ae': True, 'cross': True, 'discrim': False, 'emb': True, 'eval': True, 'use_adv_loss': True, 'sgn_eval': True, 'save': True}
]

for stage in training_stages:
    print('\n\n\nMoving to new stage\n\n\n')
    print(stage)
    for epoch in range(stage['epochs']):
        if stage['sgn_eval']:
            if validation_acc_freq > 0 and epoch % validation_acc_freq == 0: use_sgn = True
            else: use_sgn = False
        else: use_sgn = False
        train(epoch, train_ae=stage['ae'], train_cross=stage['cross'], train_discrim=stage['discrim'], train_emb_adv=stage['emb'], run_eval=stage['eval'], use_adv_loss=stage['use_adv_loss'], run_sgn_eval= use_sgn, save=stage['save'])




Moving to new stage



{'epochs': 1, 'ae': True, 'cross': False, 'discrim': False, 'emb': True, 'eval': True, 'use_adv_loss': False, 'sgn_eval': False, 'save': False}
--------------------
Epoch 1/500
Training Loss:			32.08292624601133
Validation Loss:		31.797529458999634

Training Losses:
Reconstruction Loss:		10.45112990118136
Cross Reconstruction Loss:	0.0
End Effector Loss:		0.021308674221965158
Triplet Loss:			4.01490972148385
Latent Consistency Loss:	0.2215024739219125
Privacy Loss:			0.0
Privacy Loss Dyn:		0.0
Privacy Loss Stat:		0.0
Utility Loss:			0.0
Utility Loss Dyn:		0.0
Utility Loss Stat:		0.0
Discriminator Loss:		0.0

Validation Losses:
Val Reconstruction Loss:	10.38505944609642
Val Cross Reconstruction Loss:	0.0
Val End Effector Loss:		0.016292909043841064
Val Triplet Loss:		3.999788910150528
Val Latent Consistency Loss:	0.16090828622691333
Val Privacy Loss:		0.0
Val Privacy Loss Dyn:		0.0
Val Privacy Loss Stat:		0.0
Val Utility Loss:		0.0
Val Utility Loss Dyn:		0.0
Va

In [None]:
train(501, run_sgn_eval=True)

In [None]:
# manual save
#torch.save(model.state_dict(), 'pretrained/MR.pt')

# Retargeting

In [None]:
def retarget():
    X_hat_random = {}
    X_hat_constant = {}

    x2_const = X[random.sample(list(X.keys()), 1)[0]].float().cuda().unsqueeze(0)
    with torch.no_grad():
        for file in X:
            x1 = X[file].unsqueeze(0)
            x2_random = X[random.sample(list(X.keys()), 1)[0]].unsqueeze(0)
            X_hat_random[file] = model.eval(x1.float().cuda(), x2_random.float().cuda()).cpu().numpy().squeeze()
            X_hat_constant[file] = model.eval(x1.float().cuda(), x2_const).cpu().numpy().squeeze()

    # Save results
    with open('results/X_hat_random.pkl', 'wb') as f:
        pickle.dump(X_hat_random, f)
    with open('results/X_hat_constant.pkl', 'wb') as f:
        pickle.dump(X_hat_constant, f)

# retarget()