In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import cv2

In [2]:
class CustomLSTMCell(nn.Module):
    def __init__(self, input_size, hidden_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size

        # One weight matrix for all gates
        self.xh2gates = nn.Linear(input_size + hidden_size, 4 * hidden_size)

    def forward(self, x_t, h_prev, c_prev):
        """
        Args:
            x_t: (batch_size, input_size)
            h_prev: (batch_size, hidden_size)
            c_prev: (batch_size, hidden_size)
        """
        combined = torch.cat([x_t, h_prev], dim=1) # (batch_size, input_size + hidden_size)
        gates = self.xh2gates(combined) # (batch_size, 4 * hidden_size)

        f_t, i_t, o_t, c__t = torch.chunk(gates, 4, dim=1)

        f_t = torch.sigmoid(f_t) # forget gate
        i_t = torch.sigmoid(i_t) # input gate
        o_t = torch.sigmoid(o_t) # output gate
        c__t = torch.tanh(c__t) # candidate cell state

        c_t = f_t * c_prev + i_t * c__t # new cell state
        h_t = o_t * torch.tanh(c_t) # new hidden state

        return h_t, c_t


In [3]:
class CustomLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.output_size = output_size

        self.lstm_cell = CustomLSTMCell(input_size, hidden_size)
        self.fc = nn.Linear(hidden_size, output_size)

    def forward(self, x):
        """
        Args:
            x: (batch_size, seq_len, input_size)
        """
        batch_size, seq_len, _ = x.size()
        h_t = torch.zeros(batch_size, self.hidden_size, device=x.device)
        c_t = torch.zeros(batch_size, self.hidden_size, device=x.device)

        for t in range(seq_len):
            x_t = x[:, t, :] # (batch_size, input_size)
            h_t, c_t = self.lstm_cell(x_t, h_t, c_t)
        
        # Use last hidden state for output
        out = self.fc(h_t) # (batch_size, output_size)
        return out


In [41]:
batch_size = 2048
input_size = 8
seq_len = 10
hidden_size = 64
output_size = 1 # binary classification
lr = 1e-3
epochs =  15

# synthetic data
N = 1000000
X = torch.randn(N, seq_len, input_size)
# create labels correlated with the sum of inputs
y_logits = X.sum(dim=(1,2))
y = (y_logits > 0).float().unsqueeze(1) # binary labels (N, 1)
print(X.shape)
dataset = torch.utils.data.TensorDataset(X, y)
train_len = int(0.8 * N)
train_ds, val_ds = torch.utils.data.random_split(dataset, [train_len, N - train_len])
print(f"Train size: {len(train_ds)}, Val size: {len(val_ds)}")

train_dataloader = torch.utils.data.DataLoader(train_ds, batch_size=batch_size, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_ds, batch_size=batch_size, shuffle=False)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
model = CustomLSTM(input_size, hidden_size, output_size).to(device)
criterion = nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=lr)

torch.Size([1000000, 10, 8])
Train size: 800000, Val size: 200000
Using device: cuda


In [42]:
# Training loop
for epoch in range(1, epochs + 1):
    model.train()
    running_loss = 0.0
    correct = 0
    total = 0

    for xb, yb in train_dataloader:
        xb = xb.to(device)
        yb = yb.to(device)

        optimizer.zero_grad() # zero the parameter gradients
        logits = model(xb) # forward pass
        loss = criterion(logits, yb)
        loss.backward()

        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # gradient clipping
        optimizer.step() # update weights

        running_loss += loss.item() * xb.size(0) 

        preds = (torch.sigmoid(logits) > 0.5).float()
        correct += (preds == yb).sum().item()
        total += yb.size(0)

    epoch_loss = running_loss / train_len
    epoch_acc = correct / total

    # validation
    model.eval()
    val_loss = 0.0
    val_correct = 0
    val_total = 0
    val_len = len(val_ds)
    with torch.no_grad():
        for xb, yb in val_loader:
            xb = xb.to(device)
            yb = yb.to(device)

            logits = model(xb)
            loss = criterion(logits, yb)
            val_loss += loss.item() * xb.size(0)

            preds = (torch.sigmoid(logits) > 0.5).float()
            val_correct += (preds == yb).sum().item()
            val_total += yb.size(0)

    val_loss = val_loss / val_len
    val_acc = val_correct / val_total
    print(f"Epoch {epoch:02d} | train_loss: {epoch_loss:.4f} | train_acc: {epoch_acc:.4f} "
          f"| val_loss: {val_loss:.4f} | val_acc: {val_acc:.4f}")

Epoch 01 | train_loss: 0.1587 | train_acc: 0.9381 | val_loss: 0.0471 | val_acc: 0.9861
Epoch 02 | train_loss: 0.0372 | train_acc: 0.9888 | val_loss: 0.0295 | val_acc: 0.9919
Epoch 03 | train_loss: 0.0271 | train_acc: 0.9915 | val_loss: 0.0254 | val_acc: 0.9904
Epoch 04 | train_loss: 0.0221 | train_acc: 0.9928 | val_loss: 0.0200 | val_acc: 0.9928
Epoch 05 | train_loss: 0.0195 | train_acc: 0.9932 | val_loss: 0.0194 | val_acc: 0.9920
Epoch 06 | train_loss: 0.0169 | train_acc: 0.9942 | val_loss: 0.0174 | val_acc: 0.9933
Epoch 07 | train_loss: 0.0154 | train_acc: 0.9945 | val_loss: 0.0134 | val_acc: 0.9959
Epoch 08 | train_loss: 0.0147 | train_acc: 0.9945 | val_loss: 0.0147 | val_acc: 0.9939
Epoch 09 | train_loss: 0.0131 | train_acc: 0.9952 | val_loss: 0.0132 | val_acc: 0.9946
Epoch 10 | train_loss: 0.0128 | train_acc: 0.9951 | val_loss: 0.0121 | val_acc: 0.9953
Epoch 11 | train_loss: 0.0116 | train_acc: 0.9956 | val_loss: 0.0102 | val_acc: 0.9965
Epoch 12 | train_loss: 0.0113 | train_acc: 