# Branched UHI Model Training

In [None]:
# Your code here

In [None]:
# Imports
import sys
import os

from pathlib import Path
import yaml
from box import Box

import torch
import torch.optim as optim
import torch.nn.functional as F # Added for interpolation
from torch.utils.data import DataLoader, random_split

import logging
import argparse
from pathlib import Path
import sys

import numpy as np
import pandas as pd # Needed for loading bounds from csv
from tqdm.notebook import tqdm # Use notebook tqdm
import math

import json
import os
from datetime import datetime
import shutil

# --- WANDB --- #
import wandb
# ------------ #

# --- Metrics --- 
from sklearn.metrics import mean_squared_error, r2_score
# ------------ #

# Add src directory to path to import modules
project_root = Path(os.getcwd()).parent  # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

# --- UPDATED IMPORTS for Branched Model --- #
from src.branched_uhi_model import BranchedUHIModel # NEW MODEL
from src.ingest.dataloader_branched import BranchedCityDataSet # NEW DATALOADER
# ------------------------------------------ #

from src.train.loss import masked_mae_loss, masked_mse_loss

In [None]:
### UNCOMMENT ON FIRST RUN
#!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt

# %% Configuration / Hyperparameters for BranchedUHIModel (ConvLSTM)

# --- Paths & Basic Info ---
project_root_str = str(project_root) # Store as string for config
data_dir_base = project_root / "data"
city_name = "NYC"
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_BranchedUHI_ConvLSTM" # Modified prefix

# --- Data Loading Config ---
resolution_m = 10
weather_seq_length = 60 # <<< Explicitly set to 60

# Input Data Paths (relative to project root for portability in config)
relative_data_dir = Path("data")
relative_uhi_csv = relative_data_dir / city_name / "uhi.csv"
relative_bronx_weather_csv = relative_data_dir / city_name / "bronx_weather.csv"
relative_manhattan_weather_csv = relative_data_dir / city_name / "manhattan_weather.csv"

# --- CORRECTED DEM/DSM Paths (Relative) --- #
relative_dem_path = relative_data_dir / city_name / "sat_files" / "nyc_dem_1ft_2017.tif"
relative_dsm_path = relative_data_dir / city_name / "sat_files" / "nyc_dsm_1ft_2017.tif"

# --- Cloudless Mosaic / LST Paths (Example) ---
relative_cloudless_mosaic_path = relative_data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy"
relative_single_lst_median_path = relative_data_dir / city_name / "sat_files" / f"lst_{city_name}_median_20210601_to_20210901.npy"


# --- Feature Selection Flags --- #
feature_flags = {
    "use_dsm": True,
    "use_dem": True,
    "use_clay": True,
    "use_sentinel_composite": False,
    "use_lst": False,
    "use_ndvi": False,
    "use_ndbi": False,
    "use_ndwi": False,
}

# --- Bands for Sentinel Composite (if use_sentinel_composite is True) --- #
sentinel_bands_to_load = ["blue", "green", "red", "nir", "swir16", "swir22"]

# --- Model Config (BranchedUHIModel with ConvLSTM) ---

# Clay Backbone (if feature_flags["use_clay"] is True)
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = True
relative_clay_checkpoint_path = "notebooks/clay-v1.5.ckpt"
relative_clay_metadata_path = Path("src") / "Clay" / "configs" / "metadata.yaml"

# Temporal Weather Processor (ConvLSTM) - NEW PARAMS
weather_input_channels = 6 # From dataloader (temp, humidity, wind_speed, wind_sin, wind_cos, solar)
convlstm_hidden_dims = [64, 32]  # Example: 2 layers with 64 then 32 hidden channels
convlstm_kernel_sizes = [(3,3), (3,3)] # Kernel size for each layer (must match len hidden_dims)
convlstm_num_layers = len(convlstm_hidden_dims) # Number of layers

# U-Net Head (No changes here)
unet_base_channels = 64
unet_depth = 4

# Projection Layer Channels (for fusion)
proj_static_ch = 32
proj_temporal_ch = 32 # For projecting ConvLSTM output

# --- Training Hyperparameters ---
num_workers = 4
epochs = 500
lr = 5e-5
weight_decay = 0.01
loss_type = 'mse'
patience = 50
cpu = False
n_train_batches = 8

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- Sanity Checks and Absolute Paths ---
# (Check path function is unchanged)
def check_path(relative_path, description, should_exist=True):
    abs_path = project_root / relative_path
    if should_exist and not abs_path.exists():
        raise FileNotFoundError(f"{description} not found at {abs_path}. Ensure download_data.ipynb was run.")
    return abs_path

# (Path checks are unchanged)
absolute_uhi_csv = check_path(relative_uhi_csv, "UHI CSV")
absolute_bronx_weather_csv = check_path(relative_bronx_weather_csv, "Bronx Weather CSV")
absolute_manhattan_weather_csv = check_path(relative_manhattan_weather_csv, "Manhattan Weather CSV")
absolute_dem_path = check_path(relative_dem_path, "DEM TIF") if feature_flags["use_dem"] else None
absolute_dsm_path = check_path(relative_dsm_path, "DSM TIF") if feature_flags["use_dsm"] else None
absolute_clay_checkpoint_path = check_path(relative_clay_checkpoint_path, "Clay Checkpoint") if feature_flags["use_clay"] else None
absolute_clay_metadata_path = check_path(relative_clay_metadata_path, "Clay Metadata", should_exist=feature_flags["use_clay"]) if feature_flags["use_clay"] else None
absolute_cloudless_mosaic_path = check_path(relative_cloudless_mosaic_path, "Cloudless Mosaic") if feature_flags["use_sentinel_composite"] or feature_flags["use_clay"] or feature_flags["use_ndvi"] or feature_flags["use_ndbi"] or feature_flags["use_ndwi"] else None
absolute_single_lst_median_path = check_path(relative_single_lst_median_path, "Single LST Median") if feature_flags["use_lst"] else None

