# Spread Simulation with XAI Annotations (TFT + SEIR)
This notebook implements a Temporal Fusion Transformer (TFT) model trained on COVID-19 data (OWID) enriched with SEIR model states and external factors.

## Objectives
1. **Data Preprocessing**: Load OWID data, clean, and engineer features.
2. **SEIR Modeling**: Fit an Enhanced SEIR model to each country's data to generate features.
3. **TFT Training**: Train using specific data splits (Primary, Rising 3rd Wave) and target variables (Active Cases or New Cases).

In [None]:
!pip install pytorch-forecasting pytorch-lightning pandas numpy scipy matplotlib scikit-learn seaborn optuna

In [None]:
import os
import warnings
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from scipy.integrate import odeint
from scipy.optimize import minimize
from sklearn.preprocessing import StandardScaler, MinMaxScaler, LabelEncoder
import torch
import pytorch_lightning as pl
from pytorch_lightning.callbacks import EarlyStopping, LearningRateMonitor
from pytorch_forecasting import TimeSeriesDataSet, TemporalFusionTransformer, Baseline, GroupNormalizer
from pytorch_forecasting.data import GroupNormalizer
from pytorch_forecasting.metrics import MSE, QuantileLoss, MAE, SMAPE
import optuna
from optuna.integration import PyTorchLightningPruningCallback

warnings.filterwarnings("ignore")

## 1. Data Loading (Steps 2-4)

In [None]:
def load_and_preprocess_data():
    paths = [
        "/content/drive/MyDrive/SEIR/Dataset/cleaned_owid-covid-data.csv",
        "owid-covid-data.csv",
        "https://covid.ourworldindata.org/data/owid-covid-data.csv"
    ]
    df = None
    for path in paths:
        if os.path.exists(path) or path.startswith("http"):
            try:
                print(f"Attempting to load data from: {path}")
                df = pd.read_csv(path)
                print("Success!")
                break
            except Exception as e:
                print(f"Failed to load {path}: {e}")
    
    if df is None:
        raise FileNotFoundError("Could not load OWID dataset from any source.")
    
    # Step 2: Column Selection
    cols = ['location', 'date', 'new_cases', 'new_deaths', 'new_tests', 
            'positive_rate', 'stringency_index', 
            'aged_65_older', 'aged_70_older', 'diabetes_prevalence', 'cardiovasc_death_rate', 
            'population']
    
    # Filter to available columns
    df = df[[c for c in cols if c in df.columns]].copy()
    df['date'] = pd.to_datetime(df['date'])
    
    # Sort (Pre-requisite for rolling)
    df = df.sort_values(['location', 'date']).reset_index(drop=True)
    
    # Step 3: SEIR State Estimation
    # Use filled new_cases for calculation to handle NaNs safely during rolling
    df['new_cases_filled'] = df.groupby('location')['new_cases'].apply(lambda x: x.fillna(0))
    
    # I: Rolling sum 14d / pop
    df['rolling_14_cases'] = df.groupby('location')['new_cases_filled'].rolling(14, min_periods=1).sum().reset_index(0, drop=True)
    df['I'] = df['rolling_14_cases'] / df['population']
    
    # R: (CumCases - 14d_sum) / pop
    # Use cumsum of the filled new_cases
    df['cumulative_cases'] = df.groupby('location')['new_cases_filled'].cumsum()
    df['R'] = (df['cumulative_cases'] - df['rolling_14_cases']) / df['population']
    
    # S: 1 - R - I
    df['S'] = 1 - df['R'] - df['I']
    df['S'] = df['S'].clip(lower=0)
    
    # E: Rolling sum 5d * 2.5 / pop
    df['rolling_5_cases'] = df.groupby('location')['new_cases_filled'].rolling(5, min_periods=1).sum().reset_index(0, drop=True)
    df['E'] = (df['rolling_5_cases'] * 2.5) / df['population']
    
    # Verify SEIR Sum
    seir_sum = df['S'] + df['E'] + df['I'] + df['R']
    if (seir_sum - 1.0).abs().max() > 0.05:
        print("WARNING: SEIR sum deviates by > 0.05 in some rows")
    
    # Step 4: Feature Engineering
    df['day_of_week'] = df['date'].dt.dayofweek
    df['day_of_year'] = df['date'].dt.dayofyear
    
    df['day_of_week_sin'] = np.sin(2 * np.pi * df['day_of_week'] / 7)
    df['day_of_week_cos'] = np.cos(2 * np.pi * df['day_of_week'] / 7)
    df['day_of_year_sin'] = np.sin(2 * np.pi * df['day_of_year'] / 365)
    df['day_of_year_cos'] = np.cos(2 * np.pi * df['day_of_year'] / 365)
    df['is_weekend'] = (df['day_of_week'] >= 5).astype(int)
    
    # Log transforms (log1p)
    df['new_cases_log'] = np.log1p(df['new_cases'])
    df['new_deaths_log'] = np.log1p(df['new_deaths'])
    df['new_tests_log'] = np.log1p(df['new_tests'])
    
    # Cleanup temp columns
    df.drop(columns=['new_cases_filled', 'rolling_14_cases', 'rolling_5_cases', 'cumulative_cases', 'day_of_week', 'day_of_year'], inplace=True)
    
    return df

