In [6]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import os
from sklearn.preprocessing import MinMaxScaler
import sys
sys.path.append("..")  # Add parent directory to the system path
sys.path.append("/mnt/home/network-predictive-analysis/")
import config
import torch

from neuralforecast import NeuralForecast
from neuralforecast.models import LSTM, NHITS, RNN
# %%capture
from neuralforecast.core import NeuralForecast
from neuralforecast.models import TSMixer, TSMixerx, NHITS, MLPMultivariate
from neuralforecast.losses.pytorch import MSE, MAE
from neuralforecast import NeuralForecast, DistributedConfig

COLAB=False

if COLAB:
    from google.colab import drive
    drive.mount('/content/drive')
    proj_path = config.drive_path
else:
    proj_path = "/mnt/home/network-predictive-analysis/"
    
data_path = os.path.join(proj_path, config.data_path)
if not os.path.exists(data_path):
    os.makedirs(data_path)

proc_data_path = os.path.join(proj_path, config.processed_data_path)
if not os.path.exists(proc_data_path):
    os.makedirs(proc_data_path)

application_data_file = os.path.join(data_path, config.application_data_filename)
router_data_file = os.path.join(data_path, config.router_data_filename)

#Y_df = pd.read_parquet(os.path.join(proj_path, config.processed_data_path, "intermediate_batches/Y_df_multivariate_0.parquet"))


# TimeLLM multivariate

#### timellm

In [7]:
#| export
import math
from typing import Optional

import torch
import torch.nn as nn

from neuralforecast.common._base_windows import BaseWindows
from neuralforecast.common._base_multivariate import BaseMultivariate
from neuralforecast.common._modules import RevIN

from neuralforecast.losses.pytorch import MAE

try:
    from transformers import AutoModel, AutoTokenizer, AutoConfig
    IS_TRANSFORMERS_INSTALLED = True
except ImportError:
    IS_TRANSFORMERS_INSTALLED = False

import warnings

class ReplicationPad1d(nn.Module):
    """
    ReplicationPad1d
    """
    def __init__(self, padding):
        super(ReplicationPad1d, self).__init__()
        self.padding = padding

    def forward(self, input):
        replicate_padding = input[:, :, -1].unsqueeze(-1).repeat(1, 1, self.padding[-1])
        output = torch.cat([input, replicate_padding], dim=-1)
        return output
    
class TokenEmbedding(nn.Module):
    """
    TokenEmbedding
    """       
    def __init__(self, c_in, d_model):
        super(TokenEmbedding, self).__init__()
        padding = 1 if torch.__version__ >= '1.5.0' else 2
        self.tokenConv = nn.Conv1d(in_channels=c_in, out_channels=d_model,
                                   kernel_size=3, padding=padding, padding_mode='circular', bias=False)
        for m in self.modules():
            if isinstance(m, nn.Conv1d):
                nn.init.kaiming_normal_(
                    m.weight, mode='fan_in', nonlinearity='leaky_relu')

    def forward(self, x):
        x = self.tokenConv(x.permute(0, 2, 1)).transpose(1, 2)
        return x
    
class PatchEmbedding(nn.Module):
    """
    PatchEmbedding
    """      
    def __init__(self, d_model, patch_len, stride, dropout):
        super(PatchEmbedding, self).__init__()
        # Patching
        self.patch_len = patch_len
        self.stride = stride
        self.padding_patch_layer = ReplicationPad1d((0, stride))

        # Backbone, Input encoding: projection of feature vectors onto a d-dim vector space
        self.value_embedding = TokenEmbedding(patch_len, d_model)

        # Positional embedding
        # self.position_embedding = PositionalEmbedding(d_model)

        # Residual dropout
        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        # do patching
        n_vars = x.shape[1]
        x = self.padding_patch_layer(x)
        x = x.unfold(dimension=-1, size=self.patch_len, step=self.stride)
        x = torch.reshape(x, (x.shape[0] * x.shape[1], x.shape[2], x.shape[3]))
        # Input encoding
        x = self.value_embedding(x)
        return self.dropout(x), n_vars
    
class FlattenHead(nn.Module):
    """
    FlattenHead
    """       
    def __init__(self, n_vars, nf, target_window, head_dropout=0):
        super().__init__()
        self.n_vars = n_vars
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(nf, target_window)
        self.dropout = nn.Dropout(head_dropout)

    def forward(self, x):
        x = self.flatten(x)
        x = self.linear(x)
        x = self.dropout(x)
        return x
    
class HiddenStateProjection(nn.Module):
    """
    HiddenStateProjection
    Returns the hidden state instead of final outputs.
    """       
    def __init__(self, d_ff, hidden_size):
        super().__init__()
        self.flatten = nn.Flatten(start_dim=-2)
        self.linear = nn.Linear(d_ff, hidden_size)

    def forward(self, x):
        # Flatten and project to the hidden space
        x = self.flatten(x)
        return self.linear(x)

    
