# CSV Dataset

In [1]:
import random
from typing import List
import os
import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler

import spacetimeformer as stf

import matplotlib.pyplot as plt


class CSVTimeSeries:
    def __init__(
        self,
        data_path: str = None,
        raw_df: pd.DataFrame = None,
        target_cols: List[str] = [],
        ignore_cols: List[str] = [],
        remove_target_from_context_cols: List[str] = [],
        time_col_name: str = "Datetime",
        read_csv_kwargs={},
        val_split: float = 0.15,
        test_split: float = 0.15,
        normalize: bool = True,
        drop_all_nan: bool = False,
        time_features: List[str] = [
            "year",
            "month",
            "day",
            "weekday",
            "hour",
            "minute",
        ],
    ):

        assert data_path is not None or raw_df is not None

        if raw_df is None:
            self.data_path = data_path
            assert os.path.exists(self.data_path)
            raw_df = pd.read_csv(
                self.data_path,
                **read_csv_kwargs,
            )

        if drop_all_nan:
            raw_df.dropna(axis=0, how="any", inplace=True)

        self.time_col_name = time_col_name
        assert self.time_col_name in raw_df.columns

        if not target_cols:
            target_cols = raw_df.columns.tolist()
            target_cols.remove(time_col_name)

        if ignore_cols:
            if ignore_cols == "all":
                ignore_cols = raw_df.columns.difference(target_cols).tolist()
                ignore_cols.remove(self.time_col_name)
            raw_df.drop(columns=ignore_cols, inplace=True)

        time_df = pd.to_datetime(raw_df[self.time_col_name], format="%Y-%m-%d %H:%M")
        df = stf.data.timefeatures.time_features(
            time_df,
            raw_df,
            time_col_name=self.time_col_name,
            use_features=time_features,
        )
        self.time_cols = df.columns.difference(raw_df.columns)

        # Train/Val/Test Split using holdout approach #

        def mask_intervals(mask, intervals, cond):
            for (interval_low, interval_high) in intervals:
                if interval_low is None:
                    interval_low = df[self.time_col_name].iloc[0].year
                if interval_high is None:
                    interval_high = df[self.time_col_name].iloc[-1].year
                mask[
                    (df[self.time_col_name] >= interval_low)
                    & (df[self.time_col_name] <= interval_high)
                ] = cond
            return mask

        test_cutoff = len(time_df) - max(round(test_split * len(time_df)), 1)
        val_cutoff = test_cutoff - round(val_split * len(time_df))

        val_interval_low = time_df.iloc[val_cutoff]
        val_interval_high = time_df.iloc[test_cutoff - 1]
        val_intervals = [(val_interval_low, val_interval_high)]

        test_interval_low = time_df.iloc[test_cutoff]
        test_interval_high = time_df.iloc[-1]
        test_intervals = [(test_interval_low, test_interval_high)]

        train_mask = df[self.time_col_name] > pd.Timestamp.min
        val_mask = df[self.time_col_name] > pd.Timestamp.max
        test_mask = df[self.time_col_name] > pd.Timestamp.max
        train_mask = mask_intervals(train_mask, test_intervals, False)
        train_mask = mask_intervals(train_mask, val_intervals, False)
        val_mask = mask_intervals(val_mask, val_intervals, True)
        test_mask = mask_intervals(test_mask, test_intervals, True)

        if (train_mask == False).all():
            print(f"No training data detected for file {data_path}")

        self._train_data = df[train_mask]
        self._scaler = StandardScaler()

        self.target_cols = target_cols
        for col in remove_target_from_context_cols:
            assert (
                col in self.target_cols
            ), "`remove_target_from_context_cols` should be target cols that you want to remove from the context"

        self.remove_target_from_context_cols = remove_target_from_context_cols
        not_exo_cols = self.time_cols.tolist() + target_cols
        self.exo_cols = df.columns.difference(not_exo_cols).tolist()
        self.exo_cols.remove(self.time_col_name)

        self._train_data = df[train_mask]
        self._val_data = df[val_mask]
        if test_split == 0.0:
            print("`test_split` set to 0. Using Val set as Test set.")
            self._test_data = df[val_mask]
        else:
            self._test_data = df[test_mask]

        self.normalize = normalize
        if normalize:
            self._scaler = self._scaler.fit(
                self._train_data[target_cols + self.exo_cols].values
            )
        self._train_data = self.apply_scaling_df(self._train_data)
        self._val_data = self.apply_scaling_df(self._val_data)
        self._test_data = self.apply_scaling_df(self._test_data)

    def make_hists(self):
        for col in self.target_cols + self.exo_cols:
            train = self._train_data[col]
            test = self._test_data[col]
            bins = np.linspace(-5, 5, 80)  # warning: edit bucket limits
            plt.hist(train, bins, alpha=0.5, label="Train", density=True)
            plt.hist(test, bins, alpha=0.5, label="Test", density=True)
            plt.legend(loc="upper right")
            plt.title(col)
            plt.tight_layout()
            plt.savefig(f"{col}-hist.png")
            plt.clf()

    def get_slice(self, split, start, stop, skip):
        assert split in ["train", "val", "test"]
        if split == "train":
            return self.train_data.iloc[start:stop:skip]
        elif split == "val":
            return self.val_data.iloc[start:stop:skip]
        else:
            return self.test_data.iloc[start:stop:skip]

    def apply_scaling(self, array):
        if not self.normalize:
            return array
        dim = array.shape[-1]
        return (array - self._scaler.mean_[:dim]) / self._scaler.scale_[:dim]

    def apply_scaling_df(self, df):
        if not self.normalize:
            return df
        scaled = df.copy(deep=True)
        cols = self.target_cols + self.exo_cols
        dtype = df[cols].values.dtype
        scaled[cols] = (
            df[cols].values - self._scaler.mean_.astype(dtype)
        ) / self._scaler.scale_.astype(dtype)
        return scaled

    def reverse_scaling_df(self, df):
        if not self.normalize:
            return df
        scaled = df.copy(deep=True)
        cols = self.target_cols + self.exo_cols
        dtype = df[cols].values.dtype
        scaled[cols] = (
            df[cols].values * self._scaler.scale_.astype(dtype)
        ) + self._scaler.mean_.astype(dtype)
        return scaled

    def reverse_scaling(self, array):
        if not self.normalize:
            return array
        # self._scaler is fit for target_cols + exo_cols
        # if the array dim is less than this length we start
        # slicing from the target cols
        dim = array.shape[-1]
        return (array * self._scaler.scale_[:dim]) + self._scaler.mean_[:dim]

    @property
    def train_data(self):
        return self._train_data

    @property
    def val_data(self):
        return self._val_data

    @property
    def test_data(self):
        return self._test_data

    def length(self, split):
        return {
            "train": len(self.train_data),
            "val": len(self.val_data),
            "test": len(self.test_data),
        }[split]


