# Branched UHI Model Training

In [None]:
# Your code here

In [7]:
# Imports
import sys
import os

from pathlib import Path
import yaml
from box import Box

import torch
import torch.optim as optim
import torch.nn.functional as F # Added for interpolation
from torch.utils.data import DataLoader, random_split

import logging
import argparse
from pathlib import Path
import sys

import numpy as np
import pandas as pd # Needed for loading bounds from csv
from tqdm.notebook import tqdm # Use notebook tqdm
import math

import json
import os
from datetime import datetime
import shutil

# --- WANDB --- #
import wandb
# ------------ #

# --- Metrics --- 
from sklearn.metrics import mean_squared_error, r2_score
# ------------ #

# Add src directory to path to import modules
project_root = Path(os.getcwd()).parent  # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

# --- UPDATED IMPORTS for Branched Model --- #
from src.branched_uhi_model import BranchedUHIModel 
from src.ingest.dataloader_branched import CityDataSetBranched # NEW DATALOADER
# ------------------------------------------ #

from src.train.loss import masked_mae_loss, masked_mse_loss

In [8]:
# %% Setup & Imports

import sys
import os
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from tqdm import tqdm
import math
import shutil # For checkpoint saving

# Add the project root to the Python path
project_root = Path(os.getcwd()).parent # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

# --- Import Model Components ---
from src.branched_uhi_model import BranchedUHIModel # Import the branched model
from src.ingest.dataloader_branched import CityDataSetBranched # Import the corresponding dataloader

# --- Import Training Utilities & Loss ---
from src.train.loss import masked_mse_loss, masked_mae_loss # Import loss functions
import src.train.train_utils as train_utils # Import the new utility module

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Optionally import wandb if needed
try:
    import wandb
except ImportError:
    print("wandb not installed, skipping W&B logging.")
    wandb = None

# Import necessary metrics
from sklearn.metrics import mean_squared_error, r2_score

In [13]:
# %% Configuration / Hyperparameters for BranchedUHIModel (ConvLSTM + HighRes Elev)

# --- Import utils ---
# check_path moved to train_utils
from src.train.train_utils import check_path
# -------------------

# --- Paths & Basic Info ---
project_root_str = str(project_root) # Store as string for config
data_dir_base = project_root / "data"
city_name = "NYC"
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_BranchedUHI_HiResElev" # Modified prefix

# --- Data Loading Config ---
resolution_m = 10 # Target resolution for *low-res* features (UHI, Weather, LST, Clay)
weather_seq_length = 60

# Input Data Paths (relative to project root for portability in config)
relative_data_dir = Path("data")
relative_uhi_csv = relative_data_dir / city_name / "uhi.csv"
relative_bronx_weather_csv = relative_data_dir / city_name / "bronx_weather.csv"
relative_manhattan_weather_csv = relative_data_dir / city_name / "manhattan_weather.csv"

# --- HIGH-RESOLUTION DEM/DSM Paths (Relative) --- #
# Use updated filenames from download_data
relative_dem_path_high_res = relative_data_dir / city_name / "sat_files" / "nyc_dem_1m_pc.tif"
relative_dsm_path_high_res = relative_data_dir / city_name / "sat_files" / "nyc_dsm_1m_pc.tif"
high_res_nodata = -9999.0 # Match the nodata used in DSM/DEM download
# Define low-res paths as None since we aren't using them
relative_dem_path_low_res = None
relative_dsm_path_low_res = None
low_res_elevation_nodata = None

# --- Cloudless Mosaic / LST Paths (Only needed if flags are True) ---
relative_cloudless_mosaic_path = relative_data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy"
relative_single_lst_median_path = relative_data_dir / city_name / "sat_files" / f"lst_{city_name}_median_20210601_to_20210901.npy"


# --- Feature Selection Flags (UPDATED KEYS) --- #
feature_flags = {
    "use_dem_high_res": True,
    "use_dsm_high_res": True,
    "use_dem_low_res": False,
    "use_dsm_low_res": False,
    "use_clay": True,
    "use_sentinel_composite": False,
    "use_lst": False,
    "use_ndvi": False,
    "use_ndbi": False,
    "use_ndwi": False,
}

# --- Bands for Sentinel Composite (if use_sentinel_composite is True) --- #
sentinel_bands_to_load = ["blue", "green", "red", "nir", "swir16", "swir22"]

# --- Model Config (BranchedUHIModel with ConvLSTM & HighRes Elev) ---

# Clay Backbone (if feature_flags["use_clay"] is True)
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = True
relative_clay_checkpoint_path = "notebooks/clay-v1.5.ckpt"
relative_clay_metadata_path = Path("src") / "Clay" / "configs" / "metadata.yaml"