raw_df = load_and_preprocess_data()
print(f"Data Loaded & Feature Engineered: {raw_df.shape}")

## 1.5 Exploratory Data Analysis (EDA)
Analyizing distributions, missing values, and trends before strict cleaning.

In [None]:
print("Top 5 Rows:")
print(raw_df.head())

# 1. Overview
print("\nDataset Info:")
print(raw_df.info())
print("\nSummary Statistics:")
print(raw_df.describe())

# 2. Missing Values
plt.figure(figsize=(12, 6))
sns.heatmap(raw_df.isnull(), cbar=False, cmap='viridis')
plt.title("Missing Values Heatmap")
plt.show()

# 3. Distributions
fig, axes = plt.subplots(2, 2, figsize=(15, 10))
sns.histplot(raw_df['new_cases'], bins=50, ax=axes[0, 0], log_scale=(False, True))
axes[0, 0].set_title("New Cases Distribution (Log Scale Y)")

if 'new_deaths' in raw_df.columns:
    sns.histplot(raw_df['new_deaths'], bins=50, ax=axes[0, 1], log_scale=(False, True))
    axes[0, 1].set_title("New Deaths Distribution (Log Scale Y)")

if 'positive_rate' in raw_df.columns:
    sns.boxplot(x=raw_df['positive_rate'], ax=axes[1, 0])
    axes[1, 0].set_title("Positive Rate Boxplot")

if 'stringency_index' in raw_df.columns:
    sns.histplot(raw_df['stringency_index'], bins=50, ax=axes[1, 1])
    axes[1, 1].set_title("Stringency Index Distribution")
plt.tight_layout()
plt.show()

# 4. Temporal Trends (Top 5 Locations)
top_locs = raw_df.groupby('location')['new_cases'].sum().nlargest(5).index
plt.figure(figsize=(15, 6))
sns.lineplot(data=raw_df[raw_df['location'].isin(top_locs)], x='date', y='new_cases', hue='location')
plt.title("New Cases over Time (Top 5 Locations)")
plt.show()

# 5. Correlation Matrix
numeric_df = raw_df.select_dtypes(include=[np.number])
plt.figure(figsize=(12, 10))
sns.heatmap(numeric_df.corr(), annot=False, cmap='coolwarm', vmin=-1, vmax=1)
plt.title("Feature Correlation Matrix")
plt.show()

## 2. Cleaning & Scaling (Steps 5-7)

