# Branched UHI Model Training

In [1]:
# Your code here 

In [2]:
# %% Setup & Imports

import sys
import os
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset
from tqdm import tqdm
import math
import shutil # For checkpoint saving

os.environ["WANDB_NOTEBOOK_NAME"]="Train_UHI_batched" 

# Add the project root to the Python path
project_root = Path(os.getcwd()).parent # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

# --- Import Model Components ---
from src.branched_uhi_model import BranchedUHIModel # Import the branched model
from src.ingest.dataloader_branched import CityDataSetBranched # Import the corresponding dataloader

# --- Import Training Utilities & Loss ---
from src.train.loss import masked_mse_loss, masked_mae_loss # Import loss functions
import src.train.train_utils as train_utils # Import the new utility module

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Optionally import wandb if needed
try:
    import wandb
except ImportError:
    print("wandb not installed, skipping W&B logging.")
    wandb = None

# Import necessary metrics
from sklearn.metrics import mean_squared_error, r2_score

In [3]:
# %% Configuration / Hyperparameters for BranchedUHIModel (ConvLSTM + Common Resampling)

# --- Import utils ---
from src.train.train_utils import check_path
# -------------------

# --- Paths & Basic Info ---
project_root_str = str(project_root) # Store as string for config
data_dir_base = project_root / "data"
city_name = "NYC"
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_BranchedUHI_CommonRes" # Modified prefix

# --- Data Loading Config ---
# NEW: Define the common resolution for spatial features entering the model
feature_resolution_m = 30 # Start with 10m (matches Clay/UHI grid)

# Define UHI grid resolution separately (used for final target matching)
# This assumes UHI data corresponds to 10m grid
uhi_grid_resolution_m = 10

weather_seq_length = 60

# Input Data Paths (relative)
relative_data_dir = Path("data")
relative_uhi_csv = relative_data_dir / city_name / "uhi.csv"
relative_bronx_weather_csv = relative_data_dir / city_name / "bronx_weather.csv"
relative_manhattan_weather_csv = relative_data_dir / city_name / "manhattan_weather.csv"
relative_dem_path = relative_data_dir / city_name / "sat_files" / "nyc_dem_1m_pc.tif"
relative_dsm_path = relative_data_dir / city_name / "sat_files" / "nyc_dsm_1m_pc.tif"
relative_cloudless_mosaic_path = relative_data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy"
relative_single_lst_median_path = relative_data_dir / city_name / "sat_files" / f"lst_{city_name}_median_20210601_to_20210901.npy"

# Nodata values
elevation_nodata = -9999.0
lst_nodata = 0.0 # Or appropriate value for LST median file

# --- Feature Selection Flags ---
# IMPORTANT: Only enable the features you want to use
# The Model will expect exactly these features
feature_flags = {
    "use_dem": True,              # Digital Elevation Model
    "use_dsm": True,              # Digital Surface Model
    "use_clay": True,             # Clay feature extractor
    "use_sentinel_composite": False, # Raw Sentinel-2 bands
    "use_lst": False,             # Land Surface Temperature
    "use_ndvi": False,            # Normalized Difference Vegetation Index
    "use_ndbi": False,            # Normalized Difference Built-up Index
    "use_ndwi": False,            # Normalized Difference Water Index
}

# --- Bands for Sentinel Composite (if use_sentinel_composite is True) --- #
sentinel_bands_to_load = ["blue", "green", "red", "nir", "swir16", "swir22"]

# --- Model Config (BranchedUHIModel with ConvLSTM, No separate Elev branches) ---

# Clay Backbone (if feature_flags["use_clay"] is True)
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = True
relative_clay_checkpoint_path = "notebooks/clay-v1.5.ckpt"
relative_clay_metadata_path = Path("src") / "Clay" / "configs" / "metadata.yaml"