class CSVTorchDset(Dataset):
    def __init__(
        self,
        csv_time_series: CSVTimeSeries,
        split: str = "train",
        context_points: int = 128,
        target_points: int = 32,
        time_resolution: int = 1,
    ):
        assert split in ["train", "val", "test"]
        self.split = split
        self.series = csv_time_series
        self.context_points = context_points
        self.target_points = target_points
        self.time_resolution = time_resolution

        self._slice_start_points = [
            i
            for i in range(
                0,
                self.series.length(split)
                + time_resolution * (-target_points - context_points)
                + 1,
            )
        ]

    def __len__(self):
        return len(self._slice_start_points)

    def _torch(self, *dfs):
        return tuple(torch.from_numpy(x.values).float() for x in dfs)

    def __getitem__(self, i):
        start = self._slice_start_points[i]
        series_slice = self.series.get_slice(
            self.split,
            start=start,
            stop=start
            + self.time_resolution * (self.context_points + self.target_points),
            skip=self.time_resolution,
        )
        series_slice = series_slice.drop(columns=[self.series.time_col_name])
        ctxt_slice, trgt_slice = (
            series_slice.iloc[: self.context_points],
            series_slice.iloc[self.context_points :],
        )

        ctxt_x = ctxt_slice[self.series.time_cols]
        trgt_x = trgt_slice[self.series.time_cols]

        ctxt_y = ctxt_slice[self.series.target_cols + self.series.exo_cols]
        ctxt_y = ctxt_y.drop(columns=self.series.remove_target_from_context_cols)

        trgt_y = trgt_slice[self.series.target_cols]

        return self._torch(ctxt_x, ctxt_y, trgt_x, trgt_y)

# DataModule

In [12]:
import warnings

import torch
from torch.utils.data import DataLoader


class DataModule():
    def __init__(
        self,
        datasetCls,
        dataset_kwargs: dict,
        batch_size: int,
        workers: int,
        collate_fn=None,
        overfit: bool = False,
    ):
        super().__init__()
        self.datasetCls = datasetCls
        self.batch_size = batch_size
        if "split" in dataset_kwargs.keys():
            del dataset_kwargs["split"]
        self.dataset_kwargs = dataset_kwargs
        self.workers = workers
        self.collate_fn = collate_fn
        if overfit:
            warnings.warn("Overriding val and test dataloaders to use train set!")
        self.overfit = overfit

    def train_dataloader(self, shuffle=True):
        return self._make_dloader("train", shuffle=shuffle)

    def val_dataloader(self, shuffle=False):
        return self._make_dloader("val", shuffle=shuffle)

    def test_dataloader(self, shuffle=False):
        return self._make_dloader("test", shuffle=shuffle)

    def _make_dloader(self, split, shuffle=False):
        if self.overfit:
            split = "train"
            shuffle = True
        return DataLoader(
            self.datasetCls(**self.dataset_kwargs, split=split),
            shuffle=shuffle,
            batch_size=self.batch_size,
            num_workers=self.workers,
            collate_fn=self.collate_fn,
        )

# RevIN

In [2]:
"""
Reversible Instance Normalization from 
https://github.com/ts-kim/RevIN
"""

import torch
import torch.nn as nn


class MovingAvg(nn.Module):
    def __init__(self, kernel_size, stride):
        super().__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class SeriesDecomposition(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.moving_avg = MovingAvg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class RevIN(nn.Module):
    def __init__(self, num_features: int, eps=1e-5, affine=True):
        """
        :param num_features: the number of features or channels
        :param eps: a value added for numerical stability
        :param affine: if True, RevIN has learnable affine parameters
        """
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        if self.affine:
            self._init_params()

    def forward(self, x, mode: str, update_stats=True):
        assert x.ndim == 3
        if mode == "norm":
            if update_stats:
                self._get_statistics(x)
            x = self._normalize(x)
        elif mode == "denorm":
            x = self._denormalize(x)
        else:
            raise NotImplementedError
        return x

    def _init_params(self):
        # initialize RevIN params: (C,)
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim - 1))
        self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(
            torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps
        ).detach()

    def _normalize(self, x):
        x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias
        return x

    def _denormalize(self, x):
        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps * self.eps)
        x = x * self.stdev
        x = x + self.mean
        return x

# Time2Vec

In [3]:
import torch
from torch import nn


class Time2Vec(nn.Module):
    def __init__(self, input_dim=6, embed_dim=512, act_function=torch.sin):
        assert embed_dim % input_dim == 0
        super(Time2Vec, self).__init__()
        self.enabled = embed_dim > 0
        if self.enabled:
            self.embed_dim = embed_dim // input_dim
            self.input_dim = input_dim
            self.embed_weight = nn.Parameter(torch.randn(self.input_dim, self.embed_dim))
            self.embed_bias = nn.Parameter(torch.randn(self.input_dim, self.embed_dim))
            self.act_function = act_function

    def forward(self, x):
        if self.enabled:
            x = torch.diag_embed(x)
            # x.shape = (bs, sequence_length, input_dim, input_dim)
            x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias
            # x_affine.shape = (bs, sequence_length, input_dim, time_embed_dim)
            x_affine_0, x_affine_remain = torch.split(
                x_affine, [1, self.embed_dim - 1], dim=-1
            )
            x_affine_remain = self.act_function(x_affine_remain)
            x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1)
            x_output = x_output.view(x_output.size(0), x_output.size(1), -1)
            # x_output.shape = (bs, sequence_length, input_dim * time_embed_dim)
        else:
            x_output = x
        return x_output