# --- Calculate Bounds (Unchanged) ---
uhi_df = pd.read_csv(absolute_uhi_csv)
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds = [
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {absolute_uhi_csv.name}: {bounds}")

# --- Central Config Dictionary (Updated for ConvLSTM) --- #
config = {
    # Paths & Info
    "model_type": "BranchedUHIModel_ConvLSTM",
    "project_root": project_root_str,
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "resolution_m": resolution_m,
    "weather_seq_length": weather_seq_length, # Added
    "uhi_csv": str(relative_uhi_csv),
    "bronx_weather_csv": str(absolute_bronx_weather_csv),
    "manhattan_weather_csv": str(absolute_manhattan_weather_csv),
    "bounds": bounds,
    "feature_flags": feature_flags,
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path": str(absolute_dem_path) if absolute_dem_path else None,
    "dsm_path": str(absolute_dsm_path) if absolute_dsm_path else None,
    "cloudless_mosaic_path": str(absolute_cloudless_mosaic_path) if absolute_cloudless_mosaic_path else None,
    "single_lst_median_path": str(absolute_single_lst_median_path) if absolute_single_lst_median_path else None,
    # "elevation_nodata": None, # Removed, handled by data_utils
    # Model Config (ConvLSTM)
    "weather_input_channels": weather_input_channels,
    "convlstm_hidden_dims": convlstm_hidden_dims, # Added
    "convlstm_kernel_sizes": convlstm_kernel_sizes, # Added
    "convlstm_num_layers": convlstm_num_layers, # Added
    "proj_static_ch": proj_static_ch, # Added
    "proj_temporal_ch": proj_temporal_ch, # Added
    "unet_base_channels": unet_base_channels, # Added
    "unet_depth": unet_depth, # Added
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(absolute_clay_checkpoint_path) if absolute_clay_checkpoint_path else None,
    "clay_metadata_path": str(absolute_clay_metadata_path) if absolute_clay_metadata_path else None,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience,
    "device": str(device)
}

print("\nBranched Model (ConvLSTM) Configuration dictionary created:")
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, Path) else x))


In [None]:
# %% Setup DataLoader using BranchedCityDataSet

from torch.utils.data import Subset

print("Initializing BranchedCityDataSet...")
try:
    # Use BranchedCityDataSet
    dataset = BranchedCityDataSet(
        bounds=config['bounds'],
        resolution_m=config['resolution_m'],
        uhi_csv=absolute_uhi_csv,
        bronx_weather_csv=str(absolute_bronx_weather_csv),
        manhattan_weather_csv=str(absolute_manhattan_weather_csv),
        data_dir=data_dir_base,
        city_name=config['city_name'],
        # Pass feature flags and related paths from config
        feature_flags=config['feature_flags'],
        sentinel_bands_to_load=config['sentinel_bands_to_load'],
        dem_path=str(absolute_dem_path) if config['feature_flags']['use_dem'] else None,
        dsm_path=str(absolute_dsm_path) if config['feature_flags']['use_dsm'] else None,
        cloudless_mosaic_path=str(absolute_cloudless_mosaic_path) if absolute_cloudless_mosaic_path else None,
        single_lst_median_path=str(absolute_single_lst_median_path) if config['feature_flags']['use_lst'] else None,
        # --- ADDED weather_seq_length --- #
        weather_seq_length=config['weather_seq_length'],
        # -------------------------------- #
        # target_crs can be added if needed, using default from dataloader
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Train/Val Split (Random) ---
# (Split logic unchanged)
val_percent = 0.40
n_samples = len(dataset)
if n_samples < 10:
    print(f"Warning: Dataset size ({n_samples}) is very small. Using all data for training.")
    n_val = 0
    n_train = n_samples
    val_ds = None
else:
    n_val = int(n_samples * val_percent)
    n_train = n_samples - n_val
    train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42))
print(f"Random dataset split: {len(train_ds)} training, {len(val_ds) if val_ds else 0} validation samples.")

# --- Calculate UHI Mean and Std from Training Data ONLY ---
# (Stat calculation unchanged)
print("Calculating UHI statistics from training data...")
all_train_targets = []
calc_batch_size = min(64, len(train_ds)) if len(train_ds) > 0 else 1
if len(train_ds) > 0:
    temp_loader = DataLoader(train_ds, batch_size=calc_batch_size, shuffle=False, num_workers=0)
    for batch in tqdm(temp_loader, desc="Calculating stats"):
        target_tensor = batch.get('target')
        mask_tensor = batch.get('mask')
        if target_tensor is None or mask_tensor is None:
            logging.warning("Skipping batch in stats calculation due to missing target or mask.")
            continue
        valid_targets = target_tensor[mask_tensor.bool()]
        if valid_targets.numel() > 0:
            all_train_targets.append(valid_targets.cpu())
else:
    print("Training dataset is empty, skipping UHI statistics calculation.")
if not all_train_targets:
     if len(train_ds) > 0:
         raise ValueError("No valid training targets found to calculate UHI statistics.")
     else:
         print("Warning: No training data, setting UHI stats to 0 and 1.")
         uhi_mean = 0.0
         uhi_std = 1.0
else:
    all_train_targets_tensor = torch.cat(all_train_targets)
    uhi_mean = all_train_targets_tensor.mean().item()
    uhi_std = all_train_targets_tensor.std().item()
    uhi_std = uhi_std if uhi_std > 1e-6 else 1.0