class ReprogrammingLayer(nn.Module):
    """
    ReprogrammingLayer
    """       
    def __init__(self, d_model, n_heads, d_keys=None, d_llm=None, attention_dropout=0.1):
        super(ReprogrammingLayer, self).__init__()

        d_keys = d_keys or (d_model // n_heads)

        self.query_projection = nn.Linear(d_model, d_keys * n_heads)
        self.key_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.value_projection = nn.Linear(d_llm, d_keys * n_heads)
        self.out_projection = nn.Linear(d_keys * n_heads, d_llm)
        self.n_heads = n_heads
        self.dropout = nn.Dropout(attention_dropout)

    def forward(self, target_embedding, source_embedding, value_embedding):
        B, L, _ = target_embedding.shape
        S, _ = source_embedding.shape
        H = self.n_heads

        target_embedding = self.query_projection(target_embedding).view(B, L, H, -1)
        source_embedding = self.key_projection(source_embedding).view(S, H, -1)
        value_embedding = self.value_projection(value_embedding).view(S, H, -1)

        out = self.reprogramming(target_embedding, source_embedding, value_embedding)

        out = out.reshape(B, L, -1)

        return self.out_projection(out) #16, 4, 768

    def reprogramming(self, target_embedding, source_embedding, value_embedding):
        B, L, H, E = target_embedding.shape

        scale = 1. / math.sqrt(E)

        scores = torch.einsum("blhe,she->bhls", target_embedding, source_embedding)

        A = self.dropout(torch.softmax(scale * scores, dim=-1))
        reprogramming_embedding = torch.einsum("bhls,she->blhe", A, value_embedding)

        return reprogramming_embedding
    

    #| export

class TimeLLM(BaseWindows):

    """ TimeLLM

    Time-LLM is a reprogramming framework to repurpose an off-the-shelf LLM for time series forecasting.

    It trains a reprogramming layer that translates the observed series into a language task. This is fed to the LLM and an output
    projection layer translates the output back to numerical predictions.

    **Parameters:**
    `h`: int, Forecast horizon. 
    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
    `patch_len`: int=16, length of patch.
    `stride`: int=8, stride of patch.
    `d_ff`: int=128, dimension of fcn.
    `top_k`: int=5, top tokens to consider.
    `d_llm`: int=768, hidden dimension of LLM. # LLama7b:4096; GPT2-small:768; BERT-base:768
    `d_model`: int=32, dimension of model.
    `n_heads`: int=8, number of heads in attention layer.
    `enc_in`: int=7, encoder input size.
    `dec_in`: int=7, decoder input size.
    `llm` = None, Path to pretrained LLM model to use. If not specified, it will use GPT-2 from https://huggingface.co/openai-community/gpt2"
    `llm_config` = Deprecated, configuration of LLM. If not specified, it will use the configuration of GPT-2 from https://huggingface.co/openai-community/gpt2"
    `llm_tokenizer` = Deprecated, tokenizer of LLM. If not specified, it will use the GPT-2 tokenizer from https://huggingface.co/openai-community/gpt2"
    `llm_num_hidden_layers` = 32, hidden layers in LLM
    `llm_output_attention`: bool = True, whether to output attention in encoder.
    `llm_output_hidden_states`: bool = True, whether to output hidden states.
    `prompt_prefix`: str=None, prompt to inform the LLM about the dataset.
    `dropout`: float=0.1, dropout rate.
    `stat_exog_list`: str list, static exogenous columns.
    `hist_exog_list`: str list, historic exogenous columns.
    `futr_exog_list`: str list, future exogenous columns.
    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
    `learning_rate`: float=1e-3, Learning rate between (0, 1).
    `max_steps`: int=1000, maximum number of training steps.
    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
    `val_check_steps`: int=100, Number of training steps between every validation loss check.
    `batch_size`: int=32, number of different series in each batch.
    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.
    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.
    `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch.
    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.
    `step_size`: int=1, step size between each window of temporal data.
    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
    `random_seed`: int, random_seed for pytorch initializer and numpy generators.
    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
    `alias`: str, optional,  Custom name of the model.
    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.    
    `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
    `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
    `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. 
    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).    

    **References:**
    -[Ming Jin, Shiyu Wang, Lintao Ma, Zhixuan Chu, James Y. Zhang, Xiaoming Shi, Pin-Yu Chen, Yuxuan Liang, Yuan-Fang Li, Shirui Pan, Qingsong Wen. "Time-LLM: Time Series Forecasting by Reprogramming Large Language Models"](https://arxiv.org/abs/2310.01728)
    
    """

    SAMPLING_TYPE = 'windows'
    EXOGENOUS_FUTR = False
    EXOGENOUS_HIST = False
    EXOGENOUS_STAT = False

    def __init__(self,
                 h,
                 input_size,
                 patch_len: int = 16,
                 stride: int = 8,
                 d_ff: int = 128,
                 top_k: int = 5,
                 d_llm: int = 768,
                 d_model: int = 32,
                 n_heads: int = 8,
                 enc_in: int = 7,
                 dec_in: int  = 7,
                 llm = None,
                 llm_config = None,
                 llm_tokenizer = None,
                 llm_num_hidden_layers = 32,
                 llm_output_attention: bool = True,
                 llm_output_hidden_states: bool = True,
                 prompt_prefix: Optional[str] = None,
                 dropout: float = 0.1,
                 stat_exog_list = None,
                 hist_exog_list = None,
                 futr_exog_list = None,
                 loss = MAE(),
                 valid_loss = None,
                 learning_rate: float = 1e-4,
                 max_steps: int = 5,
                 val_check_steps: int = 100,
                 batch_size: int = 32,
                 valid_batch_size: Optional[int] = None,
                 windows_batch_size: int = 1024,
                 inference_windows_batch_size: int = 1024,
                 start_padding_enabled: bool = False,
                 step_size: int = 1,
                 num_lr_decays: int = 0,
                 early_stop_patience_steps: int = -1,
                 scaler_type: str = 'identity',
                 num_workers_loader: int = 0,
                 drop_last_loader: bool = False,
                 random_seed: int = 1,
                 optimizer = None,
                 optimizer_kwargs = None,
                 lr_scheduler = None,
                 lr_scheduler_kwargs = None,
                 dataloader_kwargs = None,
                 #**trainer_kwargs)
                 
                 hidden_size = 128,
                ):
        super(TimeLLM, self).__init__(h=h,
                                      input_size=input_size,
                                      hist_exog_list=hist_exog_list,
                                      stat_exog_list=stat_exog_list,
                                      futr_exog_list = futr_exog_list,
                                      loss=loss,
                                      valid_loss=valid_loss,
                                      max_steps=max_steps,
                                      learning_rate=learning_rate,
                                      num_lr_decays=num_lr_decays,
                                      early_stop_patience_steps=early_stop_patience_steps,
                                      val_check_steps=val_check_steps,
                                      batch_size=batch_size,
                                      valid_batch_size=valid_batch_size,
                                      windows_batch_size=windows_batch_size,
                                      inference_windows_batch_size=inference_windows_batch_size,
                                      start_padding_enabled=start_padding_enabled,
                                      step_size=step_size,
                                      scaler_type=scaler_type,
                                      num_workers_loader=num_workers_loader,
                                      drop_last_loader=drop_last_loader,
                                      random_seed=random_seed,
                                      optimizer=optimizer,
                                      optimizer_kwargs=optimizer_kwargs,
                                      lr_scheduler=lr_scheduler,
                                      lr_scheduler_kwargs=lr_scheduler_kwargs,
                                      #dataloader_kwargs=dataloader_kwargs,
                                      #**trainer_kwargs
                                      )
        
        # Architecture
        self.patch_len = patch_len
        self.stride = stride
        self.d_ff = d_ff
        self.top_k = top_k
        self.d_llm = d_llm
        self.d_model = d_model
        self.dropout = dropout
        self.n_heads = n_heads
        self.enc_in = enc_in
        self.dec_in = dec_in

        DEFAULT_MODEL = "openai-community/gpt2"

        if llm is None:
            if not IS_TRANSFORMERS_INSTALLED:
                raise ImportError(
                    "Please install `transformers` to use the default LLM."
                )
                  
            print(f"Using {DEFAULT_MODEL} as default.")
            model_name = DEFAULT_MODEL
        else:
            model_name = llm

        if llm_config is not None or llm_tokenizer is not None:
            warnings.warn("'llm_config' and 'llm_tokenizer' parameters are deprecated and will be ignored. "
                        "The config and tokenizer will be automatically loaded from the specified model.", 
                        DeprecationWarning)

        try:
            self.llm_config = AutoConfig.from_pretrained(model_name)
            self.llm = AutoModel.from_pretrained(model_name, config=self.llm_config)
            self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
            print(f"Successfully loaded model: {model_name}")
        except EnvironmentError:
            print(f"Failed to load {model_name}. Loading the default model ({DEFAULT_MODEL})...")
            self.llm_config = AutoConfig.from_pretrained(DEFAULT_MODEL)
            self.llm = AutoModel.from_pretrained(DEFAULT_MODEL, config=self.llm_config)
            self.llm_tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)

        self.llm_num_hidden_layers = llm_num_hidden_layers
        self.llm_output_attention = llm_output_attention
        self.llm_output_hidden_states = llm_output_hidden_states
        self.prompt_prefix = prompt_prefix

        if self.llm_tokenizer.eos_token:
            self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
        else:
            pad_token = '[PAD]'
            self.llm_tokenizer.add_special_tokens({'pad_token': pad_token})
            self.llm_tokenizer.pad_token = pad_token

        for param in self.llm.parameters():
            param.requires_grad = False

        self.patch_embedding = PatchEmbedding(
            self.d_model, self.patch_len, self.stride, self.dropout)
        
        self.word_embeddings = self.llm.get_input_embeddings().weight
        self.vocab_size = self.word_embeddings.shape[0]
        self.num_tokens = 1024
        self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)

        self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_ff, self.d_llm)

        self.patch_nums = int((input_size - self.patch_len) / self.stride + 2)
        self.head_nf = self.d_ff * self.patch_nums

        self.output_projection = FlattenHead(self.enc_in, self.head_nf, self.h, head_dropout=self.dropout)

        self.normalize_layers = RevIN(self.enc_in, affine=False)
        
        # Define a custom hidden state projection layer
        self.output_projection = HiddenStateProjection(
            d_ff=self.d_ff, hidden_size=hidden_size
        )
        self.hidden_size = hidden_size

    def forecast(self, x_enc):

        x_enc = self.normalize_layers(x_enc, 'norm')

        B, T, N = x_enc.size()
        x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

        min_values = torch.min(x_enc, dim=1)[0]
        max_values = torch.max(x_enc, dim=1)[0]
        medians = torch.median(x_enc, dim=1).values
        lags = self.calcute_lags(x_enc)
        trends = x_enc.diff(dim=1).sum(dim=1)

        prompt = []
        for b in range(x_enc.shape[0]):
            min_values_str = str(min_values[b].tolist()[0])
            max_values_str = str(max_values[b].tolist()[0])
            median_values_str = str(medians[b].tolist()[0])
            lags_values_str = str(lags[b].tolist())
            prompt_ = (
                f"<|start_prompt|>{self.prompt_prefix}"
                f"Task description: forecast the next {str(self.h)} steps given the previous {str(self.input_size)} steps information; "
                "Input statistics: "
                f"min value {min_values_str}, "
                f"max value {max_values_str}, "
                f"median value {median_values_str}, "
                f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, "
                f"top 5 lags are : {lags_values_str}<||>"
            )

            prompt.append(prompt_)

        x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()

        prompt = self.llm_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
        prompt_embeddings = self.llm.get_input_embeddings()(prompt.to(x_enc.device))  # (batch, prompt_token, dim)

        source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)

        x_enc = x_enc.permute(0, 2, 1).contiguous()
        enc_out, n_vars = self.patch_embedding(x_enc.to(torch.float32))
        #print("enc_out", enc_out.shape)
        enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
        #print("enc_out", enc_out.shape)
        llm_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
        
        #print("llm_enc_out", llm_enc_out.shape)
        
        dec_out = self.llm(inputs_embeds=llm_enc_out).last_hidden_state
        #print("dec_out last", dec_out.shape)
        dec_out = dec_out[:, :, :self.d_ff]
        #print("dec_out", dec_out.shape)
        
        #print("dec_out", dec_out.shape)

        dec_out = torch.reshape(
            dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1]))
        dec_out = dec_out.permute(0, 1, 3, 2).contiguous()
        
        #print("dec_out", dec_out.shape)

        dec_out = self.output_projection(dec_out[:, :, :, -self.patch_nums:])
        dec_out = dec_out.permute(0, 2, 1).contiguous()
        
        #print("dec_out 4", dec_out.shape)
        
        # Return the hidden states
        dec_out = torch.reshape(dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1]))
        #print("dec_out 5", dec_out.shape) #16, 1, 128, 1
        dec_out = dec_out.permute(0, 1, 3, 2).contiguous()

        # skip original projection layer
        #dec_out = self.normalize_layers(dec_out, 'denorm')
        
        #print("dec_out 4", dec_out.shape)
        return dec_out.reshape(B, self.hidden_size)
        
        #print("dec_out", dec_out.shape)

        return dec_out
        
    def calcute_lags(self, x_enc):
        q_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, dim=-1)
        mean_value = torch.mean(corr, dim=1)
        _, lags = torch.topk(mean_value, self.top_k, dim=-1)
        return lags
    
    def forward(self, windows_batch):
        insample_y = windows_batch #['insample_y']

        x = insample_y.unsqueeze(-1)

        return self.forecast(x)
        y_pred = self.forecast(x)
        y_pred = y_pred[:, -self.h:, :]
        y_pred = self.loss.domain_map(y_pred)
        
        return y_pred
    
    
    def get_embeddings(self, x_enc):
        """
        Extract embeddings from the LLM before the final projection.

        Args:
        - x_enc (torch.Tensor): Input tensor of shape [batch_size, input_size, enc_in].

        Returns:
        - embeddings (torch.Tensor): Embeddings for each input sample.
        """
        x_enc = self.normalize_layers(x_enc, 'norm')
        B, T, N = x_enc.size()
        x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

        min_values = torch.min(x_enc, dim=1)[0]
        max_values = torch.max(x_enc, dim=1)[0]
        medians = torch.median(x_enc, dim=1).values
        lags = self.calcute_lags(x_enc)
        trends = x_enc.diff(dim=1).sum(dim=1)

        prompt = []
        for b in range(x_enc.shape[0]):
            min_values_str = str(min_values[b].tolist()[0])
            max_values_str = str(max_values[b].tolist()[0])
            median_values_str = str(medians[b].tolist()[0])
            lags_values_str = str(lags[b].tolist())
            prompt_ = (
                f"<|start_prompt|>{self.prompt_prefix}"
                f"Task description: forecast the next {str(self.h)} steps given the previous {str(self.input_size)} steps information; "
                "Input statistics: "
                f"min value {min_values_str}, "
                f"max value {max_values_str}, "
                f"median value {median_values_str}, "
                f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, "
                f"top 5 lags are : {lags_values_str}<||>"
            )
            prompt.append(prompt_)

        x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()

        prompt = self.llm_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
        prompt_embeddings = self.llm.get_input_embeddings()(prompt.to(x_enc.device))  # (batch, prompt_token, dim)

        source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)

        x_enc = x_enc.permute(0, 2, 1).contiguous()
        enc_out, n_vars = self.patch_embedding(x_enc.to(torch.float32))
        enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
        llm_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)

        # Extract embeddings from the LLM
        dec_out = self.llm(inputs_embeds=llm_enc_out).last_hidden_state

        # Output embeddings directly (before final projection)
        embeddings = dec_out[:, :, :self.d_ff]
        return embeddings




