In [None]:
import torch

torch.manual_seed(42)
import lightning.pytorch as pl
import numpy as np
import pandas as pd
import torch.nn.functional as F
import torch.optim as optim
import torchmetrics
import wandb
from pytorch_lightning.loggers import WandbLogger
from sklearn.preprocessing import (
    LabelEncoder,
    MinMaxScaler,
    RobustScaler,
    StandardScaler,
)
from torch import nn, optim, utils
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from torch.utils.data import DataLoader, TensorDataset, random_split
from tqdm.notebook import tqdm

torch.set_float32_matmul_precision("high")

In [None]:
from statsmodels.tsa.seasonal import STL, seasonal_decompose


def get_period(series_name):
    mapping = {"D": 7, "W": 4, "M": 12, "Q": 4, "Y": 2}
    if not series_name:
        return 6
    return mapping.get(series_name[0], 6)  # Default period is 6 if not found


def seasonal_decomposition(ts, period=6):
    period = get_period(ts.name)
    ts = ts.dropna()
    decompose = STL(ts, period).fit()
    return decompose.seasonal + decompose.trend

In [None]:
df = pd.read_parquet("data/m4_preprocessed.parquet")
lengths = df.no_of_datapoints.values

le = LabelEncoder()
y = le.fit_transform(df.best_model.values)
classes = {idx: class_name for idx, class_name in enumerate(le.classes_)}


scaler = StandardScaler()
df = pd.DataFrame(
    scaler.fit_transform(
        df.drop(["best_model", "no_of_datapoints"], axis=1)
        .apply(seasonal_decomposition, axis=1)[df.columns[:-2]]
        .T
    ).T,
    columns=df.columns[:-2],
    index=df.index,
).fillna(0.0)
sequences = torch.tensor(df.values)

In [None]:
wandb_logger = WandbLogger(project="ts-classification", name="lstm.ts=trend+seasonal")
wandb_logger.experiment.config["model"] = "LSTM"
wandb_logger.experiment.config["ts"] = "trend+seasonal"

In [None]:
class LSTMDataLoader(pl.LightningDataModule):
    def __init__(self, sequences, lengths, y, batch_size=32):
        super().__init__()
        self.sequences = sequences
        self.lengths = lengths
        self.y = torch.tensor(y, dtype=torch.long)
        self.batch_size = batch_size

    def setup(self, stage=None):
        dataset = list(zip(self.sequences, self.lengths, self.y))

        # Sort by sequence length (important for packing)
        dataset.sort(key=lambda x: x[1], reverse=True)

        test_size = int(0.2 * len(dataset))
        val_size = int(0.1 * len(dataset))
        train_size = len(dataset) - test_size - val_size

        self.train_dataset, self.val_dataset, self.test_dataset = random_split(
            dataset, [train_size, val_size, test_size]
        )

    def collate_fn(self, batch):
        sequences, lengths, labels = zip(*batch)

        # Convert to tensor
        sequences = torch.stack(sequences)
        lengths = torch.tensor(lengths)
        labels = torch.tensor(labels, dtype=torch.long)
        ## CONVERT SEQUENCE FROM (batch,seq_len) -> (batch,seq_len,inpu_dim)
        # here input_dim will be trend and seasonal components - >seasonal_decompose(sequences.T,period=6)
        # how can i achecive this considering we have lots of zeros ?
        return sequences, lengths, labels

    def train_dataloader(self):
        return DataLoader(
            self.train_dataset,
            batch_size=self.batch_size,
            shuffle=True,
            collate_fn=self.collate_fn,
        )

    def val_dataloader(self):
        return DataLoader(
            self.val_dataset, batch_size=self.batch_size, shuffle=False, collate_fn=self.collate_fn
        )

    def test_dataloader(self):
        return DataLoader(
            self.test_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            collate_fn=self.collate_fn,
        )