# Temporal Weather Processor (ConvLSTM)
weather_input_channels = 6
convlstm_hidden_dims = [64, 32]
convlstm_kernel_sizes = [(3,3), (3,3)]
convlstm_num_layers = len(convlstm_hidden_dims)

# High-Resolution Elevation Branches (NEW)
include_dem_branch = True # Enable DEM branch in the MODEL
include_dsm_branch = True # Enable DSM branch in the MODEL
elevation_branch_start_channels = 16
elevation_branch_out_channels = 32 # Output channels PER branch
elevation_branch_downsample_layers = 4 # Adjust based on resolution difference
elevation_branch_kernel_size = 3

# U-Net Head
unet_base_channels = 64
unet_depth = 4

# Projection Layer Channels
proj_static_ch = 32 # For projecting NON-ELEVATION static feats (Clay + Indices + LST + Sentinel)
proj_temporal_ch = 32 # For projecting ConvLSTM output

# --- Training Hyperparameters ---
num_workers = 4
epochs = 500
lr = 5e-5
weight_decay = 0.01
loss_type = 'mse'
patience = 50
cpu = False
n_train_batches = 8

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- REMOVED check_path definition (now in train_utils) ---

# --- Resolve Paths using check_path from train_utils ---
# --- Sanity Checks and Absolute Paths ---
absolute_uhi_csv = check_path(relative_uhi_csv, project_root, "UHI CSV")
absolute_bronx_weather_csv = check_path(relative_bronx_weather_csv, project_root, "Bronx Weather CSV")
absolute_manhattan_weather_csv = check_path(relative_manhattan_weather_csv, project_root, "Manhattan Weather CSV")
# Get paths based on feature flags (using correct flags)
absolute_dem_path_high_res = check_path(relative_dem_path_high_res, project_root, "High-Res DEM TIF") if feature_flags["use_dem_high_res"] else None
absolute_dsm_path_high_res = check_path(relative_dsm_path_high_res, project_root, "High-Res DSM TIF") if feature_flags["use_dsm_high_res"] else None
absolute_dem_path_low_res = check_path(relative_dem_path_low_res, project_root, "Low-Res DEM", should_exist=False) # Allow None
absolute_dsm_path_low_res = check_path(relative_dsm_path_low_res, project_root, "Low-Res DSM", should_exist=False) # Allow None
absolute_clay_checkpoint_path = check_path(relative_clay_checkpoint_path, project_root, "Clay Checkpoint") if feature_flags["use_clay"] else None
absolute_clay_metadata_path = check_path(relative_clay_metadata_path, project_root, "Clay Metadata", should_exist=feature_flags["use_clay"]) if feature_flags["use_clay"] else None
needs_mosaic = feature_flags["use_sentinel_composite"] or feature_flags["use_clay"] or feature_flags["use_ndvi"] or feature_flags["use_ndbi"] or feature_flags["use_ndwi"]
absolute_cloudless_mosaic_path = check_path(relative_cloudless_mosaic_path, project_root, "Cloudless Mosaic") if needs_mosaic else None
absolute_single_lst_median_path = check_path(relative_single_lst_median_path, project_root, "Single LST Median") if feature_flags["use_lst"] else None

