In [1]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
import numpy 
import pickle
from glob import glob
import matplotlib.pyplot as plt

"""Change to the data folder"""
new_path = "../new_train/"
val_path = "../new_val_in"

# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [2]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
train_dataset  = ArgoverseDataset(data_path=new_path)
val_dataset = ArgoverseDataset(data_path=val_path)
#print((val_dataset[0]))
#print(len(train_dataset[0]))

### Create a loader to enable batch processing

In [3]:
batch_sz = 1

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
# print(inp.size)
    print("gap")
#     print(out.size)
    #inp = np.concatenate((inp, out), axis=0)
    inp = torch.LongTensor(inp)
    print(inp.shape)
    print("after")
    print(inp)
    out = torch.LongTensor(out)
    return [inp, out]

train_loader = DataLoader(train_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0)

In [4]:

def val_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    inp = torch.LongTensor(inp)
    return inp

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0)

In [5]:
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from torch.autograd import Variable

class RNNModel(nn.Module):
    def __init__(self, input_size, output_size, hidden_dim, n_layers):
        super(RNNModel, self).__init__()
        
        # Number of hidden dimensions
        self.hidden_dim = hidden_dim
        
        # Number of hidden layers
        self.n_layers = n_layers
        
        # RNN
        self.rnn = nn.RNN(input_size, hidden_dim, n_layers, batch_first=True, nonlinearity='relu')
        #print(self.rnn)
        
        # Readout layer
        self.fc = nn.Linear(hidden_dim, output_size)
    
    def forward(self, x):
        batch_size = x.size(0)
        
        hidden = self.init_hidden(batch_size)
        #print(hidden.shape)
        out, hidden = self.rnn(x, hidden)
        out = out.contiguous().view(-1, self.hidden_dim)
        out = self.fc(out)
        
        return out
        """
        # Initialize hidden state with zeros
        h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))
            
        # One time step
        print(x)
        #print("gap")
        #print(h0.shape)
        out, hn = self.rnn(x, h0)
        out = self.fc(out[:, -1, :]) 
        return out
        """
    
    def init_hidden(self, batch_size):
        hidden = torch.zeros(self.n_layers, batch_size, self.hidden_dim)
        return hidden

In [6]:
def make_a_histogram(sample_batch, agent_id, xPos, yPos, xVel, yVel):
    inp, out = sample_batch
    batch_sz = inp.size(0)
    #agent_sz = inp.size(1)
    
    for i in range(batch_sz):
        #hist_data_xPos = np.zeros((60,19));
        #hist_data_yPos = np.zeros((60,19));
        #hist_data_xVel = np.zeros((60,19));
        hist_data_yVel = np.zeros((60,19));
        
        for j in range(60):
            #hist_data_xPos[j] = (inp[i, j,:,0])
            #hist_data_yPos[j] = (inp[i, j,:,1])
            #hist_data_xVel[j] = (inp[i, j,:,2])
            hist_data_yVel[j] = (inp[i, j,:,3])
            
        for j in range(len(hist_data_yVel)):
            for k in range(len(hist_data_yVel[j])):
                #xPos.append(hist_data_xPos[j][k])
                #yPos.append(hist_data_yPos[j][k])
                #xVel.append(hist_data_xVel[j][k])
                yVel.append(hist_data_yVel[j][k])
    
    """
    hist_data_xPos = np.zeros((60,19));
    hist_data_yPos = np.zeros((60,19));
    hist_data_xVel = np.zeros((60,19));
    hist_data_yVel = np.zeros((60,19));
    
    for i in range(60):
        hist_data_xPos[i] = (inp[0, i,:,0])
        hist_data_yPos[i] = (inp[0, i,:,1])
        hist_data_xVel[i] = (inp[0, i,:,2])
        hist_data_yVel[i] = (inp[0, i,:,3])
    
    xPos = np.zeros(60*19)
    for i in range(len(hist_data_xPos)):
        for j in range(len(hist_data_xPos[i])):
            xPos[i*19+j] = hist_data_xPos[i][j]
    
    #hist_data_xPos = hist_data_xPos.flatten()
    hist_data_yVel = hist_data_yPos.flatten()
    hist_data_xPos = hist_data_xVel.flatten()
    hist_data_yVel = hist_data_yVel.flatten()
    #print(xPos)
    
    n,bins,patches = plt.hist(x=xPos,bins='auto',alpha=0.7,rwidth=0.85)
    plt.grid(axis='y',alpha=0.75)
    maxfreq = n.max()
    plt.ylim(ymax=np.ceil(maxfreq/10) * 10 if maxfreq % 10 else maxfreq + 10)
    """

In [7]:
def show_sample_batch(sample_batch, agent_id):
    """visualize the trajectory for a batch of samples with a randon agent"""
    inp, out = sample_batch
    batch_sz = inp.size(0)
    agent_sz = inp.size(1)
    
    fig, axs = plt.subplots(1,batch_sz, figsize=(15, 3), facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace = .5, wspace=.001)
    axs = axs.ravel()   
    for i in range(batch_sz):
        axs[i].xaxis.set_ticks([])
        axs[i].yaxis.set_ticks([])
        
        # first two feature dimensions are (x,y) positions
        axs[i].scatter(inp[i, agent_id,:,0], inp[i, agent_id,:,1])
        axs[i].scatter(out[i, agent_id,:,0], out[i, agent_id,:,1])