In [None]:

class TimeLLM(BaseWindows):

    """ TimeLLM

    Time-LLM is a reprogramming framework to repurpose an off-the-shelf LLM for time series forecasting.

    It trains a reprogramming layer that translates the observed series into a language task. This is fed to the LLM and an output
    projection layer translates the output back to numerical predictions.

    **Parameters:**
    `h`: int, Forecast horizon. 
    `input_size`: int, autorregresive inputs size, y=[1,2,3,4] input_size=2 -> y_[t-2:t]=[1,2].
    `patch_len`: int=16, length of patch.
    `stride`: int=8, stride of patch.
    `d_ff`: int=128, dimension of fcn.
    `top_k`: int=5, top tokens to consider.
    `d_llm`: int=768, hidden dimension of LLM. # LLama7b:4096; GPT2-small:768; BERT-base:768
    `d_model`: int=32, dimension of model.
    `n_heads`: int=8, number of heads in attention layer.
    `enc_in`: int=7, encoder input size.
    `dec_in`: int=7, decoder input size.
    `llm` = None, Path to pretrained LLM model to use. If not specified, it will use GPT-2 from https://huggingface.co/openai-community/gpt2"
    `llm_config` = Deprecated, configuration of LLM. If not specified, it will use the configuration of GPT-2 from https://huggingface.co/openai-community/gpt2"
    `llm_tokenizer` = Deprecated, tokenizer of LLM. If not specified, it will use the GPT-2 tokenizer from https://huggingface.co/openai-community/gpt2"
    `llm_num_hidden_layers` = 32, hidden layers in LLM
    `llm_output_attention`: bool = True, whether to output attention in encoder.
    `llm_output_hidden_states`: bool = True, whether to output hidden states.
    `prompt_prefix`: str=None, prompt to inform the LLM about the dataset.
    `dropout`: float=0.1, dropout rate.
    `stat_exog_list`: str list, static exogenous columns.
    `hist_exog_list`: str list, historic exogenous columns.
    `futr_exog_list`: str list, future exogenous columns.
    `loss`: PyTorch module, instantiated train loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
    `valid_loss`: PyTorch module=`loss`, instantiated valid loss class from [losses collection](https://nixtla.github.io/neuralforecast/losses.pytorch.html).
    `learning_rate`: float=1e-3, Learning rate between (0, 1).
    `max_steps`: int=1000, maximum number of training steps.
    `num_lr_decays`: int=-1, Number of learning rate decays, evenly distributed across max_steps.
    `early_stop_patience_steps`: int=-1, Number of validation iterations before early stopping.
    `val_check_steps`: int=100, Number of training steps between every validation loss check.
    `batch_size`: int=32, number of different series in each batch.
    `valid_batch_size`: int=None, number of different series in each validation and test batch, if None uses batch_size.
    `windows_batch_size`: int=1024, number of windows to sample in each training batch, default uses all.
    `inference_windows_batch_size`: int=1024, number of windows to sample in each inference batch.
    `start_padding_enabled`: bool=False, if True, the model will pad the time series with zeros at the beginning, by input size.
    `step_size`: int=1, step size between each window of temporal data.
    `scaler_type`: str='identity', type of scaler for temporal inputs normalization see [temporal scalers](https://nixtla.github.io/neuralforecast/common.scalers.html).
    `random_seed`: int, random_seed for pytorch initializer and numpy generators.
    `num_workers_loader`: int=os.cpu_count(), workers to be used by `TimeSeriesDataLoader`.
    `drop_last_loader`: bool=False, if True `TimeSeriesDataLoader` drops last non-full batch.
    `alias`: str, optional,  Custom name of the model.
    `optimizer`: Subclass of 'torch.optim.Optimizer', optional, user specified optimizer instead of the default choice (Adam).
    `optimizer_kwargs`: dict, optional, list of parameters used by the user specified `optimizer`.    
    `lr_scheduler`: Subclass of 'torch.optim.lr_scheduler.LRScheduler', optional, user specified lr_scheduler instead of the default choice (StepLR).
    `lr_scheduler_kwargs`: dict, optional, list of parameters used by the user specified `lr_scheduler`.
    `dataloader_kwargs`: dict, optional, list of parameters passed into the PyTorch Lightning dataloader by the `TimeSeriesDataLoader`. 
    `**trainer_kwargs`: int,  keyword trainer arguments inherited from [PyTorch Lighning's trainer](https://pytorch-lightning.readthedocs.io/en/stable/api/pytorch_lightning.trainer.trainer.Trainer.html?highlight=trainer).    

    **References:**
    -[Ming Jin, Shiyu Wang, Lintao Ma, Zhixuan Chu, James Y. Zhang, Xiaoming Shi, Pin-Yu Chen, Yuxuan Liang, Yuan-Fang Li, Shirui Pan, Qingsong Wen. "Time-LLM: Time Series Forecasting by Reprogramming Large Language Models"](https://arxiv.org/abs/2310.01728)
    
    """

    SAMPLING_TYPE = 'windows'
    EXOGENOUS_FUTR = False
    EXOGENOUS_HIST = False
    EXOGENOUS_STAT = False

    def __init__(self,
                 h,
                 input_size,
                 patch_len: int = 16,
                 stride: int = 8,
                 d_ff: int = 128,
                 top_k: int = 5,
                 d_llm: int = 768,
                 d_model: int = 32,
                 n_heads: int = 8,
                 enc_in: int = 7,
                 dec_in: int  = 7,
                 llm = None,
                 llm_config = None,
                 llm_tokenizer = None,
                 llm_num_hidden_layers = 32,
                 llm_output_attention: bool = True,
                 llm_output_hidden_states: bool = True,
                 prompt_prefix: Optional[str] = None,
                 dropout: float = 0.1,
                 stat_exog_list = None,
                 hist_exog_list = None,
                 futr_exog_list = None,
                 loss = MAE(),
                 valid_loss = None,
                 learning_rate: float = 1e-4,
                 max_steps: int = 5,
                 val_check_steps: int = 100,
                 batch_size: int = 32,
                 valid_batch_size: Optional[int] = None,
                 windows_batch_size: int = 1024,
                 inference_windows_batch_size: int = 1024,
                 start_padding_enabled: bool = False,
                 step_size: int = 1,
                 num_lr_decays: int = 0,
                 early_stop_patience_steps: int = -1,
                 scaler_type: str = 'identity',
                 num_workers_loader: int = 0,
                 drop_last_loader: bool = False,
                 random_seed: int = 1,
                 optimizer = None,
                 optimizer_kwargs = None,
                 lr_scheduler = None,
                 lr_scheduler_kwargs = None,
                 dataloader_kwargs = None,
                 #**trainer_kwargs)
                ):
        super(TimeLLM, self).__init__(h=h,
                                      input_size=input_size,
                                      hist_exog_list=hist_exog_list,
                                      stat_exog_list=stat_exog_list,
                                      futr_exog_list = futr_exog_list,
                                      loss=loss,
                                      valid_loss=valid_loss,
                                      max_steps=max_steps,
                                      learning_rate=learning_rate,
                                      num_lr_decays=num_lr_decays,
                                      early_stop_patience_steps=early_stop_patience_steps,
                                      val_check_steps=val_check_steps,
                                      batch_size=batch_size,
                                      valid_batch_size=valid_batch_size,
                                      windows_batch_size=windows_batch_size,
                                      inference_windows_batch_size=inference_windows_batch_size,
                                      start_padding_enabled=start_padding_enabled,
                                      step_size=step_size,
                                      scaler_type=scaler_type,
                                      num_workers_loader=num_workers_loader,
                                      drop_last_loader=drop_last_loader,
                                      random_seed=random_seed,
                                      optimizer=optimizer,
                                      optimizer_kwargs=optimizer_kwargs,
                                      lr_scheduler=lr_scheduler,
                                      lr_scheduler_kwargs=lr_scheduler_kwargs,
                                      #dataloader_kwargs=dataloader_kwargs,
                                      #**trainer_kwargs
                                      )
        
        # Architecture
        self.patch_len = patch_len
        self.stride = stride
        self.d_ff = d_ff
        self.top_k = top_k
        self.d_llm = d_llm
        self.d_model = d_model
        self.dropout = dropout
        self.n_heads = n_heads
        self.enc_in = enc_in
        self.dec_in = dec_in

        DEFAULT_MODEL = "openai-community/gpt2"

        if llm is None:
            if not IS_TRANSFORMERS_INSTALLED:
                raise ImportError(
                    "Please install `transformers` to use the default LLM."
                )
                  
            print(f"Using {DEFAULT_MODEL} as default.")
            model_name = DEFAULT_MODEL
        else:
            model_name = llm

        if llm_config is not None or llm_tokenizer is not None:
            warnings.warn("'llm_config' and 'llm_tokenizer' parameters are deprecated and will be ignored. "
                        "The config and tokenizer will be automatically loaded from the specified model.", 
                        DeprecationWarning)

        try:
            self.llm_config = AutoConfig.from_pretrained(model_name)
            self.llm = AutoModel.from_pretrained(model_name, config=self.llm_config)
            self.llm_tokenizer = AutoTokenizer.from_pretrained(model_name)
            print(f"Successfully loaded model: {model_name}")
        except EnvironmentError:
            print(f"Failed to load {model_name}. Loading the default model ({DEFAULT_MODEL})...")
            self.llm_config = AutoConfig.from_pretrained(DEFAULT_MODEL)
            self.llm = AutoModel.from_pretrained(DEFAULT_MODEL, config=self.llm_config)
            self.llm_tokenizer = AutoTokenizer.from_pretrained(DEFAULT_MODEL)

        self.llm_num_hidden_layers = llm_num_hidden_layers
        self.llm_output_attention = llm_output_attention
        self.llm_output_hidden_states = llm_output_hidden_states
        self.prompt_prefix = prompt_prefix

        if self.llm_tokenizer.eos_token:
            self.llm_tokenizer.pad_token = self.llm_tokenizer.eos_token
        else:
            pad_token = '[PAD]'
            self.llm_tokenizer.add_special_tokens({'pad_token': pad_token})
            self.llm_tokenizer.pad_token = pad_token

        for param in self.llm.parameters():
            param.requires_grad = False

        self.patch_embedding = PatchEmbedding(
            self.d_model, self.patch_len, self.stride, self.dropout)
        
        self.word_embeddings = self.llm.get_input_embeddings().weight
        self.vocab_size = self.word_embeddings.shape[0]
        self.num_tokens = 1024
        self.mapping_layer = nn.Linear(self.vocab_size, self.num_tokens)

        self.reprogramming_layer = ReprogrammingLayer(self.d_model, self.n_heads, self.d_ff, self.d_llm)

        self.patch_nums = int((input_size - self.patch_len) / self.stride + 2)
        self.head_nf = self.d_ff * self.patch_nums

        self.output_projection = FlattenHead(self.enc_in, self.head_nf, self.h, head_dropout=self.dropout)

        self.normalize_layers = RevIN(self.enc_in, affine=False)

    def forecast(self, x_enc):

        x_enc = self.normalize_layers(x_enc, 'norm')

        B, T, N = x_enc.size()
        x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

        min_values = torch.min(x_enc, dim=1)[0]
        max_values = torch.max(x_enc, dim=1)[0]
        medians = torch.median(x_enc, dim=1).values
        lags = self.calcute_lags(x_enc)
        trends = x_enc.diff(dim=1).sum(dim=1)

        prompt = []
        for b in range(x_enc.shape[0]):
            min_values_str = str(min_values[b].tolist()[0])
            max_values_str = str(max_values[b].tolist()[0])
            median_values_str = str(medians[b].tolist()[0])
            lags_values_str = str(lags[b].tolist())
            prompt_ = (
                f"<|start_prompt|>{self.prompt_prefix}"
                f"The dataset contains application iteration times for the milc workload. "
                f"Task description: forecast the next {str(self.h)} steps given the previous {str(self.input_size)} steps information. "
                "Input statistics: "
                f"min value {min_values_str}, "
                f"max value {max_values_str}, "
                f"median value {median_values_str}, "
                #f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, "
                #f"top 5 lags are : {lags_values_str}<||>"
            )

            prompt.append(prompt_)

        x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()

        prompt = self.llm_tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
        prompt_embeddings = self.llm.get_input_embeddings()(prompt.to(x_enc.device))  # (batch, prompt_token, dim)

        source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)

        x_enc = x_enc.permute(0, 2, 1).contiguous()
        enc_out, n_vars = self.patch_embedding(x_enc.to(torch.float32))
        enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
        llm_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
        dec_out = self.llm(inputs_embeds=llm_enc_out).last_hidden_state
        dec_out = dec_out[:, :, :self.d_ff]

        dec_out = torch.reshape(
            dec_out, (-1, n_vars, dec_out.shape[-2], dec_out.shape[-1]))
        dec_out = dec_out.permute(0, 1, 3, 2).contiguous()

        dec_out = self.output_projection(dec_out[:, :, :, -self.patch_nums:])
        dec_out = dec_out.permute(0, 2, 1).contiguous()

        dec_out = self.normalize_layers(dec_out, 'denorm')

        return dec_out
        
    def calcute_lags(self, x_enc):
        q_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        k_fft = torch.fft.rfft(x_enc.permute(0, 2, 1).contiguous(), dim=-1)
        res = q_fft * torch.conj(k_fft)
        corr = torch.fft.irfft(res, dim=-1)
        mean_value = torch.mean(corr, dim=1)
        _, lags = torch.topk(mean_value, self.top_k, dim=-1)
        return lags
    
    def forward(self, windows_batch):
        insample_y = windows_batch['insample_y']

        x = insample_y.unsqueeze(-1)

        y_pred = self.forecast(x)
        y_pred = y_pred[:, -self.h:, :]
        y_pred = self.loss.domain_map(y_pred)
        
        return y_pred
    

