# TimeGrad: Training and Inference on S&P500 Data

This notebook demonstrates how to run the TimeGrad model on S&P500 time-series data. It covers:

1.  **Data Fetching**: Downloading S&P500 data using `yfinance`.
2.  **Data Preparation**: Normalizing and converting data into a GluonTS-compatible format.
3.  **Model Training**: Training the TimeGrad model using a simplified Dataloader.
4.  **Inference**: Generating synthetic time-series samples using the trained model.
5.  **Visualization**: Comparing the synthetic data with the real data.

In [None]:
import sys
import os
from typing import List
import torch
import matplotlib.pyplot as plt
import numpy as np

# Add the project's root directory to the Python path.
# This allows us to import our custom modules like `data_fetch` and `src.pts`.
module_path = os.path.abspath(os.getcwd()) # Assumes notebook is in 'Grok Assisted' directory
if module_path not in sys.path:
    sys.path.insert(0, module_path)

from gluonts.dataset.split import split

# Import from the modules directly
from src.pts.model.time_grad.diffusion import Diffusion, DiffusionConfig
from src.pts.model.time_grad.time_grad_network import TimeGradNetwork, NetworkConfig
from src.pts.model.time_grad.time_grad_estimator import TimeGradEstimator, EstimatorConfig
from src.pts.model.time_grad.time_grad_predictor import TimeGradPredictor
from src.pts.trainer import Trainer
from data_fetch import fetch_sp500_data, prepare_gluonts_dataset, DataConfig

print("Imports successful and path is set up.")

## 1. Configuration

We define all the configurations for data, network, diffusion, and the estimator in one place.

In [None]:
# Data configuration
data_config = DataConfig(start_date='2024-01-01', context_length=24, prediction_length=24)

# Network configuration
net_config = NetworkConfig(input_dim=1, hidden_dim=40, num_layers=2)

# Diffusion process configuration
diff_config = DiffusionConfig(num_steps=100, beta_start=1e-4, beta_end=0.1)

# Estimator/Training configuration
est_config = EstimatorConfig(learning_rate=1e-3, batch_size=64)

print("Configurations loaded.")

## 2. Data Fetching and Preparation

We fetch the S&P 500 data, normalize it, and prepare it as a GluonTS dataset.

In [None]:
try:
    print("Fetching S&P500 data...")
    # Fetch data
    df = fetch_sp500_data(data_config)
    print(f"Fetched {len(df)} data points")

    # Simple normalization
    data_mean = df['close'].mean()
    data_std = df['close'].std()
    df_normalized = df.copy()
    df_normalized['close'] = (df['close'] - data_mean) / data_std
    print(f"Data normalized - Mean: {data_mean:.2f}, Std: {data_std:.2f}")

    # Create dataset
    dataset = prepare_gluonts_dataset(df_normalized, data_config)
    print("Dataset prepared successfully")
except Exception as e:
    print(f"An error occurred during data fetching/preparation: {e}")

## 3. Data Splitting and Dataloader

The data is split into training and test sets. We then use a simple, custom data loader for training. This loader yields batches with the target data and initial hidden states for the RNN.

In [None]:
# Split data
train_dataset, test_dataset = split(dataset, offset=-data_config.prediction_length)
train_data = list(train_dataset)

print(f"Training data: {len(train_data)} series")

# Simple dataloader for our MVP
class SimpleLoader:
    def __init__(self, ds, context_length, num_layers, hidden_dim):
        self.ds = ds
        self.context_length = context_length
        self.num_layers = num_layers
        self.hidden_dim = hidden_dim
        
    def __iter__(self):
        for item in self.ds:
            target_data = item['target']
            
            # For simplicity, we take the last `context_length` as one sample
            if len(target_data) >= self.context_length:
                context_data = target_data[-self.context_length:]
                # Reshape to (batch_size=1, seq_len, input_dim=1)
                target_tensor = torch.tensor(
                    context_data.reshape(1, -1, 1), 
                    dtype=torch.float32
                )
                
                # Initial hidden state for the RNN
                hidden = (
                    torch.zeros(self.num_layers, 1, self.hidden_dim),
                    torch.zeros(self.num_layers, 1, self.hidden_dim)
                )
                
                yield {
                    'target': target_tensor,
                    'hidden': hidden
                }
                
    def __len__(self):
        return len(self.ds)

train_loader = SimpleLoader(
    train_data, 
    data_config.context_length, 
    net_config.num_layers, 
    net_config.hidden_dim
)

print("Dataloader created.")

## 4. Model Initialization

We initialize the core components of TimeGrad: the diffusion process, the noise prediction network, the estimator (which handles the loss), and the trainer.

In [None]:
print("Initializing models...")
# Initialize models
diffusion = Diffusion(diff_config)
network = TimeGradNetwork(net_config)
estimator = TimeGradEstimator(est_config, network, diffusion)
trainer = Trainer(estimator, epochs=10)

print("Models initialized.")

## 5. Training

Now, we train the model. The progress bar will show the status for each epoch.

In [None]:
print("Starting training...")
losses = trainer.train(train_loader)
print(f"Training completed. Final loss: {losses[-1]:.4f}")

