# Import library

In [2]:
import numpy as np

import os
import torch
torch.manual_seed(0)


from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet18,resnet50,resnet101
from tqdm import tqdm
from typing import Dict
from torch import functional as F

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace
from l5kit.geometry import transform_points
from l5kit.visualization import PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR, draw_trajectory
from prettytable import PrettyTable
from pathlib import Path

# Set environment

In [2]:
DIR_INPUT = "../../../lyft/data/lyft-motion-prediction-autonomous-vehicles"
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = DIR_INPUT
dm = LocalDataManager(None)
VALIDATION = False
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Set Configuration

In [3]:
cfg = {
    'format_version': 4,
    'model_params': {
        'model_architecture': 'resnet18',
        'history_num_frames': 10,
        'history_step_size': 1,
        'history_delta_time': 0.1,
        
        'future_num_frames': 50,
        'future_step_size': 1,
        'future_delta_time': 0.1
    },
    
    'raster_params': {
        'raster_size': [1, 1],
        'pixel_size': [0.5, 0.5],
        'ego_center': [0.25, 0.5],
        'map_type': 'py_semantic',
        'satellite_map_key': 'aerial_map/aerial_map.png',
        'semantic_map_key': 'semantic_map/semantic_map.pb',
        'dataset_meta_key': 'meta.json',
        'filter_agents_threshold': 0.5,
        'disable_traffic_light_faces': True
    },
    
    'train_data_loader': {
        'key': 'scenes/train.zarr',
        'batch_size': 32,
        'shuffle': True,
        'num_workers': 4
    },
    
    'val_data_loader': {
        'key': 'scenes/validate.zarr',
        'batch_size': 32,
        'shuffle': False,
        'num_workers': 4
    },
    
    'test_data_loader': {
        'key': 'scenes/test.zarr',
        'batch_size': 32,
        'shuffle': False,
        'num_workers': 4
    },
    
    'train_params': {
        'checkpoint_every_n_steps': 5000,
        'max_num_steps': 10000,
        'eval_every_n_steps': 500
        
    }
}

# Read validation set

In [4]:
import gc

val_cfg = cfg["val_data_loader"]

# Rasterizer
rasterizer = build_rasterizer(cfg, dm)

# Train dataset/dataloader
val_zarr = ChunkedDataset(dm.require(val_cfg["key"])).open()
val_dataset = AgentDataset(cfg, val_zarr, rasterizer)
val_dataloader = DataLoader(val_dataset,
                              shuffle=val_cfg["shuffle"],
                              batch_size=val_cfg["batch_size"],
                              num_workers=val_cfg["num_workers"])


    

# Read train set

In [5]:
import gc

train_cfg = cfg["train_data_loader"]

# Rasterizer
rasterizer = build_rasterizer(cfg, dm)

# Train dataset/dataloader
train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open()
train_dataset = AgentDataset(cfg, train_zarr, rasterizer)
train_dataloader = DataLoader(train_dataset,
                              shuffle=train_cfg["shuffle"],
                              batch_size=train_cfg["batch_size"],
                              num_workers=train_cfg["num_workers"])

gc.collect()



13

# Read test set

In [4]:
import gc

test_cfg = cfg["test_data_loader"]

# Rasterizer
rasterizer = build_rasterizer(cfg, dm)

# Train dataset/dataloader
test_zarr = ChunkedDataset(dm.require(test_cfg["key"])).open()
test_dataset = AgentDataset(cfg, test_zarr, rasterizer)
test_dataloader = DataLoader(test_dataset,
                              shuffle=test_cfg["shuffle"],
                              batch_size=test_cfg["batch_size"],
                              num_workers=test_cfg["num_workers"])

gc.collect()




13

# Define Seq-to-Seq model

In [5]:
class EncoderLSTM_LyftModel(nn.Module):
    
    def __init__(self, cfg):
        super(EncoderLSTM_LyftModel, self).__init__()
        
        self.input_sz  = 3
        self.hidden_sz = 128
        self.num_layer = 1
        self.sequence_length = 11        
        
        self.Encoder_lstm = nn.LSTM(self.input_sz,self.hidden_sz,self.num_layer,batch_first=True)
       
    def forward(self,inputs):
        '''
        Implemented Encoder with LSTM to extract 
        temporal information on history trajectories
        '''        
        output,hidden_state = self.Encoder_lstm(inputs)
        
        return output,hidden_state
    