#### gnn+timellm

In [5]:
class PositionalEncoding(nn.Module):

    def __init__(self, d_model, dropout=0.1, max_len=5000):
        super(PositionalEncoding, self).__init__()
        self.dropout = nn.Dropout(p=dropout)

        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        pe = pe.unsqueeze(0).transpose(0, 1)
        self.register_buffer('pe', pe)

    def forward(self, x):
        x = x + self.pe[:x.size(0), :]
        return self.dropout(x)


class Transformer(nn.Module):
    """
    Transformer-based module for creating temporal node embeddings.

    Args:
        dim_model (int): The dimension of the model's hidden states.
        num_heads_TR (int): The number of attention heads.
        num_encoder_layers_TR (int): The number of encoder layers.
        num_decoder_layers_TR (int): The number of decoder layers.
        dropout_p_TR (float): Dropout probability.
    """

    def __init__(
        self, dim_model, num_heads_TR, num_encoder_layers_TR,
        num_decoder_layers_TR, dropout_p_TR):

        super().__init__()

        self.pos_encoder = PositionalEncoding(dim_model)
        self.transformer = nn.Transformer(
            d_model=dim_model,
            nhead=num_heads_TR,
            num_decoder_layers=num_encoder_layers_TR,
            num_encoder_layers=num_decoder_layers_TR,
            dropout=dropout_p_TR)

    def forward(self, src):
        """
        Forward pass of the Transformer module.

        Args:
            src (torch.Tensor): Input sequence with dimensions 
                                (seq_len, num_of_nodes, node_embedds_size).
            trg (torch.Tensor): Last element of src, with dimensions 
                                (1, num_of_nodes, node_embedds_size).

        Returns:
            torch.Tensor: Temporal node embeddings for the month
                          under prediciton.
        """
        batch_size, seq_len, num_nodes, node_embedds_size = src.shape
        trg = src[:, 0, ...]
        
        # merge first two dimensions of src and trg for transformer
        src = src.reshape(-1, num_nodes, node_embedds_size)
        trg = trg.reshape(-1, num_nodes, node_embedds_size)
    
        src = self.pos_encoder(src)
        trg = self.pos_encoder(trg)
        
        
        temporal_node_embeddings = self.transformer(src, trg)
        
        # go back to original shape
        temporal_node_embeddings = temporal_node_embeddings.reshape(
            batch_size, 1, num_nodes, node_embedds_size)
        
        return temporal_node_embeddings

