# SwinUnet Wildfire Training & Test (Colab)

This notebook trains SwinUnet on a single fold of the WildfireSpreadTS HDF5 dataset, then evaluates on the held-out test split and reports Average Precision (AP) and F1.

**Prerequisites:**
- HDF5 dataset already on Google Drive (from `download_and_convert_dataset.ipynb`)
- **Runtime → Change runtime type → GPU or TPU** (T4/V100 GPU or TPU v2/v3 recommended)
- A [GitHub Personal Access Token](https://github.com/settings/tokens) stored as a Colab secret named `GITHUB_TOKEN` (Colab sidebar → key icon → Add new secret)

## 1. Mount Google Drive

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


## 2. Configuration (user-editable)

In [2]:
REPO_ORG   = "amindell11"   # Replace with your GitHub username or organisation
REPO_NAME  = "wildfire-ts-swin"
HDF5_DIR   = "/content/drive/MyDrive/wildfire_dataset/hdf5"
OUTPUT_DIR = "/content/drive/MyDrive/wildfire_runs/fold0"

DATA_FOLD_ID             = 0      # 0–11; which train/val/test year split
N_LEADING_OBSERVATIONS   = 1      # 1 or 5
MAX_EPOCHS               = 100
BATCH_SIZE               = 16
BASE_LR                  = 1e-4
FOCAL_GAMMA              = 2.0
CROP_SIDE_LENGTH         = 128
SEED                     = 42
NUM_WORKERS              = 4      # parallel data workers; 4 is safe on Colab

# Checkpointing
CHECKPOINT_INTERVAL = 10          # save a full training state every N epochs
RESUME_CHECKPOINT   = None        # set to a path to resume, e.g.:
                                  #   f"{OUTPUT_DIR}/checkpoint_epoch049.pth"

## 2b. Copy HDF5 files to local disk

Google Drive I/O is ~100× slower than local NVMe. Copying once here keeps the A100 busy instead of waiting on Drive reads. Takes 1–3 min depending on dataset size; skips files already copied.

In [3]:
import os, shutil, time
from concurrent.futures import ThreadPoolExecutor, as_completed

_LOCAL_HDF5 = "/content/hdf5"
_COPY_WORKERS = 8  # parallel Drive→local transfers

# Collect all (src, dst) pairs
_pairs = []
for _root, _dirs, _files in os.walk(HDF5_DIR):
    _rel = os.path.relpath(_root, HDF5_DIR)
    _dst_dir = os.path.join(_LOCAL_HDF5, _rel)
    os.makedirs(_dst_dir, exist_ok=True)
    for _f in _files:
        _src = os.path.join(_root, _f)
        _dst = os.path.join(_dst_dir, _f)
        if not os.path.exists(_dst):
            _pairs.append((_src, _dst))

if not _pairs:
    print(f"All files already present at {_LOCAL_HDF5}")
else:
    print(f"Copying {len(_pairs)} files with {_COPY_WORKERS} parallel workers...")
    _t0 = time.time()
    _done = 0
    with ThreadPoolExecutor(max_workers=_COPY_WORKERS) as _pool:
        _futs = {_pool.submit(shutil.copy2, s, d): (s, d) for s, d in _pairs}
        for _fut in as_completed(_futs):
            _fut.result()  # re-raises any exception
            _done += 1
            if _done % 20 == 0 or _done == len(_pairs):
                print(f"  {_done}/{len(_pairs)} files copied  ({time.time()-_t0:.0f}s elapsed)")
    print(f"Done in {time.time() - _t0:.0f}s")

HDF5_DIR = _LOCAL_HDF5  # all subsequent cells read from local disk
print(f"HDF5_DIR → {HDF5_DIR}")

  copied 2018/fire_21458798.hdf5
  copied 2018/fire_21458801.hdf5
  copied 2018/fire_21458806.hdf5
  copied 2018/fire_21458836.hdf5
  copied 2018/fire_21458848.hdf5
  copied 2018/fire_21459234.hdf5
  copied 2018/fire_21459239.hdf5
  copied 2018/fire_21459242.hdf5
  copied 2018/fire_21459249.hdf5
  copied 2018/fire_21459253.hdf5
  copied 2018/fire_21538827.hdf5
  copied 2018/fire_21615465.hdf5
  copied 2018/fire_21615469.hdf5
  copied 2018/fire_21617464.hdf5
  copied 2018/fire_21688910.hdf5
  copied 2018/fire_21688916.hdf5
  copied 2018/fire_21690064.hdf5
  copied 2018/fire_21690071.hdf5
  copied 2018/fire_21690073.hdf5
  copied 2018/fire_21690102.hdf5
  copied 2018/fire_21693353.hdf5
  copied 2018/fire_21748766.hdf5
  copied 2018/fire_21748783.hdf5
  copied 2018/fire_21748798.hdf5
  copied 2018/fire_21748801.hdf5
  copied 2018/fire_21748813.hdf5
  copied 2018/fire_21751303.hdf5
  copied 2018/fire_21751305.hdf5
  copied 2018/fire_21751315.hdf5
  copied 2018/fire_21804572.hdf5
  copied 2

## 3. Clone repo and install dependencies

In [4]:
from google.colab import userdata
_repo_url = f"https://github.com/amindell11/wildfire-ts-swin.git"
!rm -rf /content/wildfire-ts-swin
!git clone $_repo_url /content/wildfire-ts-swin
!pip install -q -r /content/wildfire-ts-swin/requirements.txt
!git -C /content/wildfire-ts-swin log --oneline

Cloning into '/content/wildfire-ts-swin'...
remote: Enumerating objects: 100, done.[K
remote: Counting objects: 100% (100/100), done.[K
remote: Compressing objects: 100% (72/72), done.[K
remote: Total 100 (delta 39), reused 87 (delta 26), pack-reused 0 (from 0)[K
Receiving objects: 100% (100/100), 145.37 KiB | 5.38 MiB/s, done.
Resolving deltas: 100% (39/39), done.
[33mba2df2e[m[33m ([m[1;36mHEAD -> [m[1;32mmain[m[33m, [m[1;31morigin/main[m[33m, [m[1;31morigin/HEAD[m[33m)[m fix HDF5 copy to handle year subdirectories
[33m91fc34b[m copy HDF5 to local disk before training to avoid Drive I/O bottleneck
[33m4400505[m use tqdm.auto for proper Jupyter/Colab rendering; device-agnostic notebook
[33md3df338[m update log clarity
[33m62ac090[m fixes
[33mfc59794[m add support for TPU / XLA
[33m894aeda[m add notebooks
[33mb9fc63a[m add init.py to recognize datasets as a packages
[33m2e898e2[m feat: add AP val and Focal loss for swin training
[33m4aaa44f[m fea

In [5]:
import sys
sys.path.insert(0, f'/content/{REPO_NAME}')

## 4. Detect accelerator (GPU / TPU / CPU)

In [6]:
!nvidia-smi

Wed Feb 25 19:55:08 2026       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 580.82.07              Driver Version: 580.82.07      CUDA Version: 13.0     |
+-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  NVIDIA A100-SXM4-80GB          Off |   00000000:00:05.0 Off |                    0 |
| N/A   35C    P0             60W /  400W |    2384MiB /  81920MiB |      0%      Default |
|                                         |                        |             Disabled |
+-----------------------------------------+------------------------+----------------------+

+----------------------------------------------

In [7]:
import torch

# Detect the best available device: TPU (XLA) > GPU (CUDA) > CPU
try:
    import torch_xla.core.xla_model as xm
    DEVICE = xm.xla_device()
    print(f"TPU device detected: {DEVICE}")
except ImportError:
    if torch.cuda.is_available():
        DEVICE = torch.device('cuda')
        print(f"GPU: {torch.cuda.get_device_name(0)}")
    else:
        DEVICE = torch.device('cpu')
        print("WARNING: No GPU or TPU detected — running on CPU (will be slow).")

GPU: NVIDIA A100-SXM4-80GB


In [8]:
## 5a. Launch TensorBoard (run this before training; refresh the board while training runs)
import os
os.makedirs(f"{OUTPUT_DIR}/log", exist_ok=True)
%load_ext tensorboard
%tensorboard --logdir {OUTPUT_DIR}/log

Reusing TensorBoard on port 6006 (pid 2149), started 1:18:55 ago. (Use '!kill 2149' to kill it.)

<IPython.core.display.Javascript object>

## 5. Train

In [9]:
import os
import random
import types
import numpy as np
import torch
import torch.backends.cudnn as cudnn

from config import get_config
from networks.vision_transformer import SwinUnet
from trainer_wildfire import trainer_wildfire
from datasets.wildfire import N_FEATURES_PER_TIMESTEP

if torch.cuda.is_available():
    cudnn.benchmark = True
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(SEED)

in_chans = N_LEADING_OBSERVATIONS * N_FEATURES_PER_TIMESTEP
extra_opts = [
    'MODEL.SWIN.IN_CHANS', str(in_chans),
    'MODEL.PRETRAIN_CKPT', 'None',
]

args = types.SimpleNamespace(
    data_dir=HDF5_DIR,
    output_dir=OUTPUT_DIR,
    n_leading_observations=N_LEADING_OBSERVATIONS,
    n_leading_observations_test_adjustment=N_LEADING_OBSERVATIONS,
    crop_side_length=CROP_SIDE_LENGTH,
    load_from_hdf5=True,
    data_fold_id=DATA_FOLD_ID,
    max_epochs=MAX_EPOCHS,
    batch_size=BATCH_SIZE,
    base_lr=BASE_LR,
    num_workers=NUM_WORKERS,
    eval_interval=1,
    seed=SEED,
    n_gpu=1,
    focal_gamma=FOCAL_GAMMA,
    cfg=f'/content/{REPO_NAME}/configs/swin_tiny_patch4_window4_128_wildfire.yaml',
    opts=extra_opts,
    zip=False,
    cache_mode='part',
    resume=RESUME_CHECKPOINT,          # set in config cell to restart from a checkpoint
    checkpoint_interval=CHECKPOINT_INTERVAL,
    accumulation_steps=None,
    use_checkpoint=False,
    amp_opt_level='O1',
    tag=None,
    eval=False,
    throughput=False,
)

config = get_config(args)
os.makedirs(OUTPUT_DIR, exist_ok=True)

net = SwinUnet(config, img_size=config.DATA.IMG_SIZE, num_classes=2).to(DEVICE)
print(f"Model in_chans={in_chans}  (n_leading_observations={N_LEADING_OBSERVATIONS} × 40 features)")
print(f"Model parameters: {sum(p.numel() for p in net.parameters()) / 1e6:.1f}M")



=> merge config from /content/wildfire-ts-swin/configs/swin_tiny_patch4_window4_128_wildfire.yaml
SwinTransformerSys expand initial----depths:[2, 2, 2, 2];depths_decoder:[1, 2, 2, 2];drop_path_rate:0.1;num_classes:2


  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]


---final upsample expand_first---
Model in_chans=40  (n_leading_observations=1 × 40 features)
Model parameters: 27.2M


Epochs:   0%|                                                     | 0/100 [00:00<?, ?ep/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
AssertionError: can only test a child process
Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 16

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       Exception ignored in: ^<function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
^^    ^self._shutdown_workers()^^
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
^    ^if w.is_alive():^
^ 
   File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
       assert self._parent_pid == os.getpid(), 'can only test a child process'Exception ignored in:  
 <function _Mul

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
  Exception ignored in:     <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0> 
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
     self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
      if w.is_alive(): ^
^ ^ ^ ^ ^ ^ ^^ ^^^Exception ign

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

Exception ignored in: <function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
Traceback (most recent call last):
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
    self._shutdown_workers()
  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers
    if w.is_alive():
       ^^^^^^^^^^^^
  File "/usr/lib/python3.12/multiprocessing/process.py", line 160, in is_alive
    assert self._parent_pid == os.getpid(), 'can only test a child process'
           ^^^^^^^^^^^^^^^^^^^^Exception ignored in: ^^^^<function _MultiProcessingDataLoaderIter.__del__ at 0x7b7ee7f15da0>
^Traceback (most recent call last):
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1707, in __del__
^^    ^self._shutdown_workers()^
^  File "/usr/local/lib/python3.12/dist-packages/torch/utils/data/dataloader.py", line 1690, in _shutdown_workers

AssertionError    : c

  Train:   0%|                                                 | 0/309 [00:00<?, ?batch/s]

  Val  :   0%|                                                 | 0/256 [00:00<?, ?batch/s]

: 

## 5a. Load Weights from Drive (optional)

Run this cell to manually load model weights into `net` from any saved file on Drive.

- Use **`best_model.pth`** or a **`checkpoint_epoch*.pth`** path.
- A full `checkpoint_epoch*.pth` file contains `model_state`, `optimizer_state`, `scheduler_state`, `iter_num`, and `best_val_ap`.  The trainer will restore all of these automatically when `RESUME_CHECKPOINT` is set in the config cell above.
- Use this cell instead when you only want to swap the model weights (e.g. for evaluation) without touching the training state.

In [None]:
LOAD_WEIGHTS_FROM_PATH = None  # e.g. f"{OUTPUT_DIR}/best_model.pth"
                               #      f"{OUTPUT_DIR}/checkpoint_epoch049.pth"

if LOAD_WEIGHTS_FROM_PATH is not None:
    state = torch.load(LOAD_WEIGHTS_FROM_PATH, map_location=DEVICE)
    # Full checkpoint dicts store weights under 'model_state'; bare state dicts work too.
    if isinstance(state, dict) and 'model_state' in state:
        state = state['model_state']
    if next(iter(state)).startswith('module.'):
        state = {k[len('module.'):]: v for k, v in state.items()}
    net.load_state_dict(state)
    print(f"Weights loaded from: {LOAD_WEIGHTS_FROM_PATH}")
else:
    print("Skipped (LOAD_WEIGHTS_FROM_PATH is None).")

## 5b. Run Training

Runs the training loop.  
- **Fresh run:** leave `RESUME_CHECKPOINT = None` in the config cell and just execute this cell.  
- **Resume after interruption:** set `RESUME_CHECKPOINT` to a `checkpoint_epoch*.pth` path in the config cell, re-run the model-init cell above, then run this cell.  
- Periodic checkpoints (`checkpoint_epoch{N:03d}.pth`) are written to `OUTPUT_DIR` every `CHECKPOINT_INTERVAL` epochs and include full optimizer + scheduler state.

In [None]:
trainer_wildfire(args, net, OUTPUT_DIR, device=DEVICE)

## 6. Evaluate on test split

In [1]:
import numpy as np
from torch.utils.data import DataLoader
from tqdm import tqdm
from datasets.wildfire import WildfireDataset, get_year_split
from utils import compute_binary_metrics, compute_ap

train_years, val_years, test_years = get_year_split(DATA_FOLD_ID)

ckpt_path = f"{OUTPUT_DIR}/best_model.pth"
state_dict = torch.load(ckpt_path, map_location=DEVICE)
if list(state_dict.keys())[0].startswith('module.'):
    state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()}
net.load_state_dict(state_dict)
net.eval()

db_test = WildfireDataset(
    data_dir=HDF5_DIR,
    included_fire_years=test_years,
    is_train=False,
    stats_years=train_years,
    n_leading_observations=N_LEADING_OBSERVATIONS,
    n_leading_observations_test_adjustment=N_LEADING_OBSERVATIONS,
    crop_side_length=CROP_SIDE_LENGTH,
    load_from_hdf5=True,
)

test_loader = DataLoader(
    db_test, batch_size=BATCH_SIZE, shuffle=False,
    num_workers=NUM_WORKERS, pin_memory=(str(DEVICE) == 'cuda'),
)

all_probs, all_preds, all_gts = [], [], []
with torch.no_grad():
    for x_batch, y_batch in tqdm(test_loader, desc="Test"):
        x_batch = x_batch.to(DEVICE)
        logits = net(x_batch)
        probs = torch.softmax(logits, dim=1)[:, 1].cpu().numpy()
        preds = (probs >= 0.5).astype(np.int64)
        gts = y_batch.numpy()
        all_probs.append(probs.flatten())
        all_preds.append(preds.flatten())
        all_gts.append(gts.flatten())

all_probs = np.concatenate(all_probs)
all_preds = np.concatenate(all_preds)
all_gts = np.concatenate(all_gts)

metrics = compute_binary_metrics(all_preds, all_gts)
ap = compute_ap(all_probs, all_gts)

print("\n" + "="*40)
print("Test Results")
print("="*40)
print(f"Test AP   : {ap:.4f}")
print(f"Test F1   : {metrics['f1']:.4f}")
print(f"Precision : {metrics['precision']:.4f}")
print(f"Recall    : {metrics['recall']:.4f}")
print("="*40)

ModuleNotFoundError: No module named 'datasets.wildfire'

## 7. Note on full 12-fold evaluation

To reproduce the paper's mean ± std AP, run this notebook **12 times** with `DATA_FOLD_ID` set to 0, 1, 2, … 11 (or run a loop in a script). Each fold uses a different train/val/test year split. Average the 12 test AP values to get the reported metric.