# Knowledge distillation, an alternative for finetuning Lag-Llama
This notebook demonstrates how knowledge distillation can be applied to Lag-Llama. By leveraging a pretrained, openly accessible Lag-Llama model as a teacher, we train a smaller and faster student neural network based on the
Lag-Llama architecture. The distillation process transfers knowledge from the larger teacher model to the compact student model, aiming to achieve comparable forecasting performance while significantly reducing computational cost and inference time. Through this approach, we explore how knowledge distillation enhances efficiency in time series forecasting without sacrificing accuracy.

## Prepare the repository

We first clone and install the required packages from the [GitHub repository](https://github.com/time-series-foundation-models/lag-llama/) that has the Lag-Llama architecture.

In [None]:
# !git clone -b update-gluonts https://github.com/time-series-foundation-models/lag-llama/

In [None]:
#cd /content/lag-llama

In [None]:
!pip install -U -r requirements.txt  # this could take some time # ignore the errors displayed by colab

**Restart your runtime now, and then continue**

In [None]:
# cd /content/lag-llama

In [None]:
# from google.colab import drive
# drive.mount('/content/lag-llama/drive')

We then download our pretrained model weights from [HuggingFace](https://huggingface.co/time-series-foundation-models/Lag-Llama) 🤗

In [None]:
!huggingface-cli download time-series-foundation-models/Lag-Llama lag-llama.ckpt --local-dir .

## Imports

We import the required packages and the lag llama estimator object which we can use to make predictions.

In [None]:
# Standard library imports
from itertools import islice
from math import prod
import time
from types import MethodType


# Third-party imports
import numpy as np
import pandas as pd
from matplotlib import pyplot as plt
import matplotlib.dates as mdates
from tqdm import tqdm

import torch
import torch.nn.functional as F
from torch.distributions import StudentT

from lightning.pytorch.callbacks import Callback, EarlyStopping


# GluonTS imports
from gluonts.dataset.common import ListDataset

from gluonts.model.forecast_generator import DistributionForecastGenerator
from gluonts.dataset.repository.datasets import get_dataset

from gluonts.torch.util import repeat_along_dim, take_last
from gluonts.evaluation import make_evaluation_predictions, Evaluator


# Lag-Llama imports
from lag_llama.gluon.lightning_module import LagLlamaLightningModule
from lag_llama.gluon.estimator import LagLlamaEstimator

import optuna

In [None]:
import sys
from types import ModuleType

# Create dummy module hierarchy
def create_dummy_module(module_path):
    """
    Create a dummy module hierarchy for the given path.
    Returns the leaf module.
    """
    parts = module_path.split('.')
    current = ''
    parent = None

    for part in parts:
        current = current + '.' + part if current else part
        if current not in sys.modules:
            module = ModuleType(current)
            sys.modules[current] = module
            if parent:
                setattr(sys.modules[parent], part, module)
        parent = current

    return sys.modules[module_path]

# Create the dummy gluonts module hierarchy
gluonts_module = create_dummy_module('gluonts.torch.modules.loss')

# Create dummy classes for the specific loss functions
class DistributionLoss:
    def __init__(self, *args, **kwargs):
        pass

    def __call__(self, *args, **kwargs):
        return 0.0

    def __getattr__(self, name):
        return lambda *args, **kwargs: None

class NegativeLogLikelihood:
    def __init__(self, *args, **kwargs):
        pass

    def __call__(self, *args, **kwargs):
        return 0.0

    def __getattr__(self, name):
        return lambda *args, **kwargs: None

class Dataset:
    def __init__(self, train, test):
        self.train = train
        self.test = test

# Add the specific classes to the module
gluonts_module.DistributionLoss = DistributionLoss
gluonts_module.NegativeLogLikelihood = NegativeLogLikelihood

## Importing the data

We can use two dataset for training/testing:


1.   The electricity dataset in GluonTS contains hourly electricity consumption data for 370 different customers over several years. Each time series represents the electricity usage of a single customer, and the dataset is commonly used for benchmarking time series forecasting models. It includes both training and test splits, allowing models to learn from past consumption patterns and predict future electricity demand.
2.   TODO




In [None]:
from gluonts.dataset.field_names import FieldName
from collections import namedtuple
from pandas.tseries.frequencies import to_offset

Dataset = namedtuple("Dataset", ["train", "validation", "test"])

def load_dataset_from_csv(
    data_csv: str,
    timestamps_csv: str,
    train_test_ratio: float = 0.8,
    train_val_ratio: float = 0.8,
    context_length: int = 168,
    prediction_length: int = 24,
):
    # Load data
    data_df = pd.read_csv(data_csv).iloc[:, 1:]
    timestamps_df = pd.read_csv(timestamps_csv).iloc[:, 1:]

    # Infer frequency from the first row
    inferred_deltas = timestamps_df.iloc[0].apply(pd.to_datetime).diff().dropna()
    inferred_freq = inferred_deltas.mode()[0]
    freq_str = to_offset(inferred_freq).freqstr

    print(f"Dataset frequency: {freq_str}")

    timestamps_df = timestamps_df.iloc[:, 0].apply(pd.to_datetime)

    assert data_df.shape[0] == timestamps_df.shape[0]


    train_list = []
    validation_list = []
    test_list = []

    n_series = len(data_df)
    n_train = int(train_test_ratio  * train_val_ratio * n_series)
    n_validation = int(train_test_ratio  * (1 - train_val_ratio) * n_series)

    for i in tqdm(range(n_train), desc="Building train dataset"):
        target = data_df.iloc[i].values.astype(np.float32)
        start = timestamps_df.iloc[i]  # first timestamp of the series

        train_list.append({
            FieldName.START: start,
            FieldName.TARGET: target
        })

    for i in tqdm(range(n_train, n_train + n_validation), desc="Building validation dataset"):
        target = data_df.iloc[i].values.astype(np.float32)
        start = timestamps_df.iloc[i]  # first timestamp of the series

        validation_list.append({
            FieldName.START: start,
            FieldName.TARGET: target
        })

    for i in tqdm(range(n_train + n_validation, n_series), desc="Building test dataset"):
        target = data_df.iloc[i].values.astype(np.float32)
        start = timestamps_df.iloc[i]  # first timestamp of the series

        test_list.append({
            FieldName.START: start,
            FieldName.TARGET: target
        })

    train_ds = ListDataset(train_list, freq=freq_str)
    validation_ds = ListDataset(validation_list, freq=freq_str)
    test_ds = ListDataset(test_list, freq=freq_str)

    return Dataset(train=train_ds, validation=validation_ds, test=test_ds)


In [None]:
# Data splits ratios and context/prediction lengths
train_test_ratio = 0.7
train_val_ratio = 0.7
#prediction_length = dataset.metadata.prediction_length
#context_length = prediction_length*3
dataset_context_length = 168
dataset_prediction_length = 24

#dataset = get_dataset("electricity")
dataset = load_dataset_from_csv("datasets/data2019.csv", "datasets/data_indexes2019.csv", train_test_ratio, train_val_ratio, dataset_context_length, dataset_prediction_length)

plt.figure(figsize=(20, 15))
date_formater = mdates.DateFormatter("%Y")
plt.rcParams.update({'font.size': 15})

for idx, entry in enumerate(islice(dataset.train, 9)):
    ax = plt.subplot(3, 3, idx + 1)
    t = pd.date_range(
        start=entry["start"].to_timestamp(),
        periods=len(entry["target"]),
        freq=entry["start"].freq,
    )
    ax.plot(t, entry["target"], label=f"Series {idx}")
    ax.set_title(f"Train Sample {idx}")
    ax.xaxis.set_major_formatter(date_formater)
    plt.xticks(rotation=60)

plt.tight_layout()
plt.legend(loc="upper right")
plt.show()

## Modifying the training step

We include a custom training step for the Lag-Llama model. This is based on the original one, but includes an extra loss component for knowledge distillation.

## Defining custom callbacks for training

In [None]:
from IPython.display import display, update_display

class LossPlotCallback(Callback):
    def __init__(self):
        super().__init__()
        self.train_losses = []
        self.student_losses = []
        self.teacher_losses = []
        self.val_losses = []
        self.display_id = "loss_plot_display"
        self.display_initialized = False

    def on_train_epoch_end(self, trainer, pl_module):
        metrics = trainer.callback_metrics
        train = metrics.get("train_loss")
        student = metrics.get("student_loss")
        teacher = metrics.get("distillation_loss")
        val = metrics.get("val_loss")

        self.train_losses.append(train.cpu().item())
        if teacher is not None:
            self.student_losses.append(student.cpu().item())
            self.teacher_losses.append(teacher.cpu().item())
        if val is not None:
            self.val_losses.append(val.cpu().item())

        self._plot_losses()

    def _plot_losses(self):
        fig = plt.figure(figsize=(10, 5))
        epochs = range(len(self.train_losses))

        # Plot each line and label with last value
        total_label = f"Training ({self.train_losses[-1]:.4f})"
        plt.plot(epochs, self.train_losses, label=total_label, marker='o')

        if self.teacher_losses:
            student_label = f"Student ({self.student_losses[-1]:.4f})"
            distill_label = f"Distillation ({self.teacher_losses[-1]:.4f})"
            plt.plot(epochs, self.student_losses, label=student_label, marker='s')
            plt.plot(epochs, self.teacher_losses, label=distill_label, marker='x')

        if self.val_losses:
            val_epochs = range(len(self.val_losses))
            val_label = f"Validation ({self.val_losses[-1]:.4f})"
            plt.plot(val_epochs, self.val_losses, label=val_label, linestyle='--', marker='^', color='orange')

        plt.xlabel("Epoch")
        plt.ylabel("Loss")
        plt.title("Training & Validation Losses")
        plt.legend(fontsize="small")
        plt.grid(True)
        plt.xlim(-0.5, max(len(self.train_losses) - 0.5, 10.5))
        plt.xticks(np.arange(0, max(len(self.train_losses), 11), round(len(self.train_losses) // 10) + 1))
        plt.tight_layout()
        
        # Display or update the plot
        if not self.display_initialized:
            display(fig, display_id=self.display_id)
            self.display_initialized = True
        else:
            update_display(fig, display_id=self.display_id)

        plt.close(fig)

## Importing the model weights

In [None]:
# Set fixed random seed for reproducibility
torch.manual_seed(42)

# Model context/prediction lengths
prediction_length = dataset_prediction_length
context_length = dataset_context_length

# Use Cuda if available
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

# Load pretrained Lag-Llama model
ckpt_path = "./lag-llama.ckpt"
ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
estimator_args = ckpt["hyper_parameters"]["model_kwargs"]

# Output distribution for the model
distr_output = "studentT"

# linear positional encoding scaling
use_rope_scaling = True
rope_scaling_arguments = {
        "type": "linear",
        "factor": max(1.0, (context_length + prediction_length) / estimator_args["context_length"]),
    }

## Initializing the models

In [None]:
# Instantiate the Lag-Llama Estimator
lag_llama_estimator = LagLlamaEstimator(
        ckpt_path=ckpt_path,
        prediction_length=prediction_length,
        context_length=context_length,
        distr_output=distr_output,

        input_size=estimator_args["input_size"],
        n_layer=estimator_args["n_layer"],
        n_embd_per_head=estimator_args["n_embd_per_head"],
        n_head=estimator_args["n_head"],
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],
        rope_scaling=rope_scaling_arguments if use_rope_scaling else None,

        nonnegative_pred_samples=False,
        #use_single_pass_sampling=True,
        aug_prob=0,
        device=device
    )

lag_llama_lightning_module = lag_llama_estimator.create_lightning_module()
lag_llama_transformation = lag_llama_estimator.create_transformation()
lag_llama_predictor = lag_llama_estimator.create_predictor(lag_llama_transformation, lag_llama_lightning_module)

## Training/Finetuning the models

In [None]:
def train_estimator(estimator, train_data, validation_data):
    start_time = time.time()
    trainOutput = estimator.train_model(training_data=train_data, validation_data=validation_data, cache_data=True, shuffle_buffer_length=1000)
    end_time = time.time()
    return trainOutput, end_time - start_time
    

def create_optuna_objective( 
        # Dataset
        train_data, 
        val_data,
        dataset_prediction_length,
        dataset_context_length,

        # Fixed training hyperparameters
        num_parallel_samples=100,
        max_epochs=100,
        batch_size=32,

        lr_list = [1e-4, 5e-4, 1e-3, 5e-3, 1e-2, 5e-2],

        # Model hyperparameters
        device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
        ckpt_path="./lag-llama.ckpt",
        distr_output = "studentT",
        use_rope_scaling = True,
        
        # Model selection
        distillation_teacher_estimator=None,
        ):
    
    def objective(trial):

        prediction_length = dataset_prediction_length
        context_length = dataset_context_length

        # Batch size and other hyperparameters
        lr = trial.suggest_categorical("learning_rate", lr_list)

        if distillation_teacher_estimator is not None:
            distillation_loss_weight = trial.suggest_float("distillation_loss_weight", 0.01, 0.99)

        # Load pretrained Lag-Llama model
        ckpt = torch.load(ckpt_path, map_location=device, weights_only=False)
        estimator_args = ckpt["hyper_parameters"]["model_kwargs"]
        
        rope_scaling_arguments = {
            "type": "linear",
            "factor": max(1.0, (context_length + prediction_length) / estimator_args["context_length"]),
        }

        estimator = LagLlamaEstimator(
            ckpt_path=None,
            prediction_length=prediction_length,
            context_length=context_length,
            distr_output=distr_output,

            input_size=estimator_args["input_size"],
            n_layer=round(estimator_args["n_layer"]/2),
            n_embd_per_head=round(estimator_args["n_embd_per_head"]/2),
            n_head=round(estimator_args["n_head"]/2),
            scaling=estimator_args["scaling"],
            time_feat=estimator_args["time_feat"],
            rope_scaling=rope_scaling_arguments if use_rope_scaling else None,

            nonnegative_pred_samples=False,
            #use_single_pass_sampling=True,
            aug_prob=0,
            device=device,

            batch_size=batch_size,
            lr=lr,
            num_parallel_samples=num_parallel_samples,
            trainer_kwargs = { # <- lightning trainer arguments
                "max_epochs": max_epochs,
                "callbacks": [
                    LossPlotCallback(),
                    optuna.integration.PyTorchLightningPruningCallback(trial, monitor="val_loss"),
                ],
                "enable_progress_bar": True,
            },

            distillation_teacher_estimator=distillation_teacher_estimator,
            distillation_loss_weight=distillation_loss_weight if distillation_teacher_estimator is not None else 0.0,
        )

        print("\n--------------------------------------------------------------")
        print(f"Trial {trial.number}:")
        for key, value in trial.params.items():
            print(f"  {key}: {value}")
        print("--------------------------------------------------------------\n\n")

        # Train the model
        trainOutput, training_time = train_estimator(estimator, train_data, val_data)
        print(f"  Training time: {training_time:.2f} seconds")

        return trainOutput.trainer.callback_metrics["val_loss"].cpu().item()  # Ensure the metric is on CPU
    
    return objective

In [None]:
sampler =  optuna.samplers.TPESampler() # optuna.samplers.GridSampler(search_space)

# optuna.pruners.MedianPruner()
pruner = optuna.pruners.PatientPruner(None, patience=5, min_delta=1e-5) 
# pruner = optuna.pruners.PatientPruner(
#     wrapped_pruner=optuna.pruners.SuccessiveHalvingPruner(
#         min_resource=5,         # allow models to train at least 5 epochs
#         reduction_factor=4,     # prune ~75% of weak trials
#         min_early_stopping_rate=0  # enable pruning from trial 0
#     ),
#     patience=5,                 # allow some slack in improvement
#     min_delta=1e-5              # very small improvement required
# )

study = optuna.create_study(sampler=sampler, pruner=pruner, direction="minimize", study_name="Distillation hyperparameter study")  # or "maximize" for accuracy
objective = create_optuna_objective(
    train_data=dataset.train,
    val_data=dataset.validation,
    dataset_prediction_length=dataset_prediction_length,
    dataset_context_length=dataset_context_length,
    distillation_teacher_estimator=lag_llama_estimator,  # Use the pretrained model for distillation
)

In [None]:
study_start_time = time.time()
study.optimize(objective, n_trials=100)
study_end_time = time.time()
print(f"Study completed in {study_end_time - study_start_time:.2f} seconds")

In [None]:
best_trial = study.best_trial
print("Best trial:")
print("  lr:", best_trial.params["learning_rate"])
print("  distillation_weight:", best_trial.params["distillation_loss_weight"])
print("  val_loss:", best_trial.value)

optuna.visualization.plot_contour(study)

In [None]:
# Batch size and other hyperparameters
batch_size = 32
num_parallel_samples = 100
max_epochs = 100
early_stopping_patience = 5
early_stopping_delta = 1e-5

# Extract the best hyperparameters from the study
lr = best_trial.params.get("learning_rate")
distillation_loss_weight = best_trial.params.get("distillation_loss_weight", 0.0)


# Instantiate the Small model without Distillation
lag_llama_small_estimator = LagLlamaEstimator(
        ckpt_path=None,
        prediction_length=prediction_length,
        context_length=context_length,
        distr_output=distr_output,

        input_size=estimator_args["input_size"],
        n_layer=round(estimator_args["n_layer"]/2),
        n_embd_per_head=round(estimator_args["n_embd_per_head"]/2),
        n_head=round(estimator_args["n_head"]/2),
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],
        rope_scaling=rope_scaling_arguments if use_rope_scaling else None,

        nonnegative_pred_samples=False,
        #use_single_pass_sampling=True,
        aug_prob=0,
        device=device,

        batch_size=batch_size,
        lr=lr,
        num_parallel_samples=num_parallel_samples,
        trainer_kwargs = { # <- lightning trainer arguments
            "max_epochs": max_epochs,
            "callbacks": [
                EarlyStopping(monitor="val_loss", min_delta=early_stopping_delta, 
                              patience=early_stopping_patience, mode="min"),
                LossPlotCallback(),
            ],
            "enable_progress_bar": True,
            }
    )

# Instantiate the Distilled model
lag_llama_distilled_estimator = LagLlamaEstimator(
        ckpt_path=None,
        prediction_length=prediction_length,
        context_length=context_length,
        distr_output=distr_output,

        input_size=estimator_args["input_size"],
        n_layer=round(estimator_args["n_layer"]/2),
        n_embd_per_head=round(estimator_args["n_embd_per_head"]/2),
        n_head=round(estimator_args["n_head"]/2),
        scaling=estimator_args["scaling"],
        time_feat=estimator_args["time_feat"],
        rope_scaling=rope_scaling_arguments if use_rope_scaling else None,

        nonnegative_pred_samples=False,
        #use_single_pass_sampling=True,
        aug_prob=0,
        device=device,

        batch_size=batch_size,
        lr=lr,
        num_parallel_samples=num_parallel_samples,
        trainer_kwargs = { # <- lightning trainer arguments
            "max_epochs": max_epochs,
            "callbacks": [
                EarlyStopping(monitor="val_loss", min_delta=early_stopping_delta, 
                              patience=early_stopping_patience, mode="min"),
                LossPlotCallback(),
            ],
            "enable_progress_bar": True,
            },
        distillation_teacher_estimator=lag_llama_estimator,
        distillation_loss_weight=distillation_loss_weight,
    )

In [None]:
# lag_llama_small_predictor, lag_llama_small_training_time = train_estimator(lag_llama_small_estimator, dataset.train, dataset.validation)

In [None]:
# lag_llama_distilled_predictor, lag_llama_distilled_training_time = train_estimator(lag_llama_distilled_estimator, dataset.train, dataset.validation)

In [None]:
# lag_llama_finetuned_predictor, lag_llama_finetuned_training_time = train_estimator(lag_llama_estimator, dataset.train, dataset.validation)

## Evaluating the models

In [None]:
# def graph_precictions(dataset, predictor, num_samples):
#     forecasts_it, tss_it = make_evaluation_predictions(
#         dataset=dataset, predictor=predictor, num_samples=num_samples
#     )

#     forecasts = list(tqdm(forecasts_it, total=len(dataset), desc="Forecasting batches"))
#     tss = list(tqdm(tss_it, total=len(dataset), desc="Ground truth"))

#     # Plot the first 9 time series
#     plt.figure(figsize=(20, 15))
#     date_formater = mdates.DateFormatter('%b, %d')
#     plt.rcParams.update({'font.size': 15})

#     # Iterate through the first 9 series, and plot the predicted samples
#     for idx, (forecast, ts) in islice(enumerate(zip(forecasts, tss)), 9):
#         ax = plt.subplot(3, 3, idx+1)

#         plt.plot(ts[-4 * prediction_length:].to_timestamp(), label="target", )
#         forecast.plot( color='g')
#         plt.xticks(rotation=60)
#         ax.xaxis.set_major_formatter(date_formater)
#         ax.set_title(forecast.item_id)

#     plt.gcf().tight_layout()
#     plt.legend()
#     plt.show()
#     return forecasts, tss

In [None]:
# forecasts_small, tss_small = graph_precictions(dataset.test, lag_llama_small_predictor, num_samples)

In [None]:
# forecasts_dis, tss_dis = graph_precictions(dataset.test, lag_llama_distilled_predictor, num_samples)

In [None]:
# forecasts_big, tss_big = graph_precictions(dataset.test, lag_llama_predictor, num_samples)

In [None]:
# forecasts_finetuned, tss_finetuned = graph_precictions(dataset.test, lag_llama_finetuned_predictor, num_samples)

## Getting some metrics

In [None]:
# evaluator = Evaluator(quantiles=[0.1, 0.5, 0.9])

# metrics_small, _ = evaluator(tss_small, forecasts_small)
# metrics_dis, _ = evaluator(tss_dis, forecasts_dis)
# df_metrics_small = pd.DataFrame.from_records(metrics_small, index=["Small Model"]).transpose()
# df_metrics_small.loc["Training time (sec)"] = lag_llama_small_training_time
# df_metrics_dis = pd.DataFrame.from_records(metrics_dis, index=["Distilled Model"]).transpose()
# df_metrics_dis.loc["Training time (sec)"] = lag_llama_distilled_training_time
# display(pd.concat([df_metrics_small, df_metrics_dis], axis=1))

In [None]:
# metrics_big, _ = evaluator(tss_big, forecasts_big)
# metrics_finetuned, _ = evaluator(tss_finetuned, forecasts_finetuned)

# df_metrics_big = pd.DataFrame.from_records(metrics_big, index=["Big Pretrained Model"]).transpose()
# df_metrics_big.loc["Training time (sec)"] = 0.0
# df_metrics_finetuned = pd.DataFrame.from_records(metrics_finetuned, index=["Big Finetuned Model"]).transpose()
# df_metrics_finetuned.loc["Training time (sec)"] = lag_llama_finetuned_training_time

# display(pd.concat([df_metrics_big, df_metrics_finetuned], axis=1))

In [None]:
# # Display all metrics
# df_metrics = pd.concat([df_metrics_small, df_metrics_dis, df_metrics_big, df_metrics_finetuned], axis=1)
# display(df_metrics)