# Temporal Weather Processor (ConvLSTM)
weather_input_channels = 6
convlstm_hidden_dims = [16, 8]
convlstm_kernel_sizes = [(3,3), (3,3)]
convlstm_num_layers = len(convlstm_hidden_dims)

# --- REMOVED High-Res Elevation Branch Config ---

# U-Net Head
unet_base_channels = 16
unet_depth = 2 # <<< REDUCED from 3

# Projection Layer Channels
proj_static_ch = 8 # For projecting ALL static feats (Clay, LST, DEM, DSM, Indices)
proj_temporal_ch = 8 # For projecting ConvLSTM output

# --- Training Hyperparameters ---
num_workers = 1
epochs = 500
lr = 5e-5
weight_decay = 0.01
loss_type = 'mse'
patience = 50
cpu = False

# V100 Suggestion: Target batch size 2 => n_train_batches = 47 / 1 = 47
n_train_batches = 47

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- Resolve Paths using check_path from train_utils ---
absolute_uhi_csv = check_path(relative_uhi_csv, project_root, "UHI CSV")
absolute_bronx_weather_csv = check_path(relative_bronx_weather_csv, project_root, "Bronx Weather CSV")
absolute_manhattan_weather_csv = check_path(relative_manhattan_weather_csv, project_root, "Manhattan Weather CSV")

# Always check paths needed by dataloader, flags control usage inside
absolute_dem_path = check_path(relative_dem_path, project_root, "DEM TIF")
absolute_dsm_path = check_path(relative_dsm_path, project_root, "DSM TIF")
absolute_clay_checkpoint_path = check_path(relative_clay_checkpoint_path, project_root, "Clay Checkpoint")
absolute_clay_metadata_path = check_path(relative_clay_metadata_path, project_root, "Clay Metadata")
absolute_cloudless_mosaic_path = check_path(relative_cloudless_mosaic_path, project_root, "Cloudless Mosaic")
absolute_single_lst_median_path = check_path(relative_single_lst_median_path, project_root, "Single LST Median", should_exist=feature_flags["use_lst"]) # Check only if flag is true

# --- Calculate Bounds ---
uhi_df = pd.read_csv(absolute_uhi_csv)
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds = [
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {absolute_uhi_csv.name}: {bounds}")

# --- Central Config Dictionary --- #
config = {
    # Paths & Info
    "model_type": "BranchedUHIModel_CommonRes",
    "project_root": project_root_str,
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "feature_resolution_m": feature_resolution_m, # NEW
    "uhi_grid_resolution_m": uhi_grid_resolution_m, # For reference
    "weather_seq_length": weather_seq_length,
    "uhi_csv": str(relative_uhi_csv),
    "bronx_weather_csv": str(absolute_bronx_weather_csv),
    "manhattan_weather_csv": str(absolute_manhattan_weather_csv),
    "bounds": bounds,
    "feature_flags": feature_flags,
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path": str(absolute_dem_path) if absolute_dem_path else None,
    "dsm_path": str(absolute_dsm_path) if absolute_dsm_path else None,
    "elevation_nodata": elevation_nodata,
    "cloudless_mosaic_path": str(absolute_cloudless_mosaic_path) if absolute_cloudless_mosaic_path else None,
    "single_lst_median_path": str(absolute_single_lst_median_path) if absolute_single_lst_median_path else None,
    "lst_nodata": lst_nodata,
    # Model Config
    "weather_input_channels": weather_input_channels,
    "convlstm_hidden_dims": convlstm_hidden_dims,
    "convlstm_kernel_sizes": convlstm_kernel_sizes,
    "convlstm_num_layers": convlstm_num_layers,
    "proj_static_ch": proj_static_ch,
    "proj_temporal_ch": proj_temporal_ch,
    "unet_base_channels": unet_base_channels,
    "unet_depth": unet_depth, # Updated
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(absolute_clay_checkpoint_path) if feature_flags["use_clay"] else None,
    "clay_metadata_path": str(absolute_clay_metadata_path) if feature_flags["use_clay"] else None,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience,
    "device": str(device)
}

print("\nBranched Model (Common Res) Configuration dictionary created:")
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x))


