In [1]:
import pickle
import random
import numpy as np
import os
import os.path as osp
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
from SGN.model import SGN
from SGN.data import NTUDataLoaders, AverageMeter
from SGN.util import make_dir, get_num_classes
from sklearn.metrics import confusion_matrix, accuracy_score, f1_score, precision_score, recall_score

In [2]:
# Hyper Parameters
dataset = 'NTU'
device = torch.device('cuda:0')
# device = torch.device('cpu')
seg = 20
lr = 5e-5
epochs = 500
utility_classes = 120
privacy_classes = 106
validation_acc_freq = 10 #-1 to disable

# Data

In [3]:
with open('ntu/X.pkl', 'rb') as f:
    X = pickle.load(f)

In [4]:
# clean data
to_del = []
for file in X:
    if type(X[file]) == list:
        to_del.append(file)
print('to delete', len(to_del))
for file in to_del:
    del X[file]

to delete 28814


In [5]:
# pad or trim data to 75 frames. when padding, repeat the last frame
# input is of shape (frames, 75)
T = 75
for file in X:
    if X[file].shape[0] < T:
        X[file] = np.pad(X[file], ((0, T - X[file].shape[0]), (0, 0)), mode='edge')
    elif X[file].shape[0] > T:
        X[file] = X[file][:T, :]

In [6]:
for file in X:
    X[file] = torch.tensor(X[file])

In [7]:
a = {}
p = {}
for file in X:
    if file[16:20] not in a:
        a[file[16:20]] = {}
    if file[8:12] not in a[file[16:20]]:
        a[file[16:20]][file[8:12]] = []
    a[file[16:20]][file[8:12]].append(file)
    
    if file[8:12] not in p:
        p[file[8:12]] = set()
    p[file[8:12]].add(file[16:20])

In [8]:
def gen_samples(samples):
    x, y = [], []
    for _ in range(samples):
        # sample two random p
        p1, p2 = random.sample(list(p.keys()), 2)
        # find overlapping a
        a1 = p[p1]
        a2 = p[p2]
        a12 = a1.intersection(a2)
        if len(a12) == 0:
            continue
        # sample two random a
        a1, a2 = random.sample(list(a12), 2)
        # sample x and y
        x1 = random.sample(a[a1][p1], 1)[0]
        x2 = random.sample(a[a2][p2], 1)[0]
        y1 = random.sample(a[a1][p2], 1)[0]
        y2 = random.sample(a[a2][p1], 1)[0]
        x.append([x1, x2])
        y.append([y1, y2])
    return x, y

batch_size = 32
train_x, train_y = gen_samples(30000)
val_x, val_y = gen_samples(10000)

In [9]:
class Data(Dataset):
    def __init__(self, X, y):
        self.X = X
        self.y = y
        
    def __getitem__(self, index):
        actors = [float(self.X[index][0][9:12]), float(self.X[index][1][9:12])]
        actions = [float(self.X[index][0][17:20]), float(self.X[index][1][17:20])]
        return X[self.X[index][0]], X[self.X[index][1]], X[self.y[index][0]],  X[self.y[index][1]], actors, actions
    
    def __len__(self):
        return len(self.X)

In [10]:
train_data = Data(train_x, train_y)
train_dl = DataLoader(train_data, batch_size=batch_size, shuffle=True)
val_data = Data(val_x, val_y)
val_dl = DataLoader(val_data, batch_size=batch_size, shuffle=True)

# Model

## Adversary

In [11]:
# Input is size of latent space
class Adversary(nn.Module):
    def __init__(self, num_classes):
        super(Adversary, self).__init__()
        self.conv = nn.ConvTranspose1d(192, 64, 3, 1, 1)
        self.ref = nn.ReflectionPad1d(3)
        self.pool = nn.MaxPool1d(2, 2)
        self.fc1 = nn.Linear(64 * 8, 256)
        self.fc2 = nn.Linear(256, num_classes)

    def forward(self, x):
        x = self.ref(x)
        x = self.conv(x)
        x = self.pool(x)
        x = x.view(x.shape[0], -1)
        x = F.relu(self.fc1(x))
        x = F.softmax(self.fc2(x), dim=1)
        return x

## Motion Retargeting