In [None]:
def clean_and_scale_data(df):
    # Step 5: Data Cleaning
    # 1. Sort
    df = df.sort_values(['location', 'date']).reset_index(drop=True)
    
    # 2. Fill (Forward then Backward)
    fill_cols = ['stringency_index', 'positive_rate', 'new_tests', 'new_deaths']
    for col in fill_cols:
        if col in df.columns:
            df[col] = df.groupby('location')[col].ffill().bfill()
            
    # Fill new_cases NaNs with 0 (keep genuine 0s)
    df['new_cases'] = df['new_cases'].fillna(0)
    
    # Re-compute logs after fill to be consistent
    df['new_cases_log'] = np.log1p(df['new_cases'])
    if 'new_deaths' in df.columns: df['new_deaths_log'] = np.log1p(df['new_deaths'])
    if 'new_tests' in df.columns: df['new_tests_log'] = np.log1p(df['new_tests'])
    
    # 3. Drop locations < 60 rows
    counts = df['location'].value_counts()
    valid_locs = counts[counts >= 60].index
    df = df[df['location'].isin(valid_locs)].copy()
    
    # 4. Drop null population
    df = df.dropna(subset=['population'])
    
    # 5. Drop rows where new_cases_log is null
    df = df.dropna(subset=['new_cases_log'])
    
    # 6. Clip positive_rate
    if 'positive_rate' in df.columns:
        df['positive_rate'] = df['positive_rate'].clip(0, 1)
        
    # 7. Clip SEIR
    for col in ['S', 'E', 'I', 'R']:
        df[col] = df[col].clip(0, 1)
        
    # Step 6: Static Covariate Encoding
    static_vars = ['aged_65_older', 'aged_70_older', 'diabetes_prevalence', 'cardiovasc_death_rate']
    # Extract first non-null
    static_df = df.groupby('location')[static_vars].first().reset_index()
    # Impute Global Median
    for col in static_vars:
        med = static_df[col].median()
        static_df[col] = static_df[col].fillna(med)
        
    # Label Encode Location
    le = LabelEncoder()
    static_df['location_idx'] = le.fit_transform(static_df['location'])
    location_map = dict(zip(le.classes_, le.transform(le.classes_)))
    
    # Scale Static (StandardScaler)
    static_scaler = StandardScaler()
    static_df[static_vars] = static_scaler.fit_transform(static_df[static_vars])
    
    # Merge back
    df = df.drop(columns=static_vars)
    df = df.merge(static_df, on='location', how='left')
    
    # Step 7: Temporal Scaling
    temporal_vars = ['S', 'E', 'I', 'R', 'new_cases_log', 'new_deaths_log', 'new_tests_log', 'positive_rate', 'stringency_index']
    temporal_scalers = {}
    
    # Training cutoff for fitting scalers (Rising 3rd Wave End)
    TRAIN_CUTOFF = "2021-12-31"
    train_mask = df['date'] <= pd.to_datetime(TRAIN_CUTOFF)
    
    for col in temporal_vars:
        if col in df.columns:
            scaler = StandardScaler()
            train_subset = df.loc[train_mask, col].values.reshape(-1, 1)
            scaler.fit(train_subset)
            df[col] = scaler.transform(df[col].values.reshape(-1, 1)).flatten()
            temporal_scalers[col] = scaler
            
    # Add time_idx for TFT
    df['time_idx'] = df.groupby('location').cumcount()
    
    return df, temporal_scalers, static_scaler, location_map

processed_df, temporal_scalers, static_scaler, loc_map = clean_and_scale_data(raw_df)
print(f"Processed Data Shape: {processed_df.shape}")

## 3. Creating TimeSeriesDataSet

In [None]:
SPLIT_CONFIGS = {
    "Primary": {
        "train_end": "2021-11-29",
        "val_start": "2021-11-30", "val_end": "2021-12-14",
        "test_start": "2021-12-15", "test_end": "2021-12-29"
    },
    "Rising 3rd Wave": {
        "train_end": "2021-12-31",
        "val_start": "2022-01-01", "val_end": "2022-01-15",
        "test_start": "2022-01-16", "test_end": "2022-01-30"
    },
    "Falling 3rd Wave": {
        "train_end": "2022-01-31",
        "val_start": "2022-02-01", "val_end": "2022-02-15",
        "test_start": "2022-02-16", "test_end": "2022-03-02"
    },
    "Post 3rd Wave": {
        "train_end": "2022-02-28",
        "val_start": "2022-03-01", "val_end": "2022-03-15",
        "test_start": "2022-03-16", "test_end": "2022-03-30"
    }
}

def create_tft_dataset(df, split_name="Rising 3rd Wave", target_variable="new_cases_log", max_prediction_length=7, max_encoder_length=14):
    print(f"Creating Dataset for split: {split_name}")
    config = SPLIT_CONFIGS[split_name]
    
    train_end = pd.to_datetime(config['train_end'])
    test_end = pd.to_datetime(config['test_end'])
    
    df_filtered = df[df['date'] <= test_end].copy()
    
    target_col = "new_cases_log"
    print(f"Target Variable: {target_col}")

    training_data = df_filtered[df_filtered['date'] <= train_end]
    training_cutoff = training_data['time_idx'].max()
    
    print(f"Training Cutoff Index: {training_cutoff} (Date: {train_end})")

    training = TimeSeriesDataSet(
        df_filtered[lambda x: x.time_idx <= training_cutoff],
        time_idx="time_idx",
        target=target_col,
        group_ids=["location_idx" if "location_idx" in df.columns else "location"],
        min_encoder_length=max_encoder_length // 2,
        max_encoder_length=max_encoder_length,
        min_prediction_length=1,
        max_prediction_length=max_prediction_length,
        static_reals=["aged_65_older", "aged_70_older", "diabetes_prevalence", "cardiovasc_death_rate"],
        time_varying_known_reals=["time_idx", "day_of_week_sin", "day_of_week_cos", "day_of_year_sin", "day_of_year_cos", "is_weekend", "stringency_index"],
        time_varying_unknown_reals=[
            target_col, 
            "S", "E", "I", "R", "positive_rate", "new_deaths_log", "new_tests_log"
        ],
        target_normalizer=None, # Already scaled
        add_relative_time_idx=True,
        add_target_scales=True,
        add_encoder_length=True,
    )

    validation = TimeSeriesDataSet.from_dataset(
        training, df_filtered, predict=True, stop_randomization=True
    )
    
    batch_size = 64
    train_dataloader = training.to_dataloader(train=True, batch_size=batch_size, num_workers=0)
    val_dataloader = validation.to_dataloader(train=False, batch_size=batch_size * 2, num_workers=0)
    
    return training, train_dataloader, val_dataloader

