In [1]:
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy as np
import pickle
from glob import glob
import matplotlib.pyplot as plt
import random
import pandas as pd
from tqdm import tqdm
from IPython.display import display

"""Change to the data folder"""
train_path = "./data/new_train/new_train/"
val_path="./data/new_val_in/new_val_in"

device='cuda'
# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
train_dataset  = ArgoverseDataset(data_path=train_path)
val_dataset=ArgoverseDataset(data_path=val_path)

### Create a loader to enable batch processing

In [3]:
batch_sz =256
n_workers=4

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [np.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [np.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
#     inp = torch.LongTensor(inp)
#     out = torch.LongTensor(out)
    inp = torch.FloatTensor(inp)
    out = torch.FloatTensor(out)
    return [inp, out]

train_loader=DataLoader(
    train_dataset,
    batch_size=batch_sz,
    shuffle=True,
    collate_fn=my_collate,
    num_workers=n_workers
)

val_loader = DataLoader(
    val_dataset,
    batch_size=batch_sz, 
    shuffle = False, 
    collate_fn=my_collate, 
    num_workers=n_workers
)

### LSTM Model

In [4]:
class LSTM_model(nn.Module):
    def __init__(self, device):
        super(LSTM_model, self).__init__()
        
        # the LSTM takes word embeddings as inputs, and outputs hidden states
        # with dimensionality hidden_dim
        self.hidden_dim=2048
        self.num_layers=3
        self.device=device
        self.lstm=nn.LSTM(
            input_size=240,
            hidden_size=self.hidden_dim,
            num_layers=self.num_layers,
            batch_first=True            
        )
        
        # The linear layer that maps from hidden state space to tag space
        self.linear=nn.Conv1d(
            in_channels=self.hidden_dim,
            out_channels=240,
            kernel_size=1
        )
        
    def forward(self, x):
        #batch_size x timesteps x 240
        x,_=self.lstm(x)
        # finally hidden layer batch_size x timesteps x hidden_dim
        x=x.transpose(1,2)
        # batch_size x hidden_dim x timesteps
        x=self.linear(x)
        x=x.transpose(1,2)
        
        return x
        
    def forward_test(self, x, num_steps=30):
        res=[]
        h=torch.zeros((self.num_layers, len(x), self.hidden_dim)).to(self.device)
        c=torch.zeros((self.num_layers, len(x), self.hidden_dim)).to(self.device)
        for step in range(num_steps):
            x, (h,c)=self.lstm(x, (h,c))
            x=x[:,-1:]
            x=x.transpose(1,2)
            x=self.linear(x)
            x=x.transpose(1,2)
            res.append(x)
        res=torch.cat(res,1)
        return res

### train model

In [5]:
help(tqdm)

Help on class tqdm in module tqdm.std:

class tqdm(tqdm.utils.Comparable)
 |  tqdm(*_, **__)
 |  
 |  Decorate an iterable object, returning an iterator which acts exactly
 |  like the original iterable, but prints a dynamically updating
 |  progressbar every time a value is requested.
 |  
 |  Method resolution order:
 |      tqdm
 |      tqdm.utils.Comparable
 |      builtins.object
 |  
 |  Methods defined here:
 |  
 |  __bool__(self)
 |  
 |  __del__(self)
 |  
 |  __enter__(self)
 |  
 |  __exit__(self, exc_type, exc_value, traceback)
 |  
 |  __hash__(self)
 |      Return hash(self).
 |  
 |  __init__(self, iterable=None, desc=None, total=None, leave=True, file=None, ncols=None, mininterval=0.1, maxinterval=10.0, miniters=None, ascii=None, disable=False, unit='it', unit_scale=False, dynamic_ncols=False, smoothing=0.3, bar_format=None, initial=0, position=None, postfix=None, unit_divisor=1000, write_bytes=None, lock_args=None, nrows=None, colour=None, delay=0, gui=False, **kwargs

In [6]:
import warnings
warnings.filterwarnings(action='ignore')

def train_lstm_model(lstm_model, data_loader, n_epochs, filename, ema_weight, device='cuda:0', verbose=False):
    '''
    Train LSTM model
    ----------
    Parameters
    ----------
        lstm_model - torch nerual network model
            model to train
        data_loader - torch DataLoader class 
            training data for model
        n_epochs - int
            number of epochs to train
        filename - string
            filepath to save training data to
        ema_weight - float
            float between (0.0,1.0) for the exponential moving average weight
        device - string, default 'cuda'
            choose to run on gpu ('cuda') or cpu ('CPU')
        verbose - boolean, default False
            If true print training progress every 10 training iterations
    -------
    Returns
    -------
         Trained LSTM model
    '''
    
    model=lstm_model(device).to(device)
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)
    
    loss_ema=-1
    loss_ema2=-1

    data=[]

    for epoch in range(n_epochs):
        for i_batch, sample_batch in enumerate(tqdm(train_loader, desc='Epoch %i/%i'%(epoch+1,n_epochs), disable=verbose)):
            '''
            TODO:
            Deep learning model training routine
            '''
            inp, out=sample_batch
            inp, out=inp.to(device), out.to(device)

            # input: batch size x 60 x 49 x 4
            # transpose: batch size x 49 x 240
            mixed=torch.cat([inp, out], 2).transpose(1,2).reshape((-1,49,240))#.float()

            y_pred=model(mixed[:,:-1])[:,-30:]
            y_pred=y_pred.reshape((-1,30,60,4)).transpose(1,2)

            loss=(torch.mean((y_pred-out)**2))**0.5
            optimizer.zero_grad() # set gradient to zero
            loss.backward() # backwards propogation
            optimizer.step() #forward step

            if loss_ema<0:
                loss_ema=loss
            loss_ema=loss_ema*ema_weight+loss*(1-ema_weight)

            with torch.no_grad():
                y_pred2=model.forward_test(inp.transpose(1,2).reshape((-1,19,240)))
                y_pred2=y_pred2.reshape((-1,30,60,4)).transpose(1,2)
                loss2=torch.mean((y_pred2-out)**2)**0.5
                if loss_ema2<0:
                    loss_ema2=loss2
                loss_ema2=loss_ema2*ema_weight+loss2*(1-ema_weight)

            if verbose and i_batch%10==0:
                loss_str='loss_full %i %i %f %f'%(epoch,i_batch,loss_ema.item(),loss.item())
                loss2_str='loss_full %i %i %f %f'%(epoch,i_batch,loss_ema2.item(),loss2.item())
                print(loss_str)
                print(loss2_str)
            
            data.append([epoch, i_batch, loss_ema.item(), loss.item(), loss_ema2.item(), loss2.item()])
    columns=["epoch","iteration","loss_ema","loss", "loss_ema2", "loss2"]
    df=pd.DataFrame(dict(zip(columns, np.array(data).T)))
    display(df)
    df.to_csv(filename)
    return model

In [7]:
%timeit -r1

train_lstm_model(
    lstm_model=LSTM_model,
    data_loader=train_loader,
    n_epochs=1,
    filename='save-data/lsmt-test-090-2.csv',
    ema_weight=0.9,
    verbose=False
)

Epoch 1/1: 100%|██████████| 805/805 [31:37<00:00,  2.36s/it]


Unnamed: 0,epoch,iteration,loss_ema,loss,loss_ema2,loss2
0,0.0,0.0,0.0,0.0,0.0,0.0
1,0.0,1.0,2.0,3.0,4.0,5.0
2,578.336304,576.158752,578.187561,579.960693,582.095825,582.515503
3,578.336365,556.560852,596.447266,595.919128,601.312317,586.292847
4,578.275879,576.041382,578.035278,579.786621,581.900757,582.304321
5,578.275879,555.930969,595.98053,595.549011,600.928162,585.935974


LSTM_model(
  (lstm): LSTM(240, 2048, num_layers=3, batch_first=True)
  (linear): Conv1d(2048, 240, kernel_size=(1,), stride=(1,))
)

above model has iteration ~2.15+/-.1 s

In [8]:
%timeit -r1
train_lstm_model(
    lstm_model=LSTM_model,
    data_loader=train_loader,
    n_epochs=1,
    filename='save-data/lstm-test_075-2.csv',
    ema_weight=0.75,
    verbose=True
)

  0%|          | 0/1 [00:04<?, ?it/s]

loss_full 0 0 607.931396 607.931396
loss_full 0 0 607.870544 607.870544


  0%|          | 0/1 [00:26<?, ?it/s]

loss_full 0 10 574.160217 561.601929
loss_full 0 10 573.792419 561.216553


  0%|          | 0/1 [00:48<?, ?it/s]

loss_full 0 20 569.431274 583.384094
loss_full 0 20 569.035400 582.973999


  0%|          | 0/1 [01:10<?, ?it/s]

loss_full 0 30 567.939087 569.132446
loss_full 0 30 567.540771 568.732727


  0%|          | 0/1 [01:32<?, ?it/s]

loss_full 0 40 565.619873 557.560486
loss_full 0 40 565.229126 557.171875


  0%|          | 0/1 [01:54<?, ?it/s]

loss_full 0 50 573.362000 563.320007
loss_full 0 50 572.977295 562.943054


  0%|          | 0/1 [02:16<?, ?it/s]

loss_full 0 60 563.549316 558.083130
loss_full 0 60 563.169800 557.710388


  0%|          | 0/1 [02:38<?, ?it/s]

loss_full 0 70 563.033325 592.305298
loss_full 0 70 562.669006 591.936768


  0%|          | 0/1 [03:00<?, ?it/s]

loss_full 0 80 563.271484 553.431763
loss_full 0 80 562.913452 553.080505


  0%|          | 0/1 [03:22<?, ?it/s]

loss_full 0 90 560.703735 583.915833
loss_full 0 90 560.348206 583.552307


  0%|          | 0/1 [03:44<?, ?it/s]

loss_full 0 100 535.484070 517.107056
loss_full 0 100 535.141602 516.767700


  0%|          | 0/1 [04:06<?, ?it/s]

loss_full 0 110 551.103516 549.269104
loss_full 0 110 550.768921 548.939331


  0%|          | 0/1 [04:28<?, ?it/s]

loss_full 0 120 533.067017 530.916138
loss_full 0 120 532.739685 530.596191


  0%|          | 0/1 [04:50<?, ?it/s]

loss_full 0 130 539.607849 538.252991
loss_full 0 130 539.283630 537.933594


  0%|          | 0/1 [05:12<?, ?it/s]

loss_full 0 140 527.888000 525.553955
loss_full 0 140 527.566345 525.236755


  0%|          | 0/1 [05:33<?, ?it/s]

loss_full 0 150 528.028625 523.317139
loss_full 0 150 527.715149 523.009216


  0%|          | 0/1 [05:55<?, ?it/s]

loss_full 0 160 521.867493 516.456970
loss_full 0 160 521.553833 516.138916


  0%|          | 0/1 [06:16<?, ?it/s]

loss_full 0 170 531.292236 546.336487
loss_full 0 170 530.980835 546.025208


  0%|          | 0/1 [06:38<?, ?it/s]

loss_full 0 180 525.296875 543.682251
loss_full 0 180 525.001953 543.387634


  0%|          | 0/1 [07:00<?, ?it/s]

loss_full 0 190 520.205200 544.012329
loss_full 0 190 519.908875 543.712830


  0%|          | 0/1 [07:21<?, ?it/s]

loss_full 0 200 522.434021 546.168213
loss_full 0 200 522.141541 545.875916


  0%|          | 0/1 [07:43<?, ?it/s]

loss_full 0 210 516.203064 516.759888
loss_full 0 210 515.922119 516.477173


  0%|          | 0/1 [08:05<?, ?it/s]

loss_full 0 220 518.889709 539.860046
loss_full 0 220 518.606567 539.575989


  0%|          | 0/1 [08:26<?, ?it/s]

loss_full 0 230 502.287598 504.323456
loss_full 0 230 502.011230 504.047089


  0%|          | 0/1 [08:48<?, ?it/s]

loss_full 0 240 514.332825 509.208527
loss_full 0 240 514.061584 508.937347


  0%|          | 0/1 [09:10<?, ?it/s]

loss_full 0 250 507.020264 519.607239
loss_full 0 250 506.755310 519.346741


  0%|          | 0/1 [09:31<?, ?it/s]

loss_full 0 260 503.141632 517.043579
loss_full 0 260 502.880219 516.776855


  0%|          | 0/1 [09:53<?, ?it/s]

loss_full 0 270 503.853729 528.327759
loss_full 0 270 503.598328 528.070129


  0%|          | 0/1 [10:15<?, ?it/s]

loss_full 0 280 494.464447 489.513580
loss_full 0 280 494.216400 489.269196


  0%|          | 0/1 [10:34<?, ?it/s]

loss_full 0 290 495.616730 512.884399
loss_full 0 290 495.366577 512.633789


  0%|          | 0/1 [10:53<?, ?it/s]

loss_full 0 300 488.990265 488.250519
loss_full 0 300 488.748901 488.009491


  0%|          | 0/1 [11:12<?, ?it/s]

loss_full 0 310 484.692444 483.371704
loss_full 0 310 484.455841 483.126831


  0%|          | 0/1 [11:31<?, ?it/s]

loss_full 0 320 496.042053 485.578430
loss_full 0 320 495.804993 485.345337


  0%|          | 0/1 [11:49<?, ?it/s]

loss_full 0 330 473.240631 462.179565
loss_full 0 330 473.015015 461.961151


  0%|          | 0/1 [12:08<?, ?it/s]

loss_full 0 340 488.296722 492.584839
loss_full 0 340 488.071991 492.362518


  0%|          | 0/1 [12:27<?, ?it/s]

loss_full 0 350 475.138000 460.156311
loss_full 0 350 474.915070 459.927277


  0%|          | 0/1 [12:46<?, ?it/s]

loss_full 0 360 485.110596 496.988525
loss_full 0 360 484.889404 496.770782


  0%|          | 0/1 [13:05<?, ?it/s]

loss_full 0 370 475.602417 467.347748
loss_full 0 370 475.390198 467.130554


  0%|          | 0/1 [13:24<?, ?it/s]

loss_full 0 380 461.309357 453.711578
loss_full 0 380 461.103271 453.503143


  0%|          | 0/1 [13:43<?, ?it/s]

loss_full 0 390 466.579926 477.353882
loss_full 0 390 466.376892 477.143494


  0%|          | 0/1 [14:02<?, ?it/s]

loss_full 0 400 472.223267 482.213837
loss_full 0 400 472.020386 482.003296


  0%|          | 0/1 [14:21<?, ?it/s]

loss_full 0 410 467.669952 464.669220
loss_full 0 410 467.471680 464.467468


  0%|          | 0/1 [14:40<?, ?it/s]

loss_full 0 420 464.614105 468.908539
loss_full 0 420 464.419006 468.721497


  0%|          | 0/1 [14:59<?, ?it/s]

loss_full 0 430 467.024475 493.874207
loss_full 0 430 466.874207 493.784241


  0%|          | 0/1 [15:17<?, ?it/s]

loss_full 0 440 456.592468 446.530579
loss_full 0 440 456.466370 446.423920


  0%|          | 0/1 [15:36<?, ?it/s]

loss_full 0 450 457.704224 456.369598
loss_full 0 450 457.938904 456.719208


  0%|          | 0/1 [15:55<?, ?it/s]

loss_full 0 460 454.784943 427.529877
loss_full 0 460 454.743469 427.235626


  0%|          | 0/1 [16:14<?, ?it/s]

loss_full 0 470 454.932281 432.659302
loss_full 0 470 454.738037 432.460205


  0%|          | 0/1 [16:33<?, ?it/s]

loss_full 0 480 465.400635 438.312805
loss_full 0 480 465.147614 438.042999


  0%|          | 0/1 [16:52<?, ?it/s]

loss_full 0 490 457.854828 462.534760
loss_full 0 490 457.655579 462.220551


  0%|          | 0/1 [17:11<?, ?it/s]

loss_full 0 500 452.764862 429.624115
loss_full 0 500 452.738586 429.412323


  0%|          | 0/1 [17:30<?, ?it/s]

loss_full 0 510 454.895111 472.852661
loss_full 0 510 454.740540 472.676758


  0%|          | 0/1 [17:49<?, ?it/s]

loss_full 0 520 452.558990 458.042999
loss_full 0 520 452.396545 457.882355


  0%|          | 0/1 [18:08<?, ?it/s]

loss_full 0 530 454.941620 456.832245
loss_full 0 530 454.782471 456.673584


  0%|          | 0/1 [18:27<?, ?it/s]

loss_full 0 540 456.475616 464.318970
loss_full 0 540 456.321655 464.169128


  0%|          | 0/1 [18:46<?, ?it/s]

loss_full 0 550 450.230225 434.227051
loss_full 0 550 450.079468 434.081757


  0%|          | 0/1 [19:05<?, ?it/s]

loss_full 0 560 445.063995 438.294769
loss_full 0 560 444.908020 438.113464


  0%|          | 0/1 [19:24<?, ?it/s]

loss_full 0 570 458.785339 489.228912
loss_full 0 570 458.637787 489.078583


  0%|          | 0/1 [19:42<?, ?it/s]

loss_full 0 580 447.103027 453.759491
loss_full 0 580 446.958191 453.616150


  0%|          | 0/1 [20:01<?, ?it/s]

loss_full 0 590 439.128540 412.757843
loss_full 0 590 438.991150 412.623047


  0%|          | 0/1 [20:20<?, ?it/s]

loss_full 0 600 436.603729 443.416382
loss_full 0 600 436.472107 443.292786


  0%|          | 0/1 [20:39<?, ?it/s]

loss_full 0 610 445.968903 459.273590
loss_full 0 610 445.841095 459.149567


  0%|          | 0/1 [20:58<?, ?it/s]

loss_full 0 620 445.431671 447.358551
loss_full 0 620 445.308289 447.234558


  0%|          | 0/1 [21:17<?, ?it/s]

loss_full 0 630 449.549927 455.037994
loss_full 0 630 449.428162 454.918762


  0%|          | 0/1 [21:36<?, ?it/s]

loss_full 0 640 437.832947 423.631653
loss_full 0 640 437.721741 423.533875


  0%|          | 0/1 [21:55<?, ?it/s]

loss_full 0 650 432.142578 430.046204
loss_full 0 650 432.037048 429.940491


  0%|          | 0/1 [22:14<?, ?it/s]

loss_full 0 660 436.128265 432.244049
loss_full 0 660 436.024963 432.144104


  0%|          | 0/1 [22:33<?, ?it/s]

loss_full 0 670 432.797668 441.228210
loss_full 0 670 432.692505 441.121307


  0%|          | 0/1 [22:52<?, ?it/s]

loss_full 0 680 431.905792 422.012695
loss_full 0 680 431.804260 421.915924


  0%|          | 0/1 [23:11<?, ?it/s]

loss_full 0 690 433.344025 463.986847
loss_full 0 690 433.241638 463.886322


  0%|          | 0/1 [23:29<?, ?it/s]

loss_full 0 700 427.303619 427.043091
loss_full 0 700 427.205200 426.946228


  0%|          | 0/1 [23:48<?, ?it/s]

loss_full 0 710 422.470215 435.883392
loss_full 0 710 422.378876 435.792572


  0%|          | 0/1 [24:07<?, ?it/s]

loss_full 0 720 424.379059 408.044708
loss_full 0 720 424.286407 407.942261


  0%|          | 0/1 [24:26<?, ?it/s]

loss_full 0 730 427.473511 435.941193
loss_full 0 730 427.384857 435.850494


  0%|          | 0/1 [24:45<?, ?it/s]

loss_full 0 740 428.428284 420.834595
loss_full 0 740 428.343201 420.751251


  0%|          | 0/1 [25:04<?, ?it/s]

loss_full 0 750 421.276062 420.922577
loss_full 0 750 421.192261 420.841370


  0%|          | 0/1 [25:23<?, ?it/s]

loss_full 0 760 426.882812 428.746399
loss_full 0 760 426.797363 428.662537


  0%|          | 0/1 [25:42<?, ?it/s]

loss_full 0 770 424.626526 430.390442
loss_full 0 770 424.547577 430.312012


  0%|          | 0/1 [26:01<?, ?it/s]

loss_full 0 780 427.561371 416.591705
loss_full 0 780 427.480225 416.506287


  0%|          | 0/1 [26:20<?, ?it/s]

loss_full 0 790 424.154419 428.606842
loss_full 0 790 424.080109 428.534271


  0%|          | 0/1 [26:39<?, ?it/s]

loss_full 0 800 418.693634 417.673828
loss_full 0 800 418.623749 417.607544


100%|██████████| 1/1 [26:46<00:00, 1606.16s/it]


LSTM_model(
  (lstm): LSTM(240, 2048, num_layers=3, batch_first=True)
  (linear): Conv1d(2048, 240, kernel_size=(1,), stride=(1,))
)

In [None]:
%timeit -r1
train_lstm_model(
    lstm_model=LSTM_model,
    data_loader=train_loader,
    n_epochs=1,
    filename='save-data/lstm-test1_05.csv',
    ema_weight=0.5,
    verbose=False
)

In [None]:
%timeit -r1
train_lstm_model(
    lstm_model=LSTM_model,
    data_loader=train_loader,
    n_epochs=1,
    filename='save-data/lstm-test1_099.csv',
    ema_weight=0.99,
    verbose=True
)

### Simple Linear Model

In [9]:
class simple_linear_model(nn.Module):
    def __init__(self, device):
        super(simple_linear_model, self).__init__()
        
        self.hidden_dim=2048
        self.num_layers=3
        self.device=device

        #simple single layer linear model
        self.linear=nn.Linear(
            in_features=240*19,
            out_features=240*30
        )
        
    def forward(self, x):        
        x=self.linear(x)
        return x
        
    def forward_test(self, x, num_steps=30):
        res=[]
        h=torch.zeros((self.num_layers, len(x), self.hidden_dim)).to(self.device)
        c=torch.zeros((self.num_layers, len(x), self.hidden_dim)).to(self.device)
        for step in range(num_steps):
            x, (h,c)=self.lstm(x, (h,c))
            x=x[:,-1:]
            x=x.transpose(1,2)
            x=self.linear(x)
            x=x.transpose(1,2)
            res.append(x)
        res=torch.cat(res,1)
        return res

In [21]:
def train_linear_model(lmodel, data_loader, n_epochs, filename, ema_weight, device='cuda', verbose=False):
    '''
    Train LSTM model
    ----------
    Parameters
    ----------
        lmodel - torch nerual network linear model
            model to train
        data_loader - torch DataLoader class 
            training data for model
        n_epochs - int
            number of epochs to train
        filename - string
            filepath to save training data to
        ema_weight - float
            float between (0.0,1.0) for the exponential moving average weight
        device - string, default 'cuda'
            choose to run on gpu ('cuda') or cpu ('CPU')
        verbose - boolean, default False
            If true print training progress every 10 training iterations
    -------
    Returns
    -------
         Trained LSTM model
    '''
    
    model=lmodel(device).to(device)
    optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)

    loss_ema=-1
    loss_ema2=-1

    data=[]

    for epoch in range(n_epochs):
        for i_batch, sample_batch in enumerate(tqdm(train_loader, desc='Epoch %i/%i'%(epoch+1,n_epochs), disable=verbose)):
            '''
            TODO:
            Deep learning model training routine
            '''
            inp, out=sample_batch
            inp, out=inp.to(device), out.to(device)

            # input: batch size x 60 x 49 x 4
            # transpose: batch size x 49 x 240
            mixed=torch.cat([inp, out], 2).transpose(1,2).reshape((-1,49,240))#.float()

            y_pred=model(inp.reshape((len(inp),-1))).reshape((-1,60,30,4))

            loss=(torch.mean((y_pred-out)**2))**0.5
            optimizer.zero_grad() # set gradient to zero
            loss.backward() # backwards propogation
            optimizer.step() #forward step

            if loss_ema<0:
                loss_ema=loss
            loss_ema=loss_ema*ema_weight+loss*(1-ema_weight)

            if verbose and i_batch%10==0:
                loss_str='loss %i %i %f %f'%(epoch,i_batch,loss_ema.item(),loss.item())
                print(loss_str)
                
            data.append([epoch, i_batch, loss_ema.item(), loss.item()])
    
    columns=["epoch","iteration","loss_ema","loss"]
    df=pd.DataFrame(dict(zip(columns, np.array(data).T)))
    display(df)
    df.to_csv(filename)
    return model

In [22]:
train_linear_model(
    lmodel=simple_linear_model,
    data_loader=train_loader,
    n_epochs=1,
    filename='save-data/simple-linear-090.csv',
    ema_weight=0.90,
    verbose=False
)

Epoch 1/1: 100%|██████████| 805/805 [04:01<00:00,  3.33it/s]


Unnamed: 0,epoch,iteration,loss_ema,loss
0,0.0,0.0,674.189087,674.189148
1,0.0,1.0,667.140930,603.708069
2,0.0,2.0,646.925781,464.989563
3,0.0,3.0,621.590942,393.577423
4,0.0,4.0,603.362366,439.305542
...,...,...,...,...
800,0.0,800.0,53.114674,36.886452
801,0.0,801.0,54.724529,69.213257
802,0.0,802.0,54.805630,55.535534
803,0.0,803.0,54.290657,49.655933


simple_linear_model(
  (linear): Linear(in_features=4560, out_features=7200, bias=True)
)

In [23]:
class multilayer_linear_model(nn.Module):
    def __init__(self, device):
        super(multilayer_linear_model, self).__init__()
        
        self.hidden_dim=8192
        self.num_layers=3
        self.device=device
                
        self.linear=nn.Sequential(
            nn.Linear(
                in_features=240*19, 
                out_features=self.hidden_dim
            ),
            nn.ReLU(),
            nn.Linear(
                in_features=self.hidden_dim,
                out_features=240*30
            )
        )
    
        # The linear layer that maps from hidden state space to tag space
#         self.linear=nn.Conv1d(
#             in_channels=self.hidden_dim,
#             out_channels=240,
#             kernel_size=1
#         )
        
    def forward(self, x):        
        x=self.linear(x)
        return x
        
    def forward_test(self, x, num_steps=30):
        res=[]
        h=torch.zeros((self.num_layers, len(x), self.hidden_dim)).to(self.device)
        c=torch.zeros((self.num_layers, len(x), self.hidden_dim)).to(self.device)
        for step in range(num_steps):
            x, (h,c)=self.lstm(x, (h,c))
            x=x[:,-1:]
            x=x.transpose(1,2)
            x=self.linear(x)
            x=x.transpose(1,2)
            res.append(x)
        res=torch.cat(res,1)
        return res

In [25]:
train_linear_model(
    lmodel=multilayer_linear_model,
    data_loader=train_loader,
    n_epochs=1,
    filename='save-data/simple-linear-090.csv',
    ema_weight=0.90,
    verbose=False
)

Epoch 1/1: 100%|██████████| 805/805 [04:07<00:00,  3.25it/s]


Unnamed: 0,epoch,iteration,loss_ema,loss
0,0.0,0.0,599.770264,599.770264
1,0.0,1.0,688.255798,1484.625732
2,0.0,2.0,665.488220,460.580475
3,0.0,3.0,649.831238,508.918549
4,0.0,4.0,638.979004,541.309204
...,...,...,...,...
800,0.0,800.0,16.430296,14.909383
801,0.0,801.0,16.421009,16.337429
802,0.0,802.0,16.275198,14.962907
803,0.0,803.0,17.299162,26.514839


multilayer_linear_model(
  (linear): Sequential(
    (0): Linear(in_features=4560, out_features=8192, bias=True)
    (1): ReLU()
    (2): Linear(in_features=8192, out_features=7200, bias=True)
  )
)

In [25]:
simple_linear_df.to_csv("simple_linear.csv")
simple_linear_df

In [27]:
torch.cuda.empty_cache()

model=multilayer_linear_model(device).to(device)
optimizer=torch.optim.Adam(model.parameters(), lr=1e-3)

loss_ema=-1
loss_ema2=-1

n_epochs=20
sequential1_df=pd.DataFrame(dict(zip(["epoch","iteration","loss_ema","loss"],[])))

for epoch in range(n_epochs):
    for i_batch, sample_batch in enumerate(train_loader):
        '''
        TODO:
        Deep learning model training routine
        '''
        inp, out=sample_batch
        inp, out=inp.to(device), out.to(device)
        
        # input: batch size x 60 x 49 x 4
        # transpose: batch size x 49 x 240
        mixed=torch.cat([inp, out], 2).transpose(1,2).reshape((-1,49,240))#.float()
        
        y_pred=model(inp.reshape((len(inp),-1))).reshpape((-1,60,30,4))
        
        loss=(torch.mean((y_pred-out)**2))**0.5
        optimizer.zero_grad() # set gradient to zero
        loss.backward() # backwards propogation
        optimizer.step() #forward step
        
        if loss_ema<0:
            loss_ema=loss
        loss_ema=loss_ema*0.99+loss*0.1
            
        if i_batch%10==0:
#             print('loss_full', epoch, i_batch, loss_ema.item(), loss.item())
            sequential1_df.append([epoch, i_batch, loss_ema.item(), loss.item()])

NameError: name 'sequential_linear_model' is not defined

In [None]:
sequential1_df.to_csv("sequential1.csv")
sequential1_df

### Visualize the batch of sequences

In [None]:
import matplotlib.pyplot as plt
import random

agent_id = 0

def show_sample_batch(sample_batch, agent_id):
    """visualize the trajectory for a batch of samples with a randon agent"""
    inp, out = sample_batch
    batch_sz = inp.size(0)
    agent_sz = inp.size(1)
    
    fig, axs = plt.subplots(1,batch_sz, figsize=(15, 3), facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace = .5, wspace=.001)
    axs = axs.ravel()   
    for i in range(batch_sz):
        axs[i].xaxis.set_ticks([])
        axs[i].yaxis.set_ticks([])
        
        # first two feature dimensions are (x,y) positions
        axs[i].scatter(inp[i, agent_id,:,0], inp[i, agent_id,:,1])
        axs[i].scatter(out[i, agent_id,:,0], out[i, agent_id,:,1])
        
        #set labels
        axs[i].set_ylabel("out")
        axs[i].set_xlabel("inp")
        

        
for i_batch, sample_batch in enumerate(val_loader):
    inp, out = sample_batch
    """TODO:
      Deep learning model
      training routine
    """
    show_sample_batch(sample_batch, agent_id)
    break

In [None]:
for sample in val_loader:
    print("test")
    print(sample)
    break