In [12]:
class Encoder(nn.Module):
    def __init__(self):
        super(Encoder, self).__init__()

        self.enc1 = nn.Conv1d(in_channels=75, out_channels=96, kernel_size=3, stride=1, padding=1)
        self.enc2 = nn.Conv1d(in_channels=96, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.enc3 = nn.Conv1d(in_channels=128, out_channels=160, kernel_size=3, stride=1, padding=1)
        self.enc4 = nn.Conv1d(in_channels=160, out_channels=192, kernel_size=3, stride=1, padding=1)
        self.ref1 = nn.ReflectionPad1d(3)
        self.ref2 = nn.ReflectionPad1d(3)
        self.ref3 = nn.ReflectionPad1d(3)
        self.ref4 = nn.ReflectionPad1d(3)

        self.pool = nn.MaxPool1d(kernel_size=2, stride=2)

        self.acti = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.ref1(x)
        x = self.acti(self.enc1(x))
        x = self.pool(x)
        
        x = self.ref2(x)
        x = self.acti(self.enc2(x))
        x = self.pool(x)
        
        x = self.ref3(x)
        x = self.acti(self.enc3(x))
        x = self.pool(x)

        x = self.ref4(x)
        x = self.acti(self.enc4(x))
        x = self.pool(x)

        return x

class Decoder(nn.Module):
    def __init__(self):
        super(Decoder, self).__init__()

        self.dec1 = nn.ConvTranspose1d(in_channels=384, out_channels=160, kernel_size=3, stride=1, padding=1)
        self.dec2 = nn.ConvTranspose1d(in_channels=160, out_channels=128, kernel_size=3, stride=1, padding=1)
        self.dec3 = nn.ConvTranspose1d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=1)
        self.dec4 = nn.ConvTranspose1d(in_channels=96, out_channels=75, kernel_size=3, stride=1, padding=1)

        self.ref1 = nn.ReflectionPad1d(3)
        self.ref2 = nn.ReflectionPad1d(3)
        self.ref3 = nn.ReflectionPad1d(3)
        self.ref4 = nn.ReflectionPad1d(3)
 
        self.up = nn.Upsample(scale_factor=2, mode='nearest')
        self.up75 = nn.Upsample(size=75, mode='nearest') 

        self.acti = nn.LeakyReLU(0.2)

    def forward(self, x):
        x = self.ref1(x)
        x = self.acti(self.dec1(x))
        x = self.up(x)

        x = self.ref2(x)
        x = self.acti(self.dec2(x))
        x = self.up(x)

        x = self.ref3(x)
        x = self.acti(self.dec3(x))
        x = self.up(x)

        x = self.ref4(x)
        x = self.acti(self.dec4(x))
        x = self.up75(x)
        return x

