# SwinUnet Wildfire Training & Test (Colab)

This notebook trains SwinUnet on a single fold of the WildfireSpreadTS HDF5 dataset, then evaluates on the held-out test split and reports Average Precision (AP) and F1.

**Prerequisites:**
- HDF5 dataset already on Google Drive (from `download_and_convert_dataset.ipynb`)
- **Runtime → Change runtime type → GPU** (T4 or better recommended)
- A [GitHub Personal Access Token](https://github.com/settings/tokens) stored as a Colab secret named `GITHUB_TOKEN` (Colab sidebar → key icon → Add new secret)

## 1. Mount Google Drive

In [19]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 2. Configuration (user-editable)

In [20]:
REPO_ORG   = "amindell11"   # Replace with your GitHub username or organisation
REPO_NAME  = "wildfire-ts-swin"
HDF5_DIR   = "/content/drive/MyDrive/wildfire_dataset/hdf5"
OUTPUT_DIR = "/content/drive/MyDrive/wildfire_runs/fold0"

DATA_FOLD_ID             = 0      # 0–11; which train/val/test year split
N_LEADING_OBSERVATIONS   = 1      # 1 or 5
MAX_EPOCHS               = 100
BATCH_SIZE               = 16
BASE_LR                  = 1e-4
FOCAL_GAMMA              = 2.0
CROP_SIDE_LENGTH         = 128
SEED                     = 42
NUM_WORKERS              = 2      # keep low on Colab

## 3. Clone repo and install dependencies

In [None]:
from google.colab import userdata
_repo_url = f"https://github.com/amindell11/wildfire-ts-swin.git"
!rm -rf /content/wildfire-ts-swin
!git clone $_repo_url /content/wildfire-ts-swin
!pip install -q -r /content/wildfire-ts-swin/requirements.txt

Cloning into '/content/wildfire-ts-swin'...
remote: Enumerating objects: 64, done.[K
remote: Counting objects: 100% (64/64), done.[K
remote: Compressing objects: 100% (48/48), done.[K
remote: Total 64 (delta 14), reused 64 (delta 14), pack-reused 0 (from 0)[K
Receiving objects: 100% (64/64), 51.91 KiB | 1.79 MiB/s, done.
Resolving deltas: 100% (14/14), done.


In [22]:
import sys
sys.path.insert(0, f'/content/{REPO_NAME}')

## 4. Verify GPU

In [23]:
!nvidia-smi

/bin/bash: line 1: nvidia-smi: command not found


In [24]:
import torch
if not torch.cuda.is_available():
    print("WARNING: No CUDA device detected. Training will be slow or fail.")
else:
    print(f"GPU: {torch.cuda.get_device_name(0)}")



## 5. Train

In [25]:
import os
import random
import types
import numpy as np
import torch
import torch.backends.cudnn as cudnn

from config import get_config
from networks.vision_transformer import SwinUnet
from trainer_wildfire import trainer_wildfire
from datasets.wildfire import N_FEATURES_PER_TIMESTEP

cudnn.benchmark = True
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)

in_chans = N_LEADING_OBSERVATIONS * N_FEATURES_PER_TIMESTEP
extra_opts = [
    'MODEL.SWIN.IN_CHANS', str(in_chans),
    'MODEL.PRETRAIN_CKPT', 'None',
]

args = types.SimpleNamespace(
    data_dir=HDF5_DIR,
    output_dir=OUTPUT_DIR,
    n_leading_observations=N_LEADING_OBSERVATIONS,
    n_leading_observations_test_adjustment=N_LEADING_OBSERVATIONS,
    crop_side_length=CROP_SIDE_LENGTH,
    load_from_hdf5=True,
    data_fold_id=DATA_FOLD_ID,
    max_epochs=MAX_EPOCHS,
    batch_size=BATCH_SIZE,
    base_lr=BASE_LR,
    num_workers=NUM_WORKERS,
    eval_interval=1,
    seed=SEED,
    n_gpu=1,
    focal_gamma=FOCAL_GAMMA,
    cfg=f'/content/{REPO_NAME}/configs/swin_tiny_patch4_window4_128_wildfire.yaml',
    opts=extra_opts,
    zip=False,
    cache_mode='part',
    resume=None,
    accumulation_steps=None,
    use_checkpoint=False,
    amp_opt_level='O1',
    tag=None,
    eval=False,
    throughput=False,
)

config = get_config(args)
os.makedirs(OUTPUT_DIR, exist_ok=True)

net = SwinUnet(config, img_size=config.DATA.IMG_SIZE, num_classes=2).cuda()
print(f"Model in_chans={in_chans}  (n_leading_observations={N_LEADING_OBSERVATIONS} × 40 features)")
print(f"Model parameters: {sum(p.numel() for p in net.parameters()) / 1e6:.1f}M")

trainer_wildfire(args, net, OUTPUT_DIR)

ModuleNotFoundError: No module named 'datasets.wildfire'

## 6. Evaluate on test split

In [None]:
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets.wildfire import WildfireDataset, get_year_split
from utils import compute_binary_metrics, compute_ap

train_years, val_years, test_years = get_year_split(DATA_FOLD_ID)

ckpt_path = f"{OUTPUT_DIR}/best_model.pth"
state_dict = torch.load(ckpt_path, map_location='cuda' if torch.cuda.is_available() else 'cpu')
if list(state_dict.keys())[0].startswith('module.'):
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
net.load_state_dict(state_dict)
net.eval()

db_test = WildfireDataset(
    data_dir=HDF5_DIR,
    included_fire_years=test_years,
    is_train=False,
    stats_years=train_years,
    n_leading_observations=N_LEADING_OBSERVATIONS,
    n_leading_observations_test_adjustment=N_LEADING_OBSERVATIONS,
    crop_side_length=CROP_SIDE_LENGTH,
    load_from_hdf5=True,
)

test_loader = DataLoader(
    db_test, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=True,
)

all_probs, all_preds, all_gts = [], [], []
with torch.no_grad():
    for x_batch, y_batch in tqdm(test_loader, desc="Test"):
        x_batch = x_batch.cuda()
        logits = net(x_batch)
        probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
        preds = (probs >= 0.5).astype(np.int64)
        gts = y_batch.numpy()
        all_probs.append(probs.flatten())
        all_preds.append(preds.flatten())
        all_gts.append(gts.flatten())

all_probs = np.concatenate(all_probs)
all_preds = np.concatenate(all_preds)
all_gts = np.concatenate(all_gts)

metrics = compute_binary_metrics(all_preds, all_gts)
ap = compute_ap(all_probs, all_gts)

print("\n" + "="*40)
print("Test Results")
print("="*40)
print(f"Test AP   : {ap:.4f}")
print(f"Test F1   : {metrics['f1']:.4f}")
print(f"Precision : {metrics['precision']:.4f}")
print(f"Recall    : {metrics['recall']:.4f}")
print("="*40)

## 7. Note on full 12-fold evaluation

To reproduce the paper's mean ± std AP, run this notebook **12 times** with `DATA_FOLD_ID` set to 0, 1, 2, … 11 (or run a loop in a script). Each fold uses a different train/val/test year split. Average the 12 test AP values to get the reported metric.