In [1]:
# %% Setup & Imports

import sys
import os
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from tqdm import tqdm
import math
import shutil # For checkpoint saving

# Add the project root to the Python path
project_root = Path(os.getcwd()).parent # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

# --- Import Model Components ---
from src.model import UHINetCNN # Import the CNN model
from src.ingest.dataloader_cnn import CityDataSet # Import the corresponding dataloader

# --- Import Training Utilities & Loss --- ### MODIFIED ###
from src.train.loss import masked_mse_loss, masked_mae_loss # Import loss functions
import src.train.train_utils as train_utils # Import the new utility module

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Optionally import wandb if needed
try:
    import wandb
except ImportError:
    print("wandb not installed, skipping W&B logging.")
    wandb = None

# Import necessary metrics
from sklearn.metrics import mean_squared_error, r2_score

In [2]:
### UNCOMMENT ON FIRST RUN
#!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt

## Configuration

Set up paths and hyperparameters.

In [3]:
# %% Configuration / Hyperparameters (CNN Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import check_path
# -------------------

# --- Paths & Basic Info ---
project_root_str = str(project_root) # Store as string for config
data_dir_base = project_root / "data"
city_name = "NYC"
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_UHINetCNN_CommonRes" # Modified prefix

# --- Data Loading Config ---
# NEW: Define the common resolution for spatial features entering the model
feature_resolution_m = 10 # Start with 10m (matches Clay/UHI grid)

# Define UHI grid resolution separately (used for final target matching)
uhi_grid_resolution_m = 10 # Assume UHI grid is 10m

# Input Data Paths (relative)
relative_data_dir = Path("data")
relative_uhi_csv = relative_data_dir / city_name / "uhi.csv"
relative_bronx_weather_csv = relative_data_dir / city_name / "bronx_weather.csv"
relative_manhattan_weather_csv = relative_data_dir / city_name / "manhattan_weather.csv"
relative_dem_path = relative_data_dir / city_name / "sat_files" / "nyc_dem_1m_pc.tif"
relative_dsm_path = relative_data_dir / city_name / "sat_files" / "nyc_dsm_1m_pc.tif"
relative_cloudless_mosaic_path = relative_data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy"
relative_single_lst_median_path = relative_data_dir / city_name / "sat_files" / f"lst_{city_name}_median_20210601_to_20210901.npy"

# Nodata values
elevation_nodata = -9999.0
lst_nodata = 0.0

# --- Feature Selection Flags --- #
feature_flags = {
    "use_dem": True,
    "use_dsm": True,
    "use_clay": True,
    "use_sentinel_composite": False,
    "use_lst": False,
    "use_ndvi": False,
    "use_ndbi": False,
    "use_ndwi": False,
}

# --- Bands for Sentinel Composite --- #
sentinel_bands_to_load = ["blue", "green", "red", "nir", "swir16", "swir22"]

# --- Model Config (UHINetCNN) ---
# Clay Backbone (if feature_flags["use_clay"])
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = True
relative_clay_checkpoint_path = "notebooks/clay-v1.5.ckpt"
relative_clay_metadata_path = Path("src") / "Clay" / "configs" / "metadata.yaml"

# Weather input channels (even though it's not a sequence)
weather_channels = 6

# U-Net Backbone Config
unet_base_channels = 64
unet_depth = 4

# --- REMOVED Elevation Branch Config ---

# --- Training Hyperparameters ---
num_workers = 4
epochs = 500
lr = 5e-5
weight_decay = 0.01
loss_type = 'mse'
patience = 50
cpu = False

# V100 Suggestion: Target batch size 4 => n_train_batches = 36 / 4 = 9
n_train_batches = 9

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- Resolve Paths --- #
absolute_uhi_csv = check_path(relative_uhi_csv, project_root, "UHI CSV")
absolute_bronx_weather_csv = check_path(relative_bronx_weather_csv, project_root, "Bronx Weather CSV")
absolute_manhattan_weather_csv = check_path(relative_manhattan_weather_csv, project_root, "Manhattan Weather CSV")
absolute_dem_path = check_path(relative_dem_path, project_root, "DEM TIF")
absolute_dsm_path = check_path(relative_dsm_path, project_root, "DSM TIF")
absolute_clay_checkpoint_path = check_path(relative_clay_checkpoint_path, project_root, "Clay Checkpoint")
absolute_clay_metadata_path = check_path(relative_clay_metadata_path, project_root, "Clay Metadata")
absolute_cloudless_mosaic_path = check_path(relative_cloudless_mosaic_path, project_root, "Cloudless Mosaic")
absolute_single_lst_median_path = check_path(relative_single_lst_median_path, project_root, "Single LST Median", should_exist=feature_flags["use_lst"])

