In [1]:
from read_data import getData
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader, Dataset
from tqdm import tqdm
from sklearn.preprocessing import StandardScaler
from read_data import LargeDataset
import gc




In [2]:
traindata, testData = getData("data")
traindata.shape, testData.shape

Training Data's shape is (10000, 50, 110, 6) and Test Data's is (10000, 50, 110, 6)


((10000, 50, 110, 6), (2100, 50, 50, 6))

In [3]:
train_mean = np.mean(traindata, axis=(0, 1, 2))
train_std = np.std(traindata, axis=(0, 1, 2))
train_std = np.where(train_std == 0, 1.0, train_std)

In [4]:
def createDataset(data, window_size = 40, forecast_horizon = 10):
    X = []
    y = []

    for sample in range(data.shape[0]):
        for t in range(data.shape[2] - window_size - forecast_horizon + 1):
            x_window = data[sample, :, t:t+window_size, :]
            y_window = data[sample, 0, t+window_size:t+window_size+forecast_horizon, :2]
            
            X.append(x_window)
            y.append(y_window)
    
    return np.array(X), np.array(y)


X, Y = createDataset(traindata)
X.shape, Y.shape

((610000, 50, 40, 6), (610000, 10, 2))

In [5]:
a, b = X[:10000], Y[:10000]
a.shape, b.shape

((10000, 50, 40, 6), (10000, 10, 2))

In [6]:
device = torch.device("mps")

class SmallNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.lstm = nn.LSTM(input_size = 6, hidden_size = 64, num_layers = 5, batch_first = True)
        self.batch_norm1 = nn.BatchNorm1d(64)
        self.pool = nn.AdaptiveAvgPool1d(10)
        self.linear1 = nn.Linear(64, 32)
        self.linear2 = nn.Linear(32, 16)
        self.linear3 = nn.Linear(16, 2)

    def forward(self, x):
        x = x.view(x.size(0), -1, x.size(-1))

        x, temp = self.lstm(x) # Output shape [batch, seq_len, 64]
        x = x.permute(0, 2, 1)  # [batch, 64, seq_len]
       

        x = self.pool(x)  # Forces output to [batch, 64, 10]
        x = self.batch_norm1(x)
        x = x.permute(0, 2, 1)  # [batch, 10, 64]
        x = self.linear1(x)
        x = torch.nn.functional.leaky_relu(x, negative_slope=0.01)
        x = self.linear2(x)
        x = torch.nn.functional.leaky_relu(x, negative_slope=0.01)
        x = self.linear3(x)
        return x

model = SmallNetwork()
# model.to(device)

# test = torch.randn(2, 2, 2, 6)
# out = model(test)
# print(out.shape)

In [7]:
trainDataset = LargeDataset(a, b, train_mean, train_std) # testing for small dataset a, b
trainDataLoader = DataLoader(trainDataset, batch_size=128, shuffle=True, num_workers=2)

model.to(device)

# Training setup
epochs = 100
lossFn = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=1e-4)

for each_epoch in range(epochs):
    model.train()
    runningLoss = 0.0
    loop = tqdm(trainDataLoader, desc=f"Epoch [{each_epoch+1}/{epochs}]")

    for batchX, batchY in loop:
        batchX, batchY = batchX.to(device, non_blocking=True), batchY.to(device, non_blocking=True)
        output = model(batchX)
        loss = lossFn(output, batchY)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        runningLoss += loss.item()

    avgLoss = runningLoss / len(trainDataLoader)

    if each_epoch % 5 == 0:
        torch.save(model.state_dict(), f'./models/small_model_{each_epoch}.pth')
        print(f"Epoch {each_epoch + 1}, Training Loss: {avgLoss:.4f}")


Epoch [1/100]: 100%|██████████| 79/79 [00:48<00:00,  1.63it/s]


Epoch 1, Training Loss: 11501282.7278