In [None]:
class LSTMClassifier(pl.LightningModule):
    def __init__(
        self, input_dim, hidden_dim=128, num_layers=2, num_classes=10, learning_rate=1e-3
    ):
        super().__init__()
        self.save_hyperparameters()
        self.lstm = nn.LSTM(
            input_dim, hidden_dim, num_layers, batch_first=True, bidirectional=True, dropout=0.3
        )
        self.fc = nn.Linear(hidden_dim * 2, num_classes)  # Bidirectional
        self.layer_norm = nn.LayerNorm(hidden_dim * 2)
        self.loss_fn = nn.CrossEntropyLoss()
        self.accuracy = torchmetrics.Accuracy(
            task="multiclass", num_classes=num_classes, average="macro"
        )
        self.f1_score = torchmetrics.F1Score(
            task="multiclass", num_classes=num_classes, average="macro"
        )
        self.auroc = torchmetrics.AUROC(
            task="multiclass", num_classes=num_classes, average="macro"
        )

        self.lr = learning_rate

    def forward(self, x, lengths):
        x = x.unsqueeze(-1)
        packed_x = pack_padded_sequence(
            x.float(), lengths.cpu(), batch_first=True, enforce_sorted=False
        )
        packed_out, (hn, cn) = self.lstm(packed_x)

        hn = torch.cat((hn[-2], hn[-1]), dim=1)  # Bidirectional concat
        hn = self.layer_norm(hn)  # Normalize
        out = self.fc(hn)
        return out

    def _common_step(self, batch, batch_idx):
        x, lengths, y = batch
        y_hat = self.forward(x, lengths)
        loss = self.loss_fn(y_hat, y)
        preds = torch.argmax(y_hat, dim=1)
        probs = torch.nn.functional.softmax(y_hat, dim=1)
        return loss, preds, probs, y

    # def on_train_batch_start(self, batch, batch_idx):
    #     lr = self.trainer.optimizers[0].param_groups[0]['lr']
    #     self.logger.experiment.log({"learning_rate": lr, "step": self.global_step})

    def training_step(self, batch, batch_idx):
        # if not batch_idx % 10:
        lr = self.trainer.optimizers[0].param_groups[0]["lr"]
        self.logger.experiment.log({"learning_rate": lr, "step": self.global_step})

        loss, preds, probs, y = self._common_step(batch, batch_idx)
        accuracy = self.accuracy(preds, y)
        f1_score = self.f1_score(preds, y)
        auroc = self.auroc(probs, y)
        self.log_dict(
            {
                "train_loss": loss,
                "train_accuracy": accuracy,
                "train_f1score": f1_score,
                "train_auroc": auroc,
            },
            prog_bar=True,
        )
        return loss

    def validation_step(self, batch, batch_idx):
        loss, preds, probs, y = self._common_step(batch, batch_idx)
        accuracy = self.accuracy(preds, y)
        f1_score = self.f1_score(preds, y)
        auroc = self.auroc(probs, y)
        self.log_dict(
            {
                "val_loss": loss,
                "val_accuracy": accuracy,
                "val_f1score": f1_score,
                "val_auroc": auroc,
            },
            prog_bar=False,
        )
        return loss

    def test_step(self, batch, batch_idx):
        loss, preds, probs, y = self._common_step(batch, batch_idx)
        accuracy = self.accuracy(preds, y)
        f1_score = self.f1_score(preds, y)
        auroc = self.auroc(probs, y)
        self.log_dict(
            {
                "test_loss": loss,
                "test_accuracy": accuracy,
                "test_f1score": f1_score,
                "test_auroc": auroc,
            },
            prog_bar=False,
        )
        return loss

    def configure_optimizers(self):
        optimizer = optim.AdamW(self.parameters(), lr=self.lr, weight_decay=1e-4)

        scheduler = {
            "scheduler": optim.lr_scheduler.CosineAnnealingWarmRestarts(
                optimizer,
                T_0=10,  # First restart at 10 epochs
                T_mult=2,  # Restart cycle doubles each time
                eta_min=3e-5,  # Minimum learning rate
            ),
            "interval": "epoch",  # Update scheduler every epoch
        }

        return {"optimizer": optimizer, "lr_scheduler": scheduler}

    # def configure_optimizers(self):
    #     optimizer = optim.AdamW(
    #         self.parameters(), lr=self.lr, weight_decay=1e-4
    #     )  # AdamW for better generalization
    #     scheduler = optim.lr_scheduler.ReduceLROnPlateau(
    #         optimizer, mode="min", factor=0.5, patience=10
    #     )
    #     return {
    #         "optimizer": optimizer,
    #         "lr_scheduler": {"scheduler": scheduler, "monitor": "val_loss"},
    # }

In [None]:
ds = LSTMDataLoader(sequences, lengths, y, batch_size=778)
model = LSTMClassifier(input_dim=1, num_classes=len(set(y)))  # sequences.shape[-1]

In [None]:
# logger = pl.loggers.TensorBoardLogger(save_dir="./log/", name="lstm_model_classifier", version=0.1)
# initialise the wandb logger and name your wandb project