# --- Calculate Bounds ---
uhi_df = pd.read_csv(absolute_uhi_csv)
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds = [
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {absolute_uhi_csv.name}: {bounds}")

# --- Central Config Dictionary (Updated for ConvLSTM + HighRes Elev) --- #
config = {
    # Paths & Info
    "model_type": "BranchedUHIModel_HiResElev",
    "project_root": project_root_str,
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "resolution_m": resolution_m,
    "weather_seq_length": weather_seq_length,
    "uhi_csv": str(relative_uhi_csv),
    "bronx_weather_csv": str(absolute_bronx_weather_csv),
    "manhattan_weather_csv": str(absolute_manhattan_weather_csv),
    "bounds": bounds,
    "feature_flags": feature_flags, # Pass updated flags
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path_high_res": str(absolute_dem_path_high_res) if absolute_dem_path_high_res else None,
    "dsm_path_high_res": str(absolute_dsm_path_high_res) if absolute_dsm_path_high_res else None,
    "high_res_nodata": high_res_nodata, # Pass high-res nodata
    "dem_path_low_res": str(absolute_dem_path_low_res) if absolute_dem_path_low_res else None,
    "dsm_path_low_res": str(absolute_dsm_path_low_res) if absolute_dsm_path_low_res else None,
    "low_res_elevation_nodata": low_res_elevation_nodata,
    "cloudless_mosaic_path": str(absolute_cloudless_mosaic_path) if absolute_cloudless_mosaic_path else None,
    "single_lst_median_path": str(absolute_single_lst_median_path) if absolute_single_lst_median_path else None,
    # Model Config
    "weather_input_channels": weather_input_channels,
    "convlstm_hidden_dims": convlstm_hidden_dims,
    "convlstm_kernel_sizes": convlstm_kernel_sizes,
    "convlstm_num_layers": convlstm_num_layers,
    "proj_static_ch": proj_static_ch,
    "proj_temporal_ch": proj_temporal_ch,
    "unet_base_channels": unet_base_channels,
    "unet_depth": unet_depth,
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(absolute_clay_checkpoint_path) if absolute_clay_checkpoint_path else None,
    "clay_metadata_path": str(absolute_clay_metadata_path) if absolute_clay_metadata_path else None,
    # High-Res Elevation Branch specific (for MODEL init)
    "include_dem_branch": include_dem_branch,
    "include_dsm_branch": include_dsm_branch,
    "elevation_branch_start_channels": elevation_branch_start_channels,
    "elevation_branch_out_channels": elevation_branch_out_channels,
    "elevation_branch_downsample_layers": elevation_branch_downsample_layers,
    "elevation_branch_kernel_size": elevation_branch_kernel_size,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience,
    "device": str(device)
}

print("\nBranched Model (HiRes Elev) Configuration dictionary created:")
# Use a default function to handle non-serializable types like Path
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x))


Using device: cuda
Loaded bounds from uhi.csv: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]