Epoch [2/100]: 100%|██████████| 79/79 [00:46<00:00,  1.70it/s]
Epoch [3/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [4/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [5/100]: 100%|██████████| 79/79 [00:49<00:00,  1.60it/s]
Epoch [6/100]: 100%|██████████| 79/79 [00:46<00:00,  1.68it/s]


Epoch 6, Training Loss: 11465956.8101


Epoch [7/100]: 100%|██████████| 79/79 [00:46<00:00,  1.70it/s]
Epoch [8/100]: 100%|██████████| 79/79 [00:45<00:00,  1.73it/s]
Epoch [9/100]: 100%|██████████| 79/79 [00:47<00:00,  1.65it/s]
Epoch [10/100]: 100%|██████████| 79/79 [00:46<00:00,  1.70it/s]
Epoch [11/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 11, Training Loss: 11334876.1519


Epoch [12/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [13/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [14/100]: 100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
Epoch [15/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [16/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]


Epoch 16, Training Loss: 11012610.7405


Epoch [17/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [18/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]
Epoch [19/100]: 100%|██████████| 79/79 [00:46<00:00,  1.68it/s]
Epoch [20/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [21/100]: 100%|██████████| 79/79 [00:47<00:00,  1.66it/s]


Epoch 21, Training Loss: 10346710.9114


Epoch [22/100]: 100%|██████████| 79/79 [00:47<00:00,  1.65it/s]
Epoch [23/100]: 100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
Epoch [24/100]: 100%|██████████| 79/79 [00:46<00:00,  1.68it/s]
Epoch [25/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]
Epoch [26/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 26, Training Loss: 9163756.4873


Epoch [27/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]
Epoch [28/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [29/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [30/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [31/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]


Epoch 31, Training Loss: 9000408.8987


Epoch [32/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [33/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [34/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]
Epoch [35/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [36/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 36, Training Loss: 8242695.6835


Epoch [37/100]: 100%|██████████| 79/79 [00:48<00:00,  1.64it/s]
Epoch [38/100]: 100%|██████████| 79/79 [00:47<00:00,  1.65it/s]
Epoch [39/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]
Epoch [40/100]: 100%|██████████| 79/79 [00:47<00:00,  1.65it/s]
Epoch [41/100]: 100%|██████████| 79/79 [00:48<00:00,  1.64it/s]


Epoch 41, Training Loss: 7037117.5253


Epoch [42/100]: 100%|██████████| 79/79 [00:48<00:00,  1.62it/s]
Epoch [43/100]: 100%|██████████| 79/79 [00:48<00:00,  1.64it/s]
Epoch [44/100]: 100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
Epoch [45/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]
Epoch [46/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 46, Training Loss: 6601832.9367


Epoch [47/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [48/100]: 100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
Epoch [49/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [50/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [51/100]: 100%|██████████| 79/79 [00:46<00:00,  1.70it/s]


Epoch 51, Training Loss: 6234388.5633


Epoch [52/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [53/100]: 100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
Epoch [54/100]: 100%|██████████| 79/79 [00:46<00:00,  1.70it/s]
Epoch [55/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [56/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 56, Training Loss: 4032840.6044


Epoch [57/100]: 100%|██████████| 79/79 [00:45<00:00,  1.74it/s]
Epoch [58/100]: 100%|██████████| 79/79 [00:44<00:00,  1.77it/s]
Epoch [59/100]: 100%|██████████| 79/79 [08:15<00:00,  6.27s/it]   
Epoch [60/100]: 100%|██████████| 79/79 [00:46<00:00,  1.70it/s]
Epoch [61/100]: 100%|██████████| 79/79 [00:45<00:00,  1.75it/s]


Epoch 61, Training Loss: 3436684.9715


Epoch [62/100]: 100%|██████████| 79/79 [00:45<00:00,  1.73it/s]
Epoch [63/100]: 100%|██████████| 79/79 [00:45<00:00,  1.73it/s]
Epoch [64/100]: 100%|██████████| 79/79 [00:45<00:00,  1.72it/s]
Epoch [65/100]: 100%|██████████| 79/79 [00:45<00:00,  1.74it/s]
Epoch [66/100]: 100%|██████████| 79/79 [00:45<00:00,  1.72it/s]


Epoch 66, Training Loss: 3127214.4019


Epoch [67/100]: 100%|██████████| 79/79 [00:48<00:00,  1.61it/s]
Epoch [68/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [69/100]: 100%|██████████| 79/79 [00:46<00:00,  1.70it/s]
Epoch [70/100]: 100%|██████████| 79/79 [00:47<00:00,  1.65it/s]
Epoch [71/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]


Epoch 71, Training Loss: 2937235.0807


Epoch [72/100]: 100%|██████████| 79/79 [00:46<00:00,  1.70it/s]
Epoch [73/100]: 100%|██████████| 79/79 [00:48<00:00,  1.64it/s]
Epoch [74/100]: 100%|██████████| 79/79 [00:49<00:00,  1.58it/s]
Epoch [75/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]
Epoch [76/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 76, Training Loss: 2899833.9161


Epoch [77/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [78/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [79/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [80/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [81/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]


Epoch 81, Training Loss: 2778107.7737


Epoch [82/100]: 100%|██████████| 79/79 [00:47<00:00,  1.65it/s]
Epoch [83/100]: 100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
Epoch [84/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [85/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [86/100]: 100%|██████████| 79/79 [00:47<00:00,  1.67it/s]


Epoch 86, Training Loss: 2884849.3180


Epoch [87/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [88/100]: 100%|██████████| 79/79 [00:45<00:00,  1.73it/s]
Epoch [89/100]: 100%|██████████| 79/79 [00:46<00:00,  1.69it/s]
Epoch [90/100]: 100%|██████████| 79/79 [00:47<00:00,  1.66it/s]
Epoch [91/100]: 100%|██████████| 79/79 [00:45<00:00,  1.72it/s]


Epoch 91, Training Loss: 2771577.6709


Epoch [92/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [93/100]: 100%|██████████| 79/79 [00:45<00:00,  1.74it/s]
Epoch [94/100]: 100%|██████████| 79/79 [00:45<00:00,  1.72it/s]
Epoch [95/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]
Epoch [96/100]: 100%|██████████| 79/79 [00:48<00:00,  1.64it/s]


Epoch 96, Training Loss: 2630476.0190


Epoch [97/100]: 100%|██████████| 79/79 [00:46<00:00,  1.71it/s]
Epoch [98/100]: 100%|██████████| 79/79 [00:48<00:00,  1.64it/s]
Epoch [99/100]: 100%|██████████| 79/79 [00:47<00:00,  1.68it/s]
Epoch [100/100]: 100%|██████████| 79/79 [00:45<00:00,  1.74it/s]