print(f"Training UHI Mean: {uhi_mean:.4f}, Std Dev: {uhi_std:.4f}")
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
# (Dataloader creation unchanged)
print("Creating dataloaders...")
train_batch_size = max(1, len(train_ds) // config['n_train_batches']) if len(train_ds) > 0 else 1
val_batch_size = len(val_ds) if val_ds and len(val_ds) > 0 else 1
print(f"Using Train Batch Size: {train_batch_size}")
train_loader = DataLoader(train_ds, batch_size=train_batch_size, shuffle=True, num_workers=config['num_workers'], pin_memory=True, drop_last=True) if len(train_ds) > 0 else None
val_loader = DataLoader(val_ds, batch_size=val_batch_size, shuffle=False, num_workers=config['num_workers'], pin_memory=True) if val_ds and len(val_ds) > 0 else None
print("Data loading setup complete.")

# --- Verify Dataloader Output --- #
# Now check for the new keys from dataloader
if train_loader:
    try:
        first_batch = next(iter(train_loader))
        print("\nFirst training batch keys:", list(first_batch.keys()))
        # Print shapes to verify
        for k, v in first_batch.items():
            if isinstance(v, torch.Tensor):
                print(f"  {k}: {v.shape}, dtype={v.dtype}")
            else:
                 print(f"  {k}: type={type(v)}")
        # Check weather seq shape
        expected_weather_shape = (train_batch_size, config['weather_seq_length'], config['weather_input_channels'], dataset.sat_H, dataset.sat_W)
        actual_weather_shape = first_batch['weather_seq'].shape
        if actual_weather_shape != expected_weather_shape:
             logging.warning(f"Weather sequence shape mismatch! Expected {expected_weather_shape}, got {actual_weather_shape}")
        # Check static features shape (if present)
        if 'static_features' in first_batch:
            print(f"  Static features channels: {first_batch['static_features'].shape[1]}")
            expected_static_channels = dataset.combined_static_features.shape[0]
            if first_batch['static_features'].shape[1] != expected_static_channels:
                 logging.warning(f"Static features channel mismatch! Dataloader has {expected_static_channels}, batch has {first_batch['static_features'].shape[1]}")
        # Check clay mosaic shape (if present)
        if 'cloudless_mosaic' in first_batch:
             expected_clay_channels = len(config['clay_bands']) # Should be 4
             if first_batch['cloudless_mosaic'].shape[1] != expected_clay_channels:
                 logging.warning(f"Cloudless mosaic channel mismatch! Expected {expected_clay_channels} for Clay, got {first_batch['cloudless_mosaic'].shape[1]}")

    except StopIteration:
        print("\nCould not get first batch, training loader might be empty.")
    except Exception as e:
        print(f"\nError inspecting first batch: {e}")
else:
    print("\nTraining loader is not available (likely no training data).")


In [None]:
# %% Model Initialization (BranchedUHIModel with ConvLSTM)

print("Initializing BranchedUHIModel with ConvLSTM...")

# --- Determine Static Input Channels for Model --- 
# This depends on which *non-Clay* static features are enabled in the dataloader
static_channels_model = 0
if config['feature_flags']['use_dsm']: static_channels_model += 1
if config['feature_flags']['use_dem']: static_channels_model += 1
if config['feature_flags']['use_sentinel_composite']: static_channels_model += len(config['sentinel_bands_to_load'])
if config['feature_flags']['use_lst']: static_channels_model += 1
if config['feature_flags']['use_ndvi']: static_channels_model += 1
if config['feature_flags']['use_ndbi']: static_channels_model += 1
if config['feature_flags']['use_ndwi']: static_channels_model += 1
print(f"Calculated non-Clay static input channels for model: {static_channels_model}")

# --- Instantiate the BranchedUHIModel with ConvLSTM --- #
model = BranchedUHIModel(
    # Non-default args first
    weather_input_channels=config['weather_input_channels'],
    convlstm_hidden_dims=config['convlstm_hidden_dims'],
    convlstm_kernel_sizes=config['convlstm_kernel_sizes'],
    convlstm_num_layers=config['convlstm_num_layers'],
    static_channels=static_channels_model, # Pass calculated non-clay static channels
    unet_base_channels=config['unet_base_channels'],
    unet_depth=config['unet_depth'],
    # Default args next
    include_clay_features=config['feature_flags']['use_clay'],
    clay_checkpoint_path=str(absolute_clay_checkpoint_path) if config['feature_flags']['use_clay'] else None,
    clay_metadata_path=str(absolute_clay_metadata_path) if config['feature_flags']['use_clay'] else None,
    freeze_clay_backbone=config['freeze_backbone'] if config['feature_flags']['use_clay'] else False,
    clay_embed_dim=1024, # Assuming ViT-Large from Clay
    proj_static_ch=config['proj_static_ch'],
    proj_temporal_ch=config['proj_temporal_ch'],
    # Clay kwargs (match constructor)
    model_size=config['clay_model_size'] if config['feature_flags']['use_clay'] else None,
    bands=config['clay_bands'] if config['feature_flags']['use_clay'] else None,
    platform=config['clay_platform'] if config['feature_flags']['use_clay'] else None,
    gsd=config['clay_gsd'] if config['feature_flags']['use_clay'] else None
).to(device)

print("BranchedUHIModel (ConvLSTM) initialized.")
# Optional: Print model summary
# try:
#     from torchinfo import summary
#     # Determine dummy input shapes (use H, W from dataloader if available)
#     H, W = (dataset.sat_H, dataset.sat_W) if 'dataset' in locals() else (224, 224) # Default guess
#     T = config['weather_seq_length']
#     B = 2 # Dummy batch size
#     dummy_weather = torch.randn(B, T, config['weather_input_channels'], H, W)
#     dummy_static = torch.randn(B, static_channels_model, H, W) if static_channels_model > 0 else None
#     dummy_clay_mosaic = torch.randn(B, len(config['clay_bands']), H, W) if config['feature_flags']['use_clay'] else None
#     dummy_norm_time = torch.randn(B, 4) if config['feature_flags']['use_clay'] else None
#     dummy_norm_latlon = torch.randn(B, 4) if config['feature_flags']['use_clay'] else None
#     summary(model, input_data=[dummy_weather, dummy_static, dummy_clay_mosaic, dummy_norm_time, dummy_norm_latlon], device=str(device))
# except ImportError:
#     print("Install torchinfo (`pip install torchinfo`) for model summary.")
# except Exception as e:
#      print(f"Could not print model summary: {e}")


In [None]:
# %% Helper Functions (Train/Validate Epochs for ConvLSTM Model)

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
    """Saves model checkpoint."""
    Path(filename).parent.mkdir(parents=True, exist_ok=True)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_filename)
        print(f"Saved new best model to {best_filename}")