# Eval Stats

In [4]:
import numpy as np

EPSILON = 1e-7


def r_squared(actual: np.ndarray, predicted: np.ndarray):
    rss = (_error(actual, predicted) ** 2).sum(1)
    tss = (_error(actual, actual.mean(1, keepdims=True)) ** 2).sum(1)
    r2 = 1.0 - rss / (tss + EPSILON)
    return r2.mean()


def _error(actual: np.ndarray, predicted: np.ndarray):
    """Simple error"""
    return actual - predicted


def _percentage_error(actual: np.ndarray, predicted: np.ndarray):
    """
    Percentage error

    Note: result is NOT multiplied by 100
    """
    return _error(actual, predicted) / (actual + EPSILON)


def mse(actual: np.ndarray, predicted: np.ndarray):
    """Mean Squared Error"""
    return np.mean(np.square(_error(actual, predicted)))


def mae(actual: np.ndarray, predicted: np.ndarray):
    """Mean Absolute Error"""
    return np.mean(np.abs(_error(actual, predicted)))


def mape(actual: np.ndarray, predicted: np.ndarray):
    """Mean Absolute Percentage Error"""
    return np.mean(np.abs(_percentage_error(actual, predicted)))


def smape(actual: np.ndarray, predicted: np.ndarray):
    """Symmetric Mean Absolute Percentage Error"""
    return np.mean(
        2.0
        * np.abs(actual - predicted)
        / ((np.abs(actual) + np.abs(predicted)) + EPSILON)
    )

# Callbacks

In [None]:
import torch

class TeacherForcingAnnealCallback:
    def __init__(self, start: float, end: float, steps: int):
        assert start >= end, "start must be >= end"
        self.start = start
        self.end = end
        self.steps = steps
        self.slope = float(start - end) / steps

    def on_train_batch_end(self, model: torch.nn.Module):
        current = model.teacher_forcing_prob
        new_teacher_forcing_prob = max(self.end, current - self.slope)
        model.teacher_forcing_prob = new_teacher_forcing_prob
        print("Teacher Forcing Prob:", new_teacher_forcing_prob)


class TimeMaskedLossCallback:
    def __init__(self, start, end, steps):
        assert start <= end, "end must be >= start"
        self.start = start
        self.end = end
        self.steps = steps
        self.slope = float(end - start) / steps
        self._time_mask = self.start

    @property
    def time_mask(self):
        return round(self._time_mask)

    def on_train_start(self, model: torch.nn.Module):
        if model.time_masked_idx is None:
            model.time_masked_idx = self.time_mask

    def on_train_batch_end(self, model: torch.nn.Module):
        self._time_mask = min(self.end, self._time_mask + self.slope)
        model.time_masked_idx = self.time_mask
        print("Time Masked Index:", self.time_mask)

class ModelCheckpoint:
    def __init__(
        self,
        dirpath: str,
        monitor: str = "val/loss",
        mode: str = "min",
        filename: str = "",
        save_top_k: int = 1,
        auto_insert_metric_name: bool = True,
    ):
        """
        Args:
            dirpath (str): Directory where checkpoints are saved.
            monitor (str): The key to monitor (e.g., "val/loss").
            mode (str): If "min", lower monitored values are better; if "max", higher are better.
            filename (str): A filename prefix. With auto_insert_metric_name true, the metric name is added automatically.
            save_top_k (int): How many top models to save. (This simple version only keeps the best model.)
            auto_insert_metric_name (bool): Whether to automatically insert the monitored metric name in the filename.
        """
        self.dirpath = dirpath
        os.makedirs(self.dirpath, exist_ok=True)
        self.monitor = monitor
        self.mode = mode
        self.filename_prefix = filename
        self.save_top_k = save_top_k  # This simple implementation only saves the best model.
        self.auto_insert_metric_name = auto_insert_metric_name
        
        # Internal state: best score seen so far.
        self.best_score = None

    def _is_improvement(self, current: float) -> bool:
        # If no best score yet, consider current as improvement.
        if self.best_score is None:
            return True
        if self.mode == "min":
            return current < self.best_score
        else:
            return current > self.best_score

    def on_validation_epoch_end(self, epoch: int, metrics: dict, model: torch.nn.Module):
        """
        To be called at the end of a validation epoch.
        
        Args:
            epoch (int): Current epoch number.
            metrics (dict): Dictionary that must contain the monitored metric (e.g. {"val/loss": value}).
            model (torch.nn.Module): The model to be checkpointed.
        """
        current = metrics.get(self.monitor)
        if current is None:
            print(f"Metric '{self.monitor}' not found in metrics; skipping checkpoint.")
            return

        if self._is_improvement(current):
            self.best_score = current
            # Build filename: insert metric info if requested
            metric_str = f"_{self.monitor.replace('/', '_')}={current:.4f}" if self.auto_insert_metric_name else ""
            filename = f"{self.filename_prefix}{epoch:02d}{metric_str}.pt"
            filepath = os.path.join(self.dirpath, filename)
            torch.save(model.state_dict(), filepath)
            print(f"[Epoch {epoch:02d}] Checkpoint saved to: {filepath}")

# LinearAR Model

In [6]:
import math
import torch
from torch import nn
from torch.optim import Adam
from einops import rearrange
import pandas as pd
import numpy as np

