# Data

## Time Features

In [None]:
import pandas as pd


def time_features_util(
    dates,
    main_df=None,
    use_features=["year", "month", "day", "weekday", "hour", "minute"],
    time_col_name="Datetime",
):
    if main_df is None:
        main_df = pd.DataFrame({})
    else:
        main_df = main_df.copy()
    years = dates.apply(lambda row: row.year)
    max_year = years.max()
    min_year = years.min()

    if "year" in use_features:
        main_df["Year"] = dates.apply(
            lambda row: (row.year - min_year) / max(1.0, (max_year - min_year))
        )

    if "month" in use_features:
        main_df["Month"] = dates.apply(
            lambda row: 2.0 * ((row.month - 1) / 11.0) - 1.0, 1
        )
    if "day" in use_features:
        main_df["Day"] = dates.apply(lambda row: 2.0 * ((row.day - 1) / 30.0) - 1.0, 1)
    if "weekday" in use_features:
        main_df["Weekday"] = dates.apply(
            lambda row: 2.0 * (row.weekday() / 6.0) - 1.0, 1
        )
    if "hour" in use_features:
        main_df["Hour"] = dates.apply(lambda row: 2.0 * ((row.hour) / 23.0) - 1.0, 1)
    if "minute" in use_features:
        main_df["Minute"] = dates.apply(
            lambda row: 2.0 * ((row.minute) / 59.0) - 1.0, 1
        )

    main_df[time_col_name] = dates
    return main_df

## DataModule

In [22]:
import warnings
import torch
from torch.utils.data import DataLoader


class DataModule:
    def __init__(
        self,
        datasetCls,
        dataset_kwargs: dict,
        batch_size: int,
        workers: int,
        collate_fn=None,
        overfit: bool = False,
    ):
        self.datasetCls = datasetCls
        self.batch_size = batch_size
        # Remove 'split' if provided in dataset_kwargs
        if "split" in dataset_kwargs:
            del dataset_kwargs["split"]
        self.dataset_kwargs = dataset_kwargs
        self.workers = workers
        self.collate_fn = collate_fn
        if overfit:
            warnings.warn("Overriding val and test dataloaders to use train set!")
        self.overfit = overfit

    def train_dataloader(self, shuffle=True):
        return self._make_dloader("train", shuffle=shuffle)

    def val_dataloader(self, shuffle=False):
        return self._make_dloader("val", shuffle=shuffle)

    def test_dataloader(self, shuffle=False):
        return self._make_dloader("test", shuffle=shuffle)

    def _make_dloader(self, split, shuffle=False):
        if self.overfit:
            split = "train"
            shuffle = True
        dataset = self.datasetCls(**self.dataset_kwargs, split=split)
        return DataLoader(
            dataset,
            batch_size=self.batch_size,
            shuffle=shuffle,
            num_workers=self.workers,
            collate_fn=self.collate_fn,
        )

    @classmethod
    def add_cli(cls, parser):
        parser.add_argument("--batch_size", type=int, default=128)
        parser.add_argument(
            "--workers",
            type=int,
            default=6,
            help="number of parallel workers for pytorch dataloader",
        )
        parser.add_argument("--overfit", action="store_true")
        return parser

## CSVDataset

In [None]:
import random
from typing import List
import os
import tqdm

import torch
from torch.utils.data import Dataset, DataLoader
import pandas as pd
import numpy as np
from sklearn.preprocessing import StandardScaler

import matplotlib.pyplot as plt