Using device: cuda
Loaded bounds from uhi.csv: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]

Branched Model (Common Res) Configuration dictionary created:
{
  "model_type": "BranchedUHIModel_CommonRes",
  "project_root": "/home/jupyter/MLC-Project",
  "city_name": "NYC",
  "wandb_project_name": "MLC_UHI_Proj",
  "wander_run_name_prefix": "NYC_BranchedUHI_CommonRes",
  "feature_resolution_m": 30,
  "uhi_grid_resolution_m": 10,
  "weather_seq_length": 60,
  "uhi_csv": "data/NYC/uhi.csv",
  "bronx_weather_csv": "/home/jupyter/MLC-Project/data/NYC/bronx_weather.csv",
  "manhattan_weather_csv": "/home/jupyter/MLC-Project/data/NYC/manhattan_weather.csv",
  "bounds": [
    -73.99445667,
    40.75879167,
    -73.87945833,
    40.85949667
  ],
  "feature_flags": {
    "use_dem": true,
    "use_dsm": true,
    "use_clay": true,
    "use_sentinel_composite": false,
    "use_lst": false,
    "use_ndvi": false,
    "use_ndbi": false,
    "us

In [None]:
# %% Data Loading and Preprocessing (Branched Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import (
    calculate_uhi_stats, # Removed split_data
    create_dataloaders
)
from torch.utils.data import Subset # Import Subset
# -------------------

print("Initializing BranchedCityDataSet...")
try:
    dataset = CityDataSetBranched(
        bounds=config["bounds"],
        feature_resolution_m=config["feature_resolution_m"], # Corrected param name
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"], # Corrected param name
        uhi_csv=absolute_uhi_csv, # Use absolute path resolved earlier
        bronx_weather_csv=absolute_bronx_weather_csv,
        manhattan_weather_csv=absolute_manhattan_weather_csv,
        data_dir=project_root_str,
        city_name=config["city_name"],
        feature_flags=config["feature_flags"],
        sentinel_bands_to_load=config["sentinel_bands_to_load"],
        dem_path=config["dem_path"], # Corrected param name
        dsm_path=config["dsm_path"], # Corrected param name
        elevation_nodata=config["elevation_nodata"], # Corrected param name
        cloudless_mosaic_path=config["cloudless_mosaic_path"],
        single_lst_median_path=config["single_lst_median_path"],
        lst_nodata=config["lst_nodata"], # Added missing param
        weather_seq_length=config["weather_seq_length"],
        target_crs_str=config.get("target_crs_str", "EPSG:4326") # Added optional param
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Sequential Train/Val Split --- #
val_percent = 0.20 # Keep the percentage definition
num_samples = len(dataset)
if num_samples < 2: # Need at least one for train and one for val
    raise ValueError(f"Dataset has only {num_samples} samples, cannot perform train/val split.")

n_train = int(num_samples * (1 - val_percent))
n_val = num_samples - n_train

if n_train == 0 or n_val == 0:
    raise ValueError(f"Split resulted in zero samples for train ({n_train}) or validation ({n_val}). Adjust val_percent or check dataset size.")

train_indices = list(range(n_train))
val_indices = list(range(n_train, num_samples))

train_ds = Subset(dataset, train_indices)
val_ds = Subset(dataset, val_indices)

print(f"Sequential dataset split: {len(train_ds)} training (indices 0-{n_train-1}), {len(val_ds)} validation (indices {n_train}-{num_samples-1}) samples.")

# --- Calculate UHI Mean and Std from Training Data ONLY --- #
uhi_mean, uhi_std = calculate_uhi_stats(train_ds)
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
train_loader, val_loader = create_dataloaders(
    train_ds,
    val_ds,
    n_train_batches=config['n_train_batches'],
    num_workers=config['num_workers'],
    device=device # Pass device from config cell
)


2025-05-06 01:27:59,987 - INFO - Target FEATURE grid size (H, W): (373, 323) @ 30m, CRS: EPSG:4326
2025-05-06 01:27:59,988 - INFO - Target UHI grid size (H, W): (1118, 969) @ 10m


Initializing BranchedCityDataSet...


Precomputing UHI grids: 100%|██████████| 59/59 [00:00<00:00, 369.89it/s]
2025-05-06 01:28:00,189 - INFO - Loading DEM from: /home/jupyter/MLC-Project/data/NYC/sat_files/nyc_dem_1m_pc.tif
2025-05-06 01:28:03,616 - INFO - Clipping DEM to bounds: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]
2025-05-06 01:28:03,617 - INFO - Opened DEM (lazy load). Native shape (approx): (1, 10071, 11501)
2025-05-06 01:28:03,618 - INFO - Loading DSM from: /home/jupyter/MLC-Project/data/NYC/sat_files/nyc_dsm_1m_pc.tif


In [None]:
# %% Model Initialization (Branched Model + Common Resampling)

# Instantiate the BranchedUHIModel
print(f"Initializing {config['model_type']}...")
model = BranchedUHIModel(
    # --- Weather Branch Config --- #
    weather_input_channels=config["weather_input_channels"],
    convlstm_hidden_dims=config["convlstm_hidden_dims"],
    convlstm_kernel_sizes=config["convlstm_kernel_sizes"],
    convlstm_num_layers=config["convlstm_num_layers"],
    # --- Static Feature Config --- #
    feature_flags=config["feature_flags"],
    sentinel_bands_to_load=config.get("sentinel_bands_to_load"), # Pass if needed
    # Clay Specific
    clay_model_size=config.get("clay_model_size"),
    clay_bands=config.get("clay_bands"),
    clay_platform=config.get("clay_platform"),
    clay_gsd=config.get("clay_gsd"),
    freeze_backbone=config.get("freeze_backbone", True),
    clay_checkpoint_path=config.get("clay_checkpoint_path"),
    clay_metadata_path=config.get("clay_metadata_path"),
    # --- Head Config --- #
    proj_static_ch=config["proj_static_ch"],
    proj_temporal_ch=config["proj_temporal_ch"],
    unet_base_channels=config["unet_base_channels"],
    unet_depth=config["unet_depth"],
    # --- Target Grid Info (NEW) --- #
    uhi_grid_resolution_m=config["uhi_grid_resolution_m"],
    bounds=config["bounds"]
)

model.to(config["device"])

# --- Optimizer --- #
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])