training_dataset, train_loader, val_loader = create_tft_dataset(
    processed_df, 
    split_name="Rising 3rd Wave", 
    target_variable="new_cases_log"
)

## 4. Hyperparameter Tuning with Optuna
Optimizing TFT hyperparameters (Hidden Size, Dropout, Attention Heads, Learning Rate) for both MSE and Quantile Loss.

In [None]:
import optuna
from optuna.integration import PyTorchLightningPruningCallback

def optimize_tft(trial, loss_metric, train_dl, val_dl, max_epochs=10):
    # Hyperparameters
    gradient_clip_val = trial.suggest_float("gradient_clip_val", 0.01, 1.0)
    hidden_size = trial.suggest_categorical("hidden_size", [16, 32, 64, 128])
    dropout = trial.suggest_float("dropout", 0.1, 0.5)
    hidden_continuous_size = trial.suggest_categorical("hidden_continuous_size", [8, 16, 32])
    attention_head_size = trial.suggest_categorical("attention_head_size", [1, 2, 4])
    learning_rate = trial.suggest_float("learning_rate", 1e-4, 1e-2, log=True)
    
    # Model
    tft = TemporalFusionTransformer.from_dataset(
        training_dataset,
        learning_rate=learning_rate,
        hidden_size=hidden_size,
        attention_head_size=attention_head_size,
        dropout=dropout,
        hidden_continuous_size=hidden_continuous_size,
        loss=loss_metric,
        log_interval=10,
        reduce_on_plateau_patience=4,
    )
    
    # Trainer
    trainer = pl.Trainer(
        max_epochs=max_epochs,
        accelerator="auto",
        gradient_clip_val=gradient_clip_val,
        callbacks=[
            EarlyStopping(monitor="val_loss", min_delta=1e-4, patience=3, verbose=False, mode="min"),
            PyTorchLightningPruningCallback(trial, monitor="val_loss")
        ],
        enable_checkpointing=False, # Disable for tuning
        logger=False
    )
    
    trainer.fit(
        tft,
        train_dataloaders=train_dl,
        val_dataloaders=val_dl,
    )
    
    return trainer.callback_metrics["val_loss"].item()

# --- Study 1: MSE ---
print("Starting Optuna Study for MSE Loss...")
study_mse = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
study_mse.optimize(lambda trial: optimize_tft(trial, MSE(), train_loader, val_loader), n_trials=10)

print("Best MSE Params:", study_mse.best_params)

# Train Final MSE Model with Best Params
best_params_mse = study_mse.best_params
tft_mse = TemporalFusionTransformer.from_dataset(
    training_dataset,
    loss=MSE(),
    **best_params_mse
)
trainer_mse = pl.Trainer(max_epochs=15, accelerator="auto", gradient_clip_val=best_params_mse["gradient_clip_val"])
trainer_mse.fit(tft_mse, train_dataloaders=train_loader, val_dataloaders=val_loader)


# --- Study 2: Quantile ---
print("Starting Optuna Study for Quantile Loss...")
study_quantile = optuna.create_study(direction="minimize", pruner=optuna.pruners.MedianPruner())
study_quantile.optimize(lambda trial: optimize_tft(trial, QuantileLoss([0.1, 0.5, 0.9]), train_loader, val_loader), n_trials=10)

print("Best Quantile Params:", study_quantile.best_params)

# Train Final Quantile Model with Best Params
best_params_quantile = study_quantile.best_params
tft_quantile = TemporalFusionTransformer.from_dataset(
    training_dataset,
    loss=QuantileLoss([0.1, 0.5, 0.9]),
    **best_params_quantile
)
trainer_quantile = pl.Trainer(max_epochs=15, accelerator="auto", gradient_clip_val=best_params_quantile["gradient_clip_val"])
trainer_quantile.fit(tft_quantile, train_dataloaders=train_loader, val_dataloaders=val_loader)

In [None]:
raw_predictions, x = tft_quantile.predict(val_loader, mode="raw", return_x=True)
for idx in range(3): 
    tft_quantile.plot_prediction(x, raw_predictions, idx=idx, add_loss_to_title=True)
    plt.show()