In [45]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


# LSTM Experiment

In [None]:
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader
from torch_geometric.data import Batch

In [47]:
train_npz = np.load('../../train.npz')
train_data = train_npz['data']
test_npz  = np.load('../../test_input.npz')
test_data  = test_npz['data']

In [48]:
print(train_data.shape, test_data.shape)

# Split once for later use
X_train = train_data[..., :50, :]
Y_train = train_data[:, 0, 50:, :2]

(10000, 50, 110, 6) (2100, 50, 50, 6)


In [49]:
import sys
sys.path.append('..')
from TrajectoryDataset import TrajectoryDatasetTrain, TrajectoryDatasetTest
from utils import train_model

sys.path.append('../..')
from models.lstm import LSTMModel

In [50]:
torch.manual_seed(251)
np.random.seed(42)

scale = 5.0

N = len(train_data)
val_size = int(0.05 * N)
train_size = N - val_size

train_dataset = TrajectoryDatasetTrain(train_data[:train_size], scale=scale, augment=True)
val_dataset = TrajectoryDatasetTrain(train_data[train_size:], scale=scale, augment=False)

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=lambda x: Batch.from_data_list(x))
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=lambda x: Batch.from_data_list(x))

# Set device for training speedup
if torch.backends.mps.is_available():
    device = torch.device('mps')
    print("Using Apple Silicon GPU")
elif torch.cuda.is_available():
    device = torch.device('cuda')
    print("Using CUDA GPU")
else:
    device = torch.device('cpu')

Using Apple Silicon GPU


In [None]:
model = LSTMModel()
optimizer = optim.Adam(model.parameters(), lr=1e-3, weight_decay=1e-4)
scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=20, gamma=0.25) # You can try different schedulers
early_stopping_patience = 10
criterion = nn.MSELoss()

In [None]:
model = model.to(device)
train_model(
    model, 
    train_dataloader, 
    val_dataloader, 
    device, 
    optimizer, 
    criterion, 
    scheduler, 
    early_stopping_patience,
    "lstm_weights"
)

Epoch:   1%|          | 1/100 [00:06<09:55,  6.01s/epoch]

Epoch 000 | Learning rate 0.001000 | train normalized MSE   1.7489 | val normalized MSE   0.5056, | val MAE   2.7305 | val MSE  24.7755


Epoch:   2%|▏         | 2/100 [00:11<09:00,  5.52s/epoch]

Epoch 001 | Learning rate 0.001000 | train normalized MSE   0.4617 | val normalized MSE   0.3849, | val MAE   2.2651 | val MSE  18.8610


Epoch:   3%|▎         | 3/100 [00:16<08:37,  5.33s/epoch]

Epoch 002 | Learning rate 0.001000 | train normalized MSE   0.3937 | val normalized MSE   0.3558, | val MAE   2.1808 | val MSE  17.4329


Epoch:   4%|▍         | 4/100 [00:21<08:31,  5.33s/epoch]

Epoch 003 | Learning rate 0.001000 | train normalized MSE   0.3623 | val normalized MSE   0.3535, | val MAE   2.2102 | val MSE  17.3198


Epoch:   5%|▌         | 5/100 [00:27<08:35,  5.43s/epoch]

Epoch 004 | Learning rate 0.001000 | train normalized MSE   0.3435 | val normalized MSE   0.3201, | val MAE   2.1834 | val MSE  15.6873


Epoch:   6%|▌         | 6/100 [00:32<08:26,  5.38s/epoch]

Epoch 005 | Learning rate 0.001000 | train normalized MSE   0.3215 | val normalized MSE   0.3251, | val MAE   2.1465 | val MSE  15.9319


Epoch:   7%|▋         | 7/100 [00:37<08:13,  5.31s/epoch]

Epoch 006 | Learning rate 0.001000 | train normalized MSE   0.3207 | val normalized MSE   0.2748, | val MAE   1.8676 | val MSE  13.4635


Epoch:   8%|▊         | 8/100 [00:42<07:59,  5.21s/epoch]

Epoch 007 | Learning rate 0.001000 | train normalized MSE   0.2997 | val normalized MSE   0.2975, | val MAE   1.9553 | val MSE  14.5796


Epoch:   9%|▉         | 9/100 [00:47<07:45,  5.12s/epoch]

Epoch 008 | Learning rate 0.001000 | train normalized MSE   0.2945 | val normalized MSE   0.2681, | val MAE   1.8193 | val MSE  13.1357


Epoch:  10%|█         | 10/100 [00:52<07:29,  4.99s/epoch]

Epoch 009 | Learning rate 0.001000 | train normalized MSE   0.2808 | val normalized MSE   0.2605, | val MAE   1.8214 | val MSE  12.7623