In [6]:
from torch_geometric_temporal.nn import STConv

class STGCNModel(nn.Module):
    def __init__(self, num_features, hidden_channels, num_nodes, edge_index=None, K=2, kernel_size=1):
        super(STGCNModel, self).__init__()
        
        # Updated kernel size for input sequence length of 2
        self.stconv1 = STConv(
            in_channels=num_features, 
            hidden_channels=hidden_channels, 
            out_channels=hidden_channels, 
            num_nodes=num_nodes, 
            K=K, 
            kernel_size=kernel_size
        )
        
        '''self.stconv2 = STConv(
            in_channels=hidden_channels, 
            hidden_channels=hidden_channels, 
            out_channels=out_channels, 
            num_nodes=num_nodes, 
            K=K, 
            kernel_size=kernel_size
        )'''
        
        self.transformer = Transformer(hidden_channels,
                                       num_heads_TR=8,
                                       num_encoder_layers_TR=2,
                                       num_decoder_layers_TR=2,
                                       dropout_p_TR=0.1)
        
        # Example adjustment for final convolution
        # self.final_conv = nn.Conv2d(in_channels=2, out_channels=1, kernel_size=(1, 1))

        self.num_nodes = num_nodes
        self.edge_index = edge_index  # Store edge_index as part of the model's state

    def forward(self, x):
        # Expected input shape: [batch_size, input_seq_len, num_nodes, num_features]
        #batch_size, seq_len, num_nodes, num_features = x.size()
        batch_size, seq_len, num_nodes, num_features = x.size()
        
        # Apply ST-Conv layers using the stored edge_index
        x = self.stconv1(x, edge_index=self.edge_index)
        
        #print("x", x.shape)
        #x = self.stconv2(x, edge_index=self.edge_index)
        
        x = self.transformer(x)
                        
        #print("x", x.shape)
        
        x = x.view(batch_size, self.num_nodes, -1)
        
        return x