# --- Calculate Bounds --- #
uhi_df = pd.read_csv(absolute_uhi_csv)
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds = [
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {absolute_uhi_csv.name}: {bounds}")

# --- Central Config Dictionary --- #
config = {
    # Paths & Info
    "model_type": "UHINetCNN_CommonRes",
    "project_root": project_root_str,
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "feature_resolution_m": feature_resolution_m,
    "uhi_grid_resolution_m": uhi_grid_resolution_m,
    "uhi_csv": str(relative_uhi_csv),
    "bronx_weather_csv": str(absolute_bronx_weather_csv),
    "manhattan_weather_csv": str(absolute_manhattan_weather_csv),
    "bounds": bounds,
    "feature_flags": feature_flags,
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path": str(absolute_dem_path) if absolute_dem_path else None,
    "dsm_path": str(absolute_dsm_path) if absolute_dsm_path else None,
    "elevation_nodata": elevation_nodata,
    "cloudless_mosaic_path": str(absolute_cloudless_mosaic_path) if absolute_cloudless_mosaic_path else None,
    "single_lst_median_path": str(absolute_single_lst_median_path) if absolute_single_lst_median_path else None,
    "lst_nodata": lst_nodata,
    # Model Config
    "weather_channels": weather_channels,
    "unet_base_channels": unet_base_channels,
    "unet_depth": unet_depth,
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(absolute_clay_checkpoint_path) if feature_flags["use_clay"] else None,
    "clay_metadata_path": str(absolute_clay_metadata_path) if feature_flags["use_clay"] else None,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience,
    "device": str(device)
}

print("\nUHINetCNN (Common Res) Configuration dictionary created:")
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x))

Using device: cuda
Loaded bounds from uhi.csv: [-73.99445667, 40.75879167, -73.87945833, 40.85949667]

Configuration dictionary created:
{
  "model_type": "UHINetCNN",
  "project_root": "/home/jupyter/MLC-Project",
  "city_name": "NYC",
  "wandb_project_name": "MLC_UHI_Proj",
  "wander_run_name_prefix": "NYC_UHINetCNN",
  "resolution_m": 10,
  "include_lst": true,
  "uhi_csv": "data/NYC/uhi.csv",
  "bronx_weather_csv": "/home/jupyter/MLC-Project/data/NYC/bronx_weather.csv",
  "manhattan_weather_csv": "/home/jupyter/MLC-Project/data/NYC/manhattan_weather.csv",
  "cloudless_mosaic_path": "/home/jupyter/MLC-Project/data/NYC/sat_files/sentinel_NYC_20210601_to_20210901_cloudless_mosaic.npy",
  "single_lst_median_path": "/home/jupyter/MLC-Project/data/NYC/sat_files/lst_NYC_median_20210601_to_20210901.npy",
  "bounds": [
    -73.99445667,
    40.75879167,
    -73.87945833,
    40.85949667
  ],
  "clay_model_size": "large",
  "clay_bands": [
    "blue",
    "green",
    "red",
    "nir"
  ],
 

## Setup DataLoader

In [4]:
# %% Data Loading and Preprocessing (CNN Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import (
    split_data,
    calculate_uhi_stats,
    create_dataloaders
)
# -------------------