# saves top-K checkpoints based on "val_loss" metric
checkpoint_callback = pl.callbacks.ModelCheckpoint(
    save_top_k=1,
    monitor="val_f1score",
    mode="max",
    dirpath="checkpoints/",
    filename="lstm-model-classifier-{epoch}-{val_f1score}",
)


trainer = pl.Trainer(
    logger=wandb_logger,
    accelerator="auto",
    devices=[0],
    min_epochs=1,
    max_epochs=1000,
    # precision='16-mixed',
    enable_model_summary=True,
    callbacks=[
        pl.callbacks.EarlyStopping("val_loss", patience=15, verbose=False),
        checkpoint_callback,
    ],
    #     default_root_dir="mnist_checkpoints/",
    enable_checkpointing=True,
)
ckpt_path = "model_checkpoints/lstm_classifier-RobustScaler.ckpt"
finetune = False
if finetune:
    trainer.fit(model, ds, ckpt_path=ckpt_path)
else:
    trainer.fit(model, ds)
trainer.save_checkpoint("model_checkpoints/lstm_classifier-RobustScaler.ckpt")

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
/home/pranav-pc/projects/ts/ts/classification/.venv/lib/python3.11/site-packages/lightning/pytorch/callbacks/model_checkpoint.py:654: Checkpoint directory /home/pranav-pc/projects/ts/nbs/src/checkpoints exists and is not empty.
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name       | Type               | Params | Mode 
----------------------------------------------------------
0 | lstm       | LSTM               | 529 K  | train
1 | fc         | Linear             | 1.8 K  | train
2 | layer_norm | LayerNorm          | 512    | train
3 | loss_fn    | CrossEntropyLoss   | 0      | train
4 | accuracy   | MulticlassAccuracy | 0      | train
5 | f1_score   | MulticlassF1Score  | 0      | train
6 | auroc      | MulticlassAUROC    | 0      | train
----------------------------------------------------------
531 K     Trainable params
0         Non-trainable params
531 K     Tot

Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

/home/pranav-pc/projects/ts/ts/classification/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/pranav-pc/projects/ts/ts/classification/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.
/home/pranav-pc/projects/ts/ts/classification/.venv/lib/python3.11/site-packages/lightning/pytorch/loops/fit_loop.py:310: The number of training batches (45) is smaller than the logging interval Trainer(log_every_n_steps=50). Set a lower value for log_every_n_steps if you want to see logs for the training epoch.


Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

In [None]:
trainer.validate(model, ds);

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Validation: |                                             | 0/? [00:00<?, ?it/s]

In [None]:
trainer.test(model, ds);

LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/home/pranav-pc/projects/ts/ts/classification/.venv/lib/python3.11/site-packages/lightning/pytorch/trainer/connectors/data_connector.py:425: The 'test_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=31` in the `DataLoader` to improve performance.


Testing: |                                                | 0/? [00:00<?, ?it/s]

In [None]:
wandb.finish()

0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▂▂▂▂▃▃▃▃▃▄▄▅▅▅▅▅▆▆▆▆▆▆▇▇▇▇███
learning_rate,▆▄▂▁▁▇▇▅▄▃▂▂▁▁██████▇▇▆▆▆▅▄▃▃▂▂▂▁▁▁▁████
step,▁▁▁▁▁▂▂▂▂▂▃▃▃▃▄▄▄▄▄▄▄▅▅▅▅▅▅▅▅▅▆▆▆▇▇▇▇▇██
test_accuracy,▁
test_auroc,▁
test_f1score,▁
test_loss,▁
train_accuracy,▁▁▂▂▂▃▃▃▄▂▃▄▄▃▄▄▄▄▄▅▅▆▄▅▆▆▆▆▇▇▆▆▇██▇█▇▅▆
train_auroc,▁▁▂▁▂▁▂▃▂▂▃▅▃▄▄▄▅▄▄▄▅▄▅▅▅▆▆▅▅▆▇▇▇█▇██▇▆▇
train_f1score,▁▁▁▂▂▂▂▁▃▃▃▄▄▄▅▄▄▄▅▅▅▆▅▆▅▅▆▆▆▆▇▇▆▆▆▇██▇▆

0,1
epoch,74.0
learning_rate,0.001
step,3329.0
test_accuracy,0.33511
test_auroc,0.73479
test_f1score,0.31113
test_loss,1.60452
train_accuracy,0.32096
train_auroc,0.74481
train_f1score,0.30213