In [7]:
class LinearModel(nn.Module):
    def __init__(self, context_points: int, shared_weights: bool = False, d_yt: int = 7):
        super().__init__()

        if not shared_weights:
            assert d_yt is not None
            layer_count = d_yt
        else:
            layer_count = 1

        self.weights = nn.Parameter(torch.ones((context_points, layer_count)), requires_grad=True)
        self.bias = nn.Parameter(torch.ones((layer_count)), requires_grad=True)

        d = math.sqrt(1.0 / context_points)
        self.weights.data.uniform_(-d, d)
        self.bias.data.uniform_(-d, d)

        self.window = context_points
        self.shared_weights = shared_weights
        self.d_yt = d_yt

    def forward(self, y_c: torch.Tensor, pred_len: int, d_yt: int = None):
        batch, length, d_yc = y_c.shape
        d_yt = d_yt or self.d_yt

        output = torch.zeros(batch, pred_len, d_yt, device=y_c.device)

        for i in range(pred_len):
            inp = torch.cat((y_c[:, i:, :d_yt], output[:, :i]), dim=1)
            output[:, i, :] = self._inner_forward(inp)
        return output

    def _inner_forward(self, inp):
        batch = inp.shape[0]
        if self.shared_weights:
            inp = rearrange(inp, "batch length dy -> (batch dy) length 1")
        baseline = (self.weights * inp[:, -self.window :, :]).sum(1) + self.bias
        if self.shared_weights:
            baseline = rearrange(baseline, "(batch dy) 1 -> batch dy", batch=batch)
        return baseline

In [8]:
def train(model, optimizer, criterion, y_c, target, epochs=1000):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        pred_len = target.shape[1]
        output = model(y_c, pred_len)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

In [9]:
# ETTm1
import numpy as np
import pandas as pd

def create_sliding_windows(data: np.ndarray, context_points: int, pred_len: int):
    """
    data: numpy array of shape (T, d_yt)
    returns: y_context (batch, context_points, d_yt) and
             y_target (batch, pred_len, d_yt)
    """
    X, Y = [], []
    total_length = data.shape[0]
    # Create sliding windows
    for i in range(total_length - context_points - pred_len + 1):
        X.append(data[i : i + context_points])
        Y.append(data[i + context_points : i + context_points + pred_len])
    return np.array(X), np.array(Y)

In [9]:
csv_path = "S:\\spatiotemporal-analysis\\ETTm1_modified.csv"
df = pd.read_csv(csv_path, parse_dates=["date"])

# Sort by date (if not already sorted)
df.sort_values("date", inplace=True)

# Use the numeric columns as features (d_yt should equal the number of features: 7)
feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
data = df[feature_columns].values  # shape (T, 7)

# Define context and prediction lengths
context_points = 10
pred_len = 5

# Create sliding windows from the data
y_context_np, y_target_np = create_sliding_windows(data, context_points, pred_len)

# Convert to torch tensors
y_context = torch.tensor(y_context_np, dtype=torch.float32)
y_target = torch.tensor(y_target_np, dtype=torch.float32)

print("y_context shape:", y_context.shape)  # (batch, context_points, 7)
print("y_target shape:", y_target.shape)    # (batch, pred_len, 7)

# Create the model, optimizer, and loss function
d_yt = len(feature_columns)
model = LinearModel(context_points=context_points, shared_weights=False, d_yt=d_yt)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Train the model using all training examples as one batch.
train(model, optimizer, criterion, y_context, y_target, epochs=1000)

y_context shape: torch.Size([69666, 10, 7])
y_target shape: torch.Size([69666, 5, 7])
Epoch 0, Loss: 70.6213
Epoch 100, Loss: 7.1133
Epoch 200, Loss: 4.4918
Epoch 300, Loss: 4.1479
Epoch 400, Loss: 3.7757
Epoch 500, Loss: 3.4020


KeyboardInterrupt: 

# Forecaster - The parent class model

