# Pytorch Baseline - Train

**Notes**
- Do not forget to enable the GPU (TPU) for training
- You have to add `kaggle_l5kit` as utility script
- Parts of the code below is from the [official example](https://github.com/lyft/l5kit/blob/master/examples/agent_motion_prediction/agent_motion_prediction.ipynb)
- [Baseline inference notebook](https://www.kaggle.com/pestipeti/pytorch-baseline-inference)

In [1]:
import numpy as np

import os
import torch
torch.manual_seed(0)

from torch import nn, optim
from torch.utils.data import DataLoader
from torchvision.models.resnet import resnet18,resnet50,resnet101
from tqdm import tqdm
from typing import Dict
from torch import functional as F

from l5kit.configs import load_config_data
from l5kit.data import LocalDataManager, ChunkedDataset
from l5kit.dataset import AgentDataset, EgoDataset
from l5kit.rasterization import build_rasterizer
from l5kit.evaluation import write_pred_csv, compute_metrics_csv, read_gt_csv, create_chopped_dataset
from l5kit.evaluation.chop_dataset import MIN_FUTURE_STEPS
from l5kit.evaluation.metrics import neg_multi_log_likelihood, time_displace
from l5kit.geometry import transform_points
from l5kit.visualization import PREDICTED_POINTS_COLOR, TARGET_POINTS_COLOR, draw_trajectory
from prettytable import PrettyTable
from pathlib import Path

In [2]:
DIR_INPUT = "/media/ubuntu/Data/project/lyft/lyft-motion-prediction-autonomous-vehicles/"

In [3]:
cfg = {
    'format_version': 4,
    'model_params': {
        'model_architecture': 'resnet18',
        
        'history_num_frames': 10,
        'history_step_size': 1,
        'history_delta_time': 0.1,
        
        'future_num_frames': 50,
        'future_step_size': 1,
        'future_delta_time': 0.1
    },
    
    'raster_params': {
        'raster_size': [1, 1],
        'pixel_size': [0.5, 0.5],
        'ego_center': [0.25, 0.5],
        'map_type': 'py_semantic',
        'satellite_map_key': 'aerial_map/aerial_map.png',
        'semantic_map_key': 'semantic_map/semantic_map.pb',
        'dataset_meta_key': 'meta.json',
        'filter_agents_threshold': 0.5
    },
    
    'train_data_loader': {
        'key': 'scenes/train.zarr',
        'batch_size': 32,
        'shuffle': True,
        'num_workers': 4
    },
    
    'val_data_loader': {
        'key': 'scenes/validate.zarr',
        'batch_size': 32,
        'shuffle': False,
        'num_workers': 4
    },
    
    'test_data_loader': {
        'key': 'scenes/test.zarr',
        'batch_size': 32,
        'shuffle': False,
        'num_workers': 4
    },
    
    'train_params': {
        'checkpoint_every_n_steps': 5000,
        'max_num_steps': 10000,
        'eval_every_n_steps': 500

        
    }
}

In [4]:
# set env variable for data
os.environ["L5KIT_DATA_FOLDER"] = DIR_INPUT
dm = LocalDataManager(None)
VALIDATION = True

## Dataset, dataloader

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
# ===== INIT DATASET
train_cfg = cfg["train_data_loader"]

# Rasterizer
rasterizer = build_rasterizer(cfg, dm)

# Train dataset/dataloader
train_zarr = ChunkedDataset(dm.require(train_cfg["key"])).open()
train_dataset = AgentDataset(cfg, train_zarr, rasterizer)
train_dataloader = DataLoader(train_dataset,
                              shuffle=train_cfg["shuffle"],
                              batch_size=train_cfg["batch_size"])
                              #num_workers=train_cfg["num_workers"])

print(train_dataset)
print(len(train_dataset))

+------------+------------+------------+-----------------+----------------------+----------------------+----------------------+---------------------+
| Num Scenes | Num Frames | Num Agents | Total Time (hr) | Avg Frames per Scene | Avg Agents per Frame | Avg Scene Time (sec) | Avg Frame frequency |
+------------+------------+------------+-----------------+----------------------+----------------------+----------------------+---------------------+
|   16265    |  4039527   | 320124624  |      112.19     |        248.36        |        79.25         |        24.83         |        10.00        |
+------------+------------+------------+-----------------+----------------------+----------------------+----------------------+---------------------+
22496709


In [7]:
# ===== INIT  VAL DATASET
val_cfg = cfg["val_data_loader"]

# Rasterizer
rasterizer = build_rasterizer(cfg, dm)

# Train dataset/dataloader
val_zarr = ChunkedDataset(dm.require(val_cfg["key"])).open()
val_dataset = AgentDataset(cfg, val_zarr, rasterizer)
val_dataloader = DataLoader(val_dataset,
                              shuffle=val_cfg["shuffle"],
                              batch_size=val_cfg["batch_size"])
                              #num_workers=train_cfg["num_workers"])

print(val_dataset)
print(len(val_dataset))

+------------+------------+------------+-----------------+----------------------+----------------------+----------------------+---------------------+
| Num Scenes | Num Frames | Num Agents | Total Time (hr) | Avg Frames per Scene | Avg Agents per Frame | Avg Scene Time (sec) | Avg Frame frequency |
+------------+------------+------------+-----------------+----------------------+----------------------+----------------------+---------------------+
|   16220    |  4030296   | 312617887  |      111.97     |        248.48        |        77.57         |        24.85         |        10.00        |
+------------+------------+------------+-----------------+----------------------+----------------------+----------------------+---------------------+
21624612


## Model

In [8]:
class EncoderLSTM_LyftModel(nn.Module):
    
    def __init__(self, cfg):
        super(EncoderLSTM_LyftModel, self).__init__()
        
        self.input_sz  = 2
        self.hidden_sz = 1024
        self.num_layer = 2
        self.sequence_length = 11
        self.bz  = cfg["train_data_loader"]["batch_size"]
        hidden_state  = torch.randn(self.num_layer,self.bz,self.hidden_sz,requires_grad=False).to(device)
        cell_state    = torch.randn( self.num_layer,self.bz, self.hidden_sz,requires_grad=False).to(device)
        self.hidden_encoder = (hidden_state,cell_state)
        
        self.Encoder_lstm = nn.LSTM(self.input_sz,self.hidden_sz,self.num_layer,batch_first=True,dropout=0.2)
       
    def forward(self,inputs):
        
        output,hidden_state = self.Encoder_lstm(inputs,self.hidden_encoder)
        
        return output,hidden_state
    
class DecoderLSTM_LyftModel(nn.Module):
    def __init__(self, cfg):
        super(DecoderLSTM_LyftModel, self).__init__()
        
        self.input_sz  = 40 #(2000 from fcn_en_output reshape to 50*40)
        self.hidden_sz = 256
        self.hidden_sz_en = 1024
        self.num_layer = 2
        self.sequence_len_en = 11
        self.sequence_len_de = 50
        self.interlayer1 = 512
        self.interlayer2 = 1024
        self.interlayer3 = 2048
        self.interlayer4 = 3072
        self.interlayer5 = 4096

        
        self.bz = cfg["train_data_loader"]["batch_size"]
        num_targets = 2 * cfg["model_params"]["future_num_frames"]
        
        self.encoderLSTM = EncoderLSTM_LyftModel (cfg)

        
        self.Decoder_lstm = nn.LSTM(40,self.hidden_sz,self.num_layer,batch_first=True,dropout=0.25)


        self.fcn_en_output = nn.Sequential(nn.Linear(in_features=self.hidden_sz_en*self.sequence_len_en, out_features=self.interlayer5),
                            #nn.BatchNorm1d(self.interlayer5),
                            nn.ReLU(inplace=True),
                            nn.Dropout(p=0.20),
                            nn.Linear(in_features=self.interlayer5, out_features=self.interlayer4),
                            #nn.BatchNorm1d(self.interlayer4),
                            nn.ReLU(inplace=True),
                            nn.Dropout(p=0.20),               
                            nn.Linear(in_features=self.interlayer4, out_features=self.interlayer3),          
                            nn.ReLU(inplace=True),
                            nn.Dropout(p=0.20),
                            nn.Linear(in_features=self.interlayer3, out_features=self.input_sz * self.sequence_len_de) )
        
        self.fcn_en_hidden = nn.Sequential(nn.Linear(in_features=self.hidden_sz_en, out_features=self.interlayer3),
                            nn.ReLU(inplace=True),
                            nn.Dropout(p=0.25),
                            nn.Linear(in_features=self.interlayer3, out_features=self.interlayer2),
                            nn.ReLU(inplace=True),
                            nn.Dropout(p=0.25),
                            #nn.Linear(in_features=self.interlayer2, out_features=self.interlayer2),
                            #nn.ReLU(inplace=True),
                            #nn.Dropout(p=0.25),
                            #nn.Linear(in_features=self.interlayer2, out_features=self.interlayer1),
                            #nn.ReLU(inplace=True),
                            #nn.Dropout(p=0.20),
                            nn.Linear(in_features=self.interlayer2, out_features=self.hidden_sz)  )
        
        self.fcn_en_cell_state = nn.Sequential(nn.Linear(in_features=self.hidden_sz_en, out_features=self.interlayer3),
                                 nn.ReLU(inplace=True),
                                 nn.Dropout(p=0.25),              
                                 nn.Linear(in_features=self.interlayer3, out_features=self.interlayer2),
                                 nn.ReLU(inplace=True),
                                 nn.Dropout(p=0.20),
                                 #nn.Linear(in_features=self.interlayer4, out_features=self.interlayer2),
                                # nn.ReLU(inplace=True),
                                 #nn.Dropout(p=0.25),              
                                 #nn.Linear(in_features=self.interlayer2, out_features=self.interlayer1),
                                 #nn.ReLU(inplace=True),
                                 #nn.Dropout(p=0.20),
                                 nn.Linear(in_features=self.interlayer2, out_features=self.hidden_sz) )

        
        self.fcn_de_output = nn.Sequential(nn.Linear(in_features=self.hidden_sz*self.sequence_len_de, out_features=self.interlayer5),
                             #nn.BatchNorm1d(self.interlayer5),
                             nn.ReLU(inplace=True),
                             nn.Dropout(p=0.25),
                             nn.Linear(in_features=self.interlayer5, out_features=self.interlayer3),
                             nn.ReLU(inplace=True),
                             nn.Dropout(p=0.25),
                             nn.Linear(in_features=self.interlayer3, out_features=self.interlayer1),
                             #nn.BatchNorm1d(self.interlayer3),
                             #nn.ReLU(inplace=True),
                             #nn.Linear(in_features=self.interlayer3, out_features=self.interlayer2),
                             #nn.ReLU(inplace=True),
                             nn.Dropout(p=0.25), 
                             #nn.Linear(in_features=self.interlayer2, out_features=self.interlayer1),
                             nn.ReLU(inplace=True),
                             nn.Linear(in_features=self.interlayer1, out_features=num_targets)  )

        


    def forward(self,inputs):

        output,hidden_state = self.encoderLSTM(inputs)
        
        #calling FCN connecting encoder and decoder
        input_to_dec_out        = self.fcn_en_output(output.reshape(self.bz,-1))
        input_to_dec_hidden     = self.fcn_en_hidden(hidden_state[0].reshape(self.num_layer,self.bz,-1))
        input_to_dec_cell_state = self.fcn_en_hidden(hidden_state[1].reshape(self.num_layer,self.bz,-1))
        
        #reshaping

        input_to_dec_out     =   input_to_dec_out.reshape(self.bz,self.sequence_len_de,-1)
        input_to_dec_hidden  = input_to_dec_hidden.reshape(self.num_layer,self.bz,-1)
        input_to_dec_cell_state = input_to_dec_cell_state.reshape(self.num_layer,self.bz,-1)
        
        decoder_out,_       = self.Decoder_lstm(input_to_dec_out,(input_to_dec_hidden,input_to_dec_cell_state) )          
        
        fc_out = self.fcn_de_output(decoder_out.reshape(decoder_out.shape[0],-1))
        
        return fc_out

In [9]:
# ==== INIT MODEL
model = DecoderLSTM_LyftModel(cfg)
model.to(device)
#optimizer = optim.SGD(model.parameters(), lr=1e-2,momentum=0.9)
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3, weight_decay=0.0005)

lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer,step_size=7000,gamma=0.1)
#lr_scheduler = CyclicLR(optimizer, base_lr=1e-2, max_lr=1e-1,cycle_momentum = True)
# Later we have to filter the invalid steps.
criterion = nn.MSELoss(reduction="none")