def train_epoch(model, dataloader, optimizer, loss_fn, device, uhi_mean, uhi_std, feature_flags):
    """Trains the BranchedUHIModel (ConvLSTM) for one epoch."""
    model.train()
    total_loss = 0.0
    all_targets_unnorm = []
    all_preds_unnorm = []
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Training', leave=False)

    for batch in progress_bar:
        # --- Unpack batch based on BranchedCityDataSet output (ConvLSTM version) ---
        # Mandatory items
        weather_seq = batch.get('weather_seq').to(device)      # (B, T, C_weather, H, W)
        target_unnorm = batch.get('target').to(device)          # (B, 1, H, W)
        mask = batch.get('mask').to(device)                  # (B, 1, H, W)

        # Optional items - get them if they exist in the batch
        static_features = batch.get('static_features') # Non-clay static features
        if static_features is not None: static_features = static_features.to(device)

        cloudless_mosaic = batch.get('cloudless_mosaic') # Input for Clay
        if cloudless_mosaic is not None: cloudless_mosaic = cloudless_mosaic.to(device)

        norm_latlon = batch.get('norm_latlon')           # Clay metadata
        if norm_latlon is not None: norm_latlon = norm_latlon.to(device)

        norm_time = batch.get('norm_timestamp')      # Clay metadata (NOTE: Key changed in dataloader update)
        if norm_time is not None: norm_time = norm_time.to(device)

        optimizer.zero_grad()
        try:
            # --- Forward Pass - Use ConvLSTM model signature ---
            # Determine target H, W from target tensor
            target_h, target_w = target_unnorm.shape[2], target_unnorm.shape[3]
            target_h_w_tuple = (target_h, target_w)
            
            prediction_norm = model(
                weather_seq=weather_seq,
                static_features=static_features,    # Pass non-clay static features separately
                cloudless_mosaic=cloudless_mosaic,  # Pass mosaic needed for Clay
                norm_latlon_tensor=norm_latlon,
                norm_time_tensor=norm_time,
                target_h_w=target_h_w_tuple       # Pass target size for final resize
            ) # Shape (B, 1, H_target, W_target)

            # --- Loss Calculation ---
            prediction_norm_final = prediction_norm.squeeze(1) # Shape (B, H_target, W_target)
            target_norm = (target_unnorm.squeeze(1) - uhi_mean) / uhi_std # Normalize target (B, H, W)
            loss = loss_fn(prediction_norm_final, target_norm, mask.squeeze(1)) # Use mask (B, H, W)

            if torch.isnan(loss):
                logging.warning("NaN loss detected, skipping backward pass.")
                continue

            loss.backward()
            optimizer.step()

            # --- Accumulate for Metrics (using unnormalized values) ---
            prediction_unnorm = prediction_norm_final * uhi_std + uhi_mean
            mask_bool = mask.squeeze(1).bool() # Get boolean mask (B, H, W)
            all_preds_unnorm.append(prediction_unnorm[mask_bool].detach().cpu())
            all_targets_unnorm.append(target_unnorm.squeeze(1)[mask_bool].detach().cpu())

            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix(loss=loss.item())

        except RuntimeError as e:
            logging.error(f"Runtime error during training: {e}")
            if "out of memory" in str(e):
                logging.error("CUDA out of memory. Try reducing batch size (or n_train_batches).")
            continue # Skip batch
        except Exception as e:
            logging.error(f"Unexpected error during training step: {e}", exc_info=True)
            continue # Skip batch

    # --- Epoch Metrics Calculation (Unchanged) ---
    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    if all_targets_unnorm:
        all_targets_unnorm_flat = torch.cat(all_targets_unnorm).numpy()
        all_preds_unnorm_flat = torch.cat(all_preds_unnorm).numpy()
        valid_idx = ~np.isnan(all_targets_unnorm_flat) & ~np.isnan(all_preds_unnorm_flat)
        if np.sum(valid_idx) > 0:
            rmse = math.sqrt(mean_squared_error(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx]))
            r2 = r2_score(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx])
        else:
             logging.warning("No valid pixels found for calculating metrics in training epoch.")
             rmse = float('nan'); r2 = float('nan')
    else:
        logging.warning("No targets accumulated for calculating metrics in training epoch.")
        rmse = float('nan'); r2 = float('nan')
    return avg_loss, rmse, r2

