In [None]:
import torch
import torch.nn as nn
import torch.optim as optim6
import matplotlib.pyplot as plt
from torch.autograd import Variable # storing data while learning
from config import CFG ,Args
from baselineUtils import load_datasets,distance_metrics
from utils import ScheduledOptim,visualize_preds
from train import train_attn_mdn 
from test import test_mdn 
from torch.optim.lr_scheduler import LambdaLR 
from model import Attention_GMM #,Attention_GMM_Encoder,Transformer_MDN
# from torch.utils.data.distributed import  DistributedSampler
# from torch.nn.parallel import DistributedDataParallel as DDP
# from torch.distributed import init_process_group,destroy_process_group
import os

# Device used and Data loader

In [None]:
device = CFG.device
batch_size = CFG.batch_size
print(f"Using {device} device")
args = Args 

In [None]:
train_dataset, val_dataset,test_dataset,mean,std = load_datasets(args)

In [None]:
train_dl = torch.utils.data.DataLoader(train_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)
val_dl = torch.utils.data.DataLoader(val_dataset, batch_size=CFG.batch_size, shuffle=True, num_workers=CFG.num_workers)
test_dl = torch.utils.data.DataLoader(test_dataset, batch_size=CFG.batch_size, shuffle=False, num_workers=CFG.num_workers)

In [None]:
# Train
in_features = CFG.in_features
out_features = CFG.out_features
num_heads = CFG.num_heads
num_encoder_layers = CFG.num_encoder_layers
num_decoder_layers =  CFG.num_decoder_layers
embedding_size = CFG.embd_size
max_length = 8
n_hidden = CFG.n_hidden
gaussians = CFG.gaussians
forecast_window = 12
drp = CFG.drop_out
add_features = CFG.add_features

In [None]:
# If you want to train transformer only copy it from commented.py
# Train attention MDN
attn_mdn = Attention_GMM(device,in_features,out_features,num_heads,num_encoder_layers,num_decoder_layers,embedding_size,n_gaussians=gaussians,n_hidden = n_hidden, dropout=drp).to(device)
#attn_mdn = Attention_GMM_Encoder(device,in_features,out_features,num_heads,num_encoder_layers,num_decoder_layers,embedding_size,n_gaussians=gaussians,n_hidden = n_hidden, dropout=drp).to(device)
# if torch.cuda.device_count() > 1:
#     print("Using", torch.cuda.device_count(), "GPUs")
#     attn_mdn = DDP(attn_mdn,device_ids=[0,1])
for p in attn_mdn.parameters():
    if p.dim() > 1:
        nn.init.xavier_uniform_(p)
# Define the optimizer
optimizer = ScheduledOptim(
        torch.optim.Adam(attn_mdn.parameters(), betas=(0.9, 0.98), eps=1e-09),
        CFG.lr_mul, CFG.d_model, CFG.n_warmup_steps) #len(train_dl)True
#         print(name, child)

In [None]:
if(Args.mode=='test'):
    PATH = Args.model_path
    attn_mdn = torch.load(PATH).to(device)
else:
    loss_train, loss_eval,val_mad,val_fad = train_attn_mdn(train_dl,val_dl,test_dl,attn_mdn,optimizer,add_features,mixtures =gaussians, epochs=CFG.epochs,mean=mean,std=std)

In [None]:
if(Args.mode=='train'):
    fig = plt.figure(1)	#identifies the figure 
    plt.title(" Training Loss Per Epoch", fontsize='16')	#title
    plt.plot(loss_train,color='Blue',label='Training Loss')	#plot the points
    plt.plot(loss_eval,color='Green',label='Evaluation Loss')	#plot the points
    plt.legend(loc="upper right")
    plt.show()

In [None]:
if(Args.mode=='train'):
    fig = plt.figure(2)	#identifies the figure 
    plt.title("Evaluation Error", fontsize='16')	#title
    # plt.plot(test_mad,color='Green', label="ADE Test")	#plot the points
    # plt.plot(test_fad,color='Red', label="FDE Test")	#plot the points
    plt.plot(val_mad,color='Green', label="ADE Validation")	#plot the points
    plt.plot(val_fad,color='Red', label="FDE Validation")	#plot the points
    plt.legend(loc="upper right")
    plt.show()

In [None]:
# Test the model
batch_preds,batch_gts,avg_mad,avg_fad,candidate_trajs,candidate_weights,best_candiates,src_trajs = test_mdn(test_dl, attn_mdn,device,add_features = add_features,mixtures=gaussians,enc_seq = 8,dec_seq=12, mode='feed',loss_mode ='mdn',mean=mean,std=std)


shape of candidate trajs (num_batchs, bacth_size, x, 2, 12, 2)  ---> where x is the number of candidate trajectories
shape of ground truth (num_batchs, bacth_size, 12, 2) ---> 

In [None]:
# Visulaize output via visualize_preds
if Args.visualize:
    visualize_preds(src_trajs,batch_gts,candidate_trajs)