## 6. Inference: Generating Synthetic Samples

With the trained network, we use the `TimeGradPredictor` to generate new time-series samples through the reverse diffusion process.

In [None]:
print("Generating synthetic samples...")
# Generate samples
predictor = TimeGradPredictor(network, diffusion)
samples = predictor.predict(
    context_length=data_config.prediction_length, 
    num_samples=100
)
print(f"Generated samples shape: {samples.shape}")

## 7. Visualization and Analysis

Finally, we visualize the results. We denormalize the generated samples and plot them against the original data to qualitatively assess the model's performance.

In [None]:
import pandas as pd

# Helper class to mimic GluonTS/PyTorchTS Forecast object
class SimpleForecast:
    """
    A simplified forecast object that holds samples and can compute quantiles.
    This is to make the generated samples compatible with plotting utilities
    that expect a forecast-like object.
    """
    def __init__(self, samples, start_date, freq):
        # samples are expected to be denormalized, shape: (num_samples, prediction_length)
        self.samples = samples
        self.start_date = start_date
        self.freq = freq
        self.prediction_length = samples.shape[1]
        
        # Create a pandas DatetimeIndex for the forecast horizon
        self.index = pd.date_range(
            start=self.start_date, 
            periods=self.prediction_length, 
            freq=self.freq
        )

    def quantile(self, q: float) -> np.ndarray:
        """
        Compute a quantile of the samples.
        
        Args:
            q: The quantile to compute (e.g., 0.5 for median).
            
        Returns:
            An array of shape (prediction_length,).
        """
        return np.quantile(self.samples, q, axis=0)

def plot_forecast(
    target: pd.Series, 
    forecast: SimpleForecast, 
    prediction_length: int, 
    prediction_intervals: tuple = (50.0, 90.0), 
    color: str = 'g'
):
    """
    Plots a univariate time series forecast against the target data.
    
    This is an adaptation of a more general plotting utility to fit the
    univariate case of this notebook. It shows the median forecast and
    prediction intervals.
    """
    print("Creating forecast visualization...")
    plt.figure(figsize=(15, 7))
    ax = plt.gca()

    # Plot the target data, showing some history before the forecast
    # We plot the last 2 * prediction_length to give context
    target_plot_range = target.iloc[-2 * prediction_length:]
    target_plot_range.plot(ax=ax, label="Observations", color="blue")

    # Define the percentiles to calculate for the prediction intervals
    ps = [50.0] + [
        50.0 + f * c / 2.0 for c in prediction_intervals for f in [-1.0, 1.0]
    ]
    percentiles_sorted = sorted(set(ps))

    def alpha_for_percentile(p):
        return (p / 100.0) ** 0.3

    # Get quantile data from the forecast object
    ps_data = [forecast.quantile(p / 100.0) for p in percentiles_sorted]
    i_p50 = len(percentiles_sorted) // 2

    # Plot the median forecast
    p50_data = ps_data[i_p50]
    p50_series = pd.Series(data=p50_data, index=forecast.index)
    p50_series.plot(color=color, ls="-", label="Median Forecast", ax=ax)

    # Plot the prediction intervals as shaded areas
    for i in range(len(percentiles_sorted) // 2):
        ptile = percentiles_sorted[i]
        alpha = alpha_for_percentile(ptile)
        ax.fill_between(
            forecast.index,
            ps_data[i],
            ps_data[-i - 1],
            facecolor=color,
            alpha=alpha,
            interpolate=True,
        )
        # Use a hack to create labels for the prediction intervals in the legend
        pd.Series(data=p50_data[:1], index=forecast.index[:1]).plot(
            color=color,
            alpha=alpha,
            linewidth=10,
            label=f"{100 - ptile * 2}% Interval",
            ax=ax,
        )

    ax.legend(loc="upper left")
    plt.title("S&P500 Forecast vs. Real Data")
    plt.ylabel("Price")
    plt.xlabel("Date")
    plt.grid(which="both")
    plt.tight_layout()
    plt.show()

# --- Execution part ---

# 1. Denormalize the generated samples
denormalized_samples = samples.detach().numpy() * data_std + data_mean

# 2. Determine the forecast start date and frequency
# The forecast starts right after the training data ends.
# We compare it against the last `prediction_length` part of our dataset.
forecast_start_date = df.index[-data_config.prediction_length]
freq = 'h' if data_config.interval == '1h' else 'D'

# 3. Create the forecast object
forecast_obj = SimpleForecast(
    samples=denormalized_samples,
    start_date=forecast_start_date,
    freq=freq
)

# 4. Plot the forecast
# The `target` is the full historical data series.
plot_forecast(
    target=df['close'], 
    forecast=forecast_obj, 
    prediction_length=data_config.prediction_length,
    color='r' # Use red for the forecast
)

# 5. Print statistics for comparison
original_data = df['close'][-data_config.prediction_length:].values
median_forecast = forecast_obj.quantile(0.5)

print("\n=== Results Summary ===")
print(f"Real data      - Mean: {np.mean(original_data):.2f}, Std: {np.std(original_data):.2f}")
print(f"Median forecast - Mean: {np.mean(median_forecast):.2f}, Std: {np.std(median_forecast):.2f}")