def validate_epoch(model, dataloader, loss_fn, device, uhi_mean, uhi_std, feature_flags):
    """Evaluates the BranchedUHIModel (ConvLSTM) on the validation set."""
    model.eval()
    total_loss = 0.0
    all_targets_unnorm = []
    all_preds_unnorm = []
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Validation', leave=False)

    with torch.no_grad():
        for batch in progress_bar:
            # --- Unpack batch (same as train_epoch) ---
            weather_seq = batch.get('weather_seq').to(device)
            target_unnorm = batch.get('target').to(device)
            mask = batch.get('mask').to(device)
            static_features = batch.get('static_features')
            if static_features is not None: static_features = static_features.to(device)
            cloudless_mosaic = batch.get('cloudless_mosaic')
            if cloudless_mosaic is not None: cloudless_mosaic = cloudless_mosaic.to(device)
            norm_latlon = batch.get('norm_latlon')
            if norm_latlon is not None: norm_latlon = norm_latlon.to(device)
            norm_time = batch.get('norm_timestamp') # Key changed
            if norm_time is not None: norm_time = norm_time.to(device)

            try:
                # --- Forward Pass ---
                target_h, target_w = target_unnorm.shape[2], target_unnorm.shape[3]
                target_h_w_tuple = (target_h, target_w)
                
                prediction_norm = model(
                    weather_seq=weather_seq,
                    static_features=static_features,
                    cloudless_mosaic=cloudless_mosaic,
                    norm_latlon_tensor=norm_latlon,
                    norm_time_tensor=norm_time,
                    target_h_w=target_h_w_tuple
                )

                # --- Loss Calculation ---
                prediction_norm_final = prediction_norm.squeeze(1)
                target_norm = (target_unnorm.squeeze(1) - uhi_mean) / uhi_std
                loss = loss_fn(prediction_norm_final, target_norm, mask.squeeze(1))

                if torch.isnan(loss):
                    logging.warning("NaN validation loss detected, skipping batch.")
                    continue

                # --- Accumulate for Metrics ---
                prediction_unnorm = prediction_norm_final * uhi_std + uhi_mean
                mask_bool = mask.squeeze(1).bool()
                all_preds_unnorm.append(prediction_unnorm[mask_bool].detach().cpu())
                all_targets_unnorm.append(target_unnorm.squeeze(1)[mask_bool].detach().cpu())

                total_loss += loss.item()
                num_batches += 1
                progress_bar.set_postfix(loss=loss.item())

            except Exception as e:
                 logging.error(f"Error during validation step: {e}", exc_info=True)
                 continue # Skip batch on errors

    # --- Epoch Metrics Calculation (Unchanged) ---
    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    if all_targets_unnorm:
        all_targets_unnorm_flat = torch.cat(all_targets_unnorm).numpy()
        all_preds_unnorm_flat = torch.cat(all_preds_unnorm).numpy()
        valid_idx = ~np.isnan(all_targets_unnorm_flat) & ~np.isnan(all_preds_unnorm_flat)
        if np.sum(valid_idx) > 0:
             rmse = math.sqrt(mean_squared_error(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx]))
             r2 = r2_score(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx])
        else:
             logging.warning("No valid pixels found for calculating metrics in validation epoch.")
             rmse = float('nan'); r2 = float('nan')
    else:
        logging.warning("No targets accumulated for calculating metrics in validation epoch.")
        rmse = float('nan'); r2 = float('nan')
    return avg_loss, rmse, r2


In [None]:
# %% Helper Functions (Train/Validate Epochs)

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
    """Saves model checkpoint."""
    Path(filename).parent.mkdir(parents=True, exist_ok=True)
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_filename)
        print(f"Saved new best model to {best_filename}")

def train_epoch(model, dataloader, optimizer, loss_fn, device, uhi_mean, uhi_std, feature_flags):
    """Trains the BranchedUHIModel for one epoch."""
    model.train()
    total_loss = 0.0
    all_targets_unnorm = []
    all_preds_unnorm = []
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Training', leave=False)

    for batch in progress_bar:
        # --- Unpack batch based on BranchedCityDataSet output ---
        # Mandatory items
        weather_seq = batch.get('weather_seq').to(device)
        target_unnorm = batch.get('target').to(device)
        mask = batch.get('mask').to(device)
        
        # Optional items - get them if present in batch (based on feature_flags used in dataloader)
        # The BranchedUHIModel forward pass will expect these if the corresponding feature_flag was True during init
        dsm = batch.get('dsm').to(device) if feature_flags['use_dsm'] and 'dsm' in batch else None
        dem = batch.get('dem').to(device) if feature_flags['use_dem'] and 'dem' in batch else None
        clay_input = batch.get('clay_patch').to(device) if feature_flags['use_clay'] and 'clay_patch' in batch else None # Assuming dataloader provides 'clay_patch'
        sentinel_composite = batch.get('sentinel_composite').to(device) if feature_flags['use_sentinel_composite'] and 'sentinel_composite' in batch else None
        lst = batch.get('lst').to(device) if feature_flags['use_lst'] and 'lst' in batch else None
        ndvi = batch.get('ndvi').to(device) if feature_flags['use_ndvi'] and 'ndvi' in batch else None
        ndbi = batch.get('ndbi').to(device) if feature_flags['use_ndbi'] and 'ndbi' in batch else None
        ndwi = batch.get('ndwi').to(device) if feature_flags['use_ndwi'] and 'ndwi' in batch else None
        # Clay specific position/time info if needed by Clay part of the model
        norm_latlon = batch.get('norm_latlon').to(device) if feature_flags['use_clay'] and 'norm_latlon' in batch else None 
        norm_time = batch.get('norm_time').to(device) if feature_flags['use_clay'] and 'norm_time' in batch else None

        optimizer.zero_grad()
        try:
            # --- Forward Pass - Use BranchedUHIModel signature ---
            # Pass all potential inputs; the model internally uses what it needs based on its config
            prediction_norm = model(
                weather_seq=weather_seq,
                dsm=dsm,
                dem=dem,
                clay_patch=clay_input, # Pass the patch data
                norm_latlon=norm_latlon, # Pass positional info for Clay
                norm_time=norm_time,     # Pass time info for Clay
                sentinel_composite=sentinel_composite,
                lst=lst,
                ndvi=ndvi,
                ndbi=ndbi,
                ndwi=ndwi
                # Add other features as model arguments if needed
            ) # Shape (B, 1, H_target, W_target)

            # --- Loss Calculation ---
            prediction_norm_final = prediction_norm.squeeze(1) # Shape (B, H_target, W_target)

            # Normalize Target for Loss Calculation
            target_norm = (target_unnorm - uhi_mean) / uhi_std

            # Calculate Loss (using normalized values)
            loss = loss_fn(prediction_norm_final, target_norm, mask)

            if torch.isnan(loss):
                logging.warning("NaN loss detected, skipping backward pass.")
                continue

            loss.backward()
            # Optional: Gradient clipping
            # torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
            optimizer.step()

            # --- Accumulate for Metrics (using unnormalized values) ---
            prediction_unnorm = prediction_norm_final * uhi_std + uhi_mean
            all_preds_unnorm.append(prediction_unnorm[mask.bool()].detach().cpu())
            all_targets_unnorm.append(target_unnorm[mask.bool()].detach().cpu())

            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix(loss=loss.item())

        except RuntimeError as e:
            logging.error(f"Runtime error during training: {e}")
            if "out of memory" in str(e):
                logging.error("CUDA out of memory. Try reducing n_train_batches (increases batch size).")
            continue # Skip batch
        except Exception as e:
            logging.error(f"Unexpected error during training step: {e}", exc_info=True)
            continue # Skip batch

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0

    # Calculate metrics on unnormalized, masked data for the whole epoch
    if all_targets_unnorm:
        all_targets_unnorm_flat = torch.cat(all_targets_unnorm).numpy()
        all_preds_unnorm_flat = torch.cat(all_preds_unnorm).numpy()
        valid_idx = ~np.isnan(all_targets_unnorm_flat) & ~np.isnan(all_preds_unnorm_flat)
        if np.sum(valid_idx) > 0: # Check if there are any valid pixels
            rmse = math.sqrt(mean_squared_error(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx]))
            r2 = r2_score(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx])
        else:
             logging.warning("No valid pixels found for calculating metrics in training epoch.")
             rmse = float('nan')
             r2 = float('nan')
    else:
        logging.warning("No targets accumulated for calculating metrics in training epoch.")
        rmse = float('nan')
        r2 = float('nan')

    return avg_loss, rmse, r2

