# TimeGrad vs Conditional TimeGrad Comparison

This notebook trains a lightweight vanilla TimeGrad model and a conditioning-aware TimeGrad variant on the same data, then visualizes their forecasts side by side.

## Setup
The imports assume this notebook lives in `notebooks/` under the project root.

In [None]:
import math
import os
from pathlib import Path
from typing import Tuple

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import torch
from torch import nn
from torch.utils.data import DataLoader

try:
    project_root = Path(__file__).resolve().parents[1]
except NameError:
    project_root = Path(os.getcwd()).resolve().parent

if str(project_root) not in os.sys.path:
    os.sys.path.insert(0, str(project_root))

from src.data_loader import TimeGradDataModule
from src.models.timegrad_core.timegrad_base import TimeGradBase
from src.predictor import ConditionalTimeGradPredictionNetwork
from src.training import ConditionalTimeGradTrainingNetwork

plt.style.use('seaborn-v0_8-darkgrid')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f'Using device: {device}')

## Load & preprocess data
This matches the preprocessing flow used in the conditional demo notebook.

In [None]:
data_base = project_root / 'data'
raw_base = data_base / 'raw'

raw_paths = {
    'target': raw_base / 'target.parquet',
    'market': raw_base / 'market.parquet',
    'daily_macro': raw_base / 'daily_macro.parquet',
    'monthly_macro': raw_base / 'monthly_macro.parquet',
    'quarterly_macro': raw_base / 'quarterly_macro.parquet',
}
data_dict = {name: pd.read_parquet(path) for name, path in raw_paths.items()}
print({k: v.shape for k, v in data_dict.items()})

context_length = 64
prediction_length = 24
batch_size = 4

dm = TimeGradDataModule(
    data_dict=data_dict,
    seq_len=context_length,
    forecast_horizon=prediction_length,
    batch_size=batch_size,
    device=str(device),
)
dm.preprocess_and_split()
dm.build_datasets()

feature_cols = dm.get_feature_columns_by_type()
feature_cols['cond_dynamic'] = feature_cols['daily'] + feature_cols['monthly']
feature_cols['cond_static'] = feature_cols['regime']

print('Feature columns by type:')
for k, v in feature_cols.items():
    print(f"  {k}: {len(v)} cols")


## Minimal vanilla TimeGrad wrapper
The helper below mirrors the conditional wrapper's normalization but omits conditioning. It allows quick training and sampling for the baseline model.