# --- Loss Function --- #
if config["loss_type"] == 'mse':
    loss_fn = masked_mse_loss
elif config["loss_type"] == 'mae':
    loss_fn = masked_mae_loss
else:
    raise ValueError(f"Unsupported loss type: {config['loss_type']}")

# --- LR Scheduler --- #
# Add ReduceLROnPlateau scheduler
from torch.optim.lr_scheduler import ReduceLROnPlateau
scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=True)
print("Initialized ReduceLROnPlateau scheduler.")

print("Model, optimizer, loss function, and scheduler initialized.")
# print(model) # Optional: Print model summary


## Configuration

Set up paths and hyperparameters.

## Setup DataLoader

## Training

In [None]:
# %% Training Loop (Generic - Use for both CNN and Branched)

# --- Imports ---
import time
from datetime import datetime
import json
from pathlib import Path
# from torch.cuda.amp import GradScaler # Removed GradScaler import
import src.train.train_utils as train_utils # Import the full module
import numpy as np # Added for isnan check
import pandas as pd # For saving log

# --- Setup ---
print(f"Model {config['model_type']} initialized on {device}")

# --- Optimizer and Loss (should be initialized in model setup cell) ---
# Ensure optimizer and loss_fn are accessible from the previous cell's scope
if 'optimizer' not in locals() or 'loss_fn' not in locals():
    raise NameError("Optimizer or loss_fn not defined. Run the model initialization cell.")