Epoch:  11%|█         | 11/100 [00:57<07:25,  5.00s/epoch]

Epoch 010 | Learning rate 0.001000 | train normalized MSE   0.2774 | val normalized MSE   0.2443, | val MAE   1.7347 | val MSE  11.9711


Epoch:  12%|█▏        | 12/100 [01:02<07:37,  5.20s/epoch]

Epoch 011 | Learning rate 0.001000 | train normalized MSE   0.2679 | val normalized MSE   0.2548, | val MAE   1.7906 | val MSE  12.4854


Epoch:  13%|█▎        | 13/100 [01:08<07:33,  5.21s/epoch]

Epoch 012 | Learning rate 0.001000 | train normalized MSE   0.2588 | val normalized MSE   0.2337, | val MAE   1.6402 | val MSE  11.4495


Epoch:  14%|█▍        | 14/100 [01:13<07:29,  5.23s/epoch]

Epoch 013 | Learning rate 0.001000 | train normalized MSE   0.2553 | val normalized MSE   0.2317, | val MAE   1.7239 | val MSE  11.3550


Epoch:  15%|█▌        | 15/100 [01:18<07:21,  5.20s/epoch]

Epoch 014 | Learning rate 0.001000 | train normalized MSE   0.2510 | val normalized MSE   0.2258, | val MAE   1.6930 | val MSE  11.0651


Epoch:  16%|█▌        | 16/100 [01:23<07:20,  5.24s/epoch]

Epoch 015 | Learning rate 0.001000 | train normalized MSE   0.2448 | val normalized MSE   0.2448, | val MAE   1.7998 | val MSE  11.9942


Epoch:  17%|█▋        | 17/100 [01:28<07:09,  5.18s/epoch]

Epoch 016 | Learning rate 0.001000 | train normalized MSE   0.2420 | val normalized MSE   0.2188, | val MAE   1.7260 | val MSE  10.7194


Epoch:  18%|█▊        | 18/100 [01:34<07:03,  5.16s/epoch]

Epoch 017 | Learning rate 0.001000 | train normalized MSE   0.2397 | val normalized MSE   0.2163, | val MAE   1.5860 | val MSE  10.5968


Epoch:  19%|█▉        | 19/100 [01:39<07:08,  5.29s/epoch]

Epoch 018 | Learning rate 0.001000 | train normalized MSE   0.2343 | val normalized MSE   0.2176, | val MAE   1.5942 | val MSE  10.6631


Epoch:  20%|██        | 20/100 [01:45<07:07,  5.34s/epoch]

Epoch 019 | Learning rate 0.000250 | train normalized MSE   0.2333 | val normalized MSE   0.2332, | val MAE   1.7780 | val MSE  11.4265


Epoch:  21%|██        | 21/100 [01:50<06:57,  5.28s/epoch]

Epoch 020 | Learning rate 0.000250 | train normalized MSE   0.2133 | val normalized MSE   0.1992, | val MAE   1.4355 | val MSE   9.7596


Epoch:  22%|██▏       | 22/100 [01:55<06:40,  5.14s/epoch]

Epoch 021 | Learning rate 0.000250 | train normalized MSE   0.2079 | val normalized MSE   0.1996, | val MAE   1.5475 | val MSE   9.7815


Epoch:  23%|██▎       | 23/100 [02:00<06:33,  5.12s/epoch]

Epoch 022 | Learning rate 0.000250 | train normalized MSE   0.2049 | val normalized MSE   0.1992, | val MAE   1.4484 | val MSE   9.7600


Epoch:  24%|██▍       | 24/100 [02:05<06:28,  5.11s/epoch]

Epoch 023 | Learning rate 0.000250 | train normalized MSE   0.2034 | val normalized MSE   0.1948, | val MAE   1.4719 | val MSE   9.5458


Epoch:  25%|██▌       | 25/100 [02:10<06:19,  5.06s/epoch]

Epoch 024 | Learning rate 0.000250 | train normalized MSE   0.2024 | val normalized MSE   0.1997, | val MAE   1.4936 | val MSE   9.7877


Epoch:  26%|██▌       | 26/100 [02:14<06:07,  4.96s/epoch]

Epoch 025 | Learning rate 0.000250 | train normalized MSE   0.2027 | val normalized MSE   0.1986, | val MAE   1.5071 | val MSE   9.7322


Epoch:  27%|██▋       | 27/100 [02:19<05:59,  4.92s/epoch]

Epoch 026 | Learning rate 0.000250 | train normalized MSE   0.2022 | val normalized MSE   0.1986, | val MAE   1.4908 | val MSE   9.7303