class DecoderLSTM_LyftModel(nn.Module):
    def __init__(self, cfg):
        super(DecoderLSTM_LyftModel, self).__init__()
        
        self.input_sz  = 128 
        self.hidden_sz = 128
        self.num_layer = 1
        self.sequence_len_de = 1
        self.batch_sz = 32
        
        num_targets = 2 * cfg["model_params"]["future_num_frames"]
        
        self.encoderLSTM = EncoderLSTM_LyftModel(cfg)

        
        self.Decoder_lstm = nn.LSTM( self.input_sz,self.hidden_sz,self.num_layer,batch_first=True)


    def forward(self,inputs):
        '''
        With last hidden state, cell state trained on encoder, 
        predict 50 future trajectories with LSTM.
        Input to the decoder is zero vector of shape (batch_sz, 50, 128) 
        since most of the current position is at 0
        '''
        
        # Last hidden state and cell state on encoder                
        _,hidden_state = self.encoderLSTM(inputs)
        
        # Dummy future trajectories                
        result = torch.zeros((self.batch_sz, 50, 128))
        
        # Input to decoder                
        inout_to_dec = torch.zeros(inputs.shape[0],self.sequence_len_de,self.input_sz).to(device)

        # Predict 50 future trajectories consecutively                
        for i in range(50):
            inout_to_dec,hidden_state = self.Decoder_lstm(inout_to_dec,(hidden_state[0],hidden_state[1]) )          
            result[:,i,:] = inout_to_dec[:,0,:]
        return result

# Generator Implementation

In [6]:
class Generator(nn.Module):
    def __init__(self, starting_shape):
        super(Generator, self).__init__()
        self.input_sz  = 128 # hidden state size from seq-to-seq
        self.hidden_sz = 256 
        self.num_layer = 1
        self.sequence_len_de = 1
        self.interlayer = 512
        self.DecoderLSTM_LyftModel = DecoderLSTM_LyftModel(cfg)
        
        num_targets = 2
        
        self.fcn_en_state_dec_state = nn.Sequential(nn.Linear(in_features=self.hidden_sz, out_features=self.hidden_sz),
                                                    nn.ReLU(inplace=True),
                                                    nn.Linear(in_features = self.hidden_sz, out_features = self.interlayer),
                                                    nn.ReLU(inplace=True),
                                                    nn.Linear(in_features = self.interlayer, out_features = self.hidden_sz),
                                                    nn.ReLU(inplace=True),
                                                    nn.Linear(in_features=self.hidden_sz, out_features=num_targets))


    def forward(self,inputs):
        '''
        With latent temporal vector from decoder, 
        add noise to the generator and generate
        50 future trajectories of each agents
        '''
        
        # Temporal vector from decoder                
        decoderLSTM = self.DecoderLSTM_LyftModel(inputs)
        
        # Noise from Gaussian Distribution                
        noise_gen = torch.randn(decoderLSTM.shape, device=device) 
        
        # Add noise to the latent vector                
        combine = torch.cat([decoderLSTM, noise_gen], dim = 2)
        
        # Generate Trajectories                
        fc_out = self.fcn_en_state_dec_state(combine.squeeze(dim=0))
        return fc_out

    

# Discriminator Implementation

In [None]:
# Discriminator with Relu
class Discriminator_0(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.input = 100
        self.lstm_hidden = 128
        self.hidden_sz_1 = 6400
        self.hidden_sz_2 = 3200
        self.hidden_sz_3 = 1600
        self.hidden_sz_4 = 800

        
        self.input_sz = 2
        self.num_layer = 1
        self.batch_sz = 32
        self.fc= nn.Sequential(
            nn.Linear(in_features=self.hidden_sz_1, out_features=self.hidden_sz_2),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=self.hidden_sz_2, out_features=self.hidden_sz_3),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=self.hidden_sz_3, out_features=self.hidden_sz_4),
            nn.ReLU(inplace=True),
            nn.Linear(in_features=self.hidden_sz_4, out_features=1)
        )
        self.lstm = nn.LSTM(self.input_sz,self.lstm_hidden,self.num_layer,batch_first=True)


        
    def forward(self, input):
        output, _ = self.lstm(input)
        output = output.reshape((self.batch_sz, -1))
        fc = self.fc(output)
        return fc
    
    