# Ensure scheduler is accessible
if 'scheduler' not in locals():
    raise NameError("Scheduler not defined. Run the model initialization cell.")

# --- AMP GradScaler --- #
# scaler = GradScaler() # Removed scaler initialization

# --- Tracking Variables --- #
best_val_loss = float('inf') # Using validation loss for checkpointing
epochs_no_improve = 0
last_saved_epoch = -1 # Track last saved epoch

# --- Output Directory & Run Name (should be set in config cell) --- #
# Ensure output_dir is a Path object if loaded from config str
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name = f"{config.get('wander_run_name_prefix', 'train')}_{run_timestamp}"
output_dir_base = Path(config.get('project_root', '.')) / "training_runs"
output_dir = output_dir_base / run_name
output_dir = Path(output_dir)
config["output_dir"] = output_dir
print(f"Checkpoints and logs will be saved to: {output_dir}")

# --- Retrieve UHI Stats from Config --- #
uhi_mean = config.get('uhi_mean')
uhi_std = config.get('uhi_std')
if uhi_mean is None or uhi_std is None:
    raise ValueError("uhi_mean/uhi_std not in config. Run data loading cell.")
print(f"Using Training UHI Mean: {uhi_mean:.4f}, Std Dev: {uhi_std:.4f}")

# --- WANDB Init --- #
if wandb:
    try:
        if wandb.run is not None: wandb.finish() # Finish previous run if any
        wandb.init(
            project=config["wandb_project_name"],
            name=run_name,
            config=config
        )
        print(f"Wandb initialized for run: {run_name}")
    except Exception as e:
        print(f"Wandb initialization failed: {e}")
        wandb = None
else:
    print("Wandb not available, skipping logging.")

# --- Training Loop --- # 
print(f"Starting {config['model_type']} training...")
training_start_time = time.time()
training_log = [] # Local log
epoch = -1 # Initialize epoch counter before the loop
start_checkpointing_epoch = 50 # <<< NEW: Epoch to start saving best model and early stopping