Epoch:  28%|██▊       | 28/100 [02:25<06:01,  5.02s/epoch]

Epoch 027 | Learning rate 0.000250 | train normalized MSE   0.2004 | val normalized MSE   0.1992, | val MAE   1.5209 | val MSE   9.7625


Epoch:  29%|██▉       | 29/100 [02:30<05:59,  5.07s/epoch]

Epoch 028 | Learning rate 0.000250 | train normalized MSE   0.2016 | val normalized MSE   0.1953, | val MAE   1.4961 | val MSE   9.5679


Epoch:  30%|███       | 30/100 [02:35<05:57,  5.10s/epoch]

Epoch 029 | Learning rate 0.000250 | train normalized MSE   0.2012 | val normalized MSE   0.1934, | val MAE   1.5054 | val MSE   9.4750


Epoch:  31%|███       | 31/100 [02:40<05:56,  5.16s/epoch]

Epoch 030 | Learning rate 0.000250 | train normalized MSE   0.1990 | val normalized MSE   0.1993, | val MAE   1.4759 | val MSE   9.7672


Epoch:  32%|███▏      | 32/100 [02:45<05:49,  5.14s/epoch]

Epoch 031 | Learning rate 0.000250 | train normalized MSE   0.1992 | val normalized MSE   0.1929, | val MAE   1.4940 | val MSE   9.4510


Epoch:  33%|███▎      | 33/100 [02:50<05:36,  5.02s/epoch]

Epoch 032 | Learning rate 0.000250 | train normalized MSE   0.1991 | val normalized MSE   0.1948, | val MAE   1.5074 | val MSE   9.5441


Epoch:  34%|███▍      | 34/100 [02:55<05:21,  4.88s/epoch]

Epoch 033 | Learning rate 0.000250 | train normalized MSE   0.1970 | val normalized MSE   0.2001, | val MAE   1.5127 | val MSE   9.8039


Epoch:  35%|███▌      | 35/100 [02:59<05:14,  4.85s/epoch]

Epoch 034 | Learning rate 0.000250 | train normalized MSE   0.1977 | val normalized MSE   0.1899, | val MAE   1.4772 | val MSE   9.3072


Epoch:  36%|███▌      | 36/100 [03:05<05:18,  4.97s/epoch]

Epoch 035 | Learning rate 0.000250 | train normalized MSE   0.1954 | val normalized MSE   0.1962, | val MAE   1.4668 | val MSE   9.6135


Epoch:  37%|███▋      | 37/100 [03:10<05:19,  5.07s/epoch]

Epoch 036 | Learning rate 0.000250 | train normalized MSE   0.1975 | val normalized MSE   0.1973, | val MAE   1.4932 | val MSE   9.6688


Epoch:  38%|███▊      | 38/100 [03:15<05:12,  5.04s/epoch]

Epoch 037 | Learning rate 0.000250 | train normalized MSE   0.1961 | val normalized MSE   0.1962, | val MAE   1.5241 | val MSE   9.6159


Epoch:  39%|███▉      | 39/100 [03:20<05:08,  5.06s/epoch]

Epoch 038 | Learning rate 0.000250 | train normalized MSE   0.1958 | val normalized MSE   0.1931, | val MAE   1.4825 | val MSE   9.4618


Epoch:  40%|████      | 40/100 [03:25<04:59,  5.00s/epoch]

Epoch 039 | Learning rate 0.000063 | train normalized MSE   0.1941 | val normalized MSE   0.1963, | val MAE   1.4632 | val MSE   9.6201


Epoch:  41%|████      | 41/100 [03:30<04:50,  4.93s/epoch]

Epoch 040 | Learning rate 0.000063 | train normalized MSE   0.1892 | val normalized MSE   0.1916, | val MAE   1.4467 | val MSE   9.3890


Epoch:  42%|████▏     | 42/100 [03:34<04:43,  4.88s/epoch]

Epoch 041 | Learning rate 0.000063 | train normalized MSE   0.1879 | val normalized MSE   0.1895, | val MAE   1.4206 | val MSE   9.2842


Epoch:  43%|████▎     | 43/100 [03:39<04:37,  4.86s/epoch]

Epoch 042 | Learning rate 0.000063 | train normalized MSE   0.1878 | val normalized MSE   0.1881, | val MAE   1.4375 | val MSE   9.2174


Epoch:  44%|████▍     | 44/100 [03:44<04:30,  4.84s/epoch]

Epoch 043 | Learning rate 0.000063 | train normalized MSE   0.1858 | val normalized MSE   0.1894, | val MAE   1.4238 | val MSE   9.2830


Epoch:  45%|████▌     | 45/100 [03:49<04:30,  4.93s/epoch]

