<a href="https://colab.research.google.com/github/AnilOsmanTur/Spatio-Temporal-Event-Prediction/blob/main/Prediction_with_RNN.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Prediction wtih LSTM model

In [1]:
! pip install pytorch-lightning --quiet
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader
from torch.utils.data import random_split
import pytorch_lightning as pl
import numpy as np

pl.seed_everything(42) # for reproducablity
print('imports done')

[K     |████████████████████████████████| 808kB 23.2MB/s 
[K     |████████████████████████████████| 112kB 55.2MB/s 
[K     |████████████████████████████████| 276kB 50.8MB/s 
[K     |████████████████████████████████| 645kB 58.2MB/s 
[K     |████████████████████████████████| 829kB 56.2MB/s 
[K     |████████████████████████████████| 1.3MB 56.1MB/s 
[K     |████████████████████████████████| 143kB 56.8MB/s 
[K     |████████████████████████████████| 296kB 56.4MB/s 
[?25h  Building wheel for future (setup.py) ... [?25l[?25hdone


Global seed set to 42


imports done


## Data Loading and Spliting

In [2]:
data = np.load('data.npy')
n_sample = data.shape[0]
flat_data = data.reshape(n_sample, -1)
print('data shape after flattening', flat_data.shape)
n_train = int(0.8 * n_sample )
n_test = n_sample - n_train
train_data = flat_data[:n_train]
test_data = flat_data[n_train:]
print('Training data split', train_data.shape)
print('Testing data split', test_data.shape)

data shape after flattening (4500, 9)
Training data split (3600, 9)
Testing data split (900, 9)


In [3]:
# data generator function to split and connect small sequences
def data_generator(data, seq_size=5):
  n_sample = data.shape[0]
  dataset = []
  labels = []
  for end in range(seq_size,n_sample):
    start = end - seq_size
    dataset.append(data[start:end])
    labels.append(data[end])
  return np.array(dataset).astype(np.float32), np.array(labels).astype(np.float32)
print('Done')

Done


In [4]:
train_set, train_labels = data_generator(train_data, seq_size=5)
test_set, test_labels = data_generator(test_data, seq_size=5)
print(train_set.shape, train_labels.shape)
print(test_set.shape, test_labels.shape)

(3595, 5, 9) (3595, 9)
(895, 5, 9) (895, 9)


## Model Creation

In [30]:
class SpatioTempPredictor(pl.LightningModule):
  def __init__(self, seq_size=5):
    super().__init__()
    self.seq_size=seq_size
    self.lstm = nn.LSTM(9, 32, 2)
    self.predictor = nn.Sequential(
      nn.Linear(32, 16),
      nn.ReLU(),
      nn.Linear(16, 9),
      nn.Sigmoid()
      )
    self.euclidian_dist = nn.PairwiseDistance(p=2)

  def hamming_dist(self, a, b):
    return torch.sum((a!=b).float())

  def forward(self, x):
    b_size = x.shape[0]
    x = x.permute(1,0,2)
    h0 = torch.randn(2, b_size, 32).to(x.device)
    c0 = torch.randn(2, b_size, 32).to(x.device)
    output, (hn, cn) = self.lstm(x, (h0, c0))
    prediction = self.predictor(output[-1])
    return prediction

  def configure_optimizers(self):
    optimizer = torch.optim.Adam(self.parameters(), lr=1e-3)
    return optimizer

  def training_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = F.mse_loss(y_hat, y)
    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    y_hat = self(x)
    loss = F.mse_loss(y_hat, y)
    e_dist = self.euclidian_dist(y_hat, y).sum()
    h_dist = self.hamming_dist(y_hat, y)

    # Calling self.log will surface up scalars for you in TensorBoard
    self.log('val_loss', loss, prog_bar=True)
    self.log('val_dist', e_dist, prog_bar=True)
    self.log('val_hdist', h_dist, prog_bar=True)

    return loss

  def test_step(self, batch, batch_idx):
    # Here we just reuse the validation_step for testing
    return self.validation_step(batch, batch_idx)


# data
dataset = list(zip(train_set, train_labels))
n_train = int(len(dataset)*0.9)
n_val = len(dataset) - n_train
train_split, val_split = torch.utils.data.random_split(dataset, [n_train, n_val])


train_loader = DataLoader(train_split, batch_size=16)
val_loader = DataLoader(val_split, batch_size=16)
test_loader = DataLoader(list(zip(test_set, test_labels)), batch_size=16)

# model
model = SpatioTempPredictor(seq_size=5)
print('Done')

Done


In [31]:
# training
trainer = pl.Trainer(gpus=1, max_epochs=50, progress_bar_refresh_rate=20)
trainer.fit(model, train_loader, val_loader)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name           | Type             | Params
----------------------------------------------------
0 | lstm           | LSTM             | 14.0 K
1 | predictor      | Sequential       | 681   
2 | euclidian_dist | PairwiseDistance | 0     
----------------------------------------------------
14.6 K    Trainable params
0         Non-trainable params
14.6 K    Total params
0.059     Total estimated model params size (MB)


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

Global seed set to 42




HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…




In [32]:
trainer.test(model, train_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_dist': 3.21528697013855,
 'val_hdist': 143.89149475097656,
 'val_loss': 0.02280138060450554}
--------------------------------------------------------------------------------


[{'val_dist': 3.21528697013855,
  'val_hdist': 143.89149475097656,
  'val_loss': 0.02280138060450554}]

In [34]:
# evaluation
trainer.test(model, test_loader)

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Testing', layout=Layout(flex='2'), max=…


--------------------------------------------------------------------------------
DATALOADER:0 TEST RESULTS
{'val_dist': 3.2082228660583496,
 'val_hdist': 143.8491668701172,
 'val_loss': 0.02240746095776558}
--------------------------------------------------------------------------------


[{'val_dist': 3.2082228660583496,
  'val_hdist': 143.8491668701172,
  'val_loss': 0.02240746095776558}]

In [None]:
%load_ext tensorboard
%tensorboard --logdir lightning_logs/