In [7]:
import torch
import torch.nn as nn
from torchvision import transforms as tfs
from torch.utils.data import DataLoader
from utils import hyperparameters
from dataset import RobotDataset, dataset_explore
from generator import Generator
from discriminator import Discriminator
from training import train_gan, test_gan
from losses import generator_loss, discriminator_loss
from metrics import ADE, FDE
import numpy as np 
from training import trimm, reconstruct
from matplotlib import pyplot as plt

In [2]:
#basic_path = '/media/felpipe/Archivos HDD/SocLab/'
path = '/home/felpipe/proyectos/Tesis/SocialPlayGround/Dataset/'

seq_len = dataset_explore(path)

netparams = hyperparameters(w=320, 
                            h=239, 
                            latent_dim=128, 
                            history_length=8, 
                            future_length=12,
                            cnn_filters=["16", "32", "64", "128", "256"],
                            lin_neurons=["256", "256"],
                            enc_layers=2,
                            lstm_dim=128,
                            output_dim=8,
                            up_criterion=0.9,
                            down_criterion=0.0,
                            alpha=0.15,
                            beta=0.15,
                            attention='add')                            

data_transforms = tfs.Compose([tfs.Resize((320, 239)),
                               tfs.ToTensor(),
                               tfs.Normalize([0.5, 0.5, 0.5, 0.5], [0.5, 0.5, 0.5, 0.5])])

data_set = RobotDataset(path, 128, seq_len, data_transforms)

data_loader = DataLoader(data_set, batch_size=8, shuffle=False)

In [3]:
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
gen = Generator(netparams, device).to(device)
dis = Discriminator(netparams, device).to(device)

gen.eval()
dis.eval()
for batch in data_loader:
    break

  return torch._C._cuda_getDeviceCount() > 0


In [4]:
trimmed = trimm(batch, 
                netparams['seq_len'], 
                netparams['history'], 
                netparams['predict_seq'], 
                int(netparams['history']/2), 
                trim_mode='relative')

for i in range(trimmed['steps']): 
    imgs = trimmed['imgs'][i].to(device)
    z = trimmed['noise'][i].to(device)
    past_routes = trimmed['past_traj'][i].to(device)
    real_routes = trimmed['future_traj'][i].to(device)
    past_vel = trimmed['past_vel'][i].to(device)
    real_vel = trimmed['future_vel'][i].to(device)
    past_obj = trimmed['past_target'][i].to(device)
    real_obj = trimmed['future_target'][i].to(device)
    past_routes = torch.cat((past_routes, past_vel), axis=2)
    past_routes = torch.cat((past_routes, past_obj), axis=2)
    real_routes = torch.cat((real_routes, real_vel), axis=2)
    real_routes = torch.cat((real_routes, real_obj), axis=2)
    print(f'Shape of past_routes {past_routes.shape}')
    print(f'Shape of real_routes {real_routes.shape}')
    break

Shape of past_routes torch.Size([8, 8, 8])
Shape of real_routes torch.Size([8, 12, 8])


In [5]:
fake_routes = gen(imgs, z, past_routes)
real_output = dis(imgs, real_routes, past_routes)
fake_output = dis(imgs, fake_routes, past_routes)

print(f'Shape of fake_routes {fake_routes.shape}')
print(f'Shape of real_output {real_output.shape}')
print(f'Shape of fake_output {fake_output.shape}')

Shape of fake_routes torch.Size([8, 12, 8])
Shape of real_output torch.Size([8, 1])
Shape of fake_output torch.Size([8, 1])


In [11]:
print(ADE(real_routes, fake_routes))
print(FDE(real_routes, fake_routes))

0.5812401
1.0492758


In [10]:
dis_loss = discriminator_loss(real_output, fake_output, netparams)
gen_loss = generator_loss(fake_output, fake_routes, real_routes, netparams)