In [7]:
# Discriminator with WeakyRelu, dropout, and batch norm
class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator, self).__init__()
        self.lstm_hidden = 128
        self.hidden_sz_1 = 6400
        self.hidden_sz_2 = 3200
        self.hidden_sz_3 = 1600
        self.hidden_sz_4 = 800

        
        self.input_sz = 2
        self.num_layer = 1
        self.batch_sz = 32
        
        self.fc= nn.Sequential(
            nn.Linear(in_features=self.hidden_sz_1, out_features=self.hidden_sz_2),
            nn.BatchNorm1d(self.hidden_sz_2),            
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.5),
            nn.Linear(in_features=self.hidden_sz_2, out_features=self.hidden_sz_3),            
            nn.BatchNorm1d(self.hidden_sz_3),                        
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.5),            
            nn.Linear(in_features=self.hidden_sz_3, out_features=self.hidden_sz_4),
            nn.BatchNorm1d(self.hidden_sz_4),                                    
            nn.LeakyReLU(inplace=True),
            nn.Dropout(0.5),            
            nn.Linear(in_features=self.hidden_sz_4, out_features=1)
        )
        self.lstm = nn.LSTM(self.input_sz,self.lstm_hidden,self.num_layer,batch_first=True)


        
    def forward(self, input):
        '''
        Input to the discriminator is either real trajectory or
        generated fake trajectory. Discriminator would differentiate
        if trajectories are real/fake with LSTM and FC
        '''                        
        output, _ = self.lstm(input)
        output = output.reshape((self.batch_sz, -1))
        fc = self.fc(output)
        return fc

In [8]:
import matplotlib.pyplot as plt

In [9]:
save_path = './yaw_v2_bdl_1129/'

# GAN Implementation

In [3]:
# Use this to put tensors on GPU/CPU automatically when defining tensors
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') 

class GAN(nn.Module):

    def __init__(self):
        super(GAN, self).__init__()
        self.num_epoch = 1
        self.batch_size = 32
        self.log_step = 1000
        self.visualize_step = 1
        self.code_size = 64
        self.learning_rate_gen =1e-3
        self.learning_rate_dis =1e-3
        self.val = 100

        # Define L2 Loss for the generator                
        self._l2_loss = nn.MSELoss()

        
        # Define the generator and both discriminator 
        self._discriminator = Discriminator().to(device)
        self._generator = Generator(self.code_size).to(device)

        # Loss function for the discriminator                
        self._classification_loss = nn.BCEWithLogitsLoss()

        # Apply weight initialization here
        self._discriminator.apply(self._weight_initialization)
        self._generator.apply(self._weight_initialization)

        # Hyper parameter for the Adam Optimizer
        betas = (0.5, 0.999)
        
        # Optimizer for Generator and Discriminator                
        self._generator_optimizer = torch.optim.Adam(self._generator.parameters(), lr = self.learning_rate_gen, betas = betas)
        self._discriminator_optimizer = torch.optim.Adam(self._discriminator.parameters(), lr = self.learning_rate_dis, betas = betas)

    # custom weights initialization for both networks
    # apply the custom weight initialization
        
    def _weight_initialization(self, m):
        # custom weights initialization for both networks
        classname = m.__class__.__name__
        if classname.find('Conv') != -1:
            nn.init.normal_(m.weight.data, 0.0, 0.02)
        elif classname.find('BatchNorm') != -1:
            nn.init.normal_(m.weight.data, 1.0, 0.02)
            nn.init.constant_(m.bias.data, 0)

    # Discriminator Loss                        
    def _loss(self, logits, labels):
        return self._classification_loss(logits, labels)

    # Generator Loss        
    def _reconstruction_loss(self, generated, target):
        return self._l2_loss(generated, target)

    # Training function
    def forward(self, train_dataloader):