class CSVTimeSeries:
    def __init__(
        self,
        data_path: str = None,
        raw_df: pd.DataFrame = None,
        target_cols: List[str] = [],
        ignore_cols: List[str] = [],
        remove_target_from_context_cols: List[str] = [],
        time_col_name: str = "Datetime",
        read_csv_kwargs={},
        val_split: float = 0.15,
        test_split: float = 0.15,
        normalize: bool = True,
        drop_all_nan: bool = False,
        time_features: List[str] = [
            "year",
            "month",
            "day",
            "weekday",
            "hour",
            "minute",
        ],
    ):

        assert data_path is not None or raw_df is not None

        if raw_df is None:
            self.data_path = data_path
            assert os.path.exists(self.data_path)
            raw_df = pd.read_csv(
                self.data_path,
                **read_csv_kwargs,
            )

        if drop_all_nan:
            raw_df.dropna(axis=0, how="any", inplace=True)

        self.time_col_name = time_col_name
        assert self.time_col_name in raw_df.columns

        if not target_cols:
            target_cols = raw_df.columns.tolist()
            target_cols.remove(time_col_name)

        if ignore_cols:
            if ignore_cols == "all":
                ignore_cols = raw_df.columns.difference(target_cols).tolist()
                ignore_cols.remove(self.time_col_name)
            raw_df.drop(columns=ignore_cols, inplace=True)

        time_df = pd.to_datetime(raw_df[self.time_col_name], format="%Y-%m-%d %H:%M")
        df = time_features_util(
            time_df,
            raw_df,
            time_col_name=self.time_col_name,
            use_features=time_features,
        )
        self.time_cols = df.columns.difference(raw_df.columns)

        # Train/Val/Test Split using holdout approach #

        def mask_intervals(mask, intervals, cond):
            for (interval_low, interval_high) in intervals:
                if interval_low is None:
                    interval_low = df[self.time_col_name].iloc[0].year
                if interval_high is None:
                    interval_high = df[self.time_col_name].iloc[-1].year
                mask[
                    (df[self.time_col_name] >= interval_low)
                    & (df[self.time_col_name] <= interval_high)
                ] = cond
            return mask

        test_cutoff = len(time_df) - max(round(test_split * len(time_df)), 1)
        val_cutoff = test_cutoff - round(val_split * len(time_df))

        val_interval_low = time_df.iloc[val_cutoff]
        val_interval_high = time_df.iloc[test_cutoff - 1]
        val_intervals = [(val_interval_low, val_interval_high)]

        test_interval_low = time_df.iloc[test_cutoff]
        test_interval_high = time_df.iloc[-1]
        test_intervals = [(test_interval_low, test_interval_high)]

        train_mask = df[self.time_col_name] > pd.Timestamp.min
        val_mask = df[self.time_col_name] > pd.Timestamp.max
        test_mask = df[self.time_col_name] > pd.Timestamp.max
        train_mask = mask_intervals(train_mask, test_intervals, False)
        train_mask = mask_intervals(train_mask, val_intervals, False)
        val_mask = mask_intervals(val_mask, val_intervals, True)
        test_mask = mask_intervals(test_mask, test_intervals, True)

        if (train_mask == False).all():
            print(f"No training data detected for file {data_path}")

        self._train_data = df[train_mask]
        self._scaler = StandardScaler()

        self.target_cols = target_cols
        for col in remove_target_from_context_cols:
            assert (
                col in self.target_cols
            ), "`remove_target_from_context_cols` should be target cols that you want to remove from the context"

        self.remove_target_from_context_cols = remove_target_from_context_cols
        not_exo_cols = self.time_cols.tolist() + target_cols
        self.exo_cols = df.columns.difference(not_exo_cols).tolist()
        self.exo_cols.remove(self.time_col_name)

        self._train_data = df[train_mask]
        self._val_data = df[val_mask]
        if test_split == 0.0:
            print("`test_split` set to 0. Using Val set as Test set.")
            self._test_data = df[val_mask]
        else:
            self._test_data = df[test_mask]

        self.normalize = normalize
        if normalize:
            self._scaler = self._scaler.fit(
                self._train_data[target_cols + self.exo_cols].values
            )
        self._train_data = self.apply_scaling_df(self._train_data)
        self._val_data = self.apply_scaling_df(self._val_data)
        self._test_data = self.apply_scaling_df(self._test_data)

    def make_hists(self):
        for col in self.target_cols + self.exo_cols:
            train = self._train_data[col]
            test = self._test_data[col]
            bins = np.linspace(-5, 5, 80)  # warning: edit bucket limits
            plt.hist(train, bins, alpha=0.5, label="Train", density=True)
            plt.hist(test, bins, alpha=0.5, label="Test", density=True)
            plt.legend(loc="upper right")
            plt.title(col)
            plt.tight_layout()
            plt.savefig(f"{col}-hist.png")
            plt.clf()

    def get_slice(self, split, start, stop, skip):
        assert split in ["train", "val", "test"]
        if split == "train":
            return self.train_data.iloc[start:stop:skip]
        elif split == "val":
            return self.val_data.iloc[start:stop:skip]
        else:
            return self.test_data.iloc[start:stop:skip]

    def apply_scaling(self, array):
        if not self.normalize:
            return array
        dim = array.shape[-1]
        return (array - self._scaler.mean_[:dim]) / self._scaler.scale_[:dim]

    def apply_scaling_df(self, df):
        if not self.normalize:
            return df
        scaled = df.copy(deep=True)
        cols = self.target_cols + self.exo_cols
        dtype = df[cols].values.dtype
        scaled[cols] = (
            df[cols].values - self._scaler.mean_.astype(dtype)
        ) / self._scaler.scale_.astype(dtype)
        return scaled

    def reverse_scaling_df(self, df):
        if not self.normalize:
            return df
        scaled = df.copy(deep=True)
        cols = self.target_cols + self.exo_cols
        dtype = df[cols].values.dtype
        scaled[cols] = (
            df[cols].values * self._scaler.scale_.astype(dtype)
        ) + self._scaler.mean_.astype(dtype)
        return scaled

    def reverse_scaling(self, array):
        if not self.normalize:
            return array
        # self._scaler is fit for target_cols + exo_cols
        # if the array dim is less than this length we start
        # slicing from the target cols
        dim = array.shape[-1]
        return (array * self._scaler.scale_[:dim]) + self._scaler.mean_[:dim]

    @property
    def train_data(self):
        return self._train_data

    @property
    def val_data(self):
        return self._val_data

    @property
    def test_data(self):
        return self._test_data

    def length(self, split):
        return {
            "train": len(self.train_data),
            "val": len(self.val_data),
            "test": len(self.test_data),
        }[split]

    @classmethod
    def add_cli(self, parser):
        parser.add_argument("--data_path", type=str, default="auto")


class CSVTorchDset(Dataset):
    def __init__(
        self,
        csv_time_series: CSVTimeSeries,
        split: str = "train",
        context_points: int = 128,
        target_points: int = 32,
        time_resolution: int = 1,
    ):
        assert split in ["train", "val", "test"]
        self.split = split
        self.series = csv_time_series
        self.context_points = context_points
        self.target_points = target_points
        self.time_resolution = time_resolution

        self._slice_start_points = [
            i
            for i in range(
                0,
                self.series.length(split)
                + time_resolution * (-target_points - context_points)
                + 1,
            )
        ]

    def __len__(self):
        return len(self._slice_start_points)

    def _torch(self, *dfs):
        return tuple(torch.from_numpy(x.values).float() for x in dfs)

    def __getitem__(self, i):
        start = self._slice_start_points[i]
        series_slice = self.series.get_slice(
            self.split,
            start=start,
            stop=start
            + self.time_resolution * (self.context_points + self.target_points),
            skip=self.time_resolution,
        )
        series_slice = series_slice.drop(columns=[self.series.time_col_name])
        ctxt_slice, trgt_slice = (
            series_slice.iloc[: self.context_points],
            series_slice.iloc[self.context_points :],
        )

        ctxt_x = ctxt_slice[self.series.time_cols]
        trgt_x = trgt_slice[self.series.time_cols]

        ctxt_y = ctxt_slice[self.series.target_cols + self.series.exo_cols]
        ctxt_y = ctxt_y.drop(columns=self.series.remove_target_from_context_cols)

        trgt_y = trgt_slice[self.series.target_cols]

        return self._torch(ctxt_x, ctxt_y, trgt_x, trgt_y)

    @classmethod
    def add_cli(self, parser):
        parser.add_argument(
            "--context_points",
            type=int,
            default=128,
            help="number of previous timesteps given to the model in order to make predictions",
        )
        parser.add_argument(
            "--target_points",
            type=int,
            default=32,
            help="number of future timesteps to predict",
        )
        parser.add_argument(
            "--time_resolution",
            type=int,
            default=1,
        )

# RevIN

In [11]:
"""
Reversible Instance Normalization from 
https://github.com/ts-kim/RevIN
"""

import torch
import torch.nn as nn


