In [None]:
import pandas as pd
import math
import numpy as np
import torch
from torch import nn
from torch.utils.data import Dataset, DataLoader
import matplotlib.pyplot as plt
import seaborn as sns

# 0. Core model + dataset (keep as-is)
# -------- Sliding-window dataset --------

class SlidingWindowDataset(Dataset):
    """
    Turn a time series (T, N) into input/target windows for forecasting.

    series: np.ndarray, shape (T,) or (T, N)
    input_len: length of history window
    pred_len:  length of prediction horizon
    """
    def __init__(self, series, input_len, pred_len, stride=1):
        if series.ndim == 1:
            series = series[:, None]  # (T,) -> (T,1)
        T, N = series.shape

        X_list, y_list = [], []
        for start in range(0, T - input_len - pred_len + 1, stride):
            end = start + input_len
            target_end = end + pred_len
            X_list.append(series[start:end])        # (input_len, N)
            y_list.append(series[end:target_end])   # (pred_len, N)

        self.X = torch.from_numpy(np.stack(X_list)).float()  # (B, L, N)
        self.y = torch.from_numpy(np.stack(y_list)).float()  # (B, H, N)

    def __len__(self):
        return self.X.shape[0]

    def __getitem__(self, idx):
        return self.X[idx], self.y[idx]


# -------- Positional encodings --------

class SinusoidalPositionalEncoding(nn.Module):
    """Standard Transformer sinusoidal positional encoding."""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        pe = torch.zeros(max_len, d_model)
        position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
        div_term = torch.exp(
            torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)
        )
        pe[:, 0::2] = torch.sin(position * div_term)
        pe[:, 1::2] = torch.cos(position * div_term)
        self.register_buffer("pe", pe.unsqueeze(0))  # (1, max_len, d_model)

    def forward(self, x):
        seq_len = x.size(1)
        return x + self.pe[:, :seq_len]


class LearnablePositionalEncoding(nn.Module):
    """Learnable positional embeddings (one vector per index)."""
    def __init__(self, d_model, max_len=5000):
        super().__init__()
        self.pos_embed = nn.Embedding(max_len, d_model)

    def forward(self, x):
        b, seq_len, _ = x.size()
        positions = torch.arange(seq_len, device=x.device).unsqueeze(0).expand(b, -1)
        return x + self.pos_embed(positions)


# -------- Time-series Transformer --------

class TimeSeriesTransformer(nn.Module):
    """
    Simple encoder-only Transformer for time-series forecasting.

    - Supports sinusoidal or learnable positional encoding.
    - Predicts 'pred_len' future steps for each input series.
    """
    def __init__(
        self,
        input_dim=1,
        d_model=64,
        n_heads=4,
        num_layers=2,
        dim_feedforward=128,
        dropout=0.1,
        pred_len=1,
        pos_encoding_type="sin",  # "sin" or "learned"
        max_len=500,
    ):
        super().__init__()
        self.input_dim = input_dim
        self.pred_len = pred_len

        self.input_proj = nn.Linear(input_dim, d_model)

        if pos_encoding_type == "sin":
            self.pos_encoding = SinusoidalPositionalEncoding(d_model, max_len)
        elif pos_encoding_type == "learned":
            self.pos_encoding = LearnablePositionalEncoding(d_model, max_len)
        else:
            self.pos_encoding = None  # no PE (for ablation)

        encoder_layer = nn.TransformerEncoderLayer(
            d_model=d_model,
            nhead=n_heads,
            dim_feedforward=dim_feedforward,
            dropout=dropout,
            batch_first=True,   # (B, L, D)
        )
        self.encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)

        self.fc_out = nn.Linear(d_model, pred_len * input_dim)
        self.attn_weights = []  # Store attention weights


        def custom_forward(module, src, src_mask=None, src_key_padding_mask=None):
            attn_output, attn_weight = module.self_attn(
                src, src, src, attn_mask=src_mask, key_padding_mask=src_key_padding_mask, need_weights=True, average_attn_weights=False
            )
            module.attn_weight = attn_weight  # save for extraction
            src = module.norm1(src + module.dropout1(attn_output))
            src = module.norm2(src + module.dropout2(module.linear2(module.dropout(module.activation(module.linear1(src))))))
            return src

    
        for layer in self.encoder.layers:
            def layer_forward(src, *args, **kwargs):
                return custom_forward(layer, src, src_mask=kwargs.get("src_mask"), src_key_padding_mask=kwargs.get("src_key_padding_mask"))
            layer.forward = layer_forward


    def forward(self, x):
        """
        x: (batch, seq_len, input_dim)
        returns: (batch, pred_len, input_dim)
        """
        h = self.input_proj(x)  # (B, L, d_model)

        if self.pos_encoding is not None:
            h = self.pos_encoding(h)

        h = self.encoder(h)     # (B, L, d_model)
        h_last = h[:, -1, :]    # (B, d_model)

        out = self.fc_out(h_last)                     # (B, pred_len * input_dim)
        out = out.view(-1, self.pred_len, self.input_dim)  # (B, pred_len, input_dim)
        return out