class AutoEncoder(nn.Module):
    def __init__(self, adv_lr=1e-3):
        super(AutoEncoder, self).__init__()

        # AutoEncoder Models
        self.static_encoder = Encoder()
        self.dynamic_encoder = Encoder()
        self.decoder = Decoder()

        # Adversarial Models
        self.priv_adv = Adversary(privacy_classes)
        self.util_adv = Adversary(utility_classes)
        self.priv_optim = torch.optim.Adam(self.priv_adv.parameters(), lr=adv_lr)
        self.util_optim = torch.optim.Adam(self.util_adv.parameters(), lr=adv_lr)

        # Freeze Adversarial Models
        self.priv_adv.eval()
        self.util_adv.eval()

        # Loss Functions
        self.triplet_loss = nn.TripletMarginLoss()
        
        # Info for loss functions
        self.end_effectors = torch.tensor([19, 15, 23, 24, 21, 22, 3]).to(device) * 3
        self.chain_lengths = torch.tensor([5, 5, 8, 8, 8, 8, 5]).to(device)

        # Lambdas for discounted loss
        self.lambda_rec = 1
        self.lambda_cross = 1
        self.lambda_ee = 2
        self.lambda_trip = 0.25
        self.lambda_adv_util = 0.5
        self.lambda_adv_priv = 0.5

    def cross(self, x1, x2):
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)
        
        x1_hat = self.decoder(torch.cat((d1, s1), dim=1))
        x2_hat = self.decoder(torch.cat((d2, s2), dim=1))
        y1_hat = self.decoder(torch.cat((d1, s2), dim=1))
        y2_hat = self.decoder(torch.cat((d2, s1), dim=1))

        return x1_hat, x2_hat, y1_hat, y2_hat
    
    def eval(self, x1, x2):
        dynamic = self.dynamic_encoder(x1)
        static = self.static_encoder(x2)
        return self.decoder(torch.cat((dynamic, static), dim=1))
    
    def loss(self, x1, x2, y1, y2, actors, actions):
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)

        x1_hat = self.decoder(torch.cat((d1, s1), dim=1))
        x2_hat = self.decoder(torch.cat((d2, s2), dim=1))
        y1_hat = self.decoder(torch.cat((d1, s2), dim=1))
        y2_hat = self.decoder(torch.cat((d2, s1), dim=1))

        d12 = self.dynamic_encoder(y1)
        d21 = self.dynamic_encoder(y2)
        s12 = self.static_encoder(y1)
        s21 = self.static_encoder(y2)
                
        # reconstruction loss
        rec_loss = self.reconstruction_loss(x1, x1_hat) + self.reconstruction_loss(x2, x2_hat)
        # print('Reconstruction Loss: ', rec_loss.item())
        
        # cross reconstruction loss
        cross_loss = self.cross_loss(y1, y2, y1_hat, y2_hat)
        # print('Cross Reconstruction Loss: ', cross_loss.item())
        
        # end effector loss
        end_effector_loss = self.end_effector_loss(x1_hat, x1) + self.end_effector_loss(x2_hat, x2) + self.end_effector_loss(y1_hat, y1) + self.end_effector_loss(y2_hat, y2)
        # print('End Effector Loss: ', end_effector_loss.item())

        # triplet loss
        triplet_loss = self.triplet_loss(d12, d1, d2) + self.triplet_loss(d21, d2, d1) + self.triplet_loss(s12, s1, s2) + self.triplet_loss(s21, s2, s1) 
        # print('Triplet Loss: ', triplet_loss.item())

        # sgn latent privacy loss (adversarial)
        adv_priv_y1, adv_priv_y2 = actors[0] - 1, actors[1] - 1
        adv_priv_y1, adv_priv_y2 = torch.eye(privacy_classes)[adv_priv_y1.long()].to(device), torch.eye(privacy_classes)[adv_priv_y2.long()].to(device)
        privacy_loss_dyn = -(self.adv_loss(self.priv_adv, d1, adv_priv_y1) + self.adv_loss(self.priv_adv, d2, adv_priv_y2))
        privacy_loss_stat = self.adv_loss(self.priv_adv, s1, adv_priv_y1) + self.adv_loss(self.priv_adv, s2, adv_priv_y2)
        privacy_loss = privacy_loss_dyn + privacy_loss_stat
        # print('Privacy Loss Dynamic: ', privacy_loss_dyn.item(), '\tPrivacy Loss Static: ', privacy_loss_stat.item(), '\tPrivacy Loss: ', privacy_loss.item())
        
        # sgn latent utility loss (adversarial)
        adv_util_y1, adv_util_y2 = actions[0] - 1, actions[1] - 1
        adv_util_y1, adv_util_y2 = torch.eye(utility_classes)[adv_util_y1.long()].to(device), torch.eye(utility_classes)[adv_util_y2.long()].to(device)
        utility_loss_dyn = self.adv_loss(self.util_adv, d1, adv_util_y1) + self.adv_loss(self.util_adv, d2, adv_util_y2)
        utility_loss_stat = -(self.adv_loss(self.util_adv, s1, adv_util_y1) + self.adv_loss(self.util_adv, s2, adv_util_y2))
        utility_loss = utility_loss_dyn + utility_loss_stat
        # print('Utility Loss Dynamic: ', utility_loss_dyn.item(), '\tUtility Loss Static: ', utility_loss_stat.item(), '\tUtility Loss: ', utility_loss.item())

        return rec_loss * self.lambda_rec + cross_loss * self.lambda_cross + end_effector_loss * self.lambda_ee + triplet_loss * self.lambda_trip + privacy_loss * self.lambda_adv_priv + utility_loss * self.lambda_adv_util, \
                x1_hat, x2_hat, y1_hat, y2_hat

    def reconstruction_loss(self, x, y):
        # return torch.square(torch.norm(x - y, dim=1)).mean()
        return F.mse_loss(x, y)
    
    def cross_loss(self, x1, x2, y1, y2):
        return F.mse_loss(x1, y1) + F.mse_loss(x2, y2)
        # return torch.square(torch.norm(x1 - y2, dim=1)).mean() + torch.square(torch.norm(x2 - y1, dim=1)).mean()
    
    def end_effector_loss(self, x, y):
        # slice to get the end effector joints
        x_ee = x[:, :, self.end_effectors.unsqueeze(-1) + torch.arange(3).to(device)] 
        y_ee = y[:, :, self.end_effectors.unsqueeze(-1) + torch.arange(3).to(device)]

        # calculate velocities
        x_vel = torch.norm(x_ee[:, 1:] - x_ee[:, :-1], dim=-1) / self.chain_lengths.unsqueeze(0)
        y_vel = torch.norm(y_ee[:, 1:] - y_ee[:, :-1], dim=-1) / self.chain_lengths.unsqueeze(0)
        
        # compute mse loss for each joint
        losses = F.mse_loss(x_vel, y_vel, reduction='none')

        # take sum over end effectors
        loss = losses.sum(dim=1)

        # take mean over batch
        loss = loss.mean()
        
        return loss
    
    def adv_loss(self, model, x, y):
        return F.mse_loss(model(x), y)

    def train_adv(self, x1, x2, actors, actions):
        # freeze encoders/decoder
        self.dynamic_encoder.eval()
        self.static_encoder.eval()
        self.decoder.eval()

        # unfreeze adversaries
        self.priv_adv.train()
        self.util_adv.train()

        # zero out gradients
        self.priv_optim.zero_grad()
        self.util_optim.zero_grad()

        # encode
        d1 = self.dynamic_encoder(x1)
        d2 = self.dynamic_encoder(x2)
        s1 = self.static_encoder(x1)
        s2 = self.static_encoder(x2)

        # train privacy adversary
        adv_priv_y1, adv_priv_y2 = actors[0] - 1, actors[1] - 1
        adv_priv_y1, adv_priv_y2 = torch.eye(privacy_classes)[adv_priv_y1.long()].to(device), torch.eye(privacy_classes)[adv_priv_y2.long()].to(device)
        priv_loss = F.mse_loss(self.priv_adv(d1), adv_priv_y1) + F.mse_loss(self.priv_adv(d2), adv_priv_y2)
        priv_loss.backward(retain_graph=True)
        self.priv_optim.step()
        
        # train utility adversary
        adv_util_y1, adv_util_y2 = actions[0] - 1, actions[1] - 1
        adv_util_y1, adv_util_y2 = torch.eye(utility_classes)[adv_util_y1.long()].to(device), torch.eye(utility_classes)[adv_util_y2.long()].to(device)
        util_loss = F.mse_loss(self.util_adv(d1), adv_util_y1) + F.mse_loss(self.util_adv(d2), adv_util_y2)
        util_loss.backward(retain_graph=True)
        self.util_optim.step()

        # unfreeze encoders/decoder
        self.dynamic_encoder.train()
        self.static_encoder.train()
        self.decoder.train()

        # freeze adversaries
        self.priv_adv.eval()
        self.util_adv.eval()

    def forward(self, x):
        dyn = self.dynamic_encoder(x)   
        sta = self.static_encoder(x)
        x = self.decoder(torch.cat((dyn, sta), dim=1))
        return x