In [7]:
class CombinedModel(nn.Module):
    def __init__(self, num_features, num_nodes, edge_index, node_mask, timellm,
                 llm_input_size=12, gnn_input_size=3, llm_hidden_size=32, gnn_hidden_size=32):
        super(CombinedModel, self).__init__()
        
        self.node_mask = node_mask.bool()
        self.edge_index = edge_index
        self.llm_input_size = llm_input_size
        self.gnn_input_size = gnn_input_size
        
        # init stgcn and timellm
        self.stgcn = STGCNModel(
            num_features=num_features,
            hidden_channels=gnn_hidden_size,
            num_nodes=num_nodes,
            edge_index=edge_index
        )
        self.timellm = timellm
        
        # fc layer to combine the two models. It must take llm_hidden_size + gnn_hidden_size as input and output one value
        self.fc = nn.Sequential(
            nn.Linear(llm_hidden_size + gnn_hidden_size, 1)
        )

    def forward(self, X):
        """
        Combine STGCN and TimeLLM outputs for predictions.

        Args:
        - X (torch.Tensor): Input tensor of shape [batch_size, seq_len, num_nodes, num_features + 1].
        - edge_index (torch.Tensor): Edge index for the graph.

        Returns:
        - Y_pred (torch.Tensor): Predictions of shape [batch_size, num_nodes_subset, 1].
        """
        # STGCN for exogenous features
        assert X.size(1) >= self.gnn_input_size, "Input sequence length must be at least the GNN input size"
        X_stgcn = X[:, -self.gnn_input_size:, :, 1:]  # Exclude the first feature
        stgcn_out = self.stgcn(X_stgcn)  # Shape: [batch_size, num_nodes, stgcn_out_channels]
        #print("stgcn_out", stgcn_out.shape)

        # Extract relevant nodes from STGCN output
        stgcn_out_subset = stgcn_out[..., self.node_mask, :]  # shape: [batch_size, num_nodes_subset, stgcn_out_channels]
        
        # TimeLLM for historical target values
        assert X.size(1) >= self.llm_input_size, "Input sequence length must be at least the LLM input size"
        timellm_inputs = X[:, -self.llm_input_size:, self.node_mask, 0]  # Shape: [batch_size, llm_input_size, num_nodes_subset, 1]
        
        # reshape timellm_inputs from [batch_size, llm_input_size, num_nodes_subset, 1] to [batch_size * num_nodes_subset, llm_input_size]
        timellm_inputs = timellm_inputs.view(-1, self.llm_input_size)
        
        
        # Get TimeLLM embeddings
        #print("timellm_inputs", timellm_inputs.shape)
        timellm_embeddings = self.timellm(timellm_inputs)  # Shape: [batch_size * num_nodes_subset, llm_hidden_size]
        
        # reshape timellm_embeddings to [batch_size, num_nodes_subset, llm_hidden_size]
        timellm_embeddings = timellm_embeddings.view(-1, stgcn_out_subset.size(1), timellm_embeddings.size(-1))        
        
        #print("timellm_embeddings", timellm_embeddings.shape)
        #print("stgcn_out_subset", stgcn_out_subset.shape)
        #timellm_embeddings torch.Size([4, 128, 128])
        #stgcn_out_subset torch.Size([1, 128, 512, 1])

        # combine embeddings along the feature dimension
        combined_embeddings = torch.cat([stgcn_out_subset, timellm_embeddings], dim=-1)  # Shape: [batch_size, num_nodes_subset, gnn_hidden_size + llm_hidden_size]

        # Apply fc layer
        return self.fc(combined_embeddings)  # Shape: [batch_size * num_nodes_subset, 1]
        