In [10]:
from abc import ABC, abstractmethod
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Forecaster(nn.Module, ABC):
    def __init__(
        self,
        d_x: int,
        d_yc: int,
        d_yt: int,
        learning_rate: float = 1e-3,
        l2_coeff: float = 0,
        loss: str = "mse",
        linear_window: int = 0,
        linear_shared_weights: bool = False,
        use_revin: bool = False,
        use_seasonal_decomp: bool = False,
        verbose: int = True,
    ):
        super().__init__()

        qprint = lambda msg: print(msg) if verbose else None
        qprint("Forecaster")
        qprint(f"\tL2: {l2_coeff}")
        qprint(f"\tLinear Window: {linear_window}")
        qprint(f"\tLinear Shared Weights: {linear_shared_weights}")
        qprint(f"\tRevIN: {use_revin}")
        qprint(f"\tDecomposition: {use_seasonal_decomp}")

        self._inv_scaler = lambda x: x
        self.l2_coeff = l2_coeff
        self.learning_rate = learning_rate
        self.time_masked_idx = None
        self.null_value = None
        self.loss = loss

        if linear_window:
            self.linear_model = LinearModel(
                linear_window, shared_weights=linear_shared_weights, d_yt=d_yt
            )
        else:
            self.linear_model = lambda x, *args, **kwargs: 0.0

        self.use_revin = use_revin
        if use_revin:
            assert d_yc == d_yt, "TODO: figure out exo case for revin"
            self.revin = RevIN(num_features=d_yc)
        else:
            self.revin = lambda x, **kwargs: x

        self.use_seasonal_decomp = use_seasonal_decomp
        if use_seasonal_decomp:
            self.seasonal_decomp = SeriesDecomposition(kernel_size=25)
        else:
            self.seasonal_decomp = lambda x: (x, x.clone())

        self.d_x = d_x
        self.d_yc = d_yc
        self.d_yt = d_yt

    def set_null_value(self, val: float) -> None:
        self.null_value = val

    def set_inv_scaler(self, scaler) -> None:
        self._inv_scaler = scaler

    def set_scaler(self, scaler) -> None:
        self._scaler = scaler

    @property
    @abstractmethod
    def train_step_forward_kwargs(self):
        return {}

    @property
    @abstractmethod
    def eval_step_forward_kwargs(self):
        return {}

    def loss_fn(self, true: torch.Tensor, preds: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        true = torch.nan_to_num(true)
        if self.loss == "mse":
            loss = (mask * (true - preds)).square().sum() / max(mask.sum(), 1)
        elif self.loss == "mae":
            loss = torch.abs(mask * (true - preds)).sum() / max(mask.sum(), 1)
        elif self.loss == "smape":
            num = 2.0 * torch.abs(preds - true)
            den = torch.abs(preds.detach()) + torch.abs(true) + 1e-5
            loss = 100.0 * (mask * (num / den)).sum() / max(mask.sum(), 1)
        else:
            raise ValueError(f"Unrecognized Loss Function: {self.loss}")
        return loss

    def forecasting_loss(self, outputs: torch.Tensor, y_t: torch.Tensor, time_mask: int) -> Tuple[torch.Tensor]:
        if self.null_value is not None:
            null_mask_mat = y_t != self.null_value
        else:
            null_mask_mat = torch.ones_like(y_t)
        null_mask_mat *= ~torch.isnan(y_t)

        time_mask_mat = torch.ones_like(y_t)
        if time_mask is not None:
            time_mask_mat[:, time_mask:] = False

        full_mask = time_mask_mat * null_mask_mat
        forecasting_loss = self.loss_fn(y_t, outputs, full_mask)
        return forecasting_loss, full_mask

    def compute_loss(self, batch: Tuple[torch.Tensor], time_mask: int = None, forward_kwargs: dict = {}) -> Tuple[torch.Tensor]:
        x_c, y_c, x_t, y_t = batch
        outputs, *_ = self.forward(x_c, y_c, x_t, y_t, **forward_kwargs)
        loss, mask = self.forecasting_loss(outputs=outputs, y_t=y_t, time_mask=time_mask)
        return loss, outputs, mask

    def predict(self, x_c: torch.Tensor, y_c: torch.Tensor, x_t: torch.Tensor, sample_preds: bool = False) -> torch.Tensor:
        og_device = y_c.device
        # Ensure tensors are on the same device as the model.
        x_c = x_c.to(next(self.parameters()).device).float()
        x_t = x_t.to(next(self.parameters()).device).float()
        y_c = torch.from_numpy(self._scaler(y_c.cpu().numpy())).to(next(self.parameters()).device).float()
        y_t = torch.zeros((x_t.shape[0], x_t.shape[1], self.d_yt), device=next(self.parameters()).device).float()

        with torch.no_grad():
            normalized_preds, *_ = self.forward(x_c, y_c, x_t, y_t, **self.eval_step_forward_kwargs)
        preds = torch.from_numpy(self._inv_scaler(normalized_preds.cpu().numpy())).to(og_device).float()
        return preds

    @abstractmethod
    def forward_model_pass(
        self,
        x_c: torch.Tensor,
        y_c: torch.Tensor,
        x_t: torch.Tensor,
        y_t: torch.Tensor,
        **forward_kwargs,
    ) -> Tuple[torch.Tensor]:
        return NotImplemented

    def nan_to_num(self, *inps):
        return (torch.nan_to_num(i) for i in inps)

    def forward(self, x_c: torch.Tensor, y_c: torch.Tensor, x_t: torch.Tensor, y_t: torch.Tensor, **forward_kwargs) -> Tuple[torch.Tensor]:
        x_c, y_c, x_t, y_t = self.nan_to_num(x_c, y_c, x_t, y_t)
        _, pred_len, d_yt = y_t.shape

        y_c = self.revin(y_c, mode="norm")
        seasonal_yc, trend_yc = self.seasonal_decomp(y_c)
        preds, *extra = self.forward_model_pass(x_c, seasonal_yc, x_t, y_t, **forward_kwargs)
        baseline = self.linear_model(trend_yc, pred_len=pred_len, d_yt=d_yt)
        output = self.revin(preds + baseline, mode="denorm")

        if extra:
            return (output,) + tuple(extra)
        return (output,)

    def _compute_stats(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.Tensor):
        pred = pred * mask
        true = torch.nan_to_num(true) * mask

        adj = mask.mean().cpu().numpy() + 1e-5
        pred = pred.detach().cpu().numpy()
        true = true.detach().cpu().numpy()
        scaled_pred = self._inv_scaler(pred)
        scaled_true = self._inv_scaler(true)
        stats = {
            "mape": mape(scaled_true, scaled_pred) / adj,
            "mae": mae(scaled_true, scaled_pred) / adj,
            "mse": mse(scaled_true, scaled_pred) / adj,
            "smape": smape(scaled_true, scaled_pred) / adj,
            "norm_mae": mae(true, pred) / adj,
            "norm_mse": mse(true, pred) / adj,
        }
        return stats

    def step(self, batch: Tuple[torch.Tensor], train: bool = False):
        kwargs = self.train_step_forward_kwargs if train else self.eval_step_forward_kwargs
        time_mask = self.time_masked_idx if train else None
        loss, output, mask = self.compute_loss(batch=batch, time_mask=time_mask, forward_kwargs=kwargs)
        *_, y_t = batch
        stats = self._compute_stats(output, y_t, mask)
        stats["loss"] = loss
        return stats
    
    def training_step(self, batch, batch_idx):
        return self.step(batch, train=True)

    def validation_step(self, batch, batch_idx):
        stats = self.step(batch, train=False)
        self.current_val_stats = stats
        return stats

    def test_step(self, batch, batch_idx):
        return self.step(batch, train=False)

    def _log_stats(self, section, outs):
        for key in outs.keys():
            stat = outs[key]
            if isinstance(stat, np.ndarray) or isinstance(stat, torch.Tensor):
                stat = stat.mean()
            self.log(f"{section}/{key}", stat, sync_dist=True)

    def training_step_end(self, outs):
        self._log_stats("train", outs)
        return {"loss": outs["loss"].mean()}

    def validation_step_end(self, outs):
        self._log_stats("val", outs)
        return outs

    def test_step_end(self, outs):
        self._log_stats("test", outs)
        return {"loss": outs["loss"].mean()}

    def predict_step(self, batch, batch_idx):
        return self(*batch, **self.eval_step_forward_kwargs)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.learning_rate, weight_decay=self.l2_coeff
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=3,
            factor=0.2,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val/loss",
            },
        }

# LinearAR Forecaster