Branched Model (HiRes Elev) Configuration dictionary created:
{
  "model_type": "BranchedUHIModel_HiResElev",
  "project_root": "/home/jupyter/MLC-Project",
  "city_name": "NYC",
  "wandb_project_name": "MLC_UHI_Proj",
  "wander_run_name_prefix": "NYC_BranchedUHI_HiResElev",
  "resolution_m": 10,
  "weather_seq_length": 60,
  "uhi_csv": "data/NYC/uhi.csv",
  "bronx_weather_csv": "/home/jupyter/MLC-Project/data/NYC/bronx_weather.csv",
  "manhattan_weather_csv": "/home/jupyter/MLC-Project/data/NYC/manhattan_weather.csv",
  "bounds": [
    -73.99445667,
    40.75879167,
    -73.87945833,
    40.85949667
  ],
  "feature_flags": {
    "use_dsm": true,
    "use_dem": true,
    "use_clay": true,
    "use_sentinel_composite": false,
    "use_lst": false,
    "use_ndvi": false,
    "use_ndbi": false,
    "use_ndwi": false
  },
  "sentinel_bands_t

In [None]:
# %% Data Loading and Preprocessing (Branched Model + HighRes Elev)

# --- Import utils ---
from src.train.train_utils import (
    split_data,
    calculate_uhi_stats,
    create_dataloaders
)
# -------------------

print("Initializing BranchedCityDataSet...")
try:
    dataset = BranchedCityDataSet(
        bounds=config["bounds"],
        resolution_m=config["resolution_m"],
        uhi_csv=absolute_uhi_csv, # Use absolute path resolved earlier
        bronx_weather_csv=absolute_bronx_weather_csv,
        manhattan_weather_csv=absolute_manhattan_weather_csv,
        data_dir=project_root_str,
        city_name=config["city_name"],
        feature_flags=config["feature_flags"],
        sentinel_bands_to_load=config["sentinel_bands_to_load"],
        dem_path_high_res=config["dem_path_high_res"],
        dsm_path_high_res=config["dsm_path_high_res"],
        high_res_nodata=config["high_res_nodata"],
        dem_path_low_res=config["dem_path_low_res"],
        dsm_path_low_res=config["dsm_path_low_res"],
        low_res_elevation_nodata=config["low_res_elevation_nodata"],
        cloudless_mosaic_path=config["cloudless_mosaic_path"],
        single_lst_median_path=config["single_lst_median_path"],
        weather_seq_length=config["weather_seq_length"]
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Train/Val Split --- #
train_ds, val_ds = split_data(dataset, val_percent=0.40, seed=42)

# --- Calculate UHI Mean and Std from Training Data ONLY --- #
uhi_mean, uhi_std = calculate_uhi_stats(train_ds)
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
train_loader, val_loader = create_dataloaders(
    train_ds,
    val_ds,
    n_train_batches=config['n_train_batches'],
    num_workers=config['num_workers'],
    device=device # Pass device from config cell
)


2025-05-05 07:45:40,103 - INFO - Loading cloudless mosaic from /home/jupyter/MLC-Project/data/NYC/sat_files/sentinel_NYC_20210601_to_20210901_cloudless_mosaic.npy
2025-05-05 07:45:40,189 - INFO - Resizing full mosaic from (1119, 1278) to (1118, 969)


Initializing BranchedCityDataSet...


2025-05-05 07:45:40,517 - INFO - Loaded Bronx weather data: 169 records
2025-05-05 07:45:40,517 - INFO - Loaded Manhattan weather data: 169 records
2025-05-05 07:45:40,534 - INFO - Computed grid cell center coordinates.
Precomputing UHI grids: 100%|██████████| 59/59 [00:00<00:00, 367.09it/s]
2025-05-05 07:45:40,706 - INFO - Precomputing weather grids for all unique timestamps...
Precomputing weather grids: 100%|██████████| 59/59 [00:06<00:00,  9.11it/s]
2025-05-05 07:45:47,183 - INFO - Finished precomputing weather grids.
2025-05-05 07:45:47,183 - INFO - Dataset initialized for NYC with 59 unique timestamps.
2025-05-05 07:45:47,184 - INFO - Target grid size (H, W): (1118, 969), CRS: EPSG:4326
2025-05-05 07:45:47,184 - INFO - Weather sequence length T = 60
2025-05-05 07:45:47,185 - INFO - Enabled features (flags): {"use_dsm": true, "use_dem": true, "use_clay": true, "use_sentinel_composite": false, "use_lst": false, "use_ndvi": false, "use_ndbi": false, "use_ndwi": false}
2025-05-05 07:

Random dataset split: 36 training, 23 validation samples.
Calculating UHI statistics from training data...


Calculating stats:   0%|          | 0/1 [00:00<?, ?it/s]

In [None]:
# %% Model Initialization (BranchedUHIModel with ConvLSTM + HighRes Elev)

print("Initializing BranchedUHIModel with HighRes Elevation Branches...")

# --- Determine Static Input Channels for Model (NON-ELEVATION) --- #
# This depends on which *non-Clay*, *non-Elevation* static features are enabled in the dataloader
static_channels_model = 0
# Check flags used by dataloader to add features to 'static_features' output
if config['feature_flags']['use_sentinel_composite']: static_channels_model += len(config['sentinel_bands_to_load'])
if config['feature_flags']['use_lst']: static_channels_model += 1
if config['feature_flags']['use_ndvi']: static_channels_model += 1
if config['feature_flags']['use_ndbi']: static_channels_model += 1
if config['feature_flags']['use_ndwi']: static_channels_model += 1
print(f"Calculated non-Clay, non-Elevation static input channels for model: {static_channels_model}")

# --- Instantiate the BranchedUHIModel with ConvLSTM & HighRes Elev --- #
model = BranchedUHIModel(
    # Non-default args first
    weather_input_channels=config['weather_input_channels'],
    convlstm_hidden_dims=config['convlstm_hidden_dims'],
    convlstm_kernel_sizes=config['convlstm_kernel_sizes'],
    convlstm_num_layers=config['convlstm_num_layers'],
    static_channels=static_channels_model, # Pass calculated NON-ELEVATION static channels
    unet_base_channels=config['unet_base_channels'],
    unet_depth=config['unet_depth'],
    # Default args next
    include_clay_features=config['feature_flags']['use_clay'],
    clay_checkpoint_path=str(absolute_clay_checkpoint_path) if config['feature_flags']['use_clay'] else None,
    clay_metadata_path=str(absolute_clay_metadata_path) if config['feature_flags']['use_clay'] else None,
    freeze_clay_backbone=config['freeze_backbone'] if config['feature_flags']['use_clay'] else False,
    clay_embed_dim=1024, # Assuming ViT-Large from Clay
    proj_static_ch=config['proj_static_ch'],
    proj_temporal_ch=config['proj_temporal_ch'],
    # --- NEW: Pass High-Res Elevation Args --- #
    include_dem_branch=config['include_dem_branch'],
    include_dsm_branch=config['include_dsm_branch'],
    elevation_branch_start_channels=config['elevation_branch_start_channels'],
    elevation_branch_out_channels=config['elevation_branch_out_channels'],
    elevation_branch_downsample_layers=config['elevation_branch_downsample_layers'],
    elevation_branch_kernel_size=config['elevation_branch_kernel_size'],
    # ----------------------------------------- #
    # Clay kwargs (match constructor)
    model_size=config['clay_model_size'] if config['feature_flags']['use_clay'] else None,
    bands=config['clay_bands'] if config['feature_flags']['use_clay'] else None,
    platform=config['clay_platform'] if config['feature_flags']['use_clay'] else None,
    gsd=config['clay_gsd'] if config['feature_flags']['use_clay'] else None
).to(device)

print("BranchedUHIModel (HighRes Elev) initialized.")
# Optional: Print model summary (might need adjustment for new inputs)
# try:
#     from torchinfo import summary
#     H, W = (dataset.sat_H, dataset.sat_W) if 'dataset' in locals() else (224, 224)
#     T = config['weather_seq_length']
#     B = 2
#     dummy_weather = torch.randn(B, T, config['weather_input_channels'], H, W)
#     dummy_static = torch.randn(B, static_channels_model, H, W) if static_channels_model > 0 else None
#     dummy_clay_mosaic = torch.randn(B, len(config['clay_bands']), H, W) if config['feature_flags']['use_clay'] else None
#     dummy_norm_time = torch.randn(B, 4) if config['feature_flags']['use_clay'] else None
#     dummy_norm_latlon = torch.randn(B, 4) if config['feature_flags']['use_clay'] else None
#     # Add dummy high-res inputs (need to guess shape or get from loaded data)
#     dummy_high_res_h, dummy_high_res_w = (H * 10, W * 10) # Example guess for 1m vs 10m
#     dummy_dem = torch.randn(B, 1, dummy_high_res_h, dummy_high_res_w) if config['include_dem_branch'] else None
#     dummy_dsm = torch.randn(B, 1, dummy_high_res_h, dummy_high_res_w) if config['include_dsm_branch'] else None
#     # Need to match forward signature order
#     summary(model, input_data=[dummy_weather, dummy_static, dummy_clay_mosaic, dummy_norm_time, dummy_norm_latlon, dummy_dem, dummy_dsm], device=str(device))
# except ImportError:
#     print("Install torchinfo (`pip install torchinfo`) for model summary.")
# except Exception as e:
#      print(f"Could not print model summary: {e}")


In [None]:
# %% Helper Functions (Train/Validate Epochs for ConvLSTM Model + HighRes Elev)

# --- REMOVED save_checkpoint definition (now in train_utils) ---

def train_epoch(model, dataloader, optimizer, loss_fn, device, uhi_mean, uhi_std, config):
    """Trains the BranchedUHIModel (ConvLSTM + HighRes Elev) for one epoch."""
    model.train()
    total_loss = 0.0
    all_targets_unnorm = []
    all_preds_unnorm = []
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Training', leave=False)

    # Get feature flags for convenience
    feature_flags = config.get('feature_flags', {}) # Flags used by dataloader
    include_clay_flag = feature_flags.get('use_clay', False)
    include_dem_branch_flag = config.get('include_dem_branch', False) # Flags used by model
    include_dsm_branch_flag = config.get('include_dsm_branch', False) # Flags used by model

    for batch in progress_bar:
        try:
            optimizer.zero_grad()

            # --- Unpack batch (BranchedCityDataSet + HighRes Elev) ---
            weather_seq = batch.get('weather_seq').to(device)
            target_unnorm = batch.get('target').to(device)
            mask = batch.get('mask').to(device)

            static_features = batch.get('static_features')
            if static_features is not None: static_features = static_features.to(device)

            cloudless_mosaic = batch.get('cloudless_mosaic')
            if include_clay_flag and cloudless_mosaic is not None: cloudless_mosaic = cloudless_mosaic.to(device)
            else: cloudless_mosaic = None

            norm_latlon = batch.get('norm_latlon')
            if include_clay_flag and norm_latlon is not None: norm_latlon = norm_latlon.to(device)
            else: norm_latlon = None

            norm_timestamp = batch.get('norm_timestamp')
            if include_clay_flag and norm_timestamp is not None: norm_timestamp = norm_timestamp.to(device)
            else: norm_timestamp = None

            high_res_dem = batch.get('high_res_dem')
            if include_dem_branch_flag and high_res_dem is not None: high_res_dem = high_res_dem.to(device)
            else: high_res_dem = None

            high_res_dsm = batch.get('high_res_dsm')
            if include_dsm_branch_flag and high_res_dsm is not None: high_res_dsm = high_res_dsm.to(device)
            else: high_res_dsm = None

            # Get target H, W for the forward pass (low-res grid)
            target_h, target_w = target_unnorm.shape[2], target_unnorm.shape[3]
            target_h_w_tuple = (target_h, target_w)

            # --- Forward Pass --- #
            prediction_norm = model(
                weather_seq=weather_seq,
                static_features=static_features,
                cloudless_mosaic=cloudless_mosaic,
                norm_latlon=norm_latlon,
                norm_timestamp=norm_timestamp,
                high_res_dem=high_res_dem,
                high_res_dsm=high_res_dsm,
                target_h_w=target_h_w_tuple
            )

            # --- Loss Calculation ---
            prediction_norm_final = prediction_norm.squeeze(1)
            target_norm = (target_unnorm.squeeze(1) - uhi_mean) / uhi_std
            loss = loss_fn(prediction_norm_final, target_norm, mask.squeeze(1))

            if torch.isnan(loss):
                logging.warning("NaN loss detected, skipping backward pass.")
                continue

            loss.backward()
            optimizer.step()

            # --- Accumulate for Metrics (using unnormalized values) ---
            prediction_unnorm = prediction_norm_final * uhi_std + uhi_mean
            mask_bool = mask.squeeze(1).bool()
            all_preds_unnorm.append(prediction_unnorm[mask_bool].detach().cpu())
            all_targets_unnorm.append(target_unnorm.squeeze(1)[mask_bool].detach().cpu())

            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix(loss=loss.item())

        except RuntimeError as e:
            logging.error(f"Runtime error during training: {e}")
            if "out of memory" in str(e).lower():
                logging.error("CUDA out of memory. Try reducing batch size (or n_train_batches), image size, or model complexity.")
                # Optionally break or raise here if OOM is critical
            # Log batch details for debugging other runtime errors
            logging.error(f"Batch keys: {batch.keys()}")
            for k, v_ in batch.items():
                 if isinstance(v_, torch.Tensor):
                      logging.error(f"  {k} shape: {v_.shape}, dtype: {v_.dtype}, device: {v_.device}")
                 else:
                      logging.error(f"  {k} type: {type(v_)}")
            continue # Continue to next batch
        except Exception as e:
            logging.error(f"Unexpected error during training step: {e}", exc_info=True)
            continue

    # --- Epoch Metrics Calculation ---
    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    rmse, r2 = float('nan'), float('nan')
    if all_targets_unnorm:
        try:
            all_targets_unnorm_flat = torch.cat(all_targets_unnorm).numpy()
            all_preds_unnorm_flat = torch.cat(all_preds_unnorm).numpy()
            valid_idx = ~np.isnan(all_targets_unnorm_flat) & ~np.isnan(all_preds_unnorm_flat)
            if np.sum(valid_idx) > 0:
                rmse = math.sqrt(mean_squared_error(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx]))
                r2 = r2_score(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx])
            else:
                 logging.warning("No valid overlapping pixels found for calculating metrics in training epoch.")
        except Exception as metric_e:
             logging.error(f"Error calculating epoch metrics: {metric_e}")
    else:
        logging.warning("No targets accumulated for calculating metrics in training epoch.")
    return avg_loss, rmse, r2