In [None]:

import pickle
adj_mx_path = "/mnt/home/network-predictive-analysis/data/processed/1056.milc512+lammps512+ur32.cont.20240730/adj.pkl"
#read adj_mx
adj_mx = torch.tensor(pickle.load(open(adj_mx_path, "rb"))[2])

# transform to edge_index
edge_index = adj_mx.nonzero(as_tuple=False).t().contiguous()

node_mask_path = "/mnt/home/network-predictive-analysis/data/processed/1056.milc512+lammps512+ur32.cont.20240730/GNN/6in_1out_24features_0/active_nodes.npz"

# read npz
node_mask_original = torch.tensor(np.load(node_mask_path)["active_nodes"])
node_mask = torch.zeros(3960).bool()
for i in range(50):
    if node_mask_original[i]:
        node_mask[i] = True


In [None]:
# create features and target tensors

X_Y_path = "/mnt/home/network-predictive-analysis/data/processed/1056.milc512+lammps512+ur32.cont.20240730/X_Y_Tensors/MILC/X_Y.npz"

# read npz
data = np.load(X_Y_path)
X = data["x"]
Y = data["y"]
iterations = X.shape[0] - 1

num_timesteps_in = 12
num_timesteps_out = 1

indices = [(i, i + num_timesteps_in + num_timesteps_out) for i in range(iterations - num_timesteps_in - num_timesteps_out + 1)]

