In [47]:
import torch
from torch.utils.data import Dataset, DataLoader
import os, os.path 
from torch.autograd import Variable
import numpy 
import pickle
from glob import glob
import torch.nn as nn

"""Change to the data folder"""
new_path = "../new_train/"

# number of sequences in each dataset
# train:205942  val:3200 test: 36272 
# sequences sampled at 10HZ rate

### Create a dataset class 

In [48]:
class ArgoverseDataset(Dataset):
    """Dataset class for Argoverse"""
    def __init__(self, data_path: str, transform=None):
        super(ArgoverseDataset, self).__init__()
        self.data_path = data_path
        self.transform = transform

        self.pkl_list = glob(os.path.join(self.data_path, '*'))
        self.pkl_list.sort()
        
    def __len__(self):
        return len(self.pkl_list)

    def __getitem__(self, idx):

        pkl_path = self.pkl_list[idx]
        with open(pkl_path, 'rb') as f:
            data = pickle.load(f)
            
        if self.transform:
            data = self.transform(data)

        return data


# intialize a dataset
val_dataset  = ArgoverseDataset(data_path=new_path)

In [58]:
class RNNModel(nn.Module):
    def __init__(self, input_dim, hidden_dim, layer_dim, output_dim):
        super(RNNModel, self).__init__()
        
        # Number of hidden dimensions
        self.hidden_dim = hidden_dim
        
        # Number of hidden layers
        self.layer_dim = layer_dim
        
        # RNN
        self.rnn = nn.RNN(input_dim, hidden_dim, layer_dim, batch_first=True, nonlinearity='relu')
        
        # Readout layer
        self.fc1 = nn.Linear(hidden_dim, 16)
        #self.fc2 = nn.Linear(16, output_dim)
    
    def forward(self, x):
        
        # Initialize hidden state with zeros
        h0 = Variable(torch.zeros(self.layer_dim, x.size(0), self.hidden_dim))
            
        # One time step
        out, hn = self.rnn(x, h0)
        print("out1",out.shape)
        out = self.fc1(out[:, -1, :]) 
        print("out2",out.shape)
        out = self.fc2(out[:, :]) 
        #print("out3 ",out.shape)
        return out

# batch_size, epoch and iteration
batch_size = 100
n_iters = 8000
num_epochs = 10


### Create a loader to enable batch processing

In [59]:
batch_sz = 16

def my_collate(batch):
    """ collate lists of samples into batches, create [ batch_sz x agent_sz x seq_len x feature] """
    inp = [numpy.dstack([scene['p_in'], scene['v_in']]) for scene in batch]
    out = [numpy.dstack([scene['p_out'], scene['v_out']]) for scene in batch]
    inp = torch.LongTensor(inp)
    out = torch.LongTensor(out)
    return [inp, out]

val_loader = DataLoader(val_dataset,batch_size=batch_sz, shuffle = False, collate_fn=my_collate, num_workers=0)
print(val_loader)

<torch.utils.data.dataloader.DataLoader object at 0x000002D90EC83DC0>


### Visualize the batch of sequences

In [60]:
import matplotlib.pyplot as plt
import random

agent_id = 0

def show_sample_batch(sample_batch, agent_id):
    """visualize the trajectory for a batch of samples with a randon agent"""
    inp, out = sample_batch
    batch_sz = inp.size(0)
    agent_sz = inp.size(1)
    
    fig, axs = plt.subplots(1,batch_sz, figsize=(15, 3), facecolor='w', edgecolor='k')
    fig.subplots_adjust(hspace = .5, wspace=.001)
    axs = axs.ravel()   
    for i in range(batch_sz):
        axs[i].xaxis.set_ticks([])
        axs[i].yaxis.set_ticks([])
        
        # first two feature dimensions are (x,y) positions
        axs[i].scatter(inp[i, agent_id,:,0], inp[i, agent_id,:,1])
        axs[i].scatter(out[i, agent_id,:,0], out[i, agent_id,:,1])


# Create RNN
input_dim = 60    # input dimension
hidden_dim = 20  # hidden layer dimension
layer_dim = 1     # number of hidden layers
output_dim = 10   # output dimension
model = RNNModel(input_dim, hidden_dim, layer_dim, output_dim)
# Cross Entropy Loss 
error = nn.CrossEntropyLoss()

# SGD Optimizer
learning_rate = 0.05
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate)   
seq_dim = 19  
loss_list = []
iteration_list = []
accuracy_list = []
count = 0
for epoch in range(1):
    for i_batch, sample_batch in enumerate(val_loader):
        inp, out = sample_batch
        print(inp.shape)
        
        train  = Variable(inp.view(-1, seq_dim, input_dim))
        labels = Variable(out )
            
        # Clear gradients
        optimizer.zero_grad()
        
        # Forward propagation
        outputs = model(train.float())
        # Calculate softmax and ross entropy loss
        print(labels.shape)
        print(outputs.shape)
        loss = error(outputs, labels)
        
        # Calculating gradients
        loss.backward()
        
        # Update parameters
        optimizer.step()
        
        count += 1
        
        if count == 1:
            # Calculate Accuracy         
            correct = 0
            total = 0
            # Iterate through test dataset
            for images, labels in test_loader:
                images = Variable(images.view(-1, seq_dim, input_dim))
                
                # Forward propagation
                outputs = model(images)
                
                # Get predictions from the maximum value
                predicted = torch.max(outputs.data, 1)[1]
                
                # Total number of labels
                total += labels.size(0)
                
                correct += (predicted == labels).sum()
            
            accuracy = 100 * correct / float(total)
            
            # store loss and iteration
            loss_list.append(loss.data)
            iteration_list.append(count)
            accuracy_list.append(accuracy)
            if count == 1:
                # Print Loss
                print('Iteration: {}  Loss: {}  Accuracy: {} %'.format(count, loss.data[0], accuracy))
                
                
        break

torch.Size([16, 60, 19, 4])
out1 torch.Size([64, 19, 100])
out2 torch.Size([64, 16])
torch.Size([16, 60, 30, 4])
torch.Size([1024])


  inp = torch.LongTensor(inp)
  out = torch.LongTensor(out)


IndexError: Dimension out of range (expected to be in range of [-1, 0], but got 1)