### Visualize the batch of sequences

In [14]:

import random
import numpy as np

agent_id = 0
learning_rate = 0.01
momentum = 0.5
device = "cpu"
input_dim = 4    # input dimension
hidden_dim = 100  # hidden layer dimension
layer_dim = 1     # number of hidden layers
output_dim = 4   # output dimension

n_epochs = 100
lr=0.01

# Define Loss, Optimizer
#model = RNNModel(input_dim, output_dim, hidden_dim, layer_dim).to(device)
model = RNNModel(input_size=input_dim, output_size=output_dim, hidden_dim=12, n_layers=1)
optimizer = optim.SGD(model.parameters(), lr=learning_rate,momentum=momentum)


    
    
model.train()

for i_batch, sample_batch in enumerate(train_loader):
    inp, out = sample_batch
    #make_a_histogram(sample_batch, agent_id, xPos, yPos, xVel, yVel)
    """TODO:
      Deep learning model
      training routine
    """
    inp, out = inp.to(device), out.to(device)
    optimizer.zero_grad()
    outList = np.zeros((1,60,30,4))
#     print(outList[0])
#     print(inp.shape)
    for i in range(30):
        output = model(inp[0].float())
        output = output.reshape(1,60,19,4)
#         print("output",output.shape)
        for j in range(60):
#             print(output[0][j][18])
#             print(outList[0][j][18])
            outList[0][j][i] = output[0][j][18].detach().numpy()
        for k in range(60):
            for j in range(1,19):
                inp[0][k][j-1] = inp[0][k][j]
            inp[0][k][18] = output[0][k][18]
    print(outList)
    tensorOut = torch.tensor(outList)
    print("tensor",tensorOut[0].shape)
    #print(output)
    print("out",out[0].shape)
    #print(out[0].shape)
    
    loss = F.nll_loss(tensorOut,out[0])
    loss.backward()
    optimizer.step()
    
    show_sample_batch(sample_batch, agent_id)
    break
    

gap
torch.Size([1, 60, 19, 4])
after
tensor([[[[3277, 1947,    0,    0],
          [3277, 1947,    0,    0],
          [3277, 1947,    0,    0],
          ...,
          [3277, 1947,    0,    0],
          [3277, 1947,    0,    0],
          [3277, 1947,    0,    0]],

         [[3277, 1977,    0,    0],
          [3277, 1977,    0,    0],
          [3277, 1977,    1,    1],
          ...,
          [3277, 1977,    0,   -1],
          [3277, 1977,    0,    0],
          [3277, 1977,    0,    0]],

         [[3232, 1922,    0,    0],
          [3232, 1922,    0,   -1],
          [3232, 1923,    0,    2],
          ...,
          [3232, 1922,    0,    0],
          [3232, 1922,    0,    0],
          [3232, 1922,    0,    0]],

         ...,

         [[   0,    0,    0,    0],
          [   0,    0,    0,    0],
          [   0,    0,    0,    0],
          ...,
          [   0,    0,    0,    0],
          [   0,    0,    0,    0],
          [   0,    0,    0,    0]],

         [[   0,

  inp = torch.LongTensor(inp)
  out = torch.LongTensor(out)


[[[[-3.69117981e+02 -8.43493958e+01  1.02473801e+02 -3.33031219e+02]
   [-1.37434525e+02 -1.33703964e+02 -4.72872276e+01 -8.11193924e+01]
   [-2.67913570e+01 -6.65094604e+01 -5.07154884e+01 -7.52995834e+01]
   ...
   [-2.12396741e-01  1.43416166e-01 -3.69407833e-01  9.62876678e-02]
   [-2.12396741e-01  1.43416166e-01 -3.69407833e-01  9.62876678e-02]
   [-2.12396741e-01  1.43416166e-01 -3.69407833e-01  9.62876678e-02]]

  [[-3.70553619e+02 -8.64100723e+01  1.02465767e+02 -3.33263763e+02]
   [-1.37782471e+02 -1.33623444e+02 -4.65024719e+01 -8.25036621e+01]
   [-2.71529083e+01 -6.74711151e+01 -5.16878090e+01 -7.52501678e+01]
   ...
   [-2.12396741e-01  1.43416166e-01 -3.69407833e-01  9.62876678e-02]
   [-2.12396741e-01  1.43416166e-01 -3.69407833e-01  9.62876678e-02]
   [-2.12396741e-01  1.43416166e-01 -3.69407833e-01  9.62876678e-02]]

  [[-3.64133301e+02 -8.33060913e+01  1.01062294e+02 -3.28468567e+02]
   [-1.35573868e+02 -1.31879623e+02 -4.66353035e+01 -8.00256271e+01]
   [-2.64764614e

ValueError: Expected input batch_size (1) to match target batch_size (60).