features, target = [], []
for lb, ub in indices:
  features.append((X[lb : lb + num_timesteps_in, :, :]))
  target.append((Y[lb + num_timesteps_in : ub, :, :]))

features = np.array(features)
target = np.array(target)

features = torch.tensor(features, dtype=torch.float32)
target = torch.tensor(target, dtype=torch.float32)
target = target[:, :, node_mask, :]

features.shape, target.shape

### train

In [None]:

import torch
import torch.nn as nn
import torch.optim as optim
from sklearn.preprocessing import MinMaxScaler
from torch.utils.data import DataLoader, TensorDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Training loop
import numpy as np
from torch.utils.tensorboard import SummaryWriter


lr_config_path = "/mnt/home/network-predictive-analysis/lr_config.txt"

def train_model(
    model, train_loader, val_loader, criterion, optimizer, num_epochs, device, log_dir="runs/experiment"
):
    # Initialize TensorBoard writer
    writer = SummaryWriter(log_dir=log_dir)
    
    model.to(device)
    for epoch in range(num_epochs):
        # Read and update learning rate
        try:
            with open(lr_config_path, "r") as f:
                lr = float(f.read().strip())
                for param_group in optimizer.param_groups:
                    param_group["lr"] = lr
        except Exception as e:
            print(f"Could not read learning rate file: {e}")
        
        # Log learning rate to TensorBoard
        writer.add_scalar("Learning Rate", optimizer.param_groups[0]["lr"], epoch)
            
            
        model.train()
        train_loss = 0
        predictions_tr, truths_tr = [], []
        for X_batch, Y_batch in train_loader:
            X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
            
            # Forward pass
            optimizer.zero_grad()
            outputs = model(X_batch)
            
            # Compute loss
            loss = criterion(outputs, Y_batch)
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            # copy predictions and truths to CPU
            predictions_cpu = outputs.detach().cpu()
            truths_cpu = Y_batch.detach().cpu()
            predictions_tr.append(predictions_cpu)
            truths_tr.append(truths_cpu)    
        
        predictions_tr = torch.stack(predictions_tr[:-1]).unsqueeze(-1).reshape(-1, predictions_tr[0].shape[1])
        truths_tr = torch.stack(truths_tr[:-1]).unsqueeze(-1).reshape(-1, truths_tr[0].shape[2])     
           
        predictions_tr = [t.reshape(t.shape[0]) for t in torch.split(predictions_tr, 1, dim=1)]
        truths_tr = [t.reshape(t.shape[0]) for t in torch.split(truths_tr, 1, dim=1)]
        
        # Validation
        model.eval()
        val_loss = 0
        predictions, truths = [], []
        with torch.no_grad():
            for X_batch, Y_batch in val_loader:
                X_batch, Y_batch = X_batch.to(device), Y_batch.to(device)
                outputs = model(X_batch)
                loss = criterion(outputs, Y_batch)
                val_loss += loss.item()
                
                # Collect predictions and truths for the specific node (node 11)
                predictions.append(outputs.cpu())
                truths.append(Y_batch.cpu())    
         
        predictions = torch.stack(predictions[:-1]).unsqueeze(-1).reshape(-1, predictions[0].shape[1])
        truths = torch.stack(truths[:-1]).unsqueeze(-1).reshape(-1, truths[0].shape[2])     
           
        predictions = [t.reshape(t.shape[0]) for t in torch.split(predictions, 1, dim=1)]
        truths = [t.reshape(t.shape[0]) for t in torch.split(truths, 1, dim=1)]
        
        # Log metrics to TensorBoard
        writer.add_scalar("Loss/Train", train_loss / len(train_loader), epoch)
        writer.add_scalar("Loss/Validation", val_loss / len(val_loader), epoch)
                
        print(f"Epoch {epoch + 1}/{num_epochs}, Train Loss: {train_loss / len(train_loader)}, Val Loss: {val_loss / len(val_loader)}")

    writer.close()




# Create DataLoader
batch_size = 8
train_dataset = TensorDataset(X_train_scaled, Y_train_scaled)
val_dataset = TensorDataset(X_val_scaled, Y_val_scaled)
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)

# Initialize model, loss function, and optimizer
llm_input_size = 12
llm_hidden_size = 32
timellm_model = TimeLLM(
            h=1, input_size=llm_input_size, d_ff=128, d_llm=768, enc_in=1, top_k=2,
            llm="openai-community/gpt2", hidden_size=llm_hidden_size
        ).to("cuda:0")

model = CombinedModel(num_features=X_train.shape[-1] - 1, num_nodes=X_train.shape[2], edge_index=edge_index.to(device), node_mask=node_mask.to(device), timellm=timellm_model, llm_input_size=llm_input_size, llm_hidden_size=llm_hidden_size)
criterion = nn.MSELoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

# Train the model
train_model(model, train_loader, val_loader, criterion, optimizer, num_epochs=2000, device=device, log_dir="/mnt/home/network-predictive-analysis/runs/experiment_3")