def validate_epoch(model, dataloader, loss_fn, device, uhi_mean, uhi_std, feature_flags):
    """Evaluates the BranchedUHIModel on the validation set."""
    model.eval()
    total_loss = 0.0
    all_targets_unnorm = []
    all_preds_unnorm = []
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Validation', leave=False)

    with torch.no_grad():
        for batch in progress_bar:
            # --- Unpack batch (same as train_epoch) ---
            weather_seq = batch.get('weather_seq').to(device)
            target_unnorm = batch.get('target').to(device)
            mask = batch.get('mask').to(device)
            dsm = batch.get('dsm').to(device) if feature_flags['use_dsm'] and 'dsm' in batch else None
            dem = batch.get('dem').to(device) if feature_flags['use_dem'] and 'dem' in batch else None
            clay_input = batch.get('clay_patch').to(device) if feature_flags['use_clay'] and 'clay_patch' in batch else None
            sentinel_composite = batch.get('sentinel_composite').to(device) if feature_flags['use_sentinel_composite'] and 'sentinel_composite' in batch else None
            lst = batch.get('lst').to(device) if feature_flags['use_lst'] and 'lst' in batch else None
            ndvi = batch.get('ndvi').to(device) if feature_flags['use_ndvi'] and 'ndvi' in batch else None
            ndbi = batch.get('ndbi').to(device) if feature_flags['use_ndbi'] and 'ndbi' in batch else None
            ndwi = batch.get('ndwi').to(device) if feature_flags['use_ndwi'] and 'ndwi' in batch else None
            norm_latlon = batch.get('norm_latlon').to(device) if feature_flags['use_clay'] and 'norm_latlon' in batch else None 
            norm_time = batch.get('norm_time').to(device) if feature_flags['use_clay'] and 'norm_time' in batch else None

            try:
                # --- Forward Pass ---
                prediction_norm = model(
                    weather_seq=weather_seq,
                    dsm=dsm,
                    dem=dem,
                    clay_patch=clay_input,
                    norm_latlon=norm_latlon,
                    norm_time=norm_time,
                    sentinel_composite=sentinel_composite,
                    lst=lst,
                    ndvi=ndvi,
                    ndbi=ndbi,
                    ndwi=ndwi
                )

                # --- Loss Calculation ---
                prediction_norm_final = prediction_norm.squeeze(1)
                target_norm = (target_unnorm - uhi_mean) / uhi_std
                loss = loss_fn(prediction_norm_final, target_norm, mask)

                if torch.isnan(loss):
                    logging.warning("NaN validation loss detected, skipping batch.")
                    continue

                # --- Accumulate for Metrics ---
                prediction_unnorm = prediction_norm_final * uhi_std + uhi_mean
                all_preds_unnorm.append(prediction_unnorm[mask.bool()].detach().cpu())
                all_targets_unnorm.append(target_unnorm[mask.bool()].detach().cpu())

                total_loss += loss.item()
                num_batches += 1
                progress_bar.set_postfix(loss=loss.item())

            except Exception as e:
                 logging.error(f"Error during validation step: {e}", exc_info=True)
                 continue # Skip batch on errors

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0

    # Calculate metrics on unnormalized, masked data
    if all_targets_unnorm:
        all_targets_unnorm_flat = torch.cat(all_targets_unnorm).numpy()
        all_preds_unnorm_flat = torch.cat(all_preds_unnorm).numpy()
        valid_idx = ~np.isnan(all_targets_unnorm_flat) & ~np.isnan(all_preds_unnorm_flat)
        if np.sum(valid_idx) > 0:
             rmse = math.sqrt(mean_squared_error(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx]))
             r2 = r2_score(all_targets_unnorm_flat[valid_idx], all_preds_unnorm_flat[valid_idx])
        else:
             logging.warning("No valid pixels found for calculating metrics in validation epoch.")
             rmse = float('nan')
             r2 = float('nan')
    else:
        logging.warning("No targets accumulated for calculating metrics in validation epoch.")
        rmse = float('nan')
        r2 = float('nan')

    return avg_loss, rmse, r2



# %% Training Loop Execution (ConvLSTM Model)

# Ensure model and dataloaders are initialized from previous cells
if 'model' not in locals() or model is None:
    raise NameError("Model is not initialized. Run the model initialization cell first.")
if 'train_loader' not in locals():
    raise NameError("Train loader is not initialized. Run the DataLoader setup cell first.")
if 'val_loader' not in locals():
    val_loader = None
    print("Validation loader not found, proceeding without validation.")

print(f"Model {config['model_type']} initialized on {device}") # Updated model type in log