def validate_epoch(model, dataloader, loss_fn, device, uhi_mean, uhi_std, config):
    """Evaluates the BranchedUHIModel (ConvLSTM + HighRes Elev) on the validation set."""
    model.eval()
    total_loss = 0.0
    all_targets_unnorm = []
    all_preds_unnorm = []
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Validation', leave=False)

    # Get feature flags for convenience
    feature_flags = config.get('feature_flags', {})
    include_clay_flag = feature_flags.get('use_clay', False)
    include_dem_branch_flag = config.get('include_dem_branch', False)
    include_dsm_branch_flag = config.get('include_dsm_branch', False)

    with torch.no_grad():
        for batch in progress_bar:
            try:
                # --- Unpack batch (same as train_epoch) ---
                weather_seq = batch.get('weather_seq').to(device)
                target_unnorm = batch.get('target').to(device)
                mask = batch.get('mask').to(device)

                static_features = batch.get('static_features')
                if static_features is not None: static_features = static_features.to(device)

                cloudless_mosaic = batch.get('cloudless_mosaic')
                if include_clay_flag and cloudless_mosaic is not None: cloudless_mosaic = cloudless_mosaic.to(device)
                else: cloudless_mosaic = None

                norm_latlon = batch.get('norm_latlon')
                if include_clay_flag and norm_latlon is not None: norm_latlon = norm_latlon.to(device)
                else: norm_latlon = None

                norm_timestamp = batch.get('norm_timestamp')
                if include_clay_flag and norm_timestamp is not None: norm_timestamp = norm_timestamp.to(device)
                else: norm_timestamp = None

                high_res_dem = batch.get('high_res_dem')
                if include_dem_branch_flag and high_res_dem is not None: high_res_dem = high_res_dem.to(device)
                else: high_res_dem = None

                high_res_dsm = batch.get('high_res_dsm')
                if include_dsm_branch_flag and high_res_dsm is not None: high_res_dsm = high_res_dsm.to(device)
                else: high_res_dsm = None

                target_h, target_w = target_unnorm.shape[2], target_unnorm.shape[3]
                target_h_w_tuple = (target_h, target_w)

                # --- Forward Pass --- #
                prediction_norm = model(
                    weather_seq=weather_seq,
                    static_features=static_features,
                    cloudless_mosaic=cloudless_mosaic,
                    norm_latlon=norm_latlon,
                    norm_timestamp=norm_timestamp,
                    high_res_dem=high_res_dem,
                    high_res_dsm=high_res_dsm,
                    target_h_w=target_h_w_tuple
                )

                # --- Loss Calculation ---
                prediction_norm_final = prediction_norm.squeeze(1)
                target_norm = (target_unnorm.squeeze(1) - uhi_mean) / uhi_std
                loss = loss_fn(prediction_norm_final, target_norm, mask.squeeze(1))

                if torch.isnan(loss):
                    logging.warning("NaN validation loss detected, skipping batch.")
                    continue

                # --- Accumulate for Metrics ---
                prediction_unnorm = prediction_norm_final * uhi_std + uhi_mean
                mask_bool = mask.squeeze(1).bool()
                all_preds_unnorm.append(prediction_unnorm[mask_bool].detach().cpu())
                all_targets_unnorm.append(target_unnorm.squeeze(1)[mask_bool].detach().cpu())

                total_loss += loss.item()
                num_batches += 1
                progress_bar.set_postfix(loss=loss.item())

            except Exception as e:
                 logging.error(f"Error during validation step: {e}", exc_info=True)
                 # Log batch details for debugging
                 logging.error(f"Batch keys: {batch.keys()}")
                 for k, v_ in batch.items():
                     if isinstance(v_, torch.Tensor):
                          logging.error(f"  {k} shape: {v_.shape}, dtype: {v_.dtype}, device: {v_.device}")
                     else:
                          logging.error(f"  {k} type: {type(v_)}")
                 continue # Continue to next batch

    # --- Epoch Metrics Calculation ---
    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    rmse, r2 = float('nan'), float('nan')
    if all_targets_unnorm:
        try:
            all_targets_unnorm_flat = torch.cat(all_targets_unnorm).numpy()
            all_preds_unnorm_flat = torch.cat(all_preds_unnorm).numpy()
            valid_idx = ~np.isnan(all_targets_unnorm_flat) & ~np.isnan(all_preds_unnorm_flat)
            if np.sum(valid_idx) > 0:
                rmse = math.sqrt(mean_squared_error(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx]))
                r2 = r2_score(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx])
            else:
                logging.warning("No valid overlapping pixels found for calculating metrics in validation epoch.")
        except Exception as metric_e:
            logging.error(f"Error calculating epoch metrics: {metric_e}")
    else:
        logging.warning("No targets accumulated for calculating metrics in validation epoch.")
    return avg_loss, rmse, r2