#         num_train = train_samples.shape[0]
        for epoch in range(self.num_epoch):
            # Save train loss for discriminator and generator                
            dis_losses = []
            gen_losses = []
            
            # Save validation loss for discriminator and generator                                
            val_dis_losses = []
            val_gen_losses = []
            step = 0

            # smooth the loss curve so that it does not fluctuate too much
            smooth_factor = 0.95
            plot_dis_s = 0
            plot_gen_s = 0
            plot_ws = 0


            max_steps = int(len(iter(train_dataloader)))

            # Fake label for discriminator            
            fake_label = torch.zeros([self.batch_size, 1], device=device)
            
            # Real label for discriminator                                    
            real_label = torch.ones([self.batch_size, 1], device=device)
            
            # Train generator and discriminator                            
            self._generator.train()
            self._discriminator.train()
            
            print('Start training ...')

            progress_bar = tqdm(range(cfg["train_params"]["max_num_steps"]))

            tr_it = iter(train_dataloader)

            
            for itr in progress_bar:
                try:
                    data = next(tr_it)
                except StopIteration:
                    tr_it = iter(train_dataloader)
                    data = next(tr_it)
                step += 1
                
                ################################################################################
                # Train the discriminator on all-real images first                             #
                ################################################################################
                # history_positions should be flipped since it's reversed
                history_positions = torch.flip(data['history_positions'], [1])
                history_yaws = torch.flip(data['history_yaws'], [1])
                history_pos_yaws = torch.cat([history_positions, history_yaws], dim = 2).to(device)

                target_availabilities = data["target_availabilities"].unsqueeze(-1).to(device)
                targets_position = data["target_positions"].to(device) 

                # Eliminate discriminator gradients                
                self._discriminator_optimizer.zero_grad()

                # Calculate BCE loss on discriminator                
                real_dis_out = self._discriminator(targets_position)
                real_dis_loss = self._loss(real_dis_out, real_label)

                # Calculate real discriminator loss gradients
                real_dis_loss.backward()

                ################################################################################
                # Train the discriminator with an all fake batch                               #
                ################################################################################
                # Detach the fake samples from the gradient calculation 
                # when feeding to the discriminator, we don't want the discriminator to 
                # receive gradient info from the Generator

                fake_samples = self._generator(history_pos_yaws).detach()
                fake_dis_out = self._discriminator(fake_samples)

                numGenerated = fake_dis_out.shape[0]

                fake_dis_label = fake_label[:numGenerated] 
                fake_dis_loss = self._loss(fake_dis_out, fake_dis_label).requires_grad_()


                # Calculate fake discriminator loss gradients
                fake_dis_loss.backward()
                
                # Update the discriminator weights                                                
                self._discriminator_optimizer.step()

                ################################################################################
                # Train the generator                                                          #
                ################################################################################                
                # Get new samples from updated discriminator. No need to detach
                # from gradient calculation here, we want the Generator to receive
                # gradient info from the discriminator so it can learn better.

                # Eliminate all generator gradients first

                self._generator_optimizer.zero_grad()

                # Generate future trajectories with history positions
                fake_samples_gen = self._generator(history_pos_yaws)

                gen_loss = self._reconstruction_loss(fake_samples_gen * target_availabilities, targets_position * target_availabilities)


                # Calculate the generator loss gradients
                gen_loss.backward()
                
                # Update the generator weights                                
                self._generator_optimizer.step()

                # Add up discriminator loss                                
                dis_loss = real_dis_loss + fake_dis_loss

                # Apply smoothing factors for the plot                                
                plot_dis_s = plot_dis_s * smooth_factor + dis_loss * (1 - smooth_factor)
                plot_gen_s = plot_gen_s * smooth_factor + gen_loss * (1 - smooth_factor)
                plot_ws = plot_ws * smooth_factor + (1 - smooth_factor)

                dis_losses.append(plot_dis_s / plot_ws)
                gen_losses.append(plot_gen_s / plot_ws)

                # Validation
                
                with torch.no_grad():
                    try:
                        val_data = next(vl_it)
                    except:
                        vl_it = iter(val_dataloader)
                        val_data = next(vl_it)
                    
                    gan.eval()
                    self._generator.eval()
                    self._discriminator.eval()
                    
                    # Fetch Validation data                                        
                    val_history_positions = torch.flip(val_data['history_positions'], [1]).to(device)
                    val_history_yaws = torch.flip(val_data['history_yaws'], [1]).to(device)
                    
                    val_target_availabilities = val_data["target_availabilities"].unsqueeze(-1).to(device)
                    val_targets_position = val_data["target_positions"].to(device)
                    val_history_pos_yaws = torch.cat([val_history_positions, val_history_yaws], dim = 2).to(device)

                    # Calculate Validation loss of discriminator
                    val_real_dis_out = self._discriminator(val_targets_position)
                    val_real_dis_loss = self._loss(val_real_dis_out, real_label)

                    # Calculate Validation loss of generator            
                    val_fake_samples_gen = self._generator(val_history_pos_yaws)
                    val_gen_loss = self._reconstruction_loss(val_fake_samples_gen * val_target_availabilities, val_targets_position * val_target_availabilities)

                    
                    # Calculate Validation loss of discriminator                                        
                    val_fake_dis_out = self._discriminator(val_fake_samples_gen)
                    val_fake_dis_label = fake_label[:numGenerated] 
                    val_fake_dis_loss = self._loss(val_fake_dis_out, val_fake_dis_label)

                    val_dis_loss = val_real_dis_loss + val_fake_dis_loss

                    # Plot Validation loss
                    val_plot_dis_s = plot_dis_s * smooth_factor + val_dis_loss * (1 - smooth_factor)
                    val_plot_gen_s = plot_gen_s * smooth_factor + val_gen_loss * (1 - smooth_factor)
                    val_plot_ws = plot_ws * smooth_factor + (1 - smooth_factor)
                    val_dis_losses.append(val_plot_dis_s / val_plot_ws)
                    val_gen_losses.append(val_plot_gen_s / val_plot_ws)

                if step % self.log_step == 0:
                    print('Iteration {0}/{1}: dis loss = {2:.4f}, gen loss = {3:.4f}'.format(step, max_steps, dis_loss, gen_loss))

                # Save our loss graph                    
                if step % self.val == 0:
                    fig = plt.figure(figsize = (8, 8))   
                    plt.plot(dis_losses)
                    plt.plot(val_dis_losses)
                    
                    plt.title('discriminator loss')
                    plt.xlabel('iterations')
                    plt.ylabel('loss')
                    plt.show()

                    plt.plot(gen_losses)
                    plt.plot(val_gen_losses)                    
                    plt.title('generator loss')
                    plt.xlabel('iterations')
                    plt.ylabel('loss')
                    plt.show()
                    
            # Save model weights                    
            checkPoint(self)
            torch.save(gan.state_dict(), save_path + "gan_yaw_v2_epoch" + str(epoch)+ ".pt")

            # Save loss graph            
            fig = plt.figure(figsize = (8, 8))   

            plt.plot(dis_losses)
            plt.plot(val_dis_losses)

            plt.title('discriminator loss')
            plt.xlabel('iterations')
            plt.ylabel('loss')
            plt.savefig(save_path + "discriminator_yaw_v2_epoch_" + str(epoch) + '.png')

            plt.plot(gen_losses)
            plt.plot(val_gen_losses)                    
            plt.title('generator loss')
            plt.xlabel('iterations')
            plt.ylabel('loss')
            plt.savefig(save_path + "generator_yaw_v2_epoch_" + str(epoch) + '.png')
        print('... Done!')