In [None]:
class VanillaTimeGradWrapper(nn.Module):
    def __init__(self, target_dim: int, prediction_length: int, scale_eps: float = 1e-5):
        super().__init__()
        self.prediction_length = prediction_length
        self.scale_eps = scale_eps
        self.model = TimeGradBase(target_dim=target_dim, prediction_length=prediction_length)

    def _normalize(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        combined = torch.cat([x_hist, x_future], dim=1)
        loc = combined.mean(dim=1, keepdim=True)
        scale = combined.std(dim=1, keepdim=True).clamp_min(self.scale_eps)
        return loc, scale, (x_future - loc) / scale

    def forward(self, x_hist: torch.Tensor, x_future: torch.Tensor) -> torch.Tensor:
        loc, scale, x_future_norm = self._normalize(x_hist, x_future)
        loss = self.model(x_future_norm)
        return loss

    @torch.no_grad()
    def sample(self, x_hist: torch.Tensor, num_samples: int = 100) -> Tuple[torch.Tensor, torch.Tensor]:
        loc = x_hist.mean(dim=1, keepdim=True)
        scale = x_hist.std(dim=1, keepdim=True).clamp_min(self.scale_eps)

        samples = [
            self.model.diffusion.sample(
                batch_size=x_hist.shape[0],
                horizon=self.prediction_length,
            ).transpose(1, 2) * scale
            + loc
            for _ in range(num_samples)
        ]
        stacked = torch.stack(samples, dim=0)  # [S, B, horizon, C]
        return stacked, loc.squeeze(1)


## Train both models (small demo epochs)
Training loops keep epochs small for a quick comparison; increase them for higher fidelity.

In [None]:
vanilla = VanillaTimeGradWrapper(target_dim=len(feature_cols['target']), prediction_length=prediction_length).to(device)
conditional_train = ConditionalTimeGradTrainingNetwork(
    target_dim=len(feature_cols['target']),
    context_length=context_length,
    prediction_length=prediction_length,
    cond_dynamic_dim=len(feature_cols['cond_dynamic']),
    cond_static_dim=len(feature_cols['cond_static']),
).to(device)

vanilla_opt = torch.optim.Adam(vanilla.parameters(), lr=1e-3)
conditional_opt = torch.optim.Adam(conditional_train.parameters(), lr=1e-3)

num_epochs = 50
for epoch in range(num_epochs):
    vanilla.train(); conditional_train.train()
    for batch in DataLoader(dm.train_set, batch_size=batch_size, shuffle=True):
        x_hist = batch['x_hist'].to(device)
        x_future = batch['x_future'].to(device)
        cond_dynamic = batch['cond_dynamic'].to(device)
        cond_static = batch['cond_static'].to(device)

        vanilla_opt.zero_grad()
        vanilla_loss = vanilla(x_hist, x_future).mean()
        vanilla_loss.backward()
        vanilla_opt.step()

        conditional_opt.zero_grad()
        cond_loss = conditional_train(x_hist, x_future, cond_dynamic, cond_static)
        cond_loss.backward()
        conditional_opt.step()


## Build conditional predictor for sampling
The predictor reuses the trained conditional backbone and supports masked or full-horizon sampling.

In [None]:
conditional_pred = ConditionalTimeGradPredictionNetwork(
    target_dim=len(feature_cols['target']),
    context_length=context_length,
    prediction_length=prediction_length,
    cond_dynamic_dim=len(feature_cols['cond_dynamic']),
    cond_static_dim=len(feature_cols['cond_static']),
).to(device)
conditional_pred.model.load_state_dict(conditional_train.model.state_dict())


## Compare forecasts on a held-out batch
We draw a batch from the test split, generate multiple samples from each model, and plot their mean forecasts against the ground truth target window.

In [None]:
test_batch = next(iter(dm.test_dataloader()))
x_hist = test_batch['x_hist'].to(device)
x_future = test_batch['x_future'].to(device)
cond_dynamic = test_batch['cond_dynamic'].to(device)
cond_static = test_batch['cond_static'].to(device)

# Vanilla samples
vanilla.eval()
vanilla_samples, _ = vanilla.sample(x_hist, num_samples=200)
vanilla_mean = vanilla_samples.mean(dim=0).detach().cpu().numpy()[0]

# Conditional samples
conditional_pred.eval()
cond_samples = conditional_pred.sample_autoregressive(
    x_hist=x_hist,
    cond_dynamic=cond_dynamic,
    cond_static=cond_static,
    num_samples=200,
    sampling_strategy='masked_step',
)
cond_mean = cond_samples.mean(dim=0).detach().cpu().numpy()[0]

truth = x_future[0].cpu().numpy()


In [None]:
def sample_crps(samples: torch.Tensor, target: torch.Tensor) -> float:
    # samples: [S, B, T, D], target: [B, T, D]
    s1 = (samples - target).abs().mean(dim=0)
    s2 = (samples.unsqueeze(0) - samples.unsqueeze(1)).abs().mean(dim=(0, 1))
    return (s1 - 0.5 * s2).mean().item()

vanilla_crps = sample_crps(vanilla_samples, x_future)
cond_crps = sample_crps(cond_samples, x_future)
print(f'CRPS (vanilla TimeGrad): {vanilla_crps:.4f}')
print(f'CRPS (conditional TimeGrad): {cond_crps:.4f}')

In [None]:
plt.figure(figsize=(10, 5))
plt.plot(truth[:, 0], label='Ground truth', linewidth=2)
plt.plot(vanilla_mean[:, 0], label='Vanilla TimeGrad mean')
plt.plot(cond_mean[:, 0], label='Conditional TimeGrad mean')
plt.xlabel('Horizon step')
plt.ylabel('Target value')
plt.title('Forecast comparison')
plt.legend()
plt.tight_layout()
plt.show()