class MovingAvg(nn.Module):
    def __init__(self, kernel_size, stride):
        super().__init__()
        self.kernel_size = kernel_size
        self.avg = nn.AvgPool1d(kernel_size=kernel_size, stride=stride, padding=0)

    def forward(self, x):
        # padding on the both ends of time series
        front = x[:, 0:1, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        end = x[:, -1:, :].repeat(1, (self.kernel_size - 1) // 2, 1)
        x = torch.cat([front, x, end], dim=1)
        x = self.avg(x.permute(0, 2, 1))
        x = x.permute(0, 2, 1)
        return x


class SeriesDecomposition(nn.Module):
    def __init__(self, kernel_size):
        super().__init__()
        self.moving_avg = MovingAvg(kernel_size, stride=1)

    def forward(self, x):
        moving_mean = self.moving_avg(x)
        res = x - moving_mean
        return res, moving_mean


class RevIN(nn.Module):
    def __init__(self, num_features: int, eps=1e-5, affine=True):
        """
        :param num_features: the number of features or channels
        :param eps: a value added for numerical stability
        :param affine: if True, RevIN has learnable affine parameters
        """
        super().__init__()
        self.num_features = num_features
        self.eps = eps
        self.affine = affine
        if self.affine:
            self._init_params()

    def forward(self, x, mode: str, update_stats=True):
        assert x.ndim == 3
        if mode == "norm":
            if update_stats:
                self._get_statistics(x)
            x = self._normalize(x)
        elif mode == "denorm":
            x = self._denormalize(x)
        else:
            raise NotImplementedError
        return x

    def _init_params(self):
        # initialize RevIN params: (C,)
        self.affine_weight = nn.Parameter(torch.ones(self.num_features))
        self.affine_bias = nn.Parameter(torch.zeros(self.num_features))

    def _get_statistics(self, x):
        dim2reduce = tuple(range(1, x.ndim - 1))
        self.mean = torch.mean(x, dim=dim2reduce, keepdim=True).detach()
        self.stdev = torch.sqrt(
            torch.var(x, dim=dim2reduce, keepdim=True, unbiased=False) + self.eps
        ).detach()

    def _normalize(self, x):
        x = x - self.mean
        x = x / self.stdev
        if self.affine:
            x = x * self.affine_weight
            x = x + self.affine_bias
        return x

    def _denormalize(self, x):
        if self.affine:
            x = x - self.affine_bias
            x = x / (self.affine_weight + self.eps * self.eps)
        x = x * self.stdev
        x = x + self.mean
        return x

# Time2Vec

In [13]:
import torch
from torch import nn


class Time2Vec(nn.Module):
    def __init__(self, input_dim=6, embed_dim=512, act_function=torch.sin):
        assert embed_dim % input_dim == 0
        super(Time2Vec, self).__init__()
        self.enabled = embed_dim > 0
        if self.enabled:
            self.embed_dim = embed_dim // input_dim
            self.input_dim = input_dim
            self.embed_weight = nn.Parameter(torch.randn(self.input_dim, self.embed_dim))
            self.embed_bias = nn.Parameter(torch.randn(self.input_dim, self.embed_dim))
            self.act_function = act_function

    def forward(self, x):
        if self.enabled:
            x = torch.diag_embed(x)
            # x.shape = (bs, sequence_length, input_dim, input_dim)
            x_affine = torch.matmul(x, self.embed_weight) + self.embed_bias
            # x_affine.shape = (bs, sequence_length, input_dim, time_embed_dim)
            x_affine_0, x_affine_remain = torch.split(
                x_affine, [1, self.embed_dim - 1], dim=-1
            )
            x_affine_remain = self.act_function(x_affine_remain)
            x_output = torch.cat([x_affine_0, x_affine_remain], dim=-1)
            x_output = x_output.view(x_output.size(0), x_output.size(1), -1)
            # x_output.shape = (bs, sequence_length, input_dim * time_embed_dim)
        else:
            x_output = x
        return x_output

# Callbacks

In [23]:
import argparse


class TeacherForcingAnnealCallback:
    def __init__(self, start, end, steps):
        assert start >= end, "teacher_forcing_start must be >= teacher_forcing_end"
        self.start = start
        self.end = end
        self.steps = steps
        self.slope = float(start - end) / steps

    def on_train_batch_end(self, model):
        """
        Update teacher forcing probability after each training batch.
        The model is expected to have an attribute 'teacher_forcing_prob'.
        """
        current = model.teacher_forcing_prob
        new_teacher_forcing_prob = max(self.end, current - self.slope)
        model.teacher_forcing_prob = new_teacher_forcing_prob
        # Replace model.log with any custom logging or simply print
        print(f"teacher_forcing_prob: {new_teacher_forcing_prob}")

    @classmethod
    def add_cli(cls, parser: argparse.ArgumentParser):
        parser.add_argument("--teacher_forcing_start", type=float, default=0.8)
        parser.add_argument("--teacher_forcing_end", type=float, default=0.0)
        parser.add_argument("--teacher_forcing_anneal_steps", type=int, default=8000)
        return parser


class TimeMaskedLossCallback:
    def __init__(self, start, end, steps):
        assert start <= end, "time_mask_start must be <= time_mask_end"
        self.start = start
        self.end = end
        self.steps = steps
        self.slope = float(end - start) / steps
        self._time_mask = self.start

    @property
    def time_mask(self):
        return round(self._time_mask)

    def on_train_start(self, model):
        """
        Set the model's time_masked_idx at the beginning of training if not already set.
        """
        if getattr(model, "time_masked_idx", None) is None:
            model.time_masked_idx = self.time_mask
            print(f"time_masked_idx set to: {self.time_mask}")

    def on_train_batch_end(self, model):
        """
        Increment the time mask value after each training batch and update the model's attribute.
        """
        self._time_mask = min(self.end, self._time_mask + self.slope)
        model.time_masked_idx = self.time_mask
        print(f"time_masked_idx: {self.time_mask}")

    @classmethod
    def add_cli(cls, parser: argparse.ArgumentParser):
        parser.add_argument("--time_mask_start", type=int, default=1)
        parser.add_argument("--time_mask_end", type=int, default=12)
        parser.add_argument("--time_mask_anneal_steps", type=int, default=1000)
        parser.add_argument("--time_mask_loss", action="store_true")
        return parser

# Eval Stats

In [14]:
import numpy as np

EPSILON = 1e-7


def r_squared(actual: np.ndarray, predicted: np.ndarray):
    rss = (_error(actual, predicted) ** 2).sum(1)
    tss = (_error(actual, actual.mean(1, keepdims=True)) ** 2).sum(1)
    r2 = 1.0 - rss / (tss + EPSILON)
    return r2.mean()


def _error(actual: np.ndarray, predicted: np.ndarray):
    """Simple error"""
    return actual - predicted


def _percentage_error(actual: np.ndarray, predicted: np.ndarray):
    """
    Percentage error

    Note: result is NOT multiplied by 100
    """
    return _error(actual, predicted) / (actual + EPSILON)


def mse(actual: np.ndarray, predicted: np.ndarray):
    """Mean Squared Error"""
    return np.mean(np.square(_error(actual, predicted)))


def mae(actual: np.ndarray, predicted: np.ndarray):
    """Mean Absolute Error"""
    return np.mean(np.abs(_error(actual, predicted)))


def mape(actual: np.ndarray, predicted: np.ndarray):
    """Mean Absolute Percentage Error"""
    return np.mean(np.abs(_percentage_error(actual, predicted)))


def smape(actual: np.ndarray, predicted: np.ndarray):
    """Symmetric Mean Absolute Percentage Error"""
    return np.mean(
        2.0
        * np.abs(actual - predicted)
        / ((np.abs(actual) + np.abs(predicted)) + EPSILON)
    )

# Forecaster

In [17]:
from abc import ABC, abstractmethod
from typing import Tuple

import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

class Forecaster(nn.Module, ABC):
    def __init__(
        self,
        d_x: int,
        d_yc: int,
        d_yt: int,
        learning_rate: float = 1e-3,
        l2_coeff: float = 0,
        loss: str = "mse",
        linear_window: int = 0,
        linear_shared_weights: bool = False,
        use_revin: bool = False,
        use_seasonal_decomp: bool = False,
        verbose: int = True,
    ):
        super().__init__()

        qprint = lambda msg: print(msg) if verbose else None
        qprint("Forecaster")
        qprint(f"\tL2: {l2_coeff}")
        qprint(f"\tLinear Window: {linear_window}")
        qprint(f"\tLinear Shared Weights: {linear_shared_weights}")
        qprint(f"\tRevIN: {use_revin}")
        qprint(f"\tDecomposition: {use_seasonal_decomp}")

        self._inv_scaler = lambda x: x
        self.l2_coeff = l2_coeff
        self.learning_rate = learning_rate
        self.time_masked_idx = None
        self.null_value = None
        self.loss = loss

        if linear_window:
            self.linear_model = LinearModel(
                linear_window, shared_weights=linear_shared_weights, d_yt=d_yt
            )
        else:
            self.linear_model = lambda x, *args, **kwargs: 0.0

        self.use_revin = use_revin
        if use_revin:
            assert d_yc == d_yt, "TODO: figure out exo case for revin"
            self.revin = RevIN(num_features=d_yc)
        else:
            self.revin = lambda x, **kwargs: x

        self.use_seasonal_decomp = use_seasonal_decomp
        if use_seasonal_decomp:
            self.seasonal_decomp = SeriesDecomposition(kernel_size=25)
        else:
            self.seasonal_decomp = lambda x: (x, x.clone())

        self.d_x = d_x
        self.d_yc = d_yc
        self.d_yt = d_yt

    def set_null_value(self, val: float) -> None:
        self.null_value = val

    def set_inv_scaler(self, scaler) -> None:
        self._inv_scaler = scaler

    def set_scaler(self, scaler) -> None:
        self._scaler = scaler

    @property
    @abstractmethod
    def train_step_forward_kwargs(self):
        return {}

    @property
    @abstractmethod
    def eval_step_forward_kwargs(self):
        return {}

    def loss_fn(self, true: torch.Tensor, preds: torch.Tensor, mask: torch.Tensor) -> torch.Tensor:
        true = torch.nan_to_num(true)
        if self.loss == "mse":
            loss = (mask * (true - preds)).square().sum() / max(mask.sum(), 1)
        elif self.loss == "mae":
            loss = torch.abs(mask * (true - preds)).sum() / max(mask.sum(), 1)
        elif self.loss == "smape":
            num = 2.0 * torch.abs(preds - true)
            den = torch.abs(preds.detach()) + torch.abs(true) + 1e-5
            loss = 100.0 * (mask * (num / den)).sum() / max(mask.sum(), 1)
        else:
            raise ValueError(f"Unrecognized Loss Function: {self.loss}")
        return loss

    def forecasting_loss(self, outputs: torch.Tensor, y_t: torch.Tensor, time_mask: int) -> Tuple[torch.Tensor]:
        if self.null_value is not None:
            null_mask_mat = y_t != self.null_value
        else:
            null_mask_mat = torch.ones_like(y_t)
        null_mask_mat *= ~torch.isnan(y_t)

        time_mask_mat = torch.ones_like(y_t)
        if time_mask is not None:
            time_mask_mat[:, time_mask:] = False

        full_mask = time_mask_mat * null_mask_mat
        forecasting_loss = self.loss_fn(y_t, outputs, full_mask)
        return forecasting_loss, full_mask

    def compute_loss(self, batch: Tuple[torch.Tensor], time_mask: int = None, forward_kwargs: dict = {}) -> Tuple[torch.Tensor]:
        x_c, y_c, x_t, y_t = batch
        outputs, *_ = self.forward(x_c, y_c, x_t, y_t, **forward_kwargs)
        loss, mask = self.forecasting_loss(outputs=outputs, y_t=y_t, time_mask=time_mask)
        return loss, outputs, mask

    def predict(self, x_c: torch.Tensor, y_c: torch.Tensor, x_t: torch.Tensor, sample_preds: bool = False) -> torch.Tensor:
        og_device = y_c.device
        # Ensure tensors are on the same device as the model.
        x_c = x_c.to(next(self.parameters()).device).float()
        x_t = x_t.to(next(self.parameters()).device).float()
        y_c = torch.from_numpy(self._scaler(y_c.cpu().numpy())).to(next(self.parameters()).device).float()
        y_t = torch.zeros((x_t.shape[0], x_t.shape[1], self.d_yt), device=next(self.parameters()).device).float()

        with torch.no_grad():
            normalized_preds, *_ = self.forward(x_c, y_c, x_t, y_t, **self.eval_step_forward_kwargs)
        preds = torch.from_numpy(self._inv_scaler(normalized_preds.cpu().numpy())).to(og_device).float()
        return preds

    @abstractmethod
    def forward_model_pass(
        self,
        x_c: torch.Tensor,
        y_c: torch.Tensor,
        x_t: torch.Tensor,
        y_t: torch.Tensor,
        **forward_kwargs,
    ) -> Tuple[torch.Tensor]:
        return NotImplemented

    def nan_to_num(self, *inps):
        return (torch.nan_to_num(i) for i in inps)

    def forward(self, x_c: torch.Tensor, y_c: torch.Tensor, x_t: torch.Tensor, y_t: torch.Tensor, **forward_kwargs) -> Tuple[torch.Tensor]:
        x_c, y_c, x_t, y_t = self.nan_to_num(x_c, y_c, x_t, y_t)
        _, pred_len, d_yt = y_t.shape

        y_c = self.revin(y_c, mode="norm")
        seasonal_yc, trend_yc = self.seasonal_decomp(y_c)
        preds, *extra = self.forward_model_pass(x_c, seasonal_yc, x_t, y_t, **forward_kwargs)
        baseline = self.linear_model(trend_yc, pred_len=pred_len, d_yt=d_yt)
        output = self.revin(preds + baseline, mode="denorm")

        if extra:
            return (output,) + tuple(extra)
        return (output,)

    def _compute_stats(self, pred: torch.Tensor, true: torch.Tensor, mask: torch.Tensor):
        pred = pred * mask
        true = torch.nan_to_num(true) * mask

        adj = mask.mean().cpu().numpy() + 1e-5
        pred = pred.detach().cpu().numpy()
        true = true.detach().cpu().numpy()
        scaled_pred = self._inv_scaler(pred)
        scaled_true = self._inv_scaler(true)
        stats = {
            "mape": mape(scaled_true, scaled_pred) / adj,
            "mae": mae(scaled_true, scaled_pred) / adj,
            "mse": mse(scaled_true, scaled_pred) / adj,
            "smape": smape(scaled_true, scaled_pred) / adj,
            "norm_mae": mae(true, pred) / adj,
            "norm_mse": mse(true, pred) / adj,
        }
        return stats

    def step(self, batch: Tuple[torch.Tensor], train: bool = False):
        kwargs = self.train_step_forward_kwargs if train else self.eval_step_forward_kwargs
        time_mask = self.time_masked_idx if train else None
        loss, output, mask = self.compute_loss(batch=batch, time_mask=time_mask, forward_kwargs=kwargs)
        *_, y_t = batch
        stats = self._compute_stats(output, y_t, mask)
        stats["loss"] = loss
        return stats
    
    def training_step(self, batch, batch_idx):
        return self.step(batch, train=True)

    def validation_step(self, batch, batch_idx):
        stats = self.step(batch, train=False)
        self.current_val_stats = stats
        return stats

    def test_step(self, batch, batch_idx):
        return self.step(batch, train=False)

    def _log_stats(self, section, outs):
        for key in outs.keys():
            stat = outs[key]
            if isinstance(stat, np.ndarray) or isinstance(stat, torch.Tensor):
                stat = stat.mean()
            self.log(f"{section}/{key}", stat, sync_dist=True)

    def training_step_end(self, outs):
        self._log_stats("train", outs)
        return {"loss": outs["loss"].mean()}

    def validation_step_end(self, outs):
        self._log_stats("val", outs)
        return outs

    def test_step_end(self, outs):
        self._log_stats("test", outs)
        return {"loss": outs["loss"].mean()}

    def predict_step(self, batch, batch_idx):
        return self(*batch, **self.eval_step_forward_kwargs)

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(
            self.parameters(), lr=self.learning_rate, weight_decay=self.l2_coeff
        )
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer,
            patience=3,
            factor=0.2,
        )
        return {
            "optimizer": optimizer,
            "lr_scheduler": {
                "scheduler": scheduler,
                "monitor": "val/loss",
            },
        }

    @classmethod
    def add_cli(self, parser):
        parser.add_argument("--gpus", type=int, nargs="+")
        parser.add_argument("--l2_coeff", type=float, default=1e-6)
        parser.add_argument("--learning_rate", type=float, default=1e-4)
        parser.add_argument("--grad_clip_norm", type=float, default=0)
        parser.add_argument("--linear_window", type=int, default=0)
        parser.add_argument("--use_revin", action="store_true")
        parser.add_argument(
            "--loss", type=str, default="mse", choices=["mse", "mae", "smape"]
        )
        parser.add_argument("--linear_shared_weights", action="store_true")
        parser.add_argument("--use_seasonal_decomp", action="store_true")
        return parser


# LinearAR

In [None]:
import math
import torch
from torch import nn
from torch.optim import Adam
from einops import rearrange

In [None]:
class LinearModel(nn.Module):
    def __init__(self, context_points: int, shared_weights: bool = False, d_yt: int = 7):
        super().__init__()

        if not shared_weights:
            assert d_yt is not None
            layer_count = d_yt
        else:
            layer_count = 1

        self.weights = nn.Parameter(torch.ones((context_points, layer_count)), requires_grad=True)
        self.bias = nn.Parameter(torch.ones((layer_count)), requires_grad=True)

        d = math.sqrt(1.0 / context_points)
        self.weights.data.uniform_(-d, d)
        self.bias.data.uniform_(-d, d)

        self.window = context_points
        self.shared_weights = shared_weights
        self.d_yt = d_yt

    def forward(self, y_c: torch.Tensor, pred_len: int, d_yt: int = None):
        batch, length, d_yc = y_c.shape
        d_yt = d_yt or self.d_yt

        output = torch.zeros(batch, pred_len, d_yt, device=y_c.device)

        for i in range(pred_len):
            inp = torch.cat((y_c[:, i:, :d_yt], output[:, :i]), dim=1)
            output[:, i, :] = self._inner_forward(inp)
        return output

    def _inner_forward(self, inp):
        batch = inp.shape[0]
        if self.shared_weights:
            inp = rearrange(inp, "batch length dy -> (batch dy) length 1")
        baseline = (self.weights * inp[:, -self.window :, :]).sum(1) + self.bias
        if self.shared_weights:
            baseline = rearrange(baseline, "(batch dy) 1 -> batch dy", batch=batch)
        return baseline

In [None]:
def train(model, optimizer, criterion, y_c, target, epochs=1000):
    model.train()
    for epoch in range(epochs):
        optimizer.zero_grad()
        pred_len = target.shape[1]
        output = model(y_c, pred_len)
        loss = criterion(output, target)
        loss.backward()
        optimizer.step()
        if epoch % 100 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")

In [19]:
import torch
from torch import nn
import torch.nn.functional as F

class Linear_Forecaster(Forecaster):
    def __init__(
        self,
        d_x: int,
        d_yc: int,
        d_yt: int,
        context_points: int,
        learning_rate: float = 1e-3,
        l2_coeff: float = 0,
        loss: str = "mse",
        linear_window: int = 0,
        linear_shared_weights: bool = False,
        use_revin: bool = False,
        use_seasonal_decomp: bool = False,
    ):
        super().__init__(
            d_x=d_x,
            d_yc=d_yc,
            d_yt=d_yt,
            l2_coeff=l2_coeff,
            learning_rate=learning_rate,
            loss=loss,
            linear_window=linear_window,
            linear_shared_weights=linear_shared_weights,
            use_revin=use_revin,
            use_seasonal_decomp=use_seasonal_decomp,
        )

        self.model = LinearModel(
            context_points, shared_weights=linear_shared_weights, d_yt=d_yt
        )

    @property
    def eval_step_forward_kwargs(self):
        return {}

    @property
    def train_step_forward_kwargs(self):
        return {}

    def forward_model_pass(self, x_c, y_c, x_t, y_t):
        _, pred_len, d_yt = y_t.shape
        output = self.model(y_c, pred_len=pred_len, d_yt=d_yt)
        return (output,)

    @classmethod
    def add_cli(self, parser):
        super().add_cli(parser)


In [None]:
# Example synthetic data
batch = 16
context_points = 10
pred_len = 5
d_yt = 7

# y_c: context data, target: future data to forecast
y_c = torch.randn(batch, context_points, d_yt)
target = torch.randn(batch, pred_len, d_yt)

model = LinearModel(context_points=context_points, shared_weights=False, d_yt=d_yt)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

train(model, optimizer, criterion, y_c, target)

Epoch 0, Loss: 1.3034
Epoch 100, Loss: 1.0491
Epoch 200, Loss: 0.9557
Epoch 300, Loss: 0.9149
Epoch 400, Loss: 0.8962
Epoch 500, Loss: 0.8877
Epoch 600, Loss: 0.8837
Epoch 700, Loss: 0.8816
Epoch 800, Loss: 0.8803
Epoch 900, Loss: 0.8795


In [None]:
# ETTm1
import numpy as np
import pandas as pd

def create_sliding_windows(data: np.ndarray, context_points: int, pred_len: int):
    """
    data: numpy array of shape (T, d_yt)
    returns: y_context (batch, context_points, d_yt) and
             y_target (batch, pred_len, d_yt)
    """
    X, Y = [], []
    total_length = data.shape[0]
    # Create sliding windows
    for i in range(total_length - context_points - pred_len + 1):
        X.append(data[i : i + context_points])
        Y.append(data[i + context_points : i + context_points + pred_len])
    return np.array(X), np.array(Y)

In [None]:
csv_path = "S:\\spatiotemporal-analysis\\ETTm1_modified.csv"
df = pd.read_csv(csv_path, parse_dates=["date"])

# Sort by date (if not already sorted)
df.sort_values("date", inplace=True)

# Use the numeric columns as features (d_yt should equal the number of features: 7)
feature_columns = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
data = df[feature_columns].values  # shape (T, 7)

# Define context and prediction lengths
context_points = 10
pred_len = 5

# Create sliding windows from the data
y_context_np, y_target_np = create_sliding_windows(data, context_points, pred_len)

# Convert to torch tensors
y_context = torch.tensor(y_context_np, dtype=torch.float32)
y_target = torch.tensor(y_target_np, dtype=torch.float32)

print("y_context shape:", y_context.shape)  # (batch, context_points, 7)
print("y_target shape:", y_target.shape)    # (batch, pred_len, 7)

# Create the model, optimizer, and loss function
d_yt = len(feature_columns)
model = LinearModel(context_points=context_points, shared_weights=False, d_yt=d_yt)
optimizer = Adam(model.parameters(), lr=1e-3)
criterion = nn.MSELoss()

# Train the model using all training examples as one batch.
train(model, optimizer, criterion, y_context, y_target, epochs=1000)

y_context shape: torch.Size([69666, 10, 7])
y_target shape: torch.Size([69666, 5, 7])
Epoch 0, Loss: 49.0206
Epoch 100, Loss: 19.7358
Epoch 200, Loss: 6.5635


KeyboardInterrupt: 

# Train.py

In [None]:
from argparse import ArgumentParser
import random
import sys
import warnings
import os
import uuid

import torch


_MODELS = ["spacetimeformer", "mtgnn", "heuristic", "lstm", "lstnet", "linear", "s4"]
_DSETS = [
    "asos",
    "metr-la",
    "pems-bay",
    "exchange",
    "precip",
    "toy2",
    "solar_energy",
    "syn",
    "mnist",
    "cifar",
    "copy",
    "cont_copy",
    "m4",
    "wiki",
    "ettm1",
    "weather",
    "monash",
    "hangzhou",
    "traffic",
]

def create_parser():
    model = sys.argv[1]
    dset = sys.argv[2]

    # Throw error now before we get confusing parser issues
    assert (
        model in _MODELS
    ), f"Unrecognized model (`{model}`). Options include: {_MODELS}"
    assert dset in _DSETS, f"Unrecognized dset (`{dset}`). Options include: {_DSETS}"

    parser = ArgumentParser()
    parser.add_argument("model")
    parser.add_argument("dset")

    # Only Linear Case is taken with ETTm1
    CSVTimeSeries.add_cli(parser)
    CSVTorchDset.add_cli(parser)
    DataModule.add_cli(parser)

    # Only Linear Case is taken with ETTm1

    Linear_Forecaster.add_cli(parser)

    TimeMaskedLossCallback.add_cli(parser)

    parser.add_argument("--wandb", action="store_true")
    parser.add_argument("--plot", action="store_true")
    parser.add_argument("--plot_samples", type=int, default=8)
    parser.add_argument("--attn_plot", action="store_true")
    parser.add_argument("--debug", action="store_true")
    parser.add_argument("--run_name", type=str, required=True)
    parser.add_argument("--accumulate", type=int, default=1)
    parser.add_argument("--val_check_interval", type=float, default=1.0)
    parser.add_argument("--limit_val_batches", type=float, default=1.0)
    parser.add_argument("--no_earlystopping", action="store_true")
    parser.add_argument("--patience", type=int, default=5)
    parser.add_argument(
        "--trials", type=int, default=1, help="How many consecutive trials to run"
    )

    if len(sys.argv) > 3 and sys.argv[3] == "-h":
        parser.print_help()
        sys.exit(0)

    return parser

def create_model(config):
    x_dim, yc_dim, yt_dim = None, None, None
    if config.dset == "metr-la":
        x_dim = 2
        yc_dim = 207
        yt_dim = 207
    elif config.dset == "pems-bay":
        x_dim = 2
        yc_dim = 325
        yt_dim = 325
    elif config.dset == "precip":
        x_dim = 2
        yc_dim = 49
        yt_dim = 49
    elif config.dset == "asos":
        x_dim = 6
        yc_dim = 6
        yt_dim = 6
    elif config.dset == "solar_energy":
        x_dim = 6
        yc_dim = 137
        yt_dim = 137
    elif config.dset == "exchange":
        x_dim = 6
        yc_dim = 8
        yt_dim = 8
    elif config.dset == "toy2":
        x_dim = 6
        yc_dim = 20
        yt_dim = 20
    elif config.dset == "syn":
        x_dim = 5
        yc_dim = 20
        yt_dim = 20
    elif config.dset == "mnist":
        x_dim = 1
        yc_dim = 28
        yt_dim = 28
    elif config.dset == "cifar":
        x_dim = 1
        yc_dim = 3
        yt_dim = 3
    elif config.dset == "copy" or config.dset == "cont_copy":
        x_dim = 1
        yc_dim = config.copy_vars
        yt_dim = config.copy_vars
    elif config.dset == "m4":
        x_dim = 4
        yc_dim = 1
        yt_dim = 1
    elif config.dset == "wiki":
        x_dim = 2
        yc_dim = 1
        yt_dim = 1
    elif config.dset == "monash":
        x_dim = 4
        yc_dim = 1
        yt_dim = 1
    elif config.dset == "ettm1":
        x_dim = 4
        yc_dim = 7
        yt_dim = 7
    elif config.dset == "weather":
        x_dim = 3
        yc_dim = 21
        yt_dim = 21
    elif config.dset == "hangzhou":
        x_dim = 4
        yc_dim = 160
        yt_dim = 160
    elif config.dset == "traffic":
        x_dim = 2
        yc_dim = 862
        yt_dim = 862
    assert x_dim is not None
    assert yc_dim is not None
    assert yt_dim is not None

    # Only Linear Model Case

    if config.model == "linear":
        forecaster = Linear_Forecaster(
            d_x=x_dim,
            d_yc=yc_dim,
            d_yt=yt_dim,
            context_points=config.context_points,
            learning_rate=config.learning_rate,
            l2_coeff=config.l2_coeff,
            loss=config.loss,
            linear_window=config.linear_window,
            linear_shared_weights=config.linear_shared_weights,
            use_revin=config.use_revin,
            use_seasonal_decomp=config.use_seasonal_decomp,
        )
    
    return forecaster


def create_dset(config):
    INV_SCALER = lambda x: x
    SCALER = lambda x: x
    NULL_VAL = None
    PLOT_VAR_IDXS = None
    PLOT_VAR_NAMES = None
    PAD_VAL = None

    if config.dset == "metr-la" or config.dset == "pems-bay":
        if config.dset == "pems-bay":
            assert (
                "pems_bay" in config.data_path
            ), "Make sure to switch to the pems-bay file!"
        data = stf.data.metr_la.METR_LA_Data(config.data_path)
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.metr_la.METR_LA_Torch,
            dataset_kwargs={"data": data},
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
        INV_SCALER = data.inverse_scale
        SCALER = data.scale
        NULL_VAL = 0.0

    elif config.dset == "hangzhou":
        data = stf.data.metro.MetroData(config.data_path)
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.metro.MetroTorch,
            dataset_kwargs={"data": data},
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
        INV_SCALER = data.inverse_scale
        SCALER = data.scale
        NULL_VAL = 0.0

    elif config.dset == "precip":
        dset = stf.data.precip.GeoDset(dset_dir=config.dset_dir, var="precip")
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.precip.CONUS_Precip,
            dataset_kwargs={
                "dset": dset,
                "context_points": config.context_points,
                "target_points": config.target_points,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
        NULL_VAL = -1.0
    elif config.dset == "syn":
        dset = stf.data.synthetic.SyntheticData(config.data_path)
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.CSVTorchDset,
            dataset_kwargs={
                "csv_time_series": dset,
                "context_points": config.context_points,
                "target_points": config.target_points,
                "time_resolution": config.time_resolution,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
        INV_SCALER = dset.reverse_scaling
        SCALER = dset.apply_scaling
    elif config.dset in ["mnist", "cifar"]:
        if config.dset == "mnist":
            config.target_points = 28 - config.context_points
            datasetCls = stf.data.image_completion.MNISTDset
            PLOT_VAR_IDXS = [18, 24]
            PLOT_VAR_NAMES = ["18th row", "24th row"]
        else:
            config.target_points = 32 * 32 - config.context_points
            datasetCls = stf.data.image_completion.CIFARDset
            PLOT_VAR_IDXS = [0]
            PLOT_VAR_NAMES = ["Reds"]
        DATA_MODULE = stf.data.DataModule(
            datasetCls=datasetCls,
            dataset_kwargs={"context_points": config.context_points},
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
    elif config.dset == "copy":
        # set these manually in case the model needs them
        config.context_points = config.copy_length + int(
            config.copy_include_lags
        )  # seq + lags
        config.target_points = config.copy_length
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.copy_task.CopyTaskDset,
            dataset_kwargs={
                "length": config.copy_length,
                "copy_vars": config.copy_vars,
                "lags": config.copy_lags,
                "mask_prob": config.copy_mask_prob,
                "include_lags": config.copy_include_lags,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
    elif config.dset == "cont_copy":
        # set these manually in case the model needs them
        config.context_points = config.copy_length + int(
            config.copy_include_lags
        )  # seq + lags
        config.target_points = config.copy_length
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.cont_copy_task.ContCopyTaskDset,
            dataset_kwargs={
                "length": config.copy_length,
                "copy_vars": config.copy_vars,
                "lags": config.copy_lags,
                "include_lags": config.copy_include_lags,
                "magnitude_matters": config.copy_mag_matters,
                "freq_shift": config.copy_freq_shift,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
    elif config.dset == "m4":
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.m4.M4TorchDset,
            dataset_kwargs={
                "data_path": config.data_path,
                "resolutions": args.resolutions,
                "max_len": args.max_len,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            collate_fn=stf.data.m4.pad_m4_collate,
            overfit=args.overfit,
        )
        NULL_VAL = -1.0
        PAD_VAL = -1.0

    elif config.dset == "wiki":
        DATA_MODULE = stf.data.DataModule(
            stf.data.wiki.WikipediaTorchDset,
            dataset_kwargs={
                "data_path": config.data_path,
                "forecast_duration": args.forecast_duration,
                "max_len": args.max_len,
            },
            batch_size=args.batch_size,
            workers=args.workers,
            collate_fn=stf.data.wiki.pad_wiki_collate,
            overfit=args.overfit,
        )
        NULL_VAL = -1.0
        PAD_VAL = -1.0
        SCALER = stf.data.wiki.WikipediaTorchDset.scale
        INV_SCALER = stf.data.wiki.WikipediaTorchDset.inverse_scale
    elif config.dset == "monash":
        root_dir = config.root_dir
        DATA_MODULE = stf.data.monash.monash_dloader.make_monash_dmodule(
            root_dir=root_dir,
            max_len=config.max_len,
            include=config.include,
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=config.overfit,
        )
        NULL_VAL = -64.0
        PAD_VAL = -64.0
    elif config.dset == "ettm1":
        target_cols = ["HUFL", "HULL", "MUFL", "MULL", "LUFL", "LULL", "OT"]
        dset = stf.data.CSVTimeSeries(
            data_path=config.data_path,
            target_cols=target_cols,
            ignore_cols=[],
            val_split=4.0 / 20,  # from informer
            test_split=4.0 / 20,  # from informer
            time_col_name="date",
            time_features=["month", "day", "weekday", "hour"],
        )
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.CSVTorchDset,
            dataset_kwargs={
                "csv_time_series": dset,
                "context_points": config.context_points,
                "target_points": config.target_points,
                "time_resolution": config.time_resolution,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
        INV_SCALER = dset.reverse_scaling
        SCALER = dset.apply_scaling
        NULL_VAL = None
        # PAD_VAL = -32.0
        PLOT_VAR_NAMES = target_cols
        PLOT_VAR_IDXS = [i for i in range(len(target_cols))]
    elif config.dset == "weather":
        data_path = config.data_path
        dset = stf.data.CSVTimeSeries(
            data_path=config.data_path,
            target_cols=[],
            ignore_cols=[],
            # paper says 7:1:2 split
            val_split=1.0 / 10,
            test_split=2.0 / 10,
            time_col_name="date",
            time_features=["day", "hour", "minute"],
        )
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.CSVTorchDset,
            dataset_kwargs={
                "csv_time_series": dset,
                "context_points": config.context_points,
                "target_points": config.target_points,
                "time_resolution": config.time_resolution,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
        INV_SCALER = dset.reverse_scaling
        SCALER = dset.apply_scaling
        NULL_VAL = None
        PLOT_VAR_NAMES = ["OT", "p (mbar)", "raining (s)"]
        PLOT_VAR_IDXS = [20, 0, 15]
    else:
        time_col_name = "Datetime"
        data_path = config.data_path
        time_features = ["year", "month", "day", "weekday", "hour", "minute"]
        if config.dset == "asos":
            if data_path == "auto":
                data_path = "./data/temperature-v1.csv"
            target_cols = ["ABI", "AMA", "ACT", "ALB", "JFK", "LGA"]
        elif config.dset == "solar_energy":
            if data_path == "auto":
                data_path = "./data/solar_AL_converted.csv"
            target_cols = [str(i) for i in range(137)]
        elif "toy" in config.dset:
            if data_path == "auto":
                if config.dset == "toy2":
                    data_path = "./data/toy_dset2.csv"
                else:
                    raise ValueError(f"Unrecognized toy dataset {config.dset}")
            target_cols = [f"D{i}" for i in range(1, 21)]
        elif config.dset == "exchange":
            if data_path == "auto":
                data_path = "./data/exchange_rate_converted.csv"
            target_cols = [
                "Australia",
                "United Kingdom",
                "Canada",
                "Switzerland",
                "China",
                "Japan",
                "New Zealand",
                "Singapore",
            ]
        elif config.dset == "traffic":
            if data_path == "auto":
                data_path = "./data/traffic.csv"
            target_cols = [f"Lane {i}" for i in range(862)]
            time_col_name = "FakeTime"
            time_features = ["month", "day"]

        dset = stf.data.CSVTimeSeries(
            data_path=data_path,
            target_cols=target_cols,
            ignore_cols="all",
            time_col_name=time_col_name,
            time_features=time_features,
            val_split=0.2,
            test_split=0.2,
        )
        DATA_MODULE = stf.data.DataModule(
            datasetCls=stf.data.CSVTorchDset,
            dataset_kwargs={
                "csv_time_series": dset,
                "context_points": config.context_points,
                "target_points": config.target_points,
                "time_resolution": config.time_resolution,
            },
            batch_size=config.batch_size,
            workers=config.workers,
            overfit=args.overfit,
        )
        INV_SCALER = dset.reverse_scaling
        SCALER = dset.apply_scaling
        NULL_VAL = None

    return (
        DATA_MODULE,
        INV_SCALER,
        SCALER,
        NULL_VAL,
        PLOT_VAR_IDXS,
        PLOT_VAR_NAMES,
        PAD_VAL,
    )


def create_callbacks(config, save_dir):
    filename = f"{config.run_name}_" + str(uuid.uuid1()).split("-")[0]
    model_ckpt_dir = os.path.join(save_dir, filename)
    config.model_ckpt_dir = model_ckpt_dir
    saving = pl.callbacks.ModelCheckpoint(
        dirpath=model_ckpt_dir,
        monitor="val/loss",
        mode="min",
        filename=f"{config.run_name}" + "{epoch:02d}",
        save_top_k=1,
        auto_insert_metric_name=True,
    )
    callbacks = [saving]

    if not config.no_earlystopping:
        callbacks.append(
            pl.callbacks.early_stopping.EarlyStopping(
                monitor="val/loss",
                patience=config.patience,
            )
        )

    if config.wandb:
        callbacks.append(pl.callbacks.LearningRateMonitor())

    if config.model == "lstm":
        callbacks.append(
            stf.callbacks.TeacherForcingAnnealCallback(
                start=config.teacher_forcing_start,
                end=config.teacher_forcing_end,
                steps=config.teacher_forcing_anneal_steps,
            )
        )
    if config.time_mask_loss:
        callbacks.append(
            stf.callbacks.TimeMaskedLossCallback(
                start=config.time_mask_start,
                end=config.time_mask_end,
                steps=config.time_mask_anneal_steps,
            )
        )
    return callbacks