In [10]:
model

DecoderLSTM_LyftModel(
  (encoderLSTM): EncoderLSTM_LyftModel(
    (Encoder_lstm): LSTM(2, 1024, num_layers=2, batch_first=True, dropout=0.2)
  )
  (Decoder_lstm): LSTM(40, 256, num_layers=2, batch_first=True, dropout=0.25)
  (fcn_en_output): Sequential(
    (0): Linear(in_features=11264, out_features=4096, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.2, inplace=False)
    (3): Linear(in_features=4096, out_features=3072, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.2, inplace=False)
    (6): Linear(in_features=3072, out_features=2048, bias=True)
    (7): ReLU(inplace=True)
    (8): Dropout(p=0.2, inplace=False)
    (9): Linear(in_features=2048, out_features=2000, bias=True)
  )
  (fcn_en_hidden): Sequential(
    (0): Linear(in_features=1024, out_features=2048, bias=True)
    (1): ReLU(inplace=True)
    (2): Dropout(p=0.25, inplace=False)
    (3): Linear(in_features=2048, out_features=1024, bias=True)
    (4): ReLU(inplace=True)
    (5): Dropout(p=0.25, inp

In [11]:
#checkpoint = torch.load('/media/ubuntu/Data/project/lyft/l5kit-1.0.6/examples/agent_motion_prediction/model/model_state_last_40k_18_false.pth'))
#model.load_state_dict(checkpoint['model_state_dict'])
#optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
#loss=checkpoint['loss']

In [12]:
#device

## Training

In [13]:
# ==== TRAIN LOOP
tr_it = iter(train_dataloader)
vl_it = iter(val_dataloader)


progress_bar = tqdm(range(cfg["train_params"]["max_num_steps"]))
losses_train = []
losses_mean_train = []
losses_val = []
losses_mean_val = []

for itr in progress_bar:
    try:
        data = next(tr_it)
    except StopIteration:
        tr_it = iter(train_dataloader)
        data = next(tr_it)
    model.train()
    torch.set_grad_enabled(True)

    # Forward pass
    history_positions = data['history_positions'].to(device)
    history_availabilities = data['history_availabilities'].to(device)
    target_availabilities = data["target_availabilities"].unsqueeze(-1).to(device)
    targets_position = data["target_positions"].to(device)

    outputs = model(history_positions)

    loss = criterion(outputs.reshape(targets_position.shape), targets_position)
    # not all the output steps are valid, but we can filter them out from the loss using availabilities
    loss = loss * target_availabilities
    loss = loss.mean()

    # Backward pass
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

    
    losses_train.append(loss.item())
    losses_mean_train.append(np.mean(losses_train))
    
    # Validation
    if VALIDATION :#& ( cfg["train_params"]["max_num_steps"] % cfg["train_params"]["eval_every_n_steps"] ==0 ):
        with torch.no_grad():
            try:
                val_data = next(vl_it)
            except StopIteration:
                vl_it = iter(val_dataloader)
                val_data = next(vl_it)

            model.eval()
            # Forward pass
            target_availabilities_val = val_data["target_availabilities"].unsqueeze(-1).to(device)
            targets_val = val_data["target_positions"].to(device)
            history_positions_val = data['history_positions'].to(device)
            history_availabilities_val = data['history_availabilities'].to(device)

            outputs_val = model(history_positions_val)
                    
            loss_v = criterion(outputs_val.reshape(targets_val.shape), targets_val)
            # not all the output steps are valid, but we can filter them out from the loss using availabilities
            loss_v = loss_v * target_availabilities_val
            loss_v = loss_v.mean()

            losses_val.append(loss_v.item())

            losses_mean_val.append(np.mean(losses_val))


        desc = f" TrainLoss: {round(loss.item(), 4)} ValLoss: {round(loss_v.item(), 4)} TrainMeanLoss: {np.mean(losses_train)} ValMeanLoss: {np.mean(losses_val)}" 
    else:
        desc = f" TrainLoss: {round(loss.item(), 4)}"


        #if len(losses_train)>0 and loss < min(losses_train):
        #    print(f"Loss improved from {min(losses_train)} to {loss}")
    lr_scheduler.step()

    progress_bar.set_description(desc)

 TrainLoss: 9.8307 ValLoss: 276.8224 TrainMeanLoss: 11.089555873060226 ValMeanLoss: 223.07350069697304: 100%|██████████| 10000/10000 [9:41:30<00:00,  3.49s/it]   


In [14]:
torch.save({
            'model_state_dict' : model.state_dict(),
            'optimizer_state_dict' : optimizer.state_dict(),
            'loss' : loss },
            '/media/ubuntu/Data/project/lyft/l5kit-1.0.6/examples/agent_motion_prediction/model/model_state_last_ENDC_lstm_10k.pth')

In [17]:
# ===== INIT DATASET
test_cfg = cfg["test_data_loader"]

# Rasterizer
rasterizer = build_rasterizer(cfg, dm)

# Test dataset/dataloader
test_zarr = ChunkedDataset(dm.require(test_cfg["key"])).open()
test_mask = np.load(f"{DIR_INPUT}/scenes/mask.npz")["arr_0"]
test_dataset = AgentDataset(cfg, test_zarr, rasterizer, agents_mask=test_mask)
test_dataloader = DataLoader(test_dataset,
                             shuffle=test_cfg["shuffle"],
                             batch_size=test_cfg["batch_size"],
                             num_workers=test_cfg["num_workers"])


print(test_dataloader)

<torch.utils.data.dataloader.DataLoader object at 0x7fd02dde75d0>


In [19]:
model.eval()

future_coords_offsets_pd = []
timestamps = []
agent_ids = []

with torch.no_grad():
    dataiter = tqdm(test_dataloader)
    
    for data in dataiter:

        history_positions = data['history_positions'].to(device)
        targets = data["target_positions"].to(device)

        outputs = model(history_positions)
        outputs = outputs.reshape(targets.shape)
        future_coords_offsets_pd.append(outputs.cpu().numpy().copy())
        timestamps.append(data["timestamp"].numpy().copy())
        agent_ids.append(data["track_id"].numpy().copy())




  0%|          | 0/2223 [00:00<?, ?it/s][A[A

  0%|          | 1/2223 [00:00<33:37,  1.10it/s][A[A

  0%|          | 2/2223 [00:01<28:39,  1.29it/s][A[A

  0%|          | 4/2223 [00:01<21:45,  1.70it/s][A[A

  0%|          | 5/2223 [00:01<17:42,  2.09it/s][A[A

  0%|          | 6/2223 [00:02<21:17,  1.74it/s][A[A

  0%|          | 8/2223 [00:03<17:52,  2.06it/s][A[A

  0%|          | 10/2223 [00:03<16:15,  2.27it/s][A[A

  1%|          | 12/2223 [00:04<13:54,  2.65it/s][A[A

  1%|          | 14/2223 [00:05<16:05,  2.29it/s][A[A

  1%|          | 18/2223 [00:06<13:28,  2.73it/s][A[A

  1%|          | 20/2223 [00:06<10:20,  3.55it/s][A[A

  1%|          | 21/2223 [00:06<10:13,  3.59it/s][A[A

  1%|          | 22/2223 [00:07<14:11,  2.59it/s][A[A

  1%|          | 25/2223 [00:07<11:45,  3.12it/s][A[A

  1%|          | 26/2223 [00:09<20:37,  1.78it/s][A[A

  1%|▏         | 30/2223 [00:10<17:12,  2.12it/s][A[A

  1%|▏         | 33/2223 [00:10<12:40,  2.88

 17%|█▋        | 386/2223 [02:29<20:22,  1.50it/s][A[A

 17%|█▋        | 389/2223 [02:29<15:33,  1.96it/s][A[A

 18%|█▊        | 390/2223 [02:31<23:26,  1.30it/s][A[A

 18%|█▊        | 393/2223 [02:31<17:36,  1.73it/s][A[A

 18%|█▊        | 394/2223 [02:32<20:07,  1.51it/s][A[A

 18%|█▊        | 395/2223 [02:32<15:23,  1.98it/s][A[A

 18%|█▊        | 397/2223 [02:32<11:46,  2.58it/s][A[A

 18%|█▊        | 398/2223 [02:34<22:44,  1.34it/s][A[A

 18%|█▊        | 401/2223 [02:34<16:33,  1.83it/s][A[A

 18%|█▊        | 402/2223 [02:36<30:31,  1.01s/it][A[A

 18%|█▊        | 403/2223 [02:37<23:02,  1.32it/s][A[A

 18%|█▊        | 406/2223 [02:38<20:27,  1.48it/s][A[A

 18%|█▊        | 410/2223 [02:40<18:28,  1.64it/s][A[A

 19%|█▊        | 414/2223 [02:42<17:25,  1.73it/s][A[A

 19%|█▉        | 418/2223 [02:43<15:18,  1.96it/s][A[A

 19%|█▉        | 422/2223 [02:44<13:24,  2.24it/s][A[A

 19%|█▉        | 426/2223 [02:46<12:48,  2.34it/s][A[A

 19%|█▉       

 31%|███▏      | 695/2223 [04:39<13:47,  1.85it/s][A[A

 31%|███▏      | 696/2223 [04:40<17:36,  1.45it/s][A[A

 31%|███▏      | 697/2223 [04:40<13:36,  1.87it/s][A[A

 31%|███▏      | 699/2223 [04:41<12:20,  2.06it/s][A[A

 31%|███▏      | 700/2223 [04:41<13:27,  1.89it/s][A[A

 32%|███▏      | 701/2223 [04:42<15:46,  1.61it/s][A[A

 32%|███▏      | 703/2223 [04:43<12:27,  2.03it/s][A[A

 32%|███▏      | 704/2223 [04:43<09:36,  2.63it/s][A[A

 32%|███▏      | 705/2223 [04:44<16:04,  1.57it/s][A[A

 32%|███▏      | 707/2223 [04:45<14:07,  1.79it/s][A[A

 32%|███▏      | 709/2223 [04:46<13:56,  1.81it/s][A[A

 32%|███▏      | 711/2223 [04:47<13:22,  1.88it/s][A[A

 32%|███▏      | 712/2223 [04:47<11:28,  2.20it/s][A[A

 32%|███▏      | 713/2223 [04:49<19:28,  1.29it/s][A[A

 32%|███▏      | 715/2223 [04:49<16:03,  1.56it/s][A[A

 32%|███▏      | 717/2223 [04:50<15:57,  1.57it/s][A[A

 32%|███▏      | 719/2223 [04:51<14:44,  1.70it/s][A[A

 32%|███▏     

 48%|████▊     | 1059/2223 [07:30<11:25,  1.70it/s][A[A

 48%|████▊     | 1062/2223 [07:31<08:51,  2.18it/s][A[A

 48%|████▊     | 1063/2223 [07:32<13:33,  1.43it/s][A[A

 48%|████▊     | 1066/2223 [07:32<10:10,  1.90it/s][A[A

 48%|████▊     | 1067/2223 [07:34<15:20,  1.26it/s][A[A

 48%|████▊     | 1071/2223 [07:35<12:41,  1.51it/s][A[A

 48%|████▊     | 1074/2223 [07:35<09:21,  2.05it/s][A[A

 48%|████▊     | 1075/2223 [07:36<12:31,  1.53it/s][A[A

 48%|████▊     | 1078/2223 [07:37<10:34,  1.80it/s][A[A

 49%|████▊     | 1079/2223 [07:39<15:20,  1.24it/s][A[A

 49%|████▊     | 1082/2223 [07:40<12:39,  1.50it/s][A[A

 49%|████▊     | 1083/2223 [07:41<16:16,  1.17it/s][A[A

 49%|████▉     | 1086/2223 [07:41<11:49,  1.60it/s][A[A

 49%|████▉     | 1087/2223 [07:42<14:58,  1.26it/s][A[A

 49%|████▉     | 1090/2223 [07:44<13:15,  1.42it/s][A[A

 49%|████▉     | 1091/2223 [07:45<16:23,  1.15it/s][A[A

 49%|████▉     | 1094/2223 [07:46<13:40,  1.38it/s][A[

 63%|██████▎   | 1396/2223 [10:10<07:27,  1.85it/s][A[A

 63%|██████▎   | 1400/2223 [10:12<07:25,  1.85it/s][A[A

 63%|██████▎   | 1404/2223 [10:14<07:16,  1.88it/s][A[A

 63%|██████▎   | 1408/2223 [10:16<07:10,  1.89it/s][A[A

 64%|██████▎   | 1412/2223 [10:18<06:49,  1.98it/s][A[A

 64%|██████▎   | 1416/2223 [10:19<06:34,  2.04it/s][A[A

 64%|██████▍   | 1420/2223 [10:21<06:37,  2.02it/s][A[A

 64%|██████▍   | 1424/2223 [10:23<06:12,  2.15it/s][A[A

 64%|██████▍   | 1428/2223 [10:24<05:35,  2.37it/s][A[A

 64%|██████▍   | 1432/2223 [10:26<05:30,  2.39it/s][A[A

 65%|██████▍   | 1436/2223 [10:28<05:36,  2.34it/s][A[A

 65%|██████▍   | 1440/2223 [10:29<05:14,  2.49it/s][A[A

 65%|██████▍   | 1442/2223 [10:30<04:29,  2.90it/s][A[A

 65%|██████▍   | 1444/2223 [10:31<05:48,  2.23it/s][A[A

 65%|██████▌   | 1446/2223 [10:32<05:34,  2.32it/s][A[A

 65%|██████▌   | 1448/2223 [10:32<05:15,  2.45it/s][A[A

 65%|██████▌   | 1450/2223 [10:33<05:11,  2.48it/s][A[

 80%|███████▉  | 1773/2223 [13:04<09:09,  1.22s/it][A[A

 80%|███████▉  | 1776/2223 [13:04<06:32,  1.14it/s][A[A

 80%|███████▉  | 1777/2223 [13:06<07:44,  1.04s/it][A[A

 80%|███████▉  | 1778/2223 [13:06<06:00,  1.23it/s][A[A

 80%|████████  | 1781/2223 [13:07<05:09,  1.43it/s][A[A

 80%|████████  | 1784/2223 [13:08<03:49,  1.91it/s][A[A

 80%|████████  | 1785/2223 [13:10<07:42,  1.06s/it][A[A

 80%|████████  | 1789/2223 [13:11<06:04,  1.19it/s][A[A

 81%|████████  | 1793/2223 [13:13<05:15,  1.36it/s][A[A

 81%|████████  | 1796/2223 [13:14<03:53,  1.83it/s][A[A

 81%|████████  | 1797/2223 [13:15<06:11,  1.15it/s][A[A

 81%|████████  | 1801/2223 [13:17<04:59,  1.41it/s][A[A

 81%|████████  | 1802/2223 [13:17<03:59,  1.76it/s][A[A

 81%|████████  | 1805/2223 [13:19<03:54,  1.78it/s][A[A

 81%|████████▏ | 1809/2223 [13:20<03:42,  1.86it/s][A[A

 82%|████████▏ | 1813/2223 [13:22<03:33,  1.92it/s][A[A

 82%|████████▏ | 1817/2223 [13:24<03:14,  2.08it/s][A[

 94%|█████████▎| 2079/2223 [15:26<02:42,  1.13s/it][A[A

 94%|█████████▎| 2083/2223 [15:27<02:07,  1.10it/s][A[A

 94%|█████████▍| 2087/2223 [15:29<01:42,  1.32it/s][A[A

 94%|█████████▍| 2090/2223 [15:30<01:19,  1.67it/s][A[A

 94%|█████████▍| 2091/2223 [15:32<02:24,  1.09s/it][A[A

 94%|█████████▍| 2094/2223 [15:33<01:51,  1.15it/s][A[A

 94%|█████████▍| 2095/2223 [15:33<01:42,  1.24it/s][A[A

 94%|█████████▍| 2098/2223 [15:35<01:24,  1.49it/s][A[A

 94%|█████████▍| 2099/2223 [15:36<01:50,  1.12it/s][A[A

 95%|█████████▍| 2102/2223 [15:37<01:23,  1.44it/s][A[A

 95%|█████████▍| 2103/2223 [15:38<02:00,  1.00s/it][A[A

 95%|█████████▍| 2107/2223 [15:40<01:36,  1.20it/s][A[A

 95%|█████████▍| 2111/2223 [15:42<01:17,  1.44it/s][A[A

 95%|█████████▌| 2114/2223 [15:42<00:56,  1.93it/s][A[A

 95%|█████████▌| 2115/2223 [15:44<01:34,  1.14it/s][A[A

 95%|█████████▌| 2119/2223 [15:45<01:18,  1.32it/s][A[A

 96%|█████████▌| 2123/2223 [15:48<01:08,  1.46it/s][A[

RuntimeError: Expected hidden[0] size (2, 18, 1024), got (2, 32, 1024)

In [None]:
write_pred_csv('/media/ubuntu/Data/project/lyft/l5kit-1.0.6/examples/agent_motion_prediction/submission_ENDC_lstm_10k.csv',
               timestamps=np.concatenate(timestamps),
               track_ids=np.concatenate(agent_ids),
               coords=np.concatenate(future_coords_offsets_pd))

write_pred_csv('/kaggle/working/submission.csv',
               timestamps=np.concatenate(timestamps),
               track_ids=np.concatenate(agent_ids),
               coords=np.concatenate(future_coords_offsets_pd))