In [None]:
import pandas as pd
import torch
from transformers import AutoModelForCausalLM

# load pretrain model
model = AutoModelForCausalLM.from_pretrained("thuml/timer-base-84m", trust_remote_code=True)
from sklearn.preprocessing import (
    LabelEncoder,
    MinMaxScaler,
    RobustScaler,
    StandardScaler,
)

# # prepare input
# batch_size, lookback_length = 1, 2880
# seqs = torch.randn(batch_size, lookback_length)

# # generate forecast
# prediction_length = 96
# normed_output = model.(seqs, max_new_tokens=prediction_length)

# print(output.shape)

In [None]:
df = pd.read_parquet("data/m4_preprocessed.parquet")
# lengths = df.no_of_datapoints.values

le = LabelEncoder()
y = le.fit_transform(df.best_model.values)
classes = {idx: class_name for idx, class_name in enumerate(le.classes_)}


scaler = RobustScaler()
df = pd.DataFrame(
    scaler.fit_transform(df.drop(["best_model", "no_of_datapoints"], axis=1).T).T,
    columns=df.columns[:-2],
    index=df.index,
).fillna(0.0)
# sequences = torch.tensor(df.values)
df["best_model"] = y

In [None]:
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torchmetrics
from pytorch_lightning.loggers import TensorBoardLogger
from torch.utils.data import DataLoader, Dataset

# Define the dataset class


class TimeSeriesDataset(Dataset):
    def __init__(self, df):
        self.data = torch.tensor(df.drop(columns=["best_model"]).values, dtype=torch.float32)
        self.labels = torch.tensor(
            df["best_model"].values, dtype=torch.long
        )  # Assuming categorical labels

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        return self.data[idx], self.labels[idx]


# Modify the model for classification


class FineTunedTimerForPrediction(pl.LightningModule):
    def __init__(self, original_model, num_classes):
        super().__init__()
        self.model = original_model.model  # Keep the original feature extractor

        # Freeze original model weights
        for param in self.model.parameters():
            param.requires_grad = False

        # New classification head
        self.classification_head = nn.Sequential(
            nn.Linear(1024, 512),
            nn.ReLU(),
            nn.Dropout(0.3),
            nn.Linear(512, 256),
            nn.ReLU(),
            nn.Linear(256, num_classes),
        )

        self.loss_function = nn.CrossEntropyLoss()

        # Accuracy & F1 Score
        self.train_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_acc = torchmetrics.Accuracy(task="multiclass", num_classes=num_classes)
        self.val_f1 = torchmetrics.F1Score(task="multiclass", num_classes=num_classes)

    def forward(self, x):
        output = self.model(x)  # This returns MoeModelOutputWithPast
        if hasattr(output, "last_hidden_state"):
            x = output.last_hidden_state  # Extract the last hidden state tensor
        elif isinstance(output, tuple):
            x = output[0]  # Extract the first element if it's a tuple
        else:
            raise ValueError(f"Unexpected model output type: {type(output)}")

        x = self.classification_head(x[:, 0, :])  # Use the first token's embedding if needed
        return x

    def training_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_function(logits, y)

        acc = self.train_acc(logits, y)
        self.log("train_loss", loss, prog_bar=True, logger=True)
        self.log("train_acc", acc, prog_bar=True, logger=True)

        return loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        logits = self(x)
        loss = self.loss_function(logits, y)

        acc = self.val_acc(logits, y)
        f1 = self.val_f1(logits, y)

        self.log("val_loss", loss, prog_bar=True, logger=True)
        self.log("val_acc", acc, prog_bar=True, logger=True)
        self.log("val_f1", f1, prog_bar=True, logger=True)

        return loss

    def configure_optimizers(self):
        return torch.optim.Adam(self.classification_head.parameters(), lr=1e-3)


# Load dataset


def get_dataloaders(df, batch_size=512):
    dataset = TimeSeriesDataset(df)
    train_size = int(0.8 * len(dataset))
    val_size = len(dataset) - train_size
    train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])

    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True, num_workers=31, pin_memory=True
    )
    val_loader = DataLoader(val_dataset, batch_size=batch_size, num_workers=31, pin_memory=True)
    return train_loader, val_loader


# Training script


def train_model(df, original_model, num_classes=10, epochs=10):
    train_loader, val_loader = get_dataloaders(df)

    model = FineTunedTimerForPrediction(original_model, num_classes)

    # Logging setup
    logger = TensorBoardLogger("log/", name="TimerClassification")

    trainer = pl.Trainer(
        max_epochs=epochs, accelerator="gpu" if torch.cuda.is_available() else "cpu", logger=logger
    )

    trainer.fit(model, train_loader, val_loader)

    return model

In [None]:
train_model(df, model, len(classes), epochs=100)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
HPU available: False, using: 0 HPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]

  | Name                | Type               | Params | Mode 
-------------------------------------------------------------------
0 | model               | TimerModel         | 84.0 M | eval 
1 | classification_head | Sequential         | 657 K  | train
2 | loss_function       | CrossEntropyLoss   | 0      | train
3 | train_acc           | MulticlassAccuracy | 0      | train
4 | val_acc             | MulticlassAccuracy | 0      | train
5 | val_f1              | MulticlassF1Score  | 0      | train
-------------------------------------------------------------------
657 K     Trainable params
84.0 M    Non-trainable params
84.7 M    Total params
338.807   Total estimated model params size (MB)
11        Modules in train mode
117       Modules in eval mode


Sanity Checking: |                                        | 0/? [00:00<?, ?it/s]

Training: |                                               | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]

Validation: |                                             | 0/? [00:00<?, ?it/s]