In [None]:
# %% Training Loop Execution

# Ensure model and dataloaders are initialized from previous cells
if 'model' not in locals() or model is None:
    raise NameError("Model is not initialized. Run the model initialization cell first.")
if 'train_loader' not in locals():
    raise NameError("Train loader is not initialized. Run the DataLoader setup cell first.")
# val_loader can be None if val_percent was 0 or dataset too small
if 'val_loader' not in locals():
    val_loader = None 
    print("Validation loader not found, proceeding without validation.")

print(f"Model {config['model_type']} initialized on {device}")

# --- Optimizer and Loss ---
best_val_r2 = -float('inf') # Initialize best R^2
epochs_no_improve = 0
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
loss_fn = masked_mae_loss if config["loss_type"] == "mae" else masked_mse_loss

# --- Output Directory & Run Name ---
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f"{config['wander_run_name_prefix']}_{run_timestamp}"
output_dir = Path(output_dir_base) / run_name
output_dir.mkdir(parents=True, exist_ok=True)
config["output_dir"] = str(output_dir) # Update config with actual output dir
print(f"Checkpoints and logs will be saved to: {output_dir}")

# --- Retrieve UHI Stats ---
uhi_mean = config.get('uhi_mean')
uhi_std = config.get('uhi_std')
if uhi_mean is None or uhi_std is None:
    raise ValueError("uhi_mean and uhi_std not found in config. Ensure they were calculated.")