# --- Optimizer and Loss (Unchanged) ---
best_val_r2 = -float('inf')
epochs_no_improve = 0
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
loss_fn = masked_mae_loss if config["loss_type"] == "mae" else masked_mse_loss

# --- Output Directory & Run Name (Unchanged) ---
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f"{config['wander_run_name_prefix']}_{run_timestamp}" # Prefix updated in config cell
output_dir = Path(output_dir_base) / run_name
output_dir.mkdir(parents=True, exist_ok=True)
config["output_dir"] = str(output_dir)
print(f"Checkpoints and logs will be saved to: {output_dir}")

# --- Retrieve UHI Stats (Unchanged) ---
uhi_mean = config.get('uhi_mean')
uhi_std = config.get('uhi_std')
if uhi_mean is None or uhi_std is None:
    raise ValueError("uhi_mean and uhi_std not found in config. Ensure they were calculated.")
print(f"Using Training UHI Mean: {uhi_mean:.4f}, Std Dev: {uhi_std:.4f} for normalization.")

# --- Feature Flags for Training/Validation functions (Unchanged) ---
feature_flags_from_config = config['feature_flags']

# --- Initialize WANDB (Unchanged) ---
if 'wandb' in sys.modules:
    try:
        if wandb.run is not None: print("Finishing previous W&B run..."); wandb.finish()
        wandb.init( project=config["wandb_project_name"], name=run_name, config=config )
        print(f"Wandb initialized for run: {run_name}")
    except Exception as e: print(f"Wandb initialization failed: {e}"); wandb = None
else: print("Wandb not imported, skipping W&B logging."); wandb = None

# Save configuration locally (Unchanged)
try:
    config_serializable = {k: str(v) if isinstance(v, Path) else v for k, v in config.items()}
    with open(output_dir / "config.json", 'w') as f:
        json.dump(config_serializable, f, indent=2, default=lambda x: str(x) if isinstance(x, Path) else x)
    print("Saved local configuration to config.json")
except Exception as e: print(f"Warning: Failed to save local configuration: {e}")

# --- Training Loop --- #
print(f"Starting {config['model_type']} training...") # Updated model type in log
training_log = []
last_saved_epoch = -1

for epoch in range(config["epochs"]):
    print(f"--- Epoch {epoch+1}/{config['epochs']} ---")

    # --- Train Epoch (Call is unchanged, function handles logic) ---
    if train_loader:
        train_loss, train_rmse, train_r2 = train_epoch(model, train_loader, optimizer, loss_fn, device,
                                                       uhi_mean, uhi_std, feature_flags_from_config)
    else:
        print("Skipping training epoch as train_loader is not available.")
        train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')

    log_metrics = {
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_rmse": train_rmse,
        "train_r2": train_r2
    }

    # --- Validation and Checkpointing (Logic unchanged) ---
    is_best = False
    if val_loader:
        val_loss, val_rmse, val_r2 = validate_epoch(model, val_loader, loss_fn, device,
                                                    uhi_mean, uhi_std, feature_flags_from_config)
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} RMSE={train_rmse:.4f} R2={train_r2:.4f} | Val Loss={val_loss:.4f} RMSE={val_rmse:.4f} R2={val_r2:.4f}")
        log_metrics.update({
            "val_loss": val_loss,
            "val_rmse": val_rmse,
            "val_r2": val_r2
        })

        if np.isnan(val_r2):
             print("Warning: Validation R^2 is NaN. Stopping training.")
             break

        is_best = val_r2 > best_val_r2
        if is_best:
            best_val_r2 = val_r2
            epochs_no_improve = 0
            print(f"New best validation R^2: {best_val_r2:.4f}")
            if wandb and wandb.run: wandb.run.summary["best_val_r2"] = best_val_r2
        else:
            epochs_no_improve += 1
            print(f"No improvement in validation R^2 for {epochs_no_improve} epochs.")

        if epochs_no_improve >= config["patience"]:
            print(f"Early stopping triggered after {config['patience']} epochs with no improvement.")
            break
    else:
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} RMSE={train_rmse:.4f} R2={train_r2:.4f} (No validation)")
        if np.isnan(train_loss):
             print("Warning: Training loss is NaN. Stopping training.")
             break

    # --- Log Metrics (Unchanged) ---
    if wandb: wandb.log(log_metrics)
    training_log.append(log_metrics)

    # --- Save Checkpoint (Unchanged) ---
    save_checkpoint(
        {'epoch': epoch + 1,
         'state_dict': model.state_dict(),
         'best_val_r2': best_val_r2,
         'optimizer' : optimizer.state_dict(),
         'config': config_serializable
         },
        is_best,
        filename=output_dir / 'checkpoint_last.pth.tar',
        best_filename=output_dir / 'model_best.pth.tar'
    )
    if is_best: last_saved_epoch = epoch + 1

# --- Final Steps (Unchanged) ---
try:
    log_df = pd.DataFrame(training_log)
    log_df.to_csv(output_dir / 'training_log.csv', index=False)
    print(f"Saved local training log to {output_dir / 'training_log.csv'}")
except Exception as e: print(f"Warning: Failed to save local training log: {e}")

print("Training finished.")
if val_loader: print(f"Best validation R^2 recorded: {best_val_r2:.4f} (epoch {last_saved_epoch if last_saved_epoch > 0 else 'N/A'})" if not np.isinf(best_val_r2) else "Best validation R^2: N/A")
print(f"Final checkpoint saved in: {output_dir / 'checkpoint_last.pth.tar'}")
if last_saved_epoch > 0: print(f"Best model saved in: {output_dir / 'model_best.pth.tar'}")
if wandb and wandb.run: wandb.finish(); print("Wandb run finished.")


In [None]:
# %% Training Loop Execution

# Ensure model and dataloaders are initialized from previous cells
if 'model' not in locals() or model is None:
    raise NameError("Model is not initialized. Run the model initialization cell first.")