In [11]:
import torch
from torch import nn
import torch.nn.functional as F

class Linear_Forecaster(Forecaster):
    def __init__(
        self,
        d_x: int,
        d_yc: int,
        d_yt: int,
        context_points: int,
        learning_rate: float = 1e-3,
        l2_coeff: float = 0,
        loss: str = "mse",
        linear_window: int = 0,
        linear_shared_weights: bool = False,
        use_revin: bool = False,
        use_seasonal_decomp: bool = False,
    ):
        super().__init__(
            d_x=d_x,
            d_yc=d_yc,
            d_yt=d_yt,
            l2_coeff=l2_coeff,
            learning_rate=learning_rate,
            loss=loss,
            linear_window=linear_window,
            linear_shared_weights=linear_shared_weights,
            use_revin=use_revin,
            use_seasonal_decomp=use_seasonal_decomp,
        )

        self.model = LinearModel(
            context_points, shared_weights=linear_shared_weights, d_yt=d_yt
        )

    @property
    def eval_step_forward_kwargs(self):
        return {}

    @property
    def train_step_forward_kwargs(self):
        return {}

    def forward_model_pass(self, x_c, y_c, x_t, y_t):
        _, pred_len, d_yt = y_t.shape
        output = self.model(y_c, pred_len=pred_len, d_yt=d_yt)
        return (output,)

    @classmethod
    def add_cli(self, parser):
        super().add_cli(parser)


# Train LinearAR Model

In [30]:
from types import SimpleNamespace

config = {
    "dset": "ettm1",
    "data_path": "S:\\spatiotemporal-analysis\\ETTm1_modified.csv",
    "context_points": 10,
    "target_points": 10,
    "time_resolution": 10,
    "batch_size": 32,
    "workers": 1,
    "overfit": False,
    "model": "linear",
    "learning_rate": 1e-3,
    "l2_coeff": 0,
    "loss": "mse",
    "linear_window": 10,
    "linear_shared_weights": False,
    "use_revin": False,
    "use_seasonal_decomp": False,
    "run_name": "linear-ettm1",
}

config = SimpleNamespace(**config)


In [31]:
_MODELS = ["spacetimeformer", "mtgnn", "heuristic", "lstm", "lstnet", "linear", "s4"]

_DSETS = [
    "asos",
    "metr-la",
    "pems-bay",
    "exchange",
    "precip",
    "toy2",
    "solar_energy",
    "syn",
    "mnist",
    "cifar",
    "copy",
    "cont_copy",
    "m4",
    "wiki",
    "ettm1",
    "weather",
    "monash",
    "hangzhou",
    "traffic",
]

In [35]:
import random
import sys
import warnings
import os
import uuid

import torch