print("Initializing CityDataSet...")
try:
    dataset = CityDataSet(
        bounds=config["bounds"],
        feature_resolution_m=config["feature_resolution_m"],
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"],
        uhi_csv=absolute_uhi_csv, # Use absolute path resolved earlier
        bronx_weather_csv=absolute_bronx_weather_csv,
        manhattan_weather_csv=absolute_manhattan_weather_csv,
        data_dir=project_root_str,
        city_name=config["city_name"],
        feature_flags=config["feature_flags"],
        sentinel_bands_to_load=config["sentinel_bands_to_load"],
        dem_path=config["dem_path"],
        dsm_path=config["dsm_path"],
        elevation_nodata=config["elevation_nodata"],
        cloudless_mosaic_path=config["cloudless_mosaic_path"],
        single_lst_median_path=config["single_lst_median_path"],
        lst_nodata=config["lst_nodata"]
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Train/Val Split --- #
train_ds, val_ds = split_data(dataset, val_percent=0.40, seed=42)

# --- Calculate UHI Mean and Std from Training Data ONLY --- #
uhi_mean, uhi_std = calculate_uhi_stats(train_ds)
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
train_loader, val_loader = create_dataloaders(
    train_ds,
    val_ds,
    n_train_batches=config['n_train_batches'],
    num_workers=config['num_workers'],
    device=device # Pass device from config cell
)


2025-04-30 02:20:01,176 - INFO - Loading cloudless mosaic from /home/jupyter/MLC-Project/data/NYC/sat_files/sentinel_NYC_20210601_to_20210901_cloudless_mosaic.npy
2025-04-30 02:20:01,254 - INFO - Loaded mosaic with 4 bands and shape (1119, 1278)
2025-04-30 02:20:01,290 - INFO - Loading single LST median from: /home/jupyter/MLC-Project/data/NYC/sat_files/lst_NYC_median_20210601_to_20210901.npy


Initializing dataset...


2025-04-30 02:20:01,391 - INFO - Loaded and normalized single LST median with shape (1, 1118, 969)
  dt_naive_or_aware = pd.to_datetime(self.bronx_weather['datetime'], errors='raise')
  dt_naive_or_aware = pd.to_datetime(self.manhattan_weather['datetime'], errors='raise')
2025-04-30 02:20:01,401 - INFO - Loaded Bronx weather data: 169 records
2025-04-30 02:20:01,402 - INFO - Loaded Manhattan weather data: 169 records
2025-04-30 02:20:03,252 - INFO - Computed grid cell coordinates and closest station map
2025-04-30 02:20:03,254 - INFO - Grid cells assigned to Bronx: 370359
2025-04-30 02:20:03,256 - INFO - Grid cells assigned to Manhattan: 712983
Precomputing UHI grids: 100%|██████████| 59/59 [00:00<00:00, 397.53it/s]
2025-04-30 02:20:03,416 - INFO - Dataset initialized for NYC with 59 unique timestamps. LST included: True
2025-04-30 02:20:03,417 - INFO - Target grid size (H, W): (1118, 969)


Random dataset split: 36 training, 23 validation samples.
Calculating UHI statistics from training data...
Training UHI Mean: 1.0004, Std Dev: 0.0162
Creating dataloaders...
Data loading setup complete.


In [None]:
# %% Model Initialization (CNN Model + Common Resampling)

# Instantiate the UHINetCNN
print(f"Initializing {config['model_type']}...")
model = UHINetCNN(
    # --- Input Feature Config --- #
    feature_flags=config["feature_flags"],
    weather_channels=config["weather_channels"],
    sentinel_bands_to_load=config.get("sentinel_bands_to_load"),
    # Clay Specific
    clay_model_size=config.get("clay_model_size"),
    clay_bands=config.get("clay_bands"),
    clay_platform=config.get("clay_platform"),
    clay_gsd=config.get("clay_gsd"),
    freeze_backbone=config.get("freeze_backbone", True),
    clay_checkpoint_path=config.get("clay_checkpoint_path"),
    clay_metadata_path=config.get("clay_metadata_path"),
    # --- CNN Backbone Config --- #
    base_channels=config["unet_base_channels"], # Corrected key
    depth=config["unet_depth"],             # Corrected key
)

model.to(config["device"])

# --- Optimizer ---
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

# --- Loss Function ---
if config["loss_type"] == 'mse':
    loss_fn = masked_mse_loss
elif config["loss_type"] == 'mae':
    loss_fn = masked_mae_loss
else:
    raise ValueError(f"Unsupported loss type: {config['loss_type']}")

print("Model, optimizer, and loss function initialized.")
# print(model) # Optional: Print model summary


Model UHINetCNN initialized on cuda
Checkpoints and logs will be saved to: /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023
Using Training UHI Mean: 1.0004, Std Dev: 0.0162 for normalization.


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33marnava1304[0m ([33marnava1304-columbia-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Wandb initialized for run: NYC_UHINetCNN_20250430_022023
Saved local configuration to config.json
Starting CNN training...
--- Epoch 1/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 1: Train Loss=645.8097 RMSE=0.0169 R2=-0.0923 | Val Loss=3226.9849 RMSE=0.0165 R2=-0.0316
New best validation R^2: -0.0316
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 2/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 2: Train Loss=619.4320 RMSE=0.0166 R2=-0.0477 | Val Loss=3201.6125 RMSE=0.0164 R2=-0.0235
New best validation R^2: -0.0235
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 3/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 3: Train Loss=592.4688 RMSE=0.0162 R2=-0.0021 | Val Loss=3103.2617 RMSE=0.0162 R2=0.0080
New best validation R^2: 0.0080
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 4/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 4: Train Loss=608.3326 RMSE=0.0164 R2=-0.0289 | Val Loss=2994.3579 RMSE=0.0159 R2=0.0428
New best validation R^2: 0.0428
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 5/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 5: Train Loss=555.4505 RMSE=0.0157 R2=0.0605 | Val Loss=2895.1689 RMSE=0.0156 R2=0.0745
New best validation R^2: 0.0745
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 6/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 6: Train Loss=520.1677 RMSE=0.0152 R2=0.1202 | Val Loss=2928.8545 RMSE=0.0157 R2=0.0637
No improvement in validation R^2 for 1 epochs.
--- Epoch 7/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 7: Train Loss=596.9733 RMSE=0.0163 R2=-0.0097 | Val Loss=2853.7822 RMSE=0.0155 R2=0.0877
New best validation R^2: 0.0877
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 8/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 8: Train Loss=532.3227 RMSE=0.0154 R2=0.0996 | Val Loss=2936.8237 RMSE=0.0157 R2=0.0612
No improvement in validation R^2 for 1 epochs.
--- Epoch 9/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 9: Train Loss=566.1273 RMSE=0.0158 R2=0.0424 | Val Loss=2793.0508 RMSE=0.0154 R2=0.1071
New best validation R^2: 0.1071
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 10/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 10: Train Loss=534.8112 RMSE=0.0154 R2=0.0954 | Val Loss=2801.7456 RMSE=0.0154 R2=0.1043
No improvement in validation R^2 for 1 epochs.
--- Epoch 11/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 11: Train Loss=562.8188 RMSE=0.0158 R2=0.0480 | Val Loss=2649.6572 RMSE=0.0150 R2=0.1530
New best validation R^2: 0.1530
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 12/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 12: Train Loss=535.5470 RMSE=0.0154 R2=0.0942 | Val Loss=2979.5686 RMSE=0.0159 R2=0.0475
No improvement in validation R^2 for 1 epochs.
--- Epoch 13/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 13: Train Loss=554.7834 RMSE=0.0157 R2=0.0616 | Val Loss=2787.3311 RMSE=0.0153 R2=0.1090
No improvement in validation R^2 for 2 epochs.
--- Epoch 14/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 14: Train Loss=505.4470 RMSE=0.0150 R2=0.1451 | Val Loss=2786.1643 RMSE=0.0153 R2=0.1093
No improvement in validation R^2 for 3 epochs.
--- Epoch 15/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 15: Train Loss=554.4417 RMSE=0.0157 R2=0.0622 | Val Loss=2901.9517 RMSE=0.0157 R2=0.0723
No improvement in validation R^2 for 4 epochs.
--- Epoch 16/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 16: Train Loss=508.4681 RMSE=0.0150 R2=0.1400 | Val Loss=2693.0142 RMSE=0.0151 R2=0.1391
No improvement in validation R^2 for 5 epochs.
--- Epoch 17/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 17: Train Loss=518.6981 RMSE=0.0152 R2=0.1227 | Val Loss=2603.2144 RMSE=0.0148 R2=0.1678
New best validation R^2: 0.1678
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 18/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 18: Train Loss=544.8151 RMSE=0.0155 R2=0.0785 | Val Loss=2575.1655 RMSE=0.0147 R2=0.1768
New best validation R^2: 0.1768
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 19/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 19: Train Loss=507.7743 RMSE=0.0150 R2=0.1411 | Val Loss=2620.6755 RMSE=0.0149 R2=0.1622
No improvement in validation R^2 for 1 epochs.
--- Epoch 20/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 20: Train Loss=503.7513 RMSE=0.0149 R2=0.1479 | Val Loss=2610.5540 RMSE=0.0148 R2=0.1655
No improvement in validation R^2 for 2 epochs.
--- Epoch 21/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 21: Train Loss=527.9740 RMSE=0.0153 R2=0.1070 | Val Loss=2656.1699 RMSE=0.0150 R2=0.1509
No improvement in validation R^2 for 3 epochs.
--- Epoch 22/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 22: Train Loss=468.0458 RMSE=0.0144 R2=0.2083 | Val Loss=2524.9971 RMSE=0.0146 R2=0.1928
New best validation R^2: 0.1928
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 23/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 23: Train Loss=470.5060 RMSE=0.0144 R2=0.2042 | Val Loss=2646.7505 RMSE=0.0150 R2=0.1539
No improvement in validation R^2 for 1 epochs.
--- Epoch 24/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 24: Train Loss=470.8963 RMSE=0.0144 R2=0.2035 | Val Loss=2551.9546 RMSE=0.0147 R2=0.1842
No improvement in validation R^2 for 2 epochs.
--- Epoch 25/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 25: Train Loss=428.7389 RMSE=0.0138 R2=0.2748 | Val Loss=2549.7131 RMSE=0.0147 R2=0.1849
No improvement in validation R^2 for 3 epochs.
--- Epoch 26/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 26: Train Loss=479.1599 RMSE=0.0146 R2=0.1895 | Val Loss=2494.6724 RMSE=0.0145 R2=0.2025
New best validation R^2: 0.2025
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 27/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 27: Train Loss=462.3014 RMSE=0.0143 R2=0.2181 | Val Loss=2327.6934 RMSE=0.0140 R2=0.2559
New best validation R^2: 0.2559
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 28/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 28: Train Loss=441.2404 RMSE=0.0140 R2=0.2537 | Val Loss=2539.5195 RMSE=0.0146 R2=0.1882
No improvement in validation R^2 for 1 epochs.
--- Epoch 29/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 29: Train Loss=445.1794 RMSE=0.0140 R2=0.2470 | Val Loss=2350.2019 RMSE=0.0141 R2=0.2487
No improvement in validation R^2 for 2 epochs.
--- Epoch 30/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 30: Train Loss=471.6626 RMSE=0.0145 R2=0.2022 | Val Loss=2426.2075 RMSE=0.0143 R2=0.2244
No improvement in validation R^2 for 3 epochs.
--- Epoch 31/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 31: Train Loss=444.9828 RMSE=0.0140 R2=0.2474 | Val Loss=2549.3318 RMSE=0.0147 R2=0.1850
No improvement in validation R^2 for 4 epochs.
--- Epoch 32/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 32: Train Loss=454.2806 RMSE=0.0142 R2=0.2316 | Val Loss=2284.2578 RMSE=0.0139 R2=0.2698
New best validation R^2: 0.2698
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 33/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 33: Train Loss=442.3483 RMSE=0.0140 R2=0.2518 | Val Loss=2414.1567 RMSE=0.0143 R2=0.2282
No improvement in validation R^2 for 1 epochs.
--- Epoch 34/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 34: Train Loss=478.2115 RMSE=0.0146 R2=0.1911 | Val Loss=2426.7026 RMSE=0.0143 R2=0.2242
No improvement in validation R^2 for 2 epochs.
--- Epoch 35/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 35: Train Loss=457.2233 RMSE=0.0142 R2=0.2266 | Val Loss=2224.4580 RMSE=0.0137 R2=0.2889
New best validation R^2: 0.2889
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 36/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 36: Train Loss=430.4526 RMSE=0.0138 R2=0.2719 | Val Loss=2388.2051 RMSE=0.0142 R2=0.2365
No improvement in validation R^2 for 1 epochs.
--- Epoch 37/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 37: Train Loss=424.3486 RMSE=0.0137 R2=0.2823 | Val Loss=2329.2456 RMSE=0.0140 R2=0.2554
No improvement in validation R^2 for 2 epochs.
--- Epoch 38/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 38: Train Loss=444.7784 RMSE=0.0140 R2=0.2477 | Val Loss=2322.1597 RMSE=0.0140 R2=0.2577
No improvement in validation R^2 for 3 epochs.
--- Epoch 39/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 39: Train Loss=442.9476 RMSE=0.0140 R2=0.2508 | Val Loss=2306.8652 RMSE=0.0140 R2=0.2625
No improvement in validation R^2 for 4 epochs.
--- Epoch 40/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 40: Train Loss=463.1776 RMSE=0.0143 R2=0.2166 | Val Loss=2428.4136 RMSE=0.0143 R2=0.2237
No improvement in validation R^2 for 5 epochs.
--- Epoch 41/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 41: Train Loss=449.8522 RMSE=0.0141 R2=0.2391 | Val Loss=2418.5312 RMSE=0.0143 R2=0.2269
No improvement in validation R^2 for 6 epochs.
--- Epoch 42/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 42: Train Loss=411.7442 RMSE=0.0135 R2=0.3036 | Val Loss=2499.1460 RMSE=0.0145 R2=0.2011
No improvement in validation R^2 for 7 epochs.
--- Epoch 43/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 43: Train Loss=399.8607 RMSE=0.0133 R2=0.3237 | Val Loss=2479.8770 RMSE=0.0145 R2=0.2072
No improvement in validation R^2 for 8 epochs.
--- Epoch 44/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 44: Train Loss=410.2854 RMSE=0.0135 R2=0.3060 | Val Loss=2315.1528 RMSE=0.0140 R2=0.2599
No improvement in validation R^2 for 9 epochs.
--- Epoch 45/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 45: Train Loss=414.6050 RMSE=0.0136 R2=0.2987 | Val Loss=2449.3984 RMSE=0.0144 R2=0.2170
No improvement in validation R^2 for 10 epochs.
--- Epoch 46/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 46: Train Loss=407.1415 RMSE=0.0134 R2=0.3114 | Val Loss=2324.5986 RMSE=0.0140 R2=0.2569
No improvement in validation R^2 for 11 epochs.
--- Epoch 47/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 47: Train Loss=418.0928 RMSE=0.0136 R2=0.2928 | Val Loss=2296.6802 RMSE=0.0139 R2=0.2658
No improvement in validation R^2 for 12 epochs.
--- Epoch 48/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 48: Train Loss=408.4152 RMSE=0.0135 R2=0.3092 | Val Loss=2232.5806 RMSE=0.0137 R2=0.2863
No improvement in validation R^2 for 13 epochs.
--- Epoch 49/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 49: Train Loss=435.4211 RMSE=0.0139 R2=0.2635 | Val Loss=2421.5137 RMSE=0.0143 R2=0.2259
No improvement in validation R^2 for 14 epochs.
--- Epoch 50/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 50: Train Loss=410.9728 RMSE=0.0135 R2=0.3049 | Val Loss=2342.7922 RMSE=0.0141 R2=0.2511
No improvement in validation R^2 for 15 epochs.
--- Epoch 51/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 51: Train Loss=396.1234 RMSE=0.0132 R2=0.3300 | Val Loss=2319.4458 RMSE=0.0140 R2=0.2585
No improvement in validation R^2 for 16 epochs.
--- Epoch 52/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 52: Train Loss=408.7001 RMSE=0.0135 R2=0.3087 | Val Loss=2257.1326 RMSE=0.0138 R2=0.2784
No improvement in validation R^2 for 17 epochs.
--- Epoch 53/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 53: Train Loss=404.9494 RMSE=0.0134 R2=0.3151 | Val Loss=2173.5132 RMSE=0.0135 R2=0.3052
New best validation R^2: 0.3052
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 54/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 54: Train Loss=401.8049 RMSE=0.0133 R2=0.3204 | Val Loss=2168.5771 RMSE=0.0135 R2=0.3068
New best validation R^2: 0.3068
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 55/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 55: Train Loss=405.4725 RMSE=0.0134 R2=0.3142 | Val Loss=2231.9050 RMSE=0.0137 R2=0.2865
No improvement in validation R^2 for 1 epochs.
--- Epoch 56/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 56: Train Loss=403.3150 RMSE=0.0134 R2=0.3178 | Val Loss=2223.2856 RMSE=0.0137 R2=0.2893
No improvement in validation R^2 for 2 epochs.
--- Epoch 57/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 57: Train Loss=400.8580 RMSE=0.0133 R2=0.3220 | Val Loss=2148.3550 RMSE=0.0135 R2=0.3132
New best validation R^2: 0.3132
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 58/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 58: Train Loss=409.2723 RMSE=0.0135 R2=0.3078 | Val Loss=2235.0488 RMSE=0.0137 R2=0.2855
No improvement in validation R^2 for 1 epochs.
--- Epoch 59/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 59: Train Loss=384.8003 RMSE=0.0131 R2=0.3491 | Val Loss=2198.5828 RMSE=0.0136 R2=0.2972
No improvement in validation R^2 for 2 epochs.
--- Epoch 60/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 60: Train Loss=389.9614 RMSE=0.0131 R2=0.3404 | Val Loss=2279.2603 RMSE=0.0139 R2=0.2714
No improvement in validation R^2 for 3 epochs.
--- Epoch 61/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 61: Train Loss=372.8433 RMSE=0.0129 R2=0.3694 | Val Loss=2147.8916 RMSE=0.0135 R2=0.3134
New best validation R^2: 0.3134
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 62/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 62: Train Loss=395.7443 RMSE=0.0132 R2=0.3306 | Val Loss=2265.1760 RMSE=0.0138 R2=0.2759
No improvement in validation R^2 for 1 epochs.
--- Epoch 63/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 63: Train Loss=400.1942 RMSE=0.0133 R2=0.3231 | Val Loss=2058.2993 RMSE=0.0132 R2=0.3420
New best validation R^2: 0.3420
Saved new best model to /home/jupyter/MLC-Project/training_runs/NYC_UHINetCNN_20250430_022023/model_best.pth.tar
--- Epoch 64/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 64: Train Loss=398.6981 RMSE=0.0133 R2=0.3256 | Val Loss=2196.8345 RMSE=0.0136 R2=0.2977
No improvement in validation R^2 for 1 epochs.
--- Epoch 65/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 65: Train Loss=398.6484 RMSE=0.0133 R2=0.3257 | Val Loss=2213.2070 RMSE=0.0137 R2=0.2925
No improvement in validation R^2 for 2 epochs.
--- Epoch 66/500 ---


Training:   0%|          | 0/9 [00:00<?, ?it/s]

Validation:   0%|          | 0/1 [00:00<?, ?it/s]

Epoch 66: Train Loss=405.5837 RMSE=0.0134 R2=0.3140 | Val Loss=2122.1089 RMSE=0.0134 R2=0.3216
No improvement in validation R^2 for 3 epochs.


In [None]:
# %% Training Loop (Generic - Use for both CNN and Branched)

# --- Imports ---
import time
from datetime import datetime
import json
from pathlib import Path
# from torch.cuda.amp import GradScaler # Removed GradScaler import
import src.train.train_utils as train_utils # Import the full module
import numpy as np # Added for isnan check
import pandas as pd # For saving log

# --- Setup ---
print(f"Model {config['model_type']} initialized on {device}")

# --- Optimizer and Loss (should be initialized in model setup cell) ---
# Ensure optimizer and loss_fn are accessible from the previous cell's scope
if 'optimizer' not in locals() or 'loss_fn' not in locals():
    raise NameError("Optimizer or loss_fn not defined. Run the model initialization cell.")

# --- AMP GradScaler --- #
# scaler = GradScaler() # Removed scaler initialization

# --- Tracking Variables --- #
best_val_r2 = -float('inf') # Using R2 for checkpointing
epochs_no_improve = 0
last_saved_epoch = -1 # Track last saved epoch

# --- Output Directory & Run Name (should be set in config cell) --- #
# Ensure output_dir is a Path object if loaded from config str
if 'output_dir_base' not in locals() or 'run_name' not in locals(): # Check if base dir and run name were set
     run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
     run_name = f"{config.get('wander_run_name_prefix', 'train')}_{run_timestamp}"
     output_dir_base = Path(config.get('project_root', '.')) / "training_runs"
     output_dir = output_dir_base / run_name
     output_dir.mkdir(parents=True, exist_ok=True)
     config["output_dir"] = str(output_dir) # Update config
else:
    output_dir = Path(config["output_dir"])
print(f"Checkpoints and logs will be saved to: {output_dir}")

# --- Retrieve UHI Stats from Config --- #
uhi_mean = config.get('uhi_mean')
uhi_std = config.get('uhi_std')
if uhi_mean is None or uhi_std is None:
    raise ValueError("uhi_mean/uhi_std not in config. Run data loading cell.")
print(f"Using Training UHI Mean: {uhi_mean:.4f}, Std Dev: {uhi_std:.4f}")

# --- WANDB Init --- #
if wandb:
    try:
        if wandb.run is not None: wandb.finish() # Finish previous run if any
        wandb.init(
            project=config["wandb_project_name"],
            name=run_name,
            config=config
        )
        print(f"Wandb initialized for run: {run_name}")
    except Exception as e:
        print(f"Wandb initialization failed: {e}")
        wandb = None
else:
    print("Wandb not available, skipping logging.")

# --- Training Loop --- # 
print(f"Starting {config['model_type']} training...")
training_start_time = time.time()
training_log = [] # Local log
epoch = -1 # Initialize epoch counter before the loop

try:
    # Use epoch range from config
    for epoch in range(config["epochs"]):
        epoch_start_time = time.time()
        print(f"--- Epoch {epoch+1}/{config['epochs']} ---")

        # --- Train --- #
        if train_loader:
            # Use generic train function from train_utils (without scaler)
            train_loss, train_rmse, train_r2 = train_utils.train_epoch_generic(
                model, train_loader, optimizer, loss_fn, device, uhi_mean, uhi_std # Removed scaler
            )
            print(f"Train Loss: {train_loss:.4f}, Train RMSE: {train_rmse:.4f}, Train R2: {train_r2:.4f}")
            if np.isnan(train_loss):
                print("Warning: Training loss is NaN. Stopping training.")
                break
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        else:
            print("Skipping training: train_loader is None.")
            train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        
        # Log train metrics AFTER checking for NaN
        if wandb:
            wandb.log(log_metrics)
        training_log.append(log_metrics) # Append to local log regardless of W&B


        # --- Validate --- #
        if val_loader:
            # Use generic validate function from train_utils
            val_loss, val_rmse, val_r2 = train_utils.validate_epoch_generic(
                model, val_loader, loss_fn, device, uhi_mean, uhi_std
            )
            print(f"Val Loss:   {val_loss:.4f}, Val RMSE:   {val_rmse:.4f}, Val R2:   {val_r2:.4f}")
            if np.isnan(val_r2):
                print("Warning: Validation R^2 is NaN. Cannot determine improvement. Stopping training.")
                break
            val_metrics = {"val_loss": val_loss, "val_rmse": val_rmse, "val_r2": val_r2}
            log_metrics.update(val_metrics) # Add val metrics for local log
            if wandb:
                wandb.log({"epoch": epoch + 1, **val_metrics}) # Log validation metrics too

            # --- Checkpointing & Early Stopping (based on Validation R2) --- #
            is_best = val_r2 > best_val_r2
            if is_best:
                print(f"Validation R2 improved from {best_val_r2:.4f} to {val_r2:.4f}")
                best_val_r2 = val_r2
                epochs_no_improve = 0
                last_saved_epoch = epoch + 1
                # Use save_checkpoint from train_utils
                train_utils.save_checkpoint({
                    'epoch': epoch + 1,
                    'state_dict': model.state_dict(),
                    'optimizer': optimizer.state_dict(),
                    'best_val_r2': best_val_r2,
                    'config': config # Save full config dict
                }, is_best=True, output_dir=output_dir)
            else:
                epochs_no_improve += 1
                print(f"Validation R2 did not improve ({val_r2:.4f}). Best: {best_val_r2:.4f}. No improvement for {epochs_no_improve} epochs.")

            if epochs_no_improve >= config['patience']:
                print(f"Early stopping triggered after {epochs_no_improve} epochs.")
                break
        else:
            print("Skipping validation/checkpointing: val_loader is None.")
            # Always save last checkpoint if no validation
            train_utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_val_r2': best_val_r2, # Save current best R2 seen (which is -inf if no val)
                'config': config
            }, is_best=False, output_dir=output_dir, filename='checkpoint_last.pth.tar')
            last_saved_epoch = epoch + 1 # Update tracker

        # --- Epoch Timing --- #
        epoch_duration = time.time() - epoch_start_time
        print(f"Epoch {epoch+1} duration: {epoch_duration:.2f} seconds")
        if wandb:
            wandb.log({"epoch": epoch + 1, "epoch_duration_sec": epoch_duration})

finally:
    # --- End Training Actions (Executed even if loop breaks early) --- #
    training_duration = time.time() - training_start_time
    print(f"\nTotal training time: {training_duration / 60:.2f} minutes")

    # --- Save Final Checkpoint --- #
    # Use the state from the last *completed* epoch (before potential break)
    final_epoch_num = epoch + 1 # This will be correct whether the loop finished or broke
    print(f"Saving final model state from epoch {final_epoch_num}...")
    try:
        train_utils.save_checkpoint({
            'epoch': final_epoch_num,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_val_r2': best_val_r2,
            'config': config
        }, is_best=False, output_dir=output_dir, filename='checkpoint_final.pth.tar')
        print(f"Final checkpoint saved to {output_dir / 'checkpoint_final.pth.tar'}")
    except Exception as e:
        print(f"Error saving final checkpoint: {e}")

    # --- Save Local Training Log --- #
    if training_log:
        try:
            log_df = pd.DataFrame(training_log)
            log_df.to_csv(output_dir / 'training_log.csv', index=False)
            print(f"Saved local training log to {output_dir / 'training_log.csv'}")
        except Exception as e:
            print(f"Warning: Failed to save local training log: {e}")
    else:
        print("No training log data to save.")


    if wandb and wandb.run: # Check if wandb run exists before logging/finishing
        wandb.log({"total_training_time_min": training_duration / 60})
        wandb.finish()
        print("W&B run finished.")

    print("Training loop finished.")
    if val_loader:
        print(f"Best validation R^2 recorded: {best_val_r2:.4f} (achieved at epoch {last_saved_epoch if last_saved_epoch > 0 else 'N/A'})")
    if last_saved_epoch > 0:
         print(f"Best model checkpoint saved in: {output_dir / 'model_best.pth.tar'}")