print(f"Using Training UHI Mean: {uhi_mean:.4f}, Std Dev: {uhi_std:.4f} for normalization.")

# --- Feature Flags for Training/Validation functions ---
feature_flags_from_config = config['feature_flags']

# --- Initialize WANDB ---
if 'wandb' in sys.modules:
    try:
        if wandb.run is not None:
            print("Finishing previous W&B run...")
            wandb.finish()
        wandb.init(
            project=config["wandb_project_name"],
            name=run_name,
            config=config # Log the entire config dictionary
        )
        print(f"Wandb initialized for run: {run_name}")
    except Exception as e:
        print(f"Wandb initialization failed: {e}")
        wandb = None
else:
    print("Wandb not imported, skipping W&B logging.")
    wandb = None

# Save configuration used for this run locally
try:
    # Use the same serialization helper as before
    config_serializable = {k: str(v) if isinstance(v, Path) else v for k, v in config.items()}
    with open(output_dir / "config.json", 'w') as f:
        json.dump(config_serializable, f, indent=2, default=lambda x: str(x) if isinstance(x, Path) else x)
    print("Saved local configuration to config.json")
except Exception as e:
    print(f"Warning: Failed to save local configuration: {e}")

# --- Training Loop ---
print(f"Starting {config['model_type']} training...")
training_log = []
last_saved_epoch = -1