## Utility/Privacy Evaluation

In [13]:
def test(test_loader, model):
    acces = AverageMeter()
    # load learnt model that obtained best performance on validation set
    model.eval()

    label_output = list()
    pred_output = list()

    for i, t in enumerate(test_loader):
        inputs = t[0]
        target = t[1]
        with torch.no_grad():
            output = model(inputs.cuda())
            output = output.view(
                (-1, inputs.size(0)//target.size(0), output.size(1)))
            output = output.mean(1)

        label_output.append(target.cpu().numpy())
        pred_output.append(output.cpu().numpy())

        acc = accuracy(output.data, target.cuda())
        acces.update(acc[0], inputs.size(0))

    label_output = np.concatenate(label_output, axis=0)
    pred_output = np.concatenate(pred_output, axis=0)

    label_index = np.argmax(label_output, axis=1)
    pred_index = np.argmax(pred_output, axis=1)

    f1 = f1_score(label_index, pred_index, average='macro', zero_division=0)
    precision = precision_score(label_index, pred_index, average='macro', zero_division=0)
    recall = recall_score(label_index, pred_index, average='macro', zero_division=0)

    return acces.avg, f1, precision, recall
    
def accuracy(output, target):
    batch_size = target.size(0)
    _, pred = output.topk(1, 1, True, True)
    pred = pred.t()
    target = torch.argmax(target, dim=1)  # Add this line to convert one-hot targets to class indices
    correct = pred.eq(target.view(1, -1).expand_as(pred))
    correct = correct.view(-1).float().sum(0, keepdim=True)
    return correct.mul_(100.0 / batch_size)
    
def sgn_eval(train_x, train_y, test_x, test_y, val_x, val_y, case, model):
    # Data loading
    ntu_loaders = NTUDataLoaders(dataset, case, seg=20, train_X=train_x, train_Y=train_y, test_X=test_x, test_Y=test_y, val_X=val_x, val_Y=val_y, aug=0)
    test_loader = ntu_loaders.get_test_loader(batch_size, 16)

    # Test
    return test(test_loader, model)

# Instantiate Models

In [14]:
model = AutoEncoder().to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.ExponentialLR(optimizer, gamma=0.99)

In [15]:
load_model = True
if load_model:
    model.load_state_dict(torch.load('pretrained/MR.pt'))

In [16]:
load_util = True
train_util = False
sgn_ar = SGN(utility_classes, None, seg, batch_size, 0).to(device)
sgn_priv = SGN(privacy_classes, None, seg, batch_size, 0).to(device)

if load_util:
    sgn_priv.load_state_dict(torch.load('SGN/pretrained/privacy.pt')['state_dict'])
    sgn_ar.load_state_dict(torch.load('SGN/pretrained/action.pt')['state_dict'])

# Training

## Train Motion Retargeting

In [17]:
sgn_train_x = np.zeros((batch_size, 300, 150))
sgn_train_y = np.zeros((batch_size, 1))
sgn_val_x = np.zeros((batch_size, 300, 150))
sgn_val_y = np.zeros((batch_size, 1))

best_loss = 0#float('inf')

for epoch in range(epochs):
    losses = []
    eval_X_known = []
    eval_Y_known = []
    eval_X_rec = []
    eval_Y_rec = []
    eval_X = []
    eval_Y = []

    for (x1, x2, y1, y2, actors, actions) in train_dl:
        # Move tensors to the configured device
        x1, x2, y1, y2 = x1.float().to(device), x2.float().to(device), y1.float().to(device), y2.float().to(device)
        
        # Train adversaries
        model.train_adv(x1, x2, actors, actions)

        # Zero the gradients
        optimizer.zero_grad()

        # Forward pass
        loss, _, _, _, _ = model.loss(x1, x2, y1, y2, actors, actions)

        # Backward and optimize
        loss.backward()
        optimizer.step()

        # Track the loss
        losses.append(loss.item())

    # Decay learning rate
    scheduler.step()

    # Validation
    with torch.no_grad():
        val_losses = []
        
        for (x1, x2, y1, y2, actors, actions) in val_dl:
            x1, x2, y1, y2 = x1.float().to(device), x2.float().to(device), y1.float().to(device), y2.float().to(device)
            
            loss, x1_hat, x2_hat, y1_hat, y2_hat = model.loss(x1, x2, y1, y2, actors, actions)
            val_losses.append(loss.item())

            if validation_acc_freq != -1 and (epoch+1) % validation_acc_freq == 0:
                eval_X_known.append(x1.cpu().numpy())
                eval_X_known.append(x2.cpu().numpy())
                eval_X_known.append(y1.cpu().numpy())
                eval_X_known.append(y2.cpu().numpy())

                eval_Y_known.append(actions[0].cpu().numpy())
                eval_Y_known.append(actions[1].cpu().numpy())
                eval_Y_known.append(actions[0].cpu().numpy())
                eval_Y_known.append(actions[1].cpu().numpy())

                eval_X_rec.append(x1_hat.cpu().numpy())
                eval_X_rec.append(x2_hat.cpu().numpy())
                eval_X.append(y1_hat.cpu().numpy())
                eval_X.append(y2_hat.cpu().numpy())

                eval_Y_rec.append(actions[0].cpu().numpy())
                eval_Y_rec.append(actions[1].cpu().numpy())
                eval_Y.append(actions[0].cpu().numpy())
                eval_Y.append(actions[1].cpu().numpy())

    # Print loss/accuracy
    print(f'--------------------\nEpoch {epoch+1}/{epochs}\nTraining Loss:\t\t{np.mean(losses)}\nValidation Loss:\t{np.mean(val_losses)}')

    # Save model
    if np.mean(val_losses) < best_loss:
        best_loss = np.mean(val_losses)
        torch.save(model.state_dict(), 'pretrained/MR.pt')

    # Test Accuracy
    if validation_acc_freq != -1 and (epoch+1) % validation_acc_freq == 0:
        eval_X_known = np.concatenate(eval_X_known)
        eval_X_rec = np.concatenate(eval_X_rec)
        eval_X = np.concatenate(eval_X)
        
        eval_X_known = np.pad(eval_X_known, ((0,0), (0,225), (0,75)), 'constant')
        eval_X_rec = np.pad(eval_X_rec, ((0,0), (0,225), (0,75)), 'constant')
        eval_X = np.pad(eval_X, ((0,0), (0,225), (0,75)), 'constant')

        eval_Y = np.concatenate(eval_Y) - 1
        eval_Y = np.eye(utility_classes)[eval_Y.astype(int)]
        eval_Y_known = np.concatenate(eval_Y_known) - 1
        eval_Y_known = np.eye(utility_classes)[eval_Y_known.astype(int)]
        eval_Y_rec = np.concatenate(eval_Y_rec) - 1
        eval_Y_rec = np.eye(utility_classes)[eval_Y_rec.astype(int)]

        #const
        acc_known, f1_known, prec_known, recall_known = sgn_eval(sgn_train_x, sgn_train_y, eval_X_known, eval_Y_known, sgn_val_x, sgn_val_y, 1, sgn_ar)
        #rec
        acc_rec, f1_rec, prec_rec, recall_rec = sgn_eval(sgn_train_x, sgn_train_y, eval_X_rec, eval_Y_rec, sgn_val_x, sgn_val_y, 1, sgn_ar)
        #cross
        acc_cross, f1_cross, prec_cross, recall_cross = sgn_eval(sgn_train_x, sgn_train_y, eval_X, eval_Y, sgn_val_x, sgn_val_y, 1, sgn_ar)
        print(f'\nConstant Accuracy:\t{acc_known}\nRec Accuracy:\t\t{acc_rec}\nCross Accuracy:\t\t{acc_cross}\n')
        print(f'Constant F1:\t\t{f1_known}\nRec F1:\t\t\t{f1_rec}\nCross F1:\t\t{f1_cross}\n')
        print(f'Constant Precision:\t{prec_known}\nRec Precision:\t\t{prec_rec}\nCross Precision:\t{prec_cross}\n')
        print(f'Constant Recall:\t{recall_known}\nRec Recall:\t\t{recall_rec}\nCross Recall:\t\t{recall_cross}\n')
    else: print('\n')

--------------------
Epoch 1/500
Training Loss:		0.7747870730166149
Validation Loss:	0.7722333897365613


--------------------
Epoch 2/500
Training Loss:		0.7685709472437551
Validation Loss:	0.7617712104588412


--------------------
Epoch 3/500
Training Loss:		0.7679472823340193
Validation Loss:	0.7668127491233054


--------------------
Epoch 4/500
Training Loss:		0.7661604440974114
Validation Loss:	0.761003422268321


--------------------
Epoch 5/500
Training Loss:		0.7634297315227357
Validation Loss:	0.7591296588436941


--------------------
Epoch 6/500
Training Loss:		0.7602364404084987
Validation Loss:	0.7555258073163836


--------------------
Epoch 7/500
Training Loss:		0.7595639066364532
Validation Loss:	0.7619602998320976


--------------------
Epoch 8/500
Training Loss:		0.7576761616575987
Validation Loss:	0.7631520950392391


--------------------
Epoch 9/500
Training Loss:		0.755744564690088
Validation Loss:	0.7545436856786857


--------------------
Epoch 10/500
Training Loss:

In [51]:
torch.save(model.state_dict(), 'pretrained/MR.pt')

# Retargeting

In [66]:
X_hat_random = {}
X_hat_constant = {}

x2_const = X[random.sample(list(X.keys()), 1)[0]].float().cuda().unsqueeze(0)
with torch.no_grad():
    for file in X:
        x1 = X[file].unsqueeze(0)
        x2_random = X[random.sample(list(X.keys()), 1)[0]].unsqueeze(0)
        X_hat_random[file] = model.eval(x1.float().cuda(), x2_random.float().cuda()).cpu().numpy().squeeze()
        X_hat_constant[file] = model.eval(x1.float().cuda(), x2_const).cpu().numpy().squeeze()

# Save results
with open('results/X_hat_random.pkl', 'wb') as f:
    pickle.dump(X_hat_random, f)
with open('results/X_hat_constant.pkl', 'wb') as f:
    pickle.dump(X_hat_constant, f)

since Python 3.9 and will be removed in a subsequent version.
  x2_const = X[random.sample(X.keys(), 1)[0]].float().cuda().unsqueeze(0)
since Python 3.9 and will be removed in a subsequent version.
  x2_random = X[random.sample(X.keys(), 1)[0]].unsqueeze(0)