if 'train_loader' not in locals():
    raise NameError("Train loader is not initialized. Run the DataLoader setup cell first.")
# val_loader can be None if val_percent was 0 or dataset too small
if 'val_loader' not in locals():
    val_loader = None 
    print("Validation loader not found, proceeding without validation.")

print(f"Model {config['model_type']} initialized on {device}")

# --- Optimizer and Loss ---
best_val_r2 = -float('inf') # Initialize best R^2
epochs_no_improve = 0
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
loss_fn = masked_mae_loss if config["loss_type"] == "mae" else masked_mse_loss

# --- Output Directory & Run Name ---
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f"{config['wander_run_name_prefix']}_{run_timestamp}"
output_dir = Path(output_dir_base) / run_name
output_dir.mkdir(parents=True, exist_ok=True)
config["output_dir"] = str(output_dir) # Update config with actual output dir
print(f"Checkpoints and logs will be saved to: {output_dir}")

# --- Retrieve UHI Stats ---
uhi_mean = config.get('uhi_mean')
uhi_std = config.get('uhi_std')
if uhi_mean is None or uhi_std is None:
    raise ValueError("uhi_mean and uhi_std not found in config. Ensure they were calculated.")
print(f"Using Training UHI Mean: {uhi_mean:.4f}, Std Dev: {uhi_std:.4f} for normalization.")

# --- Feature Flags for Training/Validation functions ---
feature_flags_from_config = config['feature_flags']

# --- Initialize WANDB ---
if 'wandb' in sys.modules:
    try:
        if wandb.run is not None:
            print("Finishing previous W&B run...")
            wandb.finish()
        wandb.init(
            project=config["wandb_project_name"],
            name=run_name,
            config=config # Log the entire config dictionary
        )
        print(f"Wandb initialized for run: {run_name}")
    except Exception as e:
        print(f"Wandb initialization failed: {e}")
        wandb = None
else:
    print("Wandb not imported, skipping W&B logging.")
    wandb = None

# Save configuration used for this run locally
try:
    # Use the same serialization helper as before
    config_serializable = {k: str(v) if isinstance(v, Path) else v for k, v in config.items()}
    with open(output_dir / "config.json", 'w') as f:
        json.dump(config_serializable, f, indent=2, default=lambda x: str(x) if isinstance(x, Path) else x)
    print("Saved local configuration to config.json")
except Exception as e:
    print(f"Warning: Failed to save local configuration: {e}")

# --- Training Loop ---
print(f"Starting {config['model_type']} training...")
training_log = []
last_saved_epoch = -1

for epoch in range(config["epochs"]):
    print(f"--- Epoch {epoch+1}/{config['epochs']} ---")

    if train_loader:
        train_loss, train_rmse, train_r2 = train_epoch(model, train_loader, optimizer, loss_fn, device,
                                                       uhi_mean, uhi_std, feature_flags_from_config)
    else:
        print("Skipping training epoch as train_loader is not available.")
        train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')

    log_metrics = {
        "epoch": epoch + 1,
        "train_loss": train_loss,
        "train_rmse": train_rmse,
        "train_r2": train_r2
    }

    is_best = False
    if val_loader:
        val_loss, val_rmse, val_r2 = validate_epoch(model, val_loader, loss_fn, device,
                                                    uhi_mean, uhi_std, feature_flags_from_config)
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} RMSE={train_rmse:.4f} R2={train_r2:.4f} | Val Loss={val_loss:.4f} RMSE={val_rmse:.4f} R2={val_r2:.4f}")
        log_metrics.update({
            "val_loss": val_loss,
            "val_rmse": val_rmse,
            "val_r2": val_r2
        })

        if np.isnan(val_r2):
             print("Warning: Validation R^2 is NaN. Cannot determine improvement. Stopping training.")
             break

        is_best = val_r2 > best_val_r2
        if is_best:
            best_val_r2 = val_r2
            epochs_no_improve = 0
            print(f"New best validation R^2: {best_val_r2:.4f}")
            if wandb and wandb.run:
                wandb.run.summary["best_val_r2"] = best_val_r2
        else:
            epochs_no_improve += 1
            print(f"No improvement in validation R^2 for {epochs_no_improve} epochs.")

        if epochs_no_improve >= config["patience"]:
            print(f"Early stopping triggered after {config['patience']} epochs with no improvement.")
            break
    else: # No validation
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} RMSE={train_rmse:.4f} R2={train_r2:.4f} (No validation)")
        if np.isnan(train_loss):
             print("Warning: Training loss is NaN. Stopping training.")
             break

    if wandb:
        wandb.log(log_metrics)
    training_log.append(log_metrics)

    save_checkpoint(
        {'epoch': epoch + 1,
         'state_dict': model.state_dict(),
         'best_val_r2': best_val_r2,
         'optimizer' : optimizer.state_dict(),
         'config': config_serializable
         },
        is_best,
        filename=output_dir / 'checkpoint_last.pth.tar',
        best_filename=output_dir / 'model_best.pth.tar'
    )
    if is_best:
        last_saved_epoch = epoch + 1

# --- Final Steps ---
try:
    log_df = pd.DataFrame(training_log)
    log_df.to_csv(output_dir / 'training_log.csv', index=False)
    print(f"Saved local training log to {output_dir / 'training_log.csv'}")
except Exception as e:
    print(f"Warning: Failed to save local training log: {e}")

print("Training finished.")
if val_loader:
    print(f"Best validation R^2 recorded: {best_val_r2:.4f} (achieved at epoch {last_saved_epoch if last_saved_epoch > 0 else 'N/A'})" if not np.isinf(best_val_r2) else "Best validation R^2: N/A")
print(f"Final checkpoint saved in: {output_dir / 'checkpoint_last.pth.tar'}")
if last_saved_epoch > 0:
     print(f"Best model saved in: {output_dir / 'model_best.pth.tar'}")

if wandb and wandb.run:
    wandb.finish()
    print("Wandb run finished.")