try:
    # Use epoch range from config
    for epoch in range(config["epochs"]):
        epoch_start_time = time.time()
        print(f"--- Epoch {epoch+1}/{config['epochs']} ---")

        # --- Train --- #
        if train_loader:
            # Use generic train function from train_utils (without scaler)
            train_loss, train_rmse, train_r2 = train_utils.train_epoch_generic(
                model, train_loader, optimizer, loss_fn, device, uhi_mean, uhi_std # Removed scaler
            )
            print(f"Train Loss: {train_loss:.4f}, Train RMSE: {train_rmse:.4f}, Train R2: {train_r2:.4f}")
            if np.isnan(train_loss):
                print("Warning: Training loss is NaN. Stopping training.")
                break
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        else:
            print("Skipping training: train_loader is None.")
            train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        
        # Log train metrics AFTER checking for NaN
        if wandb:
            wandb.log(log_metrics)
        training_log.append(log_metrics) # Append to local log regardless of W&B


        # --- Validate --- #
        if val_loader:
            # Use generic validate function from train_utils
            val_loss, val_rmse, val_r2 = train_utils.validate_epoch_generic(
                model, val_loader, loss_fn, device, uhi_mean, uhi_std
            )
            print(f"Val Loss:   {val_loss:.4f}, Val RMSE:   {val_rmse:.4f}, Val R2:   {val_r2:.4f}")
            if np.isnan(val_loss):
                print("Warning: Validation Loss is NaN. Stopping training.")
                break
            val_metrics = {"val_loss": val_loss, "val_rmse": val_rmse, "val_r2": val_r2}
            log_metrics.update(val_metrics) # Add val metrics for local log
            if wandb:
                wandb.log({"epoch": epoch + 1, **val_metrics}) # Log validation metrics too

            # --- Step the scheduler --- #
            scheduler.step(val_loss)

            # --- Checkpointing & Early Stopping (after start_checkpointing_epoch) --- #
            if epoch + 1 > start_checkpointing_epoch:
                is_best = val_loss < best_val_loss
                if is_best:
                    print(f"Validation Loss improved from {best_val_loss:.4f} to {val_loss:.4f}")
                    best_val_loss = val_loss
                    epochs_no_improve = 0
                    last_saved_epoch = epoch + 1
                    # Use save_checkpoint from train_utils
                    train_utils.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'best_val_loss': best_val_loss, # Save best loss
                        'config': config # Save full config dict
                    }, is_best=True, output_dir=output_dir)
                else:
                    epochs_no_improve += 1
                    print(f"Validation Loss did not improve ({val_loss:.4f}). Best: {best_val_loss:.4f}. No improvement for {epochs_no_improve} epochs.")

                if epochs_no_improve >= config['patience']:
                    print(f"Early stopping triggered after {epochs_no_improve} epochs without validation loss improvement.")
                    break
            else:
                print(f"Epoch {epoch+1} <= {start_checkpointing_epoch}, skipping best model check/save.")

        else:
            print("Skipping validation/checkpointing: val_loader is None.")
            # Always save last checkpoint if no validation
            train_utils.save_checkpoint({
                'epoch': epoch + 1,
                'state_dict': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'best_val_loss': best_val_loss, # Save current best loss seen
                'config': config
            }, is_best=False, output_dir=output_dir, filename='checkpoint_last.pth.tar')
            last_saved_epoch = epoch + 1 # Update tracker

        # --- Epoch Timing --- #
        epoch_duration = time.time() - epoch_start_time
        print(f"Epoch {epoch+1} duration: {epoch_duration:.2f} seconds")
        if wandb:
            wandb.log({"epoch": epoch + 1, "epoch_duration_sec": epoch_duration})

finally:
    # --- End Training Actions (Executed even if loop breaks early) --- #
    training_duration = time.time() - training_start_time
    print(f"\nTotal training time: {training_duration / 60:.2f} minutes")

    # --- Save Final Checkpoint --- #
    # Use the state from the last *completed* epoch (before potential break)
    final_epoch_num = epoch + 1 # This will be correct whether the loop finished or broke
    print(f"Saving final model state from epoch {final_epoch_num}...")
    try:
        train_utils.save_checkpoint({
            'epoch': final_epoch_num,
            'state_dict': model.state_dict(),
            'optimizer': optimizer.state_dict(),
            'best_val_loss': best_val_loss,
            'config': config
        }, is_best=False, output_dir=output_dir, filename='checkpoint_final.pth.tar')
        print(f"Final checkpoint saved to {output_dir / 'checkpoint_final.pth.tar'}")
    except Exception as e:
        print(f"Error saving final checkpoint: {e}")

    # --- Save Local Training Log --- #
    if training_log:
        try:
            log_df = pd.DataFrame(training_log)
            log_df.to_csv(output_dir / 'training_log.csv', index=False)
            print(f"Saved local training log to {output_dir / 'training_log.csv'}")
        except Exception as e:
            print(f"Warning: Failed to save local training log: {e}")
    else:
        print("No training log data to save.")


    if wandb and wandb.run: # Check if wandb run exists before logging/finishing
        wandb.log({"total_training_time_min": training_duration / 60})
        wandb.finish()
        print("W&B run finished.")

    print("Training loop finished.")
    if val_loader:
        print(f"Best validation Loss recorded: {best_val_loss:.4f} (achieved at epoch {last_saved_epoch if last_saved_epoch > 0 else 'N/A'})")
    if last_saved_epoch > 0:
         print(f"Best model checkpoint saved in: {output_dir / 'model_best.pth.tar'}")