def create_dset(config):
    INV_SCALER = lambda x: x
    SCALER = lambda x: x
    NULL_VAL = None
    PLOT_VAR_IDXS = None
    PLOT_VAR_NAMES = None
    PAD_VAL = None

    if config.dset == "ettm1":
        target_cols = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
        dset = CSVTimeSeries(
            data_path=config.data_path,
            target_cols=target_cols,
            ignore_cols=[],
            val_split=4.0 / 20,  # from informer
            test_split=4.0 / 20,  # from informer
            time_col_name="date",
            time_features=["month", "day", "weekday", "hour"],
        )
        DATA_MODULE = DataModule(
            datasetCls=CSVTorchDset,
            dataset_kwargs={
                "csv_time_series": dset,
                "context_points": config.context_points,
                "target_points": config.target_points,
                "time_resolution": config.time_resolution,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=config.overfit,
        )
        INV_SCALER = dset.reverse_scaling
        SCALER = dset.apply_scaling
        NULL_VAL = None
        # PAD_VAL = -32.0
        PLOT_VAR_NAMES = target_cols
        PLOT_VAR_IDXS = [i for i in range(len(target_cols))]

    return (
        DATA_MODULE,
        INV_SCALER,
        SCALER,
        NULL_VAL,
        PLOT_VAR_IDXS,
        PLOT_VAR_NAMES,
        PAD_VAL,
    )

def create_model(config):
    x_dim, yc_dim, yt_dim = None, None, None

    if config.dset == "ettm1":
        x_dim = 4
        yc_dim = 7
        yt_dim = 7

    assert x_dim is not None
    assert yc_dim is not None
    assert yt_dim is not None

    if config.model == "linear":
        forecaster = Linear_Forecaster(
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            context_points=config.context_points,
            learning_rate=config.learning_rate,
            l2_coeff=config.l2_coeff,
            loss=config.loss,
            linear_window=config.linear_window,
            linear_shared_weights=config.linear_shared_weights,
            use_revin=config.use_revin,
            use_seasonal_decomp=config.use_seasonal_decomp,
        )

    return forecaster

def create_callbacks(config, save_dir):
    # Create a unique checkpoint directory (for saving the best model)
    filename = f"{config.run_name}_" + str(uuid.uuid1()).split("-")[0]
    model_ckpt_dir = os.path.join(save_dir, filename)
    os.makedirs(model_ckpt_dir, exist_ok=True)
    config.model_ckpt_dir = model_ckpt_dir

    # Create a list of callback objects.
    callbacks = []

    # Add custom callbacks if desired (they must have an on_train_batch_end(model) method)
    
    return callbacks, model_ckpt_dir

In [38]:
def run_validation(model, val_loader, device, epoch, callbacks, criterion=None):
    model.eval()
    total_loss = 0.0
    n_batches = 0
    with torch.no_grad():
        for batch in val_loader:
            # Move each tensor to device. Expecting batch to be a tuple (x_c, y_c, x_t, y_t)
            batch = [b.to(device) for b in batch]
            # Forward pass. Here we assume your model has a method compute_loss or similar.
            loss, outputs, _ = model.compute_loss(batch, time_mask=None, forward_kwargs=model.eval_step_forward_kwargs)
            total_loss += loss.item()
            n_batches += 1
    avg_loss = total_loss / n_batches if n_batches > 0 else float('inf')
    
    # Call validation callback hooks for checkpointing/early stopping etc.
    for cb in callbacks:
        if hasattr(cb, "on_validation_epoch_end"):
            metrics = {"val/loss": avg_loss}
            cb.on_validation_epoch_end(epoch, metrics, model)
    
    return avg_loss

def train(model, optimizer, train_loader, val_loader, device, args, callbacks):
    # Setup for gradient accumulation
    accumulate_steps = args.accumulate if hasattr(args, 'accumulate') else 1
    gradient_clip = args.grad_clip_norm if hasattr(args, 'grad_clip_norm') else None
    
    best_val_loss = float('inf')
    best_epoch = 0
    epoch = 0
    early_stop_counter = 0
    
    while epoch < args.epochs:
        model.train()
        running_loss = 0.0
        optimizer.zero_grad()
        for batch_idx, batch in enumerate(train_loader):
            # Overfit mode handling: if debug, you might limit the number of batches.

            batch = [b.to(device) for b in batch]
            loss, outputs, _ = model.compute_loss(batch, time_mask=None, forward_kwargs=model.train_step_forward_kwargs)
            loss = loss / accumulate_steps
            loss.backward()
            
            # Call per-batch callbacks that use on_train_batch_end.
            for cb in callbacks:
                if hasattr(cb, "on_train_batch_end"):
                    # We omit trainer parameter for simplicity.
                    cb.on_train_batch_end(None, model)
            
            if (batch_idx + 1) % accumulate_steps == 0:
                if gradient_clip is not None:
                    # Gradient clipping using norm.
                    nn.utils.clip_grad_norm_(model.parameters(), gradient_clip)
                optimizer.step()
                optimizer.zero_grad()
            running_loss += loss.item() * accumulate_steps  # scale back loss for logging

            # Run validation at intervals if desired
            if (batch_idx + 1) % val_check_freq == 0:
                val_loss = run_validation(model, val_loader, device, epoch, callbacks)
                print(f"[Epoch {epoch:02d}, Batch {batch_idx+1}/{total_train_batches}] Validation Loss: {val_loss:.6f}")
                # Early stopping check if early stopping callback not available
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_epoch = epoch
                    early_stop_counter = 0
                    # Save checkpoint manually (if no Lightning saving callback is provided)
                    ckpt_name = f"{args.run_name}_{epoch:02d}_val_loss={val_loss:.4f}.pt"
                    ckpt_path = os.path.join(args.model_ckpt_dir, ckpt_name)
                    os.makedirs(args.model_ckpt_dir, exist_ok=True)
                    torch.save(model.state_dict(), ckpt_path)
                    print(f"Checkpoint saved: {ckpt_path}")
                else:
                    early_stop_counter += 1
                    if early_stop_counter >= args.patience:
                        print("Early stopping triggered.")
                        return best_epoch, best_val_loss
        epoch_loss = running_loss / total_train_batches
        print(f"[Epoch {epoch:02d}] Training Loss: {epoch_loss:.6f}")
        # Run epoch-end validation if not already done inside batch loop
        if val_check_freq > total_train_batches:
            val_loss = run_validation(model, val_loader, device, epoch, callbacks)
            print(f"[Epoch {epoch:02d}] Validation Loss: {val_loss:.6f}")
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                best_epoch = epoch
                early_stop_counter = 0
                ckpt_name = f"{args.run_name}_{epoch:02d}_val_loss={val_loss:.4f}.pt"
                ckpt_path = os.path.join(args.model_ckpt_dir, ckpt_name)
                os.makedirs(args.model_ckpt_dir, exist_ok=True)
                torch.save(model.state_dict(), ckpt_path)
                print(f"Checkpoint saved: {ckpt_path}")
            else:
                early_stop_counter += 1
                if early_stop_counter >= args.patience:
                    print("Early stopping triggered.")
                    return best_epoch, best_val_loss
        epoch += 1
    return best_epoch, best_val_loss

def test(model, test_loader, device, criterion=None):
    model.eval()
    total_loss = 0.0
    n_batches = 0
    with torch.no_grad():
        for batch in test_loader:
            batch = [b.to(device) for b in batch]
            loss, outputs, _ = model.compute_loss(batch, time_mask=None, forward_kwargs=model.eval_step_forward_kwargs)
            total_loss += loss.item()
            n_batches += 1
    avg_loss = total_loss / n_batches if n_batches > 0 else float('inf')
    print(f"Test Loss: {avg_loss:.6f}")
    return avg_loss

In [None]:
def main(args):
    # Setup log directory.
    log_dir = os.getenv("STF_LOG_DIR")
    if log_dir is None:
        log_dir = "./data/STF_LOG_DIR"
        print("Using default log dir: ./data/STF_LOG_DIR")
    if not os.path.exists(log_dir):
        os.makedirs(log_dir)
    
    # Optionally initialize wandb
    logger = None

    # Create dataset/dataloaders.
    (data_module, inv_scaler, scaler, null_val, plot_var_idxs, plot_var_names, pad_val) = create_dset(args)
    
    # Create model.
    args.null_value = null_val
    args.pad_value = pad_val
    model = create_model(args)
    model.set_inv_scaler(inv_scaler)
    model.set_scaler(scaler)
    model.set_null_value(null_val)
    
    # Create callbacks and checkpoint directory.
    filename = f"{args.run_name}_" + str(uuid.uuid1()).split("-")[0]
    model_ckpt_dir = os.path.join(log_dir, filename)
    args.model_ckpt_dir = model_ckpt_dir
    os.makedirs(model_ckpt_dir, exist_ok=True)
    callbacks = create_callbacks(args, save_dir=log_dir)
    
    
    # Decide validation control frequency.
    # If args.val_check_interval <= 1.0, it is a fraction, otherwise, epoch interval.
    # Here we assume args.epochs exists.
    args.epochs = args.epochs if hasattr(args, 'epochs') else 50
    
    # Use GPU if available.
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model.to(device)
    
    # Create dataloaders.
    train_loader = data_module.train_dataloader(shuffle=True)
    val_loader = data_module.val_dataloader(shuffle=False)
    test_loader = data_module.test_dataloader(shuffle=False)
    
    # Setup optimizer. Adjust LR as needed.
    optimizer = torch.optim.Adam(model.parameters(), lr=args.learning_rate)
    
    import time
    start_time = time.time()
    best_epoch, best_val_loss = train(model, optimizer, train_loader, val_loader, device, args, callbacks)
    duration = time.time() - start_time
    print(f"Training completed in {duration:.1f} seconds; Best epoch: {best_epoch}, Best val loss: {best_val_loss:.6f}")
    
    # Load best checkpoint for testing.
    # This sample assumes a checkpoint file naming convention, and you may load manually.
    # For simplicity, we assume the best ckpt is the one saved at best_epoch.
    ckpt_name = f"{args.run_name}_{best_epoch:02d}_val_loss={best_val_loss:.4f}.pt"
    ckpt_path = os.path.join(args.model_ckpt_dir, ckpt_name)
    print(f"Loading best checkpoint: {ckpt_path}")
    model.load_state_dict(torch.load(ckpt_path, map_location=device))
    
    # Run test.
    test_loss = test(model, test_loader, device)
    print(f"Test Loss: {test_loss:.6f}")

config.args = 50
main(config)

Using default log dir: ./data/STF_LOG_DIR


  main_df["Month"] = dates.apply(
  main_df["Day"] = dates.apply(lambda row: 2.0 * ((row.day - 1) / 30.0) - 1.0, 1)
  main_df["Weekday"] = dates.apply(
  main_df["Hour"] = dates.apply(lambda row: 2.0 * ((row.hour) / 23.0) - 1.0, 1)


Forecaster
	L2: 0
	Linear Window: 10
	Linear Shared Weights: False
	RevIN: False
	Decomposition: False


In [None]:
def main(config):
    (
        data_module,
        inv_scaler,
        scaler,
        null_val,
        plot_var_idxs,
        plot_var_names,
        pad_val,
    ) = create_dset(config)

    # Model
    config["null_val"] = null_val
    config["pad_val"] = pad_val

    forecaster = create_model(config)
    forecaster.set_inv_scaler(inv_scaler)
    forecaster.set_scaler(scaler)
    forecaster.set_null_value(null_val)

    callbacks = [] # Need to implement

    

In [None]:
# While Training should consider the on_train_start and on_train_batch_end methods for the callbacks too
# Skipped the callbacks functionality -> need to implement it later

# Running the Linear Model

In [9]:
# ETTm1
import numpy as np
import pandas as pd

def create_sliding_windows(data: np.ndarray, context_points: int, pred_len: int):
    """
    data: numpy array of shape (T, d_yt)
    returns: y_context (batch, context_points, d_yt) and
             y_target (batch, pred_len, d_yt)
    """
    X, Y = [], []
    total_length = data.shape[0]
    # Create sliding windows
    for i in range(total_length - context_points - pred_len + 1):
        X.append(data[i : i + context_points])
        Y.append(data[i + context_points : i + context_points + pred_len])
    return np.array(X), np.array(Y)

In [11]:
import pandas as pd
import torch
import torch.nn as nn
from torch.optim import Adam

# Read CSV and prepare data
csv_path = "S:\\spatiotemporal-analysis\\ETTm1_modified.csv"
df = pd.read_csv(csv_path, parse_dates=["date"])

# Sort by date (if not already sorted)
df.sort_values("date", inplace=True)

# Use the numeric columns as features (d_yt should equal the number of features: 7)
feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
data = df[feature_columns].values  # shape (T, 7)

# Define context and prediction lengths
context_points = 10
pred_len = 5

# Create sliding windows from the data
y_context_np, y_target_np = create_sliding_windows(data, context_points, pred_len)

# Convert to torch tensors
y_context = torch.tensor(y_context_np, dtype=torch.float32)  # shape: (batch, context_points, 7)
y_target = torch.tensor(y_target_np, dtype=torch.float32)    # shape: (batch, pred_len, 7)

print("y_context shape:", y_context.shape)
print("y_target shape:", y_target.shape)

# Create dummy inputs for x_c and x_t (assuming no exogenous features)
# If you have exogenous features, replace these with proper tensors.
batch = y_context.shape[0]
x_context = torch.empty(batch, context_points, 0)  # no features
x_target = torch.empty(batch, pred_len, 0)           # no features

# d_yt is the number of features (7 in this example)
d_yt = len(feature_columns)

# Instantiate the forecaster.
# Notice that for RevIN to work (if enabled), d_yc must equal d_yt.
# Here we disable RevIN by setting use_revin=False.
model = Linear_Forecaster(
    d_x=0,            # No exogenous input in this example
    d_yc=d_yt,
    d_yt=d_yt,
    context_points=context_points,
    use_revin=True,  # set True only if your use case requires and d_yc==d_yt holds
    use_seasonal_decomp=True,
    linear_window=0
)

optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# A simple training loop
def train(model, optimizer, criterion, x_c, y_c, x_t, y_t, epochs=1000):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        # Forward pass returns a tuple, so use the first element
        preds = model(x_c, y_c, x_t, y_t)
        if isinstance(preds, tuple):
            preds = preds[0]
        loss = criterion(preds, y_t)
        loss.backward()
        optimizer.step()
        if (epoch + 1) % 100 == 0:
            print(f"Epoch {epoch+1}, Loss: {loss.item():.6f}")

train(model, optimizer, criterion, x_context, y_context, x_target, y_target, epochs=1000)

y_context shape: torch.Size([69666, 10, 7])
y_target shape: torch.Size([69666, 5, 7])
Forecaster
	L2: 0
	Linear Window: 0
	Linear Shared Weights: False
	RevIN: True
	Decomposition: True
Epoch 100, Loss: 4.400672
Epoch 200, Loss: 3.183222
Epoch 300, Loss: 2.524292
Epoch 400, Loss: 2.418489
Epoch 500, Loss: 2.339239
Epoch 600, Loss: 2.266816


KeyboardInterrupt: 