# Save model parameters

In [14]:
def checkPoint(model):
    torch.save(model.state_dict(), save_path + 'gan.pt' )
    torch.save(model._generator.state_dict(), save_path + 'gen.pt')
    torch.save(model._discriminator.state_dict(), save_path + 'dis.pt')
    torch.save(model._generator_optimizer.state_dict(), save_path + 'gen_optim.pt')
    torch.save(model._discriminator_optimizer.state_dict(), save_path + 'dis_optim.pt')  

In [16]:
import pdb
%pdb

Automatic pdb calling has been turned ON


# Train GAN Model
#### train with discriminator lr=1e-3

In [None]:
gan = GAN()
gan(train_dataloader)
checkPoint(gan)

# evaluation on test dataset

In [14]:
def calculate_MSE_loss(predicted, targets, availabilities, num_iter):
    loss = nn.MSELoss(reduction = 'mean')
    
    predicted_ = torch.stack(predicted)
    targets_ = torch.stack(targets)
    availabilities_ = torch.stack(availabilities)
    availabilities_ = availabilities_.view(-1, 32, 50, 1)
    loss_temp = loss(predicted_ * availabilities_, targets_ * availabilities_)
    iter_loss = (num_iter, loss_temp)
    return iter_loss

# Load model saved

In [15]:
file_name = save_path + "gan_yaw_v2_0.pt"

