<a href="https://colab.research.google.com/github/Series-Parallel/UCR_Time_Series_Classification_Deep_Learning_From_Scratch/blob/main/MLP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install lightning > /dev/null

In [None]:
import pandas as pd
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import lightning as L
from lightning.pytorch.callbacks import ModelCheckpoint
from lightning.pytorch.callbacks import Callback

from torch.utils.data import DataLoader, TensorDataset
from torch.optim import Adadelta
from sklearn.preprocessing import OneHotEncoder
from sklearn.preprocessing import MinMaxScaler

In [None]:
def reducer(filename):
  data = np.loadtxt(filename, delimiter=',')
  Y = data[:,0]
  X = data[:,1:]
  return X, Y

In [None]:
x_train, y_train = reducer("Adiac_TRAIN.txt")
x_test, y_test = reducer("Adiac_TEST.txt")

In [None]:
classes = len(np.unique(y_test))

normalizing the labels

In [None]:
y_train = ((y_train - y_train.min())/(y_train.max() - y_train.min()) * (classes - 1)).astype(int)
y_test = ((y_test - y_test.min())/ (y_test.max() - y_test.min()) * (classes - 1)).astype(int)

In [None]:
y_train_tensor = torch.tensor(y_train, dtype=torch.long)
y_test_tensor = torch.tensor(y_test, dtype=torch.long)

In [None]:
x_train_mean = x_train.mean()
x_train_std = x_train.std()
x_train = (x_train - x_train_mean) / x_train_std
x_test = (x_test - x_train_mean) / x_train_std

In [None]:
input_train = torch.tensor(x_train, dtype=torch.float32)
input_test = torch.tensor(x_test, dtype=torch.float32)

In [None]:
train_dataset = TensorDataset(input_train, y_train_tensor)
train_dataloader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [None]:
val_dataset = TensorDataset(input_test, y_test_tensor)
val_dataloader = DataLoader(val_dataset, batch_size=16)

In [None]:
class MLP(L.LightningModule):

  def __init__(self, input_dim, output_dim):
    super().__init__()

    L.seed_everything(813306)

    self.model = nn.Sequential(
        nn.Dropout(0.1),
        nn.Linear(input_dim, 500),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(500, 500),
        nn.ReLU(),
        nn.Dropout(0.2),
        nn.Linear(500, 500),
        nn.ReLU(),
        nn.Dropout(0.3),
        nn.Linear(500, output_dim)
      )

    self.loss_fn = nn.CrossEntropyLoss()

  def forward(self, x):
    return self.model(x)

  def training_step(self, batch, batch_size):
    x, y = batch
    logits = self(x)
    loss = self.loss_fn(logits, y)
    self.log("train_loss", loss)
    return loss

  def validation_step(self, batch, batch_idx):
    x, y = batch
    logits = self(x)
    val_loss = self.loss_fn(logits, y)
    acc = (torch.argmax(logits, dim=1) == y).float().mean()
    self.log('val_loss', val_loss, prog_bar=True)
    self.log('val_acc', acc, prog_bar=True)
    return {"val_loss": val_loss, "val_acc": acc}

  def configure_optimizers(self):
    return Adadelta(self.parameters(), lr=0.1)

In [None]:
input_dim = input_train.shape[1]

In [None]:
output_dim = len(torch.unique(y_train_tensor))

In [None]:
model = MLP(input_dim=input_dim, output_dim=output_dim)

In [None]:
checkpoint_callback = ModelCheckpoint(
    monitor='val_loss',   # or use 'val_acc' if you prefer
    mode='min',
    save_top_k=1,
    verbose=True,
    filename='best-checkpoint'
)

In [None]:
trainer = L.Trainer(
    max_epochs=5000,
    callbacks=[checkpoint_callback]
)

In [None]:
trainer.fit(model, train_dataloaders=train_dataloader, val_dataloaders=val_dataloader)

In [None]:
best_model = MLP(input_dim, output_dim)
best_model.load_state_dict(torch.load(checkpoint_callback.best_model_path)['state_dict'])

# Evaluate on test set
best_model.eval()
with torch.no_grad():
    logits = best_model(input_test)
    predictions = torch.argmax(logits, dim=1)
    accuracy = (predictions == y_test_tensor).float().mean()

print(f"Best Validation Accuracy: {checkpoint_callback.best_model_score.item():.4f}")
print(f"Test Accuracy (Best Model): {accuracy.item():.4f}")