# 🌦️ Training Weather Forecasting Models

This notebook demonstrates the complete training pipeline for **Graph Neural Network (GNN)** and **Spectral Fourier Neural Operator (SFNO)** models using the NCEP weather dataset. It includes:

- Data loading and preprocessing
- Configuration setup
- Training for both GNN and SFNO models
- Performance visualization
- Saving final model checkpoints

In [None]:
# 📦 Import Dependencies
import os
import sys
sys.path.append("../")  # Ensure access to parent-level modules

import matplotlib.pyplot as plt
import torch

# Custom modules
from data.ncep_dataloader import create_ncep_dataloader
from src.model import *

## ⚙️ Configuration
Set up all configuration parameters for data, models, training, and logging.

In [None]:
# Define configuration dictionary
config = {
    # Dataset
    "data_dir": "../data/raw",
    "variables": ['air.2m.gauss.2024', 'uwnd.10m.gauss.2024', 'vwnd.10m.gauss.2024', 'slp.2024'],  # temperature, wind (u,v), surface pressure
    "history_steps": 3,
    "forecast_steps": 1,
    "batch_size": 32,
    "use_graph": True,
    "num_workers": 4,

    # Model
    "hidden_channels": 128,
    "num_layers": 3,
    "K": 5,
    "lmax": 16,

    # Training
    "epochs": 50,
    "learning_rate": 1e-3,
    "weight_decay": 1e-5,

    # Logging / Saving
    "checkpoint_dir": "../results/checkpoints",
    "logs_dir": "../results/logs"
}

# Ensure directories exist
os.makedirs(config["checkpoint_dir"], exist_ok=True)
os.makedirs(config["logs_dir"], exist_ok=True)

## 📥 Load NCEP Weather Data
Load training, validation, and test dataloaders using the provided `create_ncep_dataloader()` utility.

In [None]:
train_loader, val_loader, test_loader = create_ncep_dataloader(
    config["data_dir"],
    batch_size=config["batch_size"],
    variables=config["variables"],
    history_steps=config["history_steps"],
    forecast_steps=config["forecast_steps"],
    use_graph=config["use_graph"],
    num_workers=config["num_workers"]
)

print(f"Train samples: {len(train_loader.dataset)}")
print(f"Validation samples: {len(val_loader.dataset)}")
print(f"Test samples: {len(test_loader.dataset)}")

## 🧠 Train GNN Model
Run the training script for the GNN model. Ensure that the model architecture and training logic is implemented in `src/train/train_gnn.py`.

In [None]:
%run src/train/train_gnn.py

## 🌀 Train SFNO Model
Execute the SFNO model training using its respective script.

In [None]:
%run src/train/train_sfno.py

## 📊 Compare Training Performance
Extract TensorBoard logs and visualize loss curves for both GNN and SFNO models to assess convergence and generalization.

In [None]:
import tensorflow as tf
import pandas as pd

def get_tensorboard_data(log_path, tag):
    data = []
    for event_file in os.listdir(log_path):
        if event_file.startswith('events.'):
            full_path = os.path.join(log_path, event_file)
            for e in tf.compat.v1.train.summary_iterator(full_path):
                for v in e.summary.value:
                    if v.tag == tag:
                        data.append((e.step, v.simple_value))
    return pd.DataFrame(data, columns=['step', 'value'])

In [None]:
# Load training/validation losses for both models
gnn_train_loss = get_tensorboard_data(os.path.join(config["logs_dir"], "gnn_model"), "Loss/train")
gnn_val_loss = get_tensorboard_data(os.path.join(config["logs_dir"], "gnn_model"), "Loss/val")

sfno_train_loss = get_tensorboard_data(os.path.join(config["logs_dir"], "sfno_model"), "Loss/train")
sfno_val_loss = get_tensorboard_data(os.path.join(config["logs_dir"], "sfno_model"), "Loss/val")

In [None]:
# 📉 Plot training/validation loss curves
plt.figure(figsize=(12, 8))

plt.subplot(2, 1, 1)
plt.plot(gnn_train_loss['step'], gnn_train_loss['value'], label='GNN Train')
plt.plot(gnn_val_loss['step'], gnn_val_loss['value'], label='GNN Validation')
plt.title('GNN Model Training Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.subplot(2, 1, 2)
plt.plot(sfno_train_loss['step'], sfno_train_loss['value'], label='SFNO Train')
plt.plot(sfno_val_loss['step'], sfno_val_loss['value'], label='SFNO Validation')
plt.title('SFNO Model Training Curves')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)

plt.tight_layout()
plt.savefig("../results/model_training_comparison.png")
plt.show()

## 💾 Save Final Models
Persist the trained weights of both models for later use in evaluation or deployment.

In [None]:
torch.save(gnn_model.state_dict(), os.path.join(config["checkpoint_dir"], "final_gnn_model.pt"))
torch.save(sfno_model.state_dict(), os.path.join(config["checkpoint_dir"], "final_sfno_model.pt"))
print("✅ Final models saved successfully.")