ganModel = GAN()
ganModel.load_state_dict(torch.load(file_name))
ganModel


GAN(
  (_l2_loss): MSELoss()
  (_discriminator): Discriminator(
    (fc): Sequential(
      (0): Linear(in_features=6400, out_features=3200, bias=True)
      (1): BatchNorm1d(3200, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (2): LeakyReLU(negative_slope=0.01, inplace=True)
      (3): Dropout(p=0.5, inplace=False)
      (4): Linear(in_features=3200, out_features=1600, bias=True)
      (5): BatchNorm1d(1600, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (6): LeakyReLU(negative_slope=0.01, inplace=True)
      (7): Dropout(p=0.5, inplace=False)
      (8): Linear(in_features=1600, out_features=800, bias=True)
      (9): BatchNorm1d(800, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (10): LeakyReLU(negative_slope=0.01, inplace=True)
      (11): Dropout(p=0.5, inplace=False)
      (12): Linear(in_features=800, out_features=1, bias=True)
    )
    (lstm): LSTM(2, 128, batch_first=True)
  )
  (_generator): Generator(
   

# Predict on test function

In [23]:
def test(ganModel):
    future_coords_offsets_pd = []
    real_target = []
    target_availabilities = []
    timestamps = []
    agent_ids = []
    # test_losses = []
    num_iter = 0

    with torch.no_grad():
        dataiter = tqdm(test_dataloader)

        for data in dataiter:
            history_positions = data['history_positions'].to(device)
#             history_positions = torch.flip(data['history_positions'], [1]).to(device)
            history_yaws = data['history_yaws']
#             history_yaws = torch.flip(data['history_yaws'], [1])
    #         timestamps.append(data["timestamp"].numpy().copy())
    #         agent_ids.append(data["track_id"].numpy().copy())
            real_target.append(data["target_positions"])
            target_availabilities.append(data["target_availabilities"])

            # Set the generator to evaluation mode, to make batchnorm stats stay fixed
            ganModel._generator.eval()
            future_coords_offsets_pd.append(ganModel._generator(torch.cat([history_positions, history_yaws], dim = 2)))
            num_iter += 1
            
            if num_iter % 1000 == 0 :
                break
    #             iter_loss = calculate_MSE_loss(future_coords_offsets_pd, real_target, target_availabilities)
    #             test_losses.append(iter_loss)
    #             print("iteration: ", iter_loss[0], ", loss: ", iter_loss[1])

    #             torch.save(future_coords_offsets_pd, save_path + "ganv2_future.pt")
    #             torch.save(real_target, save_path + "ganv2_target.pt")
    #             torch.save(target_availabilities, save_path + "ganv2_t_avail.pt")
    #             torch.save(timestamps, "./v2_test_result_1128/ganv2_timestamps.pt")
    #             torch.save(agent_ids, "./v2_test_result_1128/ganv2_ids.pt")

    torch.save(future_coords_offsets_pd, save_path + "ganv2_future_all.pt")
    torch.save(real_target, save_path + "ganv2_target_all.pt")
    torch.save(target_availabilities, save_path + "ganv2_t_avail_all.pt")
    torch.save(timestamps, save_path + "ganv2_timestamps_all.pt")
    torch.save(agent_ids, save_path + "ganv2_ids_all.pt")

    avg_loss = calculate_MSE_loss(future_coords_offsets_pd, real_target, target_availabilities, num_iter)
    # test_losses.append(avg_loss)
    print("MSE on the test set: ", avg_loss)

# test on the testset with discriminator, leakyRelu, dropout, batch norm,  lr=1e-3

In [24]:
test(ganModel)

  1%|          | 999/185990 [02:45<8:29:32,  6.05it/s] 


MSE on the test set:  (1000, tensor(2.1781))


# Diversity check

In [16]:
import pandas as pd

def make_N_samples(ganModel, n=100):
    future_coords_offsets_pd = []
    real_target = []
    target_availabilities = []
    timestamps = []
    agent_ids = []
    test_losses = []
    num_iter = 1
    loss = nn.MSELoss(reduction = 'mean')
    test_df = pd.DataFrame()
    batch_size = 32

    with torch.no_grad():
        dataiter = tqdm(test_dataloader)

        # GET DATA
        for i, data in enumerate(dataiter):
            if i == 0:
                history_positions = torch.flip(data['history_positions'], [1]).to(device)
                history_yaws = torch.flip(data['history_yaws'], [1])
                real_target.append(data["target_positions"])
                target_availabilities.append(data["target_availabilities"])
            else:
                break

        # RUN 100 ITER on 1 BATCH
        for j in range(n):
            ganModel._generator.eval()
            real_target  = ganModel._generator(torch.cat([history_positions, history_yaws], dim = 2))

            batch_size = real_target.shape[0]
            time_length = real_target.shape[1]


            for q in range(batch_size):
                for t in range(time_length):
                    # save data as dictionary for the dataframe
                    dic = {  
                        'agent_id': q,
                        't' : t,
                        'target_x' : data["target_positions"][q,t,0].numpy(),
                        'target_y' : data["target_positions"][q,t,1].numpy(),
                        'predict_x' : real_target[q,t,0].numpy(),
                        'predict_y' : real_target[q,t,1].numpy(),
                        'avail' : data["target_availabilities"][q,t].numpy()
                    }
                    if dic['avail'] == 1:
                        dic['loss'] = (dic['target_x'] - dic['predict_x']) ** 2 + (dic['target_y'] - dic['predict_y']) ** 2
                    else:
                        dic['loss'] = 0
                    test_df = test_df.append(dic, ignore_index = True)         
            print("j th iter: ", j)
        return(test_df)  

In [17]:
# save to csv file
test_df = make_N_samples(ganModel)
test_df.to_csv(save_path + 'test_df_yaw_bdl.csv', index = False)

  0%|          | 1/185990 [00:00<50:41:55,  1.02it/s]


j th iter:  0
j th iter:  1
j th iter:  2
j th iter:  3
j th iter:  4
j th iter:  5
j th iter:  6
j th iter:  7
j th iter:  8
j th iter:  9
j th iter:  10
j th iter:  11
j th iter:  12
j th iter:  13
j th iter:  14
j th iter:  15
j th iter:  16
j th iter:  17
j th iter:  18
j th iter:  19
j th iter:  20
j th iter:  21
j th iter:  22
j th iter:  23
j th iter:  24
j th iter:  25
j th iter:  26
j th iter:  27
j th iter:  28
j th iter:  29
j th iter:  30
j th iter:  31
j th iter:  32
j th iter:  33
j th iter:  34
j th iter:  35
j th iter:  36
j th iter:  37
j th iter:  38
j th iter:  39
j th iter:  40
j th iter:  41
j th iter:  42
j th iter:  43
j th iter:  44
j th iter:  45
j th iter:  46
j th iter:  47
j th iter:  48
j th iter:  49
j th iter:  50
j th iter:  51
j th iter:  52
j th iter:  53
j th iter:  54
j th iter:  55
j th iter:  56
j th iter:  57
j th iter:  58
j th iter:  59
j th iter:  60
j th iter:  61
j th iter:  62
j th iter:  63
j th iter:  64
j th iter:  65
j th iter:  66
j th 

In [None]:
import pandas as pd
ds = pd.read_csv(save_path + 'test_df_yaw_bdl.csv')
ds

In [None]:
def load_ckp(model, optimizer, checkpoint_fpath = './yaw_v2_1127/'):
    gen_optim = torch.load(checkpoint_fpath+'')
    dis_optim = torch.load(checkpoint_fpath+'')
    gan_model = torch.load(checkpoint_fpath+'')
    generator = torch.load(checkpoint_fpath+'')
    dis_model = torch.load(checkpoint_fpath+'')
    
    model.load_state_dict(checkpoint['state_dict'])
    model._generator_optimizer.load_state_dict(checkpoint['optimizer'])
    return model, optimizer, checkpoint['epoch']