def plot_attention_map(attn, input_len, title="", head=0):
    """
    attn: tensor of shape (num_heads, B, T_q, T_k)
    """
    head_attn = attn[head, 0]  # choose head 0, batch 0
    plt.figure(figsize=(8, 6))
    sns.heatmap(head_attn[:, :input_len], cmap="viridis")
    plt.title(title)
    plt.xlabel("Key Time Step")
    plt.ylabel("Query Time Step")
    plt.tight_layout()
    plt.show()



In [18]:
# 1. Single training wrapper (reuse everywhere)
def train_transformer_on_series(
    series,
    input_len=64,
    pred_len=1,
    batch_size=32,
    n_epochs=10,
    lr=1e-3,
    pos_encoding_type="sin",
    device=None,
):
    """
    Train a TimeSeriesTransformer on one time series.

    Returns
    -------
    model : nn.Module
    history : dict with keys 'train_loss', 'val_loss'
    test_loss : float (MSE on held-out test set)
    splits : (train_set, val_set, test_set)
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # Ensure shape (T, N)
    if isinstance(series, np.ndarray) and series.ndim == 1:
        series = series[:, None]
    elif torch.is_tensor(series) and series.ndim == 1:
        series = series.unsqueeze(-1).cpu().numpy()

    dataset = SlidingWindowDataset(series, input_len, pred_len)
    n_total = len(dataset)
    n_train = int(0.7 * n_total)
    n_val = int(0.15 * n_total)
    n_test = n_total - n_train - n_val

    train_set, val_set, test_set = torch.utils.data.random_split(
        dataset,
        [n_train, n_val, n_test],
        generator=torch.Generator().manual_seed(42),
    )

    train_loader = DataLoader(train_set, batch_size=batch_size, shuffle=True)
    val_loader   = DataLoader(val_set,   batch_size=batch_size)
    test_loader  = DataLoader(test_set,  batch_size=batch_size)

    model = TimeSeriesTransformer(
        input_dim=dataset.X.shape[-1],
        d_model=32,
        n_heads=4,
        num_layers=1,
        dim_feedforward=64,
        dropout=0.1,
        pred_len=pred_len,
        pos_encoding_type=pos_encoding_type,
        max_len=input_len,
    ).to(device)

    criterion = nn.MSELoss()
    optimizer = torch.optim.Adam(model.parameters(), lr=lr)

    history = {"train_loss": [], "val_loss": []}

    for epoch in range(n_epochs):
        # --- train ---
        model.train()
        train_loss = 0.0
        for X_batch, y_batch in train_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)

            optimizer.zero_grad()
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            loss.backward()
            optimizer.step()
            train_loss += loss.item() * X_batch.size(0)

        train_loss /= len(train_loader.dataset)

        # --- validation ---
        model.eval()
        val_loss = 0.0
        with torch.no_grad():
            for X_batch, y_batch in val_loader:
                X_batch = X_batch.to(device)
                y_batch = y_batch.to(device)
                y_pred = model(X_batch)
                loss = criterion(y_pred, y_batch)
                val_loss += loss.item() * X_batch.size(0)

        val_loss /= len(val_loader.dataset)

        history["train_loss"].append(train_loss)
        history["val_loss"].append(val_loss)

        print(
            f"Epoch {epoch+1}/{n_epochs} "
            f"| train={train_loss:.4f} | val={val_loss:.4f}"
        )

    # --- test ---
    test_loss = 0.0
    model.eval()
    with torch.no_grad():
        for X_batch, y_batch in test_loader:
            X_batch = X_batch.to(device)
            y_batch = y_batch.to(device)
            y_pred = model(X_batch)
            loss = criterion(y_pred, y_batch)
            test_loss += loss.item() * X_batch.size(0)

    test_loss /= len(test_loader.dataset)
    print(f"Test MSE: {test_loss:.4f}")

    return model, history, test_loss, (train_set, val_set, test_set)


In [19]:
def run_experiment_suite(
    series_dict,
    mode_label,
    horizons=(1, 10),
    n_epochs=10,
):
    """
    Run Transformer on all families, all pos_enc variants, and all horizons.

    Parameters
    ----------
    series_dict : dict
        Maps family name to config:
            {
                "series": np.array (T,),
                "input_len": int,
                "pos_encs": ["sin", "learned", ...],
            }
    mode_label : str
        E.g., "raw" or "stationary"
    horizons : tuple of int
        Forecast steps to evaluate.
    n_epochs : int
        Training epochs for each run.

    Returns
    -------
    results_df : pd.DataFrame
        One row per (family, pos_enc, horizon).
    histories : dict
        Keys = (mode, family, pos_enc, horizon) → training history.
    models : dict
        Keys = same as above → trained model.
    attn_outputs : dict
        Keys = same as above → list of attention maps (one per layer).
    """
    results = []
    histories = {}
    models = {}
    attn_outputs = {}

    for family, cfg in series_dict.items():
        series = cfg["series"]
        input_len = cfg["input_len"]
        pos_encs = cfg["pos_encs"]

        for horizon in horizons:
            for pos_enc in pos_encs:
                print(f"\n=== {mode_label} | {family} | pos_enc={pos_enc} | horizon={horizon} ===")
                model, history, test_loss, splits = train_transformer_on_series(
                    series,
                    input_len=input_len,
                    pred_len=horizon,
                    n_epochs=n_epochs,
                    pos_encoding_type=pos_enc,
                )

                key = (mode_label, family, pos_enc, horizon)
                histories[key] = history
                models[key] = model

                # Extract one attention map from test set
                with torch.no_grad():
                    test_loader = DataLoader(splits[2], batch_size=1)
                    X_sample, _ = next(iter(test_loader))
                    model.eval()
                    _ = model(X_sample)  # forward to trigger .attn_weights
                    attn_outputs[key] = model.attn_weights  # list of (H, B, L, L)

                results.append({
                    "mode": mode_label,
                    "family": family,
                    "pos_enc": pos_enc,
                    "horizon": horizon,
                    "test_MSE": float(test_loss),
                })

    results_df = pd.DataFrame(results)
    return results_df, histories, models, attn_outputs


In [20]:
from synthetic_data import (
    simulate_heavy_t_ar1,
    simulate_garch_11,
    simulate_regime_switching_mean,
    simulate_1_over_f_noise,
    simulate_season_trend_outliers,
    simulate_random_walk,
    simulate_jump_diffusion,
    simulate_multi_seasonality,
    simulate_trend_breaks,  
)


def zscore(X, axis=0, eps=1e-8):
    mean = X.mean(axis=axis, keepdims=True)
    std = X.std(axis=axis, keepdims=True)
    return (X - mean) / (std + eps)

# Fix functions
def make_stationary_ar1(X): return zscore(X)
def make_stationary_garch(r): return zscore(r)
def make_stationary_regime(X): return zscore(np.diff(X, axis=0))
def make_stationary_long_memory(X): return zscore(X)
def make_stationary_season_trend(Y, season_period):
    dY = np.diff(Y, axis=0)
    if dY.shape[0] <= season_period:
        raise ValueError("Too short for season differencing")
    return zscore(dY[season_period:] - dY[:-season_period])
def make_stationary_random_walk(X): return zscore(np.diff(X, axis=0))
def make_stationary_jump_diffusion(X):
    X = np.clip(X, np.percentile(X, 1), np.percentile(X, 99))
    return zscore(X)
def make_stationary_multi_season(X, period1, period2):
    X1 = X[period1:] - X[:-period1]
    X2 = X1[period2:] - X1[:-period2]
    return zscore(X2)
def make_stationary_trend_break(X): return zscore(np.diff(X, axis=0))


def fix_distribution(family, series):
    if family == "Heavy-tailed AR(1)":
        return zscore(series)
    elif family == "GARCH(1,1)":
        return zscore(series)
    elif family == "Regime-switching":
        return zscore(np.diff(series, axis=0))
    elif family == "1/f noise":
        return zscore(series)
    elif family == "Season+Trend+Outliers":
        dY = np.diff(series, axis=0)
        if dY.shape[0] > 50:
            season_adj = dY[50:] - dY[:-50]
            return zscore(season_adj)
        else:
            return zscore(dY)
    elif family in ["Random Walk", "Jump Diffusion", "Trend Breaks"]:
        return zscore(np.diff(series, axis=0))
    elif family == "Multi-Seasonality":
        return zscore(series)
    else:
        raise ValueError(f"No fix defined for: {family}")


In [21]:

# Synthetic generation
T = 5000
n_epochs = 10
X_A = simulate_heavy_t_ar1(T, N=1, seed=0)
r_B, _ = simulate_garch_11(T, N=1, seed=0)
X_C, _ = simulate_regime_switching_mean(T, N=1, seed=0)
X_D = simulate_1_over_f_noise(T, N=1, seed=0)
Y_E, _ = simulate_season_trend_outliers(T=2000, N=1, seed=0)
RW = simulate_random_walk(T, N=1, seed=0)
JD = simulate_jump_diffusion(T, N=1, seed=0)
MS = simulate_multi_seasonality(T, N=1, freqs=(50, 150), amps=(1.0, 0.8), seed=0)
TB = simulate_trend_breaks(T, N=1, seed=0)

# Raw dictionary
raw_dict = {
    "Heavy-tailed AR(1)": {"series": X_A[:, 0], "input_len": 64, "pos_encs": ["sin", "learned"]},
    "GARCH(1,1)":         {"series": r_B[:, 0], "input_len": 64, "pos_encs": ["sin"]},
    "Regime-switching":   {"series": X_C[:, 0], "input_len": 64, "pos_encs": ["sin"]},
    "1/f noise":          {"series": X_D[:, 0], "input_len": 128, "pos_encs": ["sin"]},
    "Season+Trend+Outliers": {"series": Y_E[:, 0], "input_len": 128, "pos_encs": ["sin"]},
    "Random Walk":        {"series": RW[:, 0], "input_len": 128, "pos_encs": ["sin"]},
    "Jump Diffusion":     {"series": JD[:, 0], "input_len": 128, "pos_encs": ["sin"]},
    "Multi-Seasonality":  {"series": MS[:, 0], "input_len": 128, "pos_encs": ["sin"]},
    "Trend Breaks":       {"series": TB[:, 0], "input_len": 128, "pos_encs": ["sin"]},
}

# Stationary versions
stationary_dict = {
    "Heavy-tailed AR(1)": (X_A[:, 0], make_stationary_ar1(X_A)[:, 0], 64),
    "GARCH(1,1)":         (r_B[:, 0], make_stationary_garch(r_B)[:, 0], 64),
    "Regime-switching":   (X_C[:, 0], make_stationary_regime(X_C)[:, 0], 64),
    "1/f noise":          (X_D[:, 0], make_stationary_long_memory(X_D)[:, 0], 128),
    "Season+Trend+Outliers": (Y_E[:, 0], make_stationary_season_trend(Y_E, 50)[:, 0], 128),
    "Random Walk":        (RW[:, 0], make_stationary_random_walk(RW)[:, 0], 128),
    "Jump Diffusion":     (JD[:, 0], make_stationary_jump_diffusion(JD)[:, 0], 128),
    "Multi-Seasonality":  (MS[:, 0], make_stationary_multi_season(MS, 50, 150)[:, 0], 128),
    "Trend Breaks":       (TB[:, 0], make_stationary_trend_break(TB)[:, 0], 128),
}

# Convert to fixed_dict format
fixed_dict = {
    name: {
        "series": fixed,
        "input_len": input_len,
        "pos_encs": raw_dict[name]["pos_encs"],
    }
    for name, (_, fixed, input_len) in stationary_dict.items()
}

In [22]:
def run_experiment_suite(
    series_dict,
    mode_label,
    horizons=(1, 10),
    n_epochs=10,
    preprocessing_fn=None,  # Optional preprocessing applied to each series
):
    """
    Run Transformer on all families, all pos_enc variants, and all horizons.
    
    series_dict[family] = {
        "series": np.array (T,),
        "input_len": int,
        "pos_encs": ["sin", "learned", ...],
    }
    """
    results = []
    histories = {}
    models = {}

    for family, cfg in series_dict.items():
        series = cfg["series"]
        input_len = cfg["input_len"]
        pos_encs = cfg["pos_encs"]

        # Apply fix or transformation if provided
        if preprocessing_fn is not None:
            try:
                series = preprocessing_fn(family, series)
            except Exception as e:
                print(f"[Warning] Preprocessing failed for {family}: {e}")
                continue

        for horizon in horizons:
            for pos_enc in pos_encs:
                print(f"\n=== {mode_label} | {family} | pos_enc={pos_enc} | horizon={horizon} ===")
                model, history, test_loss, splits = train_transformer_on_series(
                    series,
                    input_len=input_len,
                    pred_len=horizon,
                    n_epochs=n_epochs,
                    pos_encoding_type=pos_enc,
                )

                key = (mode_label, family, pos_enc, horizon)
                histories[key] = history
                models[key] = model

                results.append({
                    "mode": mode_label,
                    "family": family,
                    "pos_enc": pos_enc,
                    "horizon": horizon,
                    "test_MSE": float(test_loss),
                })

    results_df = pd.DataFrame(results)
    return results_df, histories, models


In [23]:
# Run raw series experiments
raw_results_df, raw_histories, raw_models = run_experiment_suite(
    raw_dict,
    mode_label="raw",
    horizons=(1, 10),
    n_epochs=n_epochs,
)

# Run fixed (stationary) series experiments
fixed_results_df, fixed_histories, fixed_models = run_experiment_suite(
    fixed_dict,
    mode_label="fixed",
    horizons=(1, 10),
    n_epochs=n_epochs,
)

# Combine results
combined_df = pd.concat([raw_results_df, fixed_results_df], ignore_index=True)
combined_df


=== raw | Heavy-tailed AR(1) | pos_enc=sin | horizon=1 ===
Epoch 1/10 | train=3.1493 | val=2.9273
Epoch 2/10 | train=3.0621 | val=2.8827
Epoch 3/10 | train=3.0233 | val=2.8156
Epoch 4/10 | train=2.9972 | val=2.8169
Epoch 5/10 | train=2.9865 | val=2.8427
Epoch 6/10 | train=2.9709 | val=2.8099
Epoch 7/10 | train=2.9779 | val=2.7903
Epoch 8/10 | train=2.9618 | val=2.7563
Epoch 9/10 | train=2.9570 | val=2.7918
Epoch 10/10 | train=2.9557 | val=2.8024
Test MSE: 2.5840

=== raw | Heavy-tailed AR(1) | pos_enc=learned | horizon=1 ===
Epoch 1/10 | train=3.2868 | val=2.9425
Epoch 2/10 | train=3.0808 | val=2.8487
Epoch 3/10 | train=3.0159 | val=2.8571
Epoch 4/10 | train=2.9984 | val=2.8137
Epoch 5/10 | train=2.9948 | val=2.8384
Epoch 6/10 | train=2.9686 | val=2.8328
Epoch 7/10 | train=2.9619 | val=2.8182
Epoch 8/10 | train=2.9479 | val=2.8048
Epoch 9/10 | train=2.9462 | val=2.8109
Epoch 10/10 | train=2.9464 | val=2.8424
Test MSE: 2.5665

=== raw | Heavy-tailed AR(1) | pos_enc=sin | horizon=10 ===

KeyboardInterrupt: 

In [None]:
# 4. Visualizations
def plot_mse_raw_vs_stationary(combined_df):
    for horizon in sorted(combined_df["horizon"].unique()):
        df_h = combined_df[combined_df["horizon"] == horizon]

        plt.figure(figsize=(8, 4))
        # pivot: index=family+pos_enc, columns=mode, values=test_MSE
        df_pivot = df_h.pivot_table(
            index=["family", "pos_enc"],
            columns="mode",
            values="test_MSE",
        )

        df_pivot.plot(kind="bar", ax=plt.gca())
        plt.title(f"Test MSE – horizon={horizon} (raw vs stationary)")
        plt.ylabel("Test MSE")
        plt.xlabel("Family / positional encoding")
        plt.xticks(rotation=45, ha="right")
        plt.tight_layout()
        plt.show()

plot_mse_raw_vs_stationary(combined_df)