for epoch in range(config["epochs"]):
    print(f"--- Epoch {epoch+1}/{config['epochs']} ---")

    if train_loader:
        train_loss, train_rmse, train_r2 = train_epoch(model, train_loader, optimizer, loss_fn, device,
                                                       uhi_mean, uhi_std, feature_flags_from_config)
    else:
        print("Skipping training epoch as train_loader is not available.")
        train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')

    log_metrics = {
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_rmse": train_rmse,
        "train_r2": train_r2
    }

    is_best = False
    if val_loader:
        val_loss, val_rmse, val_r2 = validate_epoch(model, val_loader, loss_fn, device,
                                                    uhi_mean, uhi_std, feature_flags_from_config)
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} RMSE={train_rmse:.4f} R2={train_r2:.4f} | Val Loss={val_loss:.4f} RMSE={val_rmse:.4f} R2={val_r2:.4f}")
        log_metrics.update({
            "val_loss": val_loss,
            "val_rmse": val_rmse,
            "val_r2": val_r2
        })

        if np.isnan(val_r2):
             print("Warning: Validation R^2 is NaN. Cannot determine improvement. Stopping training.")
             break

        is_best = val_r2 > best_val_r2
        if is_best:
            best_val_r2 = val_r2
            epochs_no_improve = 0
            print(f"New best validation R^2: {best_val_r2:.4f}")
            if wandb and wandb.run:
                wandb.run.summary["best_val_r2"] = best_val_r2
        else:
            epochs_no_improve += 1
            print(f"No improvement in validation R^2 for {epochs_no_improve} epochs.")

        if epochs_no_improve >= config["patience"]:
            print(f"Early stopping triggered after {config['patience']} epochs with no improvement.")
            break
    else: # No validation
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} RMSE={train_rmse:.4f} R2={train_r2:.4f} (No validation)")
        if np.isnan(train_loss):
             print("Warning: Training loss is NaN. Stopping training.")
             break

    if wandb:
        wandb.log(log_metrics)
    training_log.append(log_metrics)

    save_checkpoint(
        {'epoch': epoch + 1,
         'state_dict': model.state_dict(),
         'best_val_r2': best_val_r2,
         'optimizer' : optimizer.state_dict(),
         'config': config_serializable
         },
        is_best,
        filename=output_dir / 'checkpoint_last.pth.tar',
        best_filename=output_dir / 'model_best.pth.tar'
    )
    if is_best:
        last_saved_epoch = epoch + 1

# --- Final Steps ---
try:
    log_df = pd.DataFrame(training_log)
    log_df.to_csv(output_dir / 'training_log.csv', index=False)
    print(f"Saved local training log to {output_dir / 'training_log.csv'}")
except Exception as e:
    print(f"Warning: Failed to save local training log: {e}")

print("Training finished.")
if val_loader:
    print(f"Best validation R^2 recorded: {best_val_r2:.4f} (achieved at epoch {last_saved_epoch if last_saved_epoch > 0 else 'N/A'})" if not np.isinf(best_val_r2) else "Best validation R^2: N/A")
print(f"Final checkpoint saved in: {output_dir / 'checkpoint_last.pth.tar'}")
if last_saved_epoch > 0:
     print(f"Best model saved in: {output_dir / 'model_best.pth.tar'}")

if wandb and wandb.run:
    wandb.finish()
    print("Wandb run finished.")