Epoch 044 | Learning rate 0.000063 | train normalized MSE   0.1878 | val normalized MSE   0.1902, | val MAE   1.4402 | val MSE   9.3183


Epoch:  46%|████▌     | 46/100 [03:55<04:36,  5.11s/epoch]

Epoch 045 | Learning rate 0.000063 | train normalized MSE   0.1861 | val normalized MSE   0.1879, | val MAE   1.4143 | val MSE   9.2094


Epoch:  47%|████▋     | 47/100 [04:00<04:37,  5.24s/epoch]

Epoch 046 | Learning rate 0.000063 | train normalized MSE   0.1877 | val normalized MSE   0.1874, | val MAE   1.4183 | val MSE   9.1819


Epoch:  48%|████▊     | 48/100 [04:06<04:35,  5.29s/epoch]

Epoch 047 | Learning rate 0.000063 | train normalized MSE   0.1855 | val normalized MSE   0.1895, | val MAE   1.4235 | val MSE   9.2843


Epoch:  49%|████▉     | 49/100 [04:11<04:27,  5.24s/epoch]

Epoch 048 | Learning rate 0.000063 | train normalized MSE   0.1858 | val normalized MSE   0.1899, | val MAE   1.4518 | val MSE   9.3054


Epoch:  50%|█████     | 50/100 [04:16<04:19,  5.19s/epoch]

Epoch 049 | Learning rate 0.000063 | train normalized MSE   0.1857 | val normalized MSE   0.1870, | val MAE   1.4479 | val MSE   9.1608


Epoch:  51%|█████     | 51/100 [04:21<04:16,  5.23s/epoch]

Epoch 050 | Learning rate 0.000063 | train normalized MSE   0.1849 | val normalized MSE   0.1878, | val MAE   1.4181 | val MSE   9.2014


Epoch:  52%|█████▏    | 52/100 [04:26<04:08,  5.17s/epoch]

Epoch 051 | Learning rate 0.000063 | train normalized MSE   0.1857 | val normalized MSE   0.1868, | val MAE   1.4213 | val MSE   9.1533


Epoch:  53%|█████▎    | 53/100 [04:31<03:59,  5.09s/epoch]

Epoch 052 | Learning rate 0.000063 | train normalized MSE   0.1853 | val normalized MSE   0.1872, | val MAE   1.4172 | val MSE   9.1750


Epoch:  54%|█████▍    | 54/100 [04:36<03:51,  5.03s/epoch]

Epoch 053 | Learning rate 0.000063 | train normalized MSE   0.1850 | val normalized MSE   0.1870, | val MAE   1.4279 | val MSE   9.1621


Epoch:  55%|█████▌    | 55/100 [04:41<03:43,  4.97s/epoch]

Epoch 054 | Learning rate 0.000063 | train normalized MSE   0.1840 | val normalized MSE   0.1878, | val MAE   1.4356 | val MSE   9.2008


Epoch:  56%|█████▌    | 56/100 [04:46<03:38,  4.96s/epoch]

Epoch 055 | Learning rate 0.000063 | train normalized MSE   0.1836 | val normalized MSE   0.1878, | val MAE   1.4120 | val MSE   9.2003


Epoch:  57%|█████▋    | 57/100 [04:50<03:30,  4.89s/epoch]

Epoch 056 | Learning rate 0.000063 | train normalized MSE   0.1840 | val normalized MSE   0.1896, | val MAE   1.4294 | val MSE   9.2924


Epoch:  58%|█████▊    | 58/100 [04:55<03:23,  4.85s/epoch]

Epoch 057 | Learning rate 0.000063 | train normalized MSE   0.1844 | val normalized MSE   0.1882, | val MAE   1.4218 | val MSE   9.2214


Epoch:  59%|█████▉    | 59/100 [05:00<03:19,  4.88s/epoch]

Epoch 058 | Learning rate 0.000063 | train normalized MSE   0.1848 | val normalized MSE   0.1884, | val MAE   1.4090 | val MSE   9.2333


Epoch:  59%|█████▉    | 59/100 [05:05<03:32,  5.18s/epoch]

Epoch 059 | Learning rate 0.000016 | train normalized MSE   0.1839 | val normalized MSE   0.1885, | val MAE   1.4240 | val MSE   9.2383
Early stop!





In [None]:
test_dataset = TrajectoryDatasetTest(test_data, scale=scale)
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False,
                         collate_fn=lambda xs: Batch.from_data_list(xs))

lstm_weights = torch.load("lstm_weights.pt")
model.load_state_dict(lstm_weights)

# also save entire model
torch.save(model, "full_lstm_model.pt")

test_dataset.generate_submission_predictions(model, device, output_file_name='lstm_submission')