# GridFM Quick Start Guide

This notebook demonstrates how to use GridFM for multi-task energy forecasting.

In [None]:
# Install GridFM (if not already installed)
# !pip install -e ..

In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from gridfm import GridFM, GridFMConfig
from gridfm.data.nyiso import NYISODataLoader

## 1. Load Data

First, let's download and load NYISO data.

In [None]:
# Initialize data loader
data_loader = NYISODataLoader(
    data_dir="../data/nyiso",
    start_date="2023-01-01",
    end_date="2023-03-31",
)

# Download data (skip if already downloaded)
data_loader.download_data(force=False)

# Load data
data = data_loader.load_data()
print(f"Loaded data keys: {data.keys()}")

In [None]:
# Visualize load data
if 'load' in data and len(data['load']) > 0:
    fig, ax = plt.subplots(figsize=(12, 4))
    for zone in ['N.Y.C.', 'LONGIL', 'WEST']:
        if zone in data['load'].columns:
            ax.plot(data['load'][zone][:288*7], label=zone, alpha=0.7)
    ax.set_xlabel('Time (5-min intervals)')
    ax.set_ylabel('Load (MW)')
    ax.set_title('NYISO Load by Zone (1 Week)')
    ax.legend()
    plt.show()

## 2. Create Dataset

In [None]:
# Create PyTorch dataset
dataset = data_loader.create_dataset(
    sequence_length=288,  # 24 hours input
    forecast_horizon=24,  # 2 hours prediction
    stride=12  # 1 hour between samples
)

print(f"Dataset size: {len(dataset)} samples")

# Get a sample
sample = dataset[0]
print(f"Input shape: {sample['input'].shape}")
print(f"Load target shape: {sample['load'].shape}")

## 3. Initialize Model

In [None]:
# Create model configuration
config = GridFMConfig(
    backbone="simple",  # Use simple backbone for demo
    hidden_dim=128,
    num_freq_components=32,
    num_zones=11,
    tasks=["load", "lbmp"],
    forecast_horizon=24,
    enable_power_balance=True,
    physics_weight=0.1,
)

# Initialize model
model = GridFM(config)
print(f"Model parameters: {sum(p.numel() for p in model.parameters()):,}")

## 4. Make Predictions

In [None]:
# Get adjacency matrix
adjacency = data_loader.get_adjacency_matrix()

# Prepare input
x = sample['input'].unsqueeze(0).unsqueeze(-1)  # Add batch and feature dims
print(f"Input shape: {x.shape}")

# Forward pass
model.eval()
with torch.no_grad():
    predictions = model(x, adjacency)

print(f"\nPrediction shapes:")
for task, pred in predictions.items():
    if isinstance(pred, torch.Tensor):
        print(f"  {task}: {pred.shape}")

## 5. Visualize Predictions

In [None]:
# Plot predictions vs targets
if 'load' in predictions:
    fig, axes = plt.subplots(1, 2, figsize=(14, 4))
    
    # Load prediction for one zone
    zone_idx = 0
    pred_load = predictions['load'][0, :, zone_idx].numpy()
    target_load = sample['load'][:, zone_idx].numpy()
    
    axes[0].plot(target_load, label='Target', linewidth=2)
    axes[0].plot(pred_load, label='Prediction', linewidth=2, linestyle='--')
    axes[0].set_xlabel('Time Step (5 min)')
    axes[0].set_ylabel('Normalized Load')
    axes[0].set_title(f'Load Forecast - Zone {zone_idx}')
    axes[0].legend()
    
    # Error distribution
    error = pred_load - target_load
    axes[1].hist(error, bins=20, edgecolor='black')
    axes[1].set_xlabel('Prediction Error')
    axes[1].set_ylabel('Frequency')
    axes[1].set_title('Error Distribution')
    axes[1].axvline(x=0, color='r', linestyle='--')
    
    plt.tight_layout()
    plt.show()

## 6. Training (Mini Example)

In [None]:
from torch.utils.data import DataLoader
from torch.optim import Adam

# Create dataloader
train_dataloader = DataLoader(dataset, batch_size=8, shuffle=True)

# Setup optimizer
optimizer = Adam(model.parameters(), lr=1e-4)

# Training loop (mini example)
model.train()
losses = []

for epoch in range(3):
    epoch_loss = 0
    for batch in train_dataloader:
        x = batch['input'].unsqueeze(-1)
        targets = {k: batch[k] for k in ['load', 'lbmp'] if k in batch}
        
        optimizer.zero_grad()
        predictions = model(x, adjacency)
        loss, _ = model.compute_loss(predictions, targets, adjacency)
        loss.backward()
        optimizer.step()
        
        epoch_loss += loss.item()
    
    avg_loss = epoch_loss / len(train_dataloader)
    losses.append(avg_loss)
    print(f"Epoch {epoch+1}/3, Loss: {avg_loss:.4f}")

# Plot loss
plt.plot(losses)
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.title('Training Loss')
plt.show()

## Next Steps

- Try loading a pre-trained model with `GridFM.from_pretrained()`
- Experiment with different configurations
- Train on the full dataset using `scripts/train.py`
- Explore the explainability features with SHAP