# Branched UHI Model Training

In [None]:
# %% Imports and Setup

# --- Standard Libraries ---
import os
import sys
from pathlib import Path
import json
import logging
import warnings
import time
from datetime import datetime

# --- Data Handling ---
import numpy as np
import pandas as pd
import xarray as xr
import rioxarray # For geospatial data handling with xarray

# --- PyTorch ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset # Added Subset


# --- Visualization & Progress ---
# Optional: Import matplotlib or other plotting libs if needed for checks
from tqdm.notebook import tqdm # Use notebook version if running interactively
import wandb

# --- Custom Modules ---
# Project root is the parent directory of the current working directory
project_root = Path(os.getcwd()).parent
# project_root = Path(__file__).resolve().parent.parent if "__file__" in globals() else Path(os.getcwd()).parent
src = project_root / "src"

# Add src directory to Python path
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))
    print(sys.path)

# Import custom modules
from src.ingest.dataloader_branched import CityDataSetBranched
from src.branched_uhi_model import BranchedUHIModel
from src.train.loss import masked_mse_loss, masked_mae_loss # Import loss functions
import src.train.train_utils as train_utils # Import the utility module

# --- Environment Setup ---
# Optional: Configure logging level
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Optional: Ignore specific warnings if needed
# warnings.filterwarnings('ignore', category=UserWarning, message='.*TypedStorage is deprecated.*')

print(f"Project Root: {project_root}")
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

['/opt/conda/lib/python310.zip', '/opt/conda/lib/python3.10', '/opt/conda/lib/python3.10/lib-dynload', '', '/opt/conda/lib/python3.10/site-packages', '/home/jupyter/MLC-Project']
Project Root: /home/jupyter/MLC-Project
PyTorch Version: 2.7.0+cu126
CUDA Available: True
CUDA Device Name: NVIDIA L4


In [2]:
# %% Configuration / Hyperparameters for BranchedUHIModel (ConvLSTM + Common Resampling)

# --- Import utils ---
from src.train.train_utils import check_path # For path validation
from src.ingest.data_utils import calculate_actual_weather_channels # For dynamic weather channels
# -------------------

# --- Paths & Basic Info ---
# project_root is defined in the first cell
project_root_str = str(project_root)
data_dir = project_root / "data"
city_name = "NYC" # Should be defined or loaded
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_BranchedUHI"

# --- Data Loading Config ---
feature_resolution_m = 50
uhi_grid_resolution_m = 50 # UHI target grid
temporal_seq_len = 60 

# --- Weather Feature Selection
enabled_weather_features = [
    "air_temp", 
    "rel_humidity", 
    "avg_windspeed", 
    "wind_direction",      
    "solar_flux"
]
# Calculate the number of actual weather channels that will be produced by the dataloader
actual_dataloader_weather_channels = calculate_actual_weather_channels(enabled_weather_features)
# ------------------------------------

# --- Define Absolute Input Data Paths Directly ---
uhi_csv_path = data_dir / city_name / "uhi.csv"
bronx_weather_csv_path = data_dir / city_name / "bronx_weather.csv"
manhattan_weather_csv_path = data_dir / city_name / "manhattan_weather.csv"

dem_path = data_dir / city_name / "sat_files" / f"{city_name.lower()}_dem_nasadem_native-resolution_pc.tif"
dsm_path = data_dir / city_name / "sat_files" / f"{city_name.lower()}_dsm_cop-dem-glo-30_native-resolution_pc.tif"
cloudless_mosaic_path = data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy" # Added .npy
# Assuming the LST filename structure from download_data.ipynb if it's used
lst_time_window_str_for_filename = "20210601_to_20210901" # Match download_data
single_lst_median_path = data_dir / city_name / "sat_files" / f"lst_{city_name}_median_{lst_time_window_str_for_filename}.npy" # Corrected and added .npy


# Nodata values
elevation_nodata = np.nan # Or np.nan
lst_nodata = np.nan # Or np.nan

# --- Feature Selection Flags ---
feature_flags = {
    "use_dem": False,
    "use_dsm": True,
    "use_clay": True,
    "use_sentinel_composite": False,
    "use_lst": False, # Set to True if LST is intended to be used
    "use_ndvi": False,
    "use_ndbi": False,
    "use_ndwi": False,
}

# --- Bands for Sentinel Composite (if use_sentinel_composite is True) ---
sentinel_bands_to_load = []

# --- Model Config (BranchedUHIModel with ConvLSTM, No separate Elev branches) ---
# Clay Backbone
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = True # Keep Clay backbone frozen for BranchedUHIModel typically
clay_checkpoint_path = project_root / "notebooks" / "clay-v1.5.ckpt"
clay_metadata_path = project_root / "src" / "Clay" / "configs" / "metadata.yaml"

# Temporal Weather Processor (ConvLSTM)
convlstm_hidden_dims = [32, 16] # Keeping original depth for now
convlstm_kernel_sizes = [(3,3), (3,3)]
convlstm_num_layers = len(convlstm_hidden_dims)

# Projection Layer Channels
clay_proj_channels = 16
proj_static_ch = 2 
proj_temporal_ch = 16
projection_dropout_rate = 0.1 # ADDED

# --- Head Configuration ---
head_type = "unet" # Options: "unet" or "simple_cnn"
# U-Net specific
unet_base_channels = 32
unet_depth = 3 # Keeping original depth for now
unet_dropout_rate = 0.3 # Dropout for U-Net blocks
# SimpleCNN specific
simple_cnn_hidden_dims = [32, 16] # Example, adjust as needed
simple_cnn_kernel_size = 3
simple_cnn_dropout_rate = 0.1 # Dropout for SimpleCNN Head

# --- Training Hyperparameters ---
num_workers = 1
epochs = 500
lr = 5e-5 # From previous successful run
weight_decay = 5e-4 # From previous successful run
loss_type = 'mse'
patience = 50 # Early stopping patience
reduce_lr_patience = 20 # Patience for ReduceLROnPlateau scheduler
cpu = False
max_grad_norm = 1.0
# Effective batch size will be num_train_samples // n_train_batches
# Assuming 47 training samples from logs: 47 // 11 batches -> avg batch size ~4.27, actual is 4 for most, last one smaller or 3 if 47//12
n_train_batches = 11 # For batch size ~4 with 47 training samples (47//11 = 4)

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- Validate Paths (using check_path for files that *must* exist) ---
uhi_csv_path = check_path(uhi_csv_path, project_root, "UHI CSV")
bronx_weather_csv_path = check_path(bronx_weather_csv_path, project_root, "Bronx Weather CSV")
manhattan_weather_csv_path = check_path(manhattan_weather_csv_path, project_root, "Manhattan Weather CSV")

if feature_flags["use_dem"]:
    dem_path = check_path(dem_path, project_root, "DEM TIF")
if feature_flags["use_dsm"]:
    dsm_path = check_path(dsm_path, project_root, "DSM TIF")
if feature_flags["use_clay"]:
    clay_checkpoint_path = check_path(clay_checkpoint_path, project_root, "Clay Checkpoint")
    clay_metadata_path = check_path(clay_metadata_path, project_root, "Clay Metadata")
if feature_flags["use_clay"] or feature_flags["use_sentinel_composite"]:
    cloudless_mosaic_path = check_path(cloudless_mosaic_path, project_root, "Cloudless Mosaic")
if feature_flags["use_lst"]:
    single_lst_median_path = check_path(single_lst_median_path, project_root, "Single LST Median", should_exist=True)


# --- Calculate Bounds ---
uhi_df = pd.read_csv(uhi_csv_path) 
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds = [
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {uhi_csv_path.name}: {bounds}") 


# --- Central Config Dictionary --- #
config = {
    # Paths & Info
    "model_type": "BranchedUHIModel", 
    "project_root": project_root_str,
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "feature_resolution_m": feature_resolution_m,
    "uhi_grid_resolution_m": uhi_grid_resolution_m,
    "temporal_seq_len": temporal_seq_len,
    "clay_proj_channels": clay_proj_channels,
    "enabled_weather_features": enabled_weather_features, 
    "uhi_csv": str(uhi_csv_path),
    "bronx_weather_csv": str(bronx_weather_csv_path),
    "manhattan_weather_csv": str(manhattan_weather_csv_path),
    "bounds": bounds,
    "feature_flags": feature_flags,
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path": str(dem_path) if feature_flags["use_dem"] else None,
    "dsm_path": str(dsm_path) if feature_flags["use_dsm"] else None,
    "elevation_nodata": elevation_nodata,
    "cloudless_mosaic_path": str(cloudless_mosaic_path) if feature_flags["use_clay"] or feature_flags["use_sentinel_composite"] else None,
    "single_lst_median_path": str(single_lst_median_path) if feature_flags["use_lst"] else None,
    "lst_nodata": lst_nodata,
    # Model Config
    "convlstm_hidden_dims": convlstm_hidden_dims,
    "convlstm_kernel_sizes": convlstm_kernel_sizes,
    "convlstm_num_layers": convlstm_num_layers,
    "proj_static_ch": proj_static_ch,
    "proj_temporal_ch": proj_temporal_ch,
    "projection_dropout_rate": projection_dropout_rate, # ADDED
    "head_type": head_type,
    "unet_base_channels": unet_base_channels if head_type == "unet" else None,
    "unet_depth": unet_depth if head_type == "unet" else None,
    "unet_dropout_rate": unet_dropout_rate if head_type == "unet" else None,
    "simple_cnn_hidden_dims": simple_cnn_hidden_dims if head_type == "simple_cnn" else None,
    "simple_cnn_kernel_size": simple_cnn_kernel_size if head_type == "simple_cnn" else None,
    "simple_cnn_dropout_rate": simple_cnn_dropout_rate if head_type == "simple_cnn" else None,
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(clay_checkpoint_path) if feature_flags["use_clay"] else None,
    "clay_metadata_path": str(clay_metadata_path) if feature_flags["use_clay"] else None,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience, # For EarlyStopping
    "reduce_lr_patience": reduce_lr_patience, # For ReduceLROnPlateau
    "max_grad_norm": max_grad_norm,
    "device": str(device)
}

# --- Create Run Directory & Update Config ---
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name_suffix = f"{config['model_type']}_{city_name}_{config['head_type']}_{run_timestamp}"
run_dir = output_dir_base / run_name_suffix
run_dir.mkdir(parents=True, exist_ok=True)
config["run_dir"] = str(run_dir) 

print(f"Run directory: {run_dir}")
print("\\nBranched Model Configuration dictionary created:")
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x))


Using device: cuda
Loaded bounds from uhi.csv: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]
Run directory: /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717
\nBranched Model Configuration dictionary created:
{
  "model_type": "BranchedUHIModel",
  "project_root": "/home/jupyter/MLC-Project",
  "city_name": "NYC",
  "wandb_project_name": "MLC_UHI_Proj",
  "wander_run_name_prefix": "NYC_BranchedUHI",
  "feature_resolution_m": 50,
  "uhi_grid_resolution_m": 50,
  "temporal_seq_len": 60,
  "clay_proj_channels": 16,
  "enabled_weather_features": [
    "air_temp",
    "rel_humidity",
    "avg_windspeed",
    "wind_direction",
    "solar_flux"
  ],
  "uhi_csv": "/home/jupyter/MLC-Project/data/NYC/uhi.csv",
  "bronx_weather_csv": "/home/jupyter/MLC-Project/data/NYC/bronx_weather.csv",
  "manhattan_weather_csv": "/home/jupyter/MLC-Project/data/NYC/manhattan_weather.csv",
  "bounds": [
    -73.99445667,
   

In [3]:
\
# %% Data Loading and Preprocessing (Branched Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import (
    calculate_uhi_stats, # Removed split_data
    create_dataloaders
)
from torch.utils.data import Subset # Import Subset
# -------------------

print("Initializing BranchedCityDataSet...")
try:
    dataset = CityDataSetBranched(
        bounds=config["bounds"],
        feature_resolution_m=config["feature_resolution_m"], # Corrected param name
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"], # Corrected param name
        uhi_csv=config["uhi_csv"], # Use paths from config
        bronx_weather_csv=config["bronx_weather_csv"],
        manhattan_weather_csv=config["manhattan_weather_csv"],
        data_dir=project_root_str,
        city_name=config["city_name"],
        feature_flags=config["feature_flags"],
        enabled_weather_features=config["enabled_weather_features"], # NEW: Pass from config
        sentinel_bands_to_load=config["sentinel_bands_to_load"],
        dem_path=config["dem_path"], # Corrected param name
        dsm_path=config["dsm_path"], # Corrected param name
        elevation_nodata=config["elevation_nodata"], # Corrected param name
        cloudless_mosaic_path=config["cloudless_mosaic_path"],
        single_lst_median_path=config["single_lst_median_path"],
        lst_nodata=config["lst_nodata"], # Added missing param
        temporal_seq_len=config["temporal_seq_len"], # RENAMED from weather_seq_length
        target_crs_str=config.get("target_crs_str", "EPSG:4326") # Added optional param
        # use_autoregressive_uhi is removed here as it's always active in the dataloader now
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Sequential Train/Val Split --- #
val_percent = 0.20 # Keep the percentage definition
num_samples = len(dataset)
if num_samples < 2: # Need at least one for train and one for val
    raise ValueError(f"Dataset has only {num_samples} samples, cannot perform train/val split.")

n_train = int(num_samples * (1 - val_percent))
n_val = num_samples - n_train

if n_train == 0 or n_val == 0:
    raise ValueError(f"Split resulted in zero samples for train ({n_train}) or validation ({n_val}). Adjust val_percent or check dataset size.")

train_indices = list(range(n_train))
val_indices = list(range(n_train, num_samples))

train_ds = Subset(dataset, train_indices)
val_ds = Subset(dataset, val_indices)

print(f"Sequential dataset split: {len(train_ds)} training (indices 0-{n_train-1}), {len(val_ds)} validation (indices {n_train}-{num_samples-1}) samples.")

# --- Calculate UHI Mean and Std from Training Data ONLY --- #
uhi_mean, uhi_std = calculate_uhi_stats(train_ds)
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
train_loader, val_loader = create_dataloaders(
    train_ds,
    val_ds,
    n_train_batches=config['n_train_batches'],
    num_workers=config['num_workers'],
    device=device # Pass device from config cell
)


2025-05-07 17:27:17,797 - INFO - Dataloader configured to include previous UHI grid for autoregression (always active).
2025-05-07 17:27:17,798 - INFO - Dataloader will produce 6 weather channels based on enabled features: ['air_temp', 'rel_humidity', 'avg_windspeed', 'wind_direction', 'solar_flux']
2025-05-07 17:27:17,799 - INFO - Target FEATURE grid size (H, W): (224, 194) @ 50m, CRS: EPSG:4326
2025-05-07 17:27:17,799 - INFO - Target UHI grid size (H, W): (224, 194) @ 50m


Initializing BranchedCityDataSet...


Precomputing UHI grids: 100%|██████████| 59/59 [00:00<00:00, 3704.05it/s]
2025-05-07 17:27:17,854 - INFO - Loading DSM from: /home/jupyter/MLC-Project/data/NYC/sat_files/nyc_dsm_cop-dem-glo-30_native-resolution_pc.tif
2025-05-07 17:27:17,866 - INFO - DSM loaded raw shape: (1, 364, 415)
2025-05-07 17:27:17,878 - INFO - Clipping DSM to bounds: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]
2025-05-07 17:27:17,879 - INFO - Opened DSM (lazy load). Native shape (approx): (1, 364, 415)
2025-05-07 17:27:17,880 - INFO - Calculating global DSM 2nd/98th percentiles...
2025-05-07 17:27:17,885 - INFO - Global DSM p2: 0.00, p98: 94.01
2025-05-07 17:27:17,886 - INFO - Loading cloudless mosaic from /home/jupyter/MLC-Project/data/NYC/sat_files/sentinel_NYC_20210601_to_20210901_cloudless_mosaic.npy with memory mapping
2025-05-07 17:27:17,887 - INFO - Loaded mosaic shape (native res): (5, 1119, 1278)
2025-05-07 17:27:17,895 - INFO - Loaded Bronx we

Sequential dataset split: 47 training (indices 0-46), 12 validation (indices 47-58) samples.


Calculating stats: 100%|██████████| 47/47 [00:00<00:00, 9596.55it/s]
2025-05-07 17:27:17,911 - INFO - Training UHI Mean: 1.0004, Std Dev: 0.0169
2025-05-07 17:27:17,912 - INFO - Creating dataloaders...
2025-05-07 17:27:17,913 - INFO - Using Train Batch Size: 3
2025-05-07 17:27:17,915 - INFO - Using Validation Batch Size: 1
2025-05-07 17:27:17,915 - INFO - Data loading setup complete.


In [4]:
\
# %% Model Initialization (Branched Model + Common Resampling)

# --- Import necessary components ---
from src.branched_uhi_model import BranchedUHIModel
from src.train.loss import masked_mse_loss, masked_mae_loss
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import logging # Ensure logging is imported if not already

# Instantiate the BranchedUHIModel
print(f"Initializing {config['model_type']}...")
try:
    model = BranchedUHIModel(
        # --- Weather Branch Config --- #
        enabled_weather_features=config["enabled_weather_features"], 
        convlstm_hidden_dims=config["convlstm_hidden_dims"],
        convlstm_kernel_sizes=config["convlstm_kernel_sizes"],
        convlstm_num_layers=config["convlstm_num_layers"],
        temporal_seq_len=config["temporal_seq_len"],
        # --- Static Feature Config --- #
        feature_flags=config["feature_flags"],
        clay_proj_channels=config["clay_proj_channels"],
        sentinel_bands_to_load=config.get("sentinel_bands_to_load"), 
        # Clay Specific
        clay_model_size=config.get("clay_model_size"),
        clay_bands=config.get("clay_bands"),
        clay_platform=config.get("clay_platform"),
        clay_gsd=config.get("clay_gsd"),
        freeze_backbone=config.get("freeze_backbone", True),
        clay_checkpoint_path=config.get("clay_checkpoint_path"),
        clay_metadata_path=config.get("clay_metadata_path"),
        # --- Projection Config --- #
        proj_static_ch=config["proj_static_ch"],
        proj_temporal_ch=config["proj_temporal_ch"],
        projection_dropout_rate=config.get("projection_dropout_rate", 0.0), # ADDED
        # --- Head Config --- #
        head_type=config["head_type"],
        unet_base_channels=config.get("unet_base_channels"), 
        unet_depth=config.get("unet_depth"),
        unet_dropout_rate=config.get("unet_dropout_rate"),
        simple_cnn_hidden_dims=config.get("simple_cnn_hidden_dims"),
        simple_cnn_kernel_size=config.get("simple_cnn_kernel_size"),
        simple_cnn_dropout_rate=config.get("simple_cnn_dropout_rate", 0.1), 
        # --- Target Grid Info --- #
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"],
        bounds=config["bounds"]
    )
    model.to(config["device"])
    print(f"{config['model_type']} initialized successfully.")

except Exception as e:
    logging.error(f"Error initializing BranchedUHIModel: {e}", exc_info=True)
    raise # Re-raise the exception after logging

# --- Optimizer --- #
try:
    optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
    print("Optimizer (AdamW) initialized.")
except Exception as e:
    logging.error(f"Error initializing optimizer: {e}", exc_info=True)
    raise

# --- Loss Function --- #
if config["loss_type"] == 'mse':
    loss_fn = masked_mse_loss
    print("Loss function set to masked_mse_loss.")
elif config["loss_type"] == 'mae':
    loss_fn = masked_mae_loss
    print("Loss function set to masked_mae_loss.")
else:
    raise ValueError(f"Unsupported loss type: {config['loss_type']}")

# --- LR Scheduler --- #
try:
    # Use the specific patience for the LR scheduler
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=config['reduce_lr_patience'], factor=0.5)
    print(f"Initialized ReduceLROnPlateau scheduler with patience={config['reduce_lr_patience']}.")
except Exception as e:
    logging.error(f"Error initializing scheduler: {e}", exc_info=True)
    scheduler = None 
    print("Proceeding without LR scheduler due to initialization error.")


print("\\nModel, optimizer, loss function, and scheduler setup complete.")


2025-05-07 17:27:17,928 - INFO - BranchedModel configured for target output UHI grid: (224, 194)
2025-05-07 17:27:17,939 - INFO - ConvLSTM input dimension set to 7 (Weather: 6 + Prev UHI: 1)


Initializing BranchedUHIModel...
Manually loading checkpoint: /home/jupyter/MLC-Project/notebooks/clay-v1.5.ckpt
Instantiating ClayMAEModule manually...


2025-05-07 17:27:25,129 - INFO - Loading pretrained weights from Hugging Face hub (timm/vit_large_patch14_reg4_dinov2.lvd142m)
2025-05-07 17:27:25,247 - INFO - [timm/vit_large_patch14_reg4_dinov2.lvd142m] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Loading state_dict manually into self.model.model...


2025-05-07 17:27:28,978 - INFO - Identified final encoder layer as self.model.model.proj
2025-05-07 17:27:28,979 - INFO - Keeping Clay backbone frozen.
2025-05-07 17:27:28,984 - INFO - ClayFeatureExtractor output channels set to: 1024
2025-05-07 17:27:29,052 - INFO - Added BatchNorm2d before Clay projection for 1024 channels
2025-05-07 17:27:29,054 - INFO - Added Clay projection Conv1x1: 1024 -> 16 channels
2025-05-07 17:27:29,055 - INFO - Total input channels for STATIC projection: 1
2025-05-07 17:27:29,055 - INFO - ConvLSTM input channels (actual_weather_input_channels): 6
2025-05-07 17:27:29,056 - INFO - Static projection: 1 -> 2 channels
2025-05-07 17:27:29,057 - INFO - Temporal projection: 16 -> 16 channels
2025-05-07 17:27:29,058 - INFO - Adding clay_output_channels (16) to head input channels
2025-05-07 17:27:29,071 - INFO - Initialized UNetDecoder. In channels: 34, Base channels: 32, Depth: 3
2025-05-07 17:27:29,072 - INFO - BranchedModel using UNetDecoder head. Output channels

Clay model properties: model_size=large, embed_dim=1024, patch_size=16 (patch_size OVERRIDDEN)
Normalization prepared for bands: ['blue', 'green', 'red', 'nir']
BranchedUHIModel initialized successfully.
Optimizer (AdamW) initialized.
Loss function set to masked_mse_loss.
Initialized ReduceLROnPlateau scheduler.
\nModel, optimizer, loss function, and scheduler setup complete.


In [5]:
# %% Training Loop (Branched Model)

# --- Helper functions ---
# Get the warmup epochs from config or default to 5
warmup_epochs = config.get("warmup_epochs", 5)

# --- Training Loop ---
# Initialize metrics tracking
best_val_loss = float('inf')
best_val_r2 = -float('inf')
patience_counter = 0
training_log = []

# Create run directory
model_save_dir = Path(config['run_dir']) / "checkpoints"
model_save_dir.mkdir(parents=True, exist_ok=True)

# Save config to run directory
config_path = Path(config['run_dir']) / "config.json"
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x)

# Initialize wandb
try:
    import wandb
    wandb_available = True
    print("Weights & Biases (wandb) available for logging.")
except ImportError:
    wandb_available = False
    wandb = None
    print("Weights & Biases (wandb) not available. Skipping wandb logging.")

if wandb_available:
    # Configure wandb
    wandb.init(
        project=config['wandb_project_name'],
        name=f"{config['wander_run_name_prefix']}_{datetime.now().strftime('%Y%m%d_%H%M')}",
        config=config
    )
    
    # Optional: watch model parameters
    wandb.watch(model)

print(f"Starting training for {config['epochs']} epochs with patience {config['patience']}")

try:
    for epoch in range(config['epochs']):
        epoch_start = time.time()
        
        # --- Train --- #
        if train_loader:
            # Use generic train function from train_utils
            train_loss, train_rmse, train_r2 = train_utils.train_epoch_generic(
                model, train_loader, optimizer, loss_fn, device, 
                uhi_mean=config['uhi_mean'], 
                uhi_std=config['uhi_std'],
                feature_flags=config['feature_flags'],
                max_grad_norm=config.get("max_grad_norm", 1.0)
            )
            print(f"Train Loss: {train_loss:.4f}, Train RMSE: {train_rmse:.4f}, Train R2: {train_r2:.4f}")
            if np.isnan(train_loss):
                print("Warning: Training loss is NaN. Stopping training.")
                break
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        else:
            print("Skipping training: train_loader is None.")
            train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        
        # Log train metrics AFTER checking for NaN
        if wandb:
            wandb.log(log_metrics)
        training_log.append(log_metrics) # Append to local log regardless of W&B


        # --- Validate --- #
        if val_loader:
            # Use generic validate function from train_utils
            val_loss, val_rmse, val_r2 = train_utils.validate_epoch_generic(
                model, val_loader, loss_fn, device, 
                uhi_mean=config['uhi_mean'], 
                uhi_std=config['uhi_std'],
                feature_flags=config['feature_flags']
            )
            print(f"Val Loss:   {val_loss:.4f}, Val RMSE:   {val_rmse:.4f}, Val R2:   {val_r2:.4f}")
            if np.isnan(val_loss):
                print("Warning: Validation Loss is NaN. Stopping training.")
                break
            val_metrics = {"val_loss": val_loss, "val_rmse": val_rmse, "val_r2": val_r2}
            log_metrics.update(val_metrics)
            
            # Log validation metrics
            if wandb:
                wandb.log(val_metrics)
            
            # Warmup period: don't save or check early stopping until after warmup_epochs
            if epoch >= warmup_epochs:
                # Check for improvement (using validation loss now)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_val_r2 = val_r2
                    patience_counter = 0
                    
                    # Save best model
                    train_utils.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict() if scheduler else None,
                        'loss': val_loss,
                        'val_rmse': val_rmse,
                        'val_r2': val_r2,
                        'config': config
                    }, is_best=True, output_dir=model_save_dir)
                    print(f"New best model saved at epoch {epoch+1} with val_loss {val_loss:.4f}")
                else:
                    patience_counter += 1
                    print(f"No improvement. Patience: {patience_counter}/{config['patience']}")
                    
                    # Early stopping check
                    if patience_counter >= config['patience']:
                        print(f"Early stopping triggered after {epoch+1} epochs")
                        break
            else:
                print(f"Warmup epoch {epoch+1}/{warmup_epochs}. Skipping checkpointing and early stopping.")
        else:
            print("Skipping validation: val_loader is None.")
            
        # Step the scheduler after validation (if it exists)
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
            if wandb:
                wandb.log({"lr": optimizer.param_groups[0]['lr']})
                
        # Print epoch summary
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{config['epochs']} completed in {epoch_time:.2f}s")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")
        print("-" * 80)
            
    print("Training complete!")
    print(f"Best validation loss: {best_val_loss:.4f}, Best R2: {best_val_r2:.4f}")
    
    # Save the final model
    final_model_path = model_save_dir / "final_model.pt"
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'config': config
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    # Save training log
    log_df = pd.DataFrame(training_log)
    log_path = Path(config['run_dir']) / "training_log.csv"
    log_df.to_csv(log_path, index=False)
    print(f"Training log saved to {log_path}")
    
except Exception as e:
    print(f"Error during training: {str(e)}")
    raise
finally:
    # Finish wandb run
    if wandb:
        wandb.finish()

Weights & Biases (wandb) available for logging.


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33marnava1304[0m ([33marnava1304-columbia-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Starting training for 500 epochs with patience 50


                                                                    

Train Loss: 1.0722, Train RMSE: 0.0176, Train R2: -0.0866


                                                                           

Val Loss:   0.7241, Val RMSE:   0.0144, Val R2:   -0.0276
Warmup epoch 1/5. Skipping checkpointing and early stopping.
Epoch 1/500 completed in 51.86s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.9987, Train RMSE: 0.0169, Train R2: -0.0072


                                                                           

Val Loss:   0.7100, Val RMSE:   0.0142, Val R2:   -0.0067
Warmup epoch 2/5. Skipping checkpointing and early stopping.
Epoch 2/500 completed in 50.57s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.9760, Train RMSE: 0.0169, Train R2: 0.0011


                                                                           

Val Loss:   0.6941, Val RMSE:   0.0140, Val R2:   0.0199
Warmup epoch 3/5. Skipping checkpointing and early stopping.
Epoch 3/500 completed in 50.79s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 1.0049, Train RMSE: 0.0170, Train R2: -0.0125


                                                                           

Val Loss:   0.6675, Val RMSE:   0.0137, Val R2:   0.0607
Warmup epoch 4/5. Skipping checkpointing and early stopping.
Epoch 4/500 completed in 50.81s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.9637, Train RMSE: 0.0165, Train R2: 0.0458


                                                                           

Val Loss:   0.6554, Val RMSE:   0.0136, Val R2:   0.0739
Warmup epoch 5/5. Skipping checkpointing and early stopping.
Epoch 5/500 completed in 50.53s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.9776, Train RMSE: 0.0166, Train R2: 0.0287


                                                                           

Val Loss:   0.6663, Val RMSE:   0.0138, Val R2:   0.0596


2025-05-07 17:32:46,487 - INFO - Saved current checkpoint to /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717/checkpoints/checkpoint.pth.tar
2025-05-07 17:32:48,347 - INFO - Saved new best model to /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717/checkpoints/model_best.pth.tar


New best model saved at epoch 6 with val_loss 0.6663
Epoch 6/500 completed in 63.18s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.9241, Train RMSE: 0.0163, Train R2: 0.0683


                                                                           

Val Loss:   0.6842, Val RMSE:   0.0139, Val R2:   0.0387
No improvement. Patience: 1/50
Epoch 7/500 completed in 50.76s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.9447, Train RMSE: 0.0165, Train R2: 0.0419


                                                                           

Val Loss:   0.6707, Val RMSE:   0.0138, Val R2:   0.0586
No improvement. Patience: 2/50
Epoch 8/500 completed in 50.60s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.9222, Train RMSE: 0.0162, Train R2: 0.0765


                                                                           

Val Loss:   0.6588, Val RMSE:   0.0136, Val R2:   0.0747


2025-05-07 17:35:31,116 - INFO - Saved current checkpoint to /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717/checkpoints/checkpoint.pth.tar
2025-05-07 17:35:40,211 - INFO - Saved new best model to /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717/checkpoints/model_best.pth.tar


New best model saved at epoch 9 with val_loss 0.6588
Epoch 9/500 completed in 70.51s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.9008, Train RMSE: 0.0162, Train R2: 0.0759


                                                                           

Val Loss:   0.6683, Val RMSE:   0.0138, Val R2:   0.0558
No improvement. Patience: 1/50
Epoch 10/500 completed in 50.53s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8849, Train RMSE: 0.0159, Train R2: 0.1151


                                                                           

Val Loss:   0.6706, Val RMSE:   0.0138, Val R2:   0.0568
No improvement. Patience: 2/50
Epoch 11/500 completed in 50.61s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8534, Train RMSE: 0.0157, Train R2: 0.1321


                                                                           

Val Loss:   0.6453, Val RMSE:   0.0135, Val R2:   0.0915


2025-05-07 17:38:22,886 - INFO - Saved current checkpoint to /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717/checkpoints/checkpoint.pth.tar
2025-05-07 17:38:32,054 - INFO - Saved new best model to /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717/checkpoints/model_best.pth.tar


New best model saved at epoch 12 with val_loss 0.6453
Epoch 12/500 completed in 70.71s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8827, Train RMSE: 0.0159, Train R2: 0.1133


                                                                           

Val Loss:   0.6691, Val RMSE:   0.0138, Val R2:   0.0587
No improvement. Patience: 1/50
Epoch 13/500 completed in 50.73s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8502, Train RMSE: 0.0156, Train R2: 0.1493


                                                                           

Val Loss:   0.6685, Val RMSE:   0.0138, Val R2:   0.0574
No improvement. Patience: 2/50
Epoch 14/500 completed in 50.61s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8784, Train RMSE: 0.0159, Train R2: 0.1164


                                                                           

Val Loss:   0.6486, Val RMSE:   0.0136, Val R2:   0.0859
No improvement. Patience: 3/50
Epoch 15/500 completed in 50.47s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8263, Train RMSE: 0.0155, Train R2: 0.1541


                                                                           

Val Loss:   0.6629, Val RMSE:   0.0137, Val R2:   0.0607
No improvement. Patience: 4/50
Epoch 16/500 completed in 50.65s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8756, Train RMSE: 0.0160, Train R2: 0.1014


                                                                           

Val Loss:   0.6736, Val RMSE:   0.0138, Val R2:   0.0513
No improvement. Patience: 5/50
Epoch 17/500 completed in 50.54s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8779, Train RMSE: 0.0160, Train R2: 0.1063


                                                                           

Val Loss:   0.6594, Val RMSE:   0.0137, Val R2:   0.0706
No improvement. Patience: 6/50
Epoch 18/500 completed in 50.47s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8474, Train RMSE: 0.0154, Train R2: 0.1625


                                                                           

Val Loss:   0.6515, Val RMSE:   0.0136, Val R2:   0.0809
No improvement. Patience: 7/50
Epoch 19/500 completed in 50.71s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8668, Train RMSE: 0.0156, Train R2: 0.1447


                                                                           

Val Loss:   0.6539, Val RMSE:   0.0136, Val R2:   0.0774
No improvement. Patience: 8/50
Epoch 20/500 completed in 50.76s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8501, Train RMSE: 0.0157, Train R2: 0.1329


                                                                           

Val Loss:   0.6717, Val RMSE:   0.0138, Val R2:   0.0560
No improvement. Patience: 9/50
Epoch 21/500 completed in 50.84s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8473, Train RMSE: 0.0157, Train R2: 0.1333


                                                                           

Val Loss:   0.6683, Val RMSE:   0.0137, Val R2:   0.0612
No improvement. Patience: 10/50
Epoch 22/500 completed in 50.81s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8600, Train RMSE: 0.0157, Train R2: 0.1301


                                                                           

Val Loss:   0.6572, Val RMSE:   0.0136, Val R2:   0.0773
No improvement. Patience: 11/50
Epoch 23/500 completed in 50.61s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8649, Train RMSE: 0.0156, Train R2: 0.1471


                                                                           

Val Loss:   0.6869, Val RMSE:   0.0139, Val R2:   0.0349
No improvement. Patience: 12/50
Epoch 24/500 completed in 50.64s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8568, Train RMSE: 0.0156, Train R2: 0.1437


                                                                           

Val Loss:   0.6744, Val RMSE:   0.0138, Val R2:   0.0532
No improvement. Patience: 13/50
Epoch 25/500 completed in 50.62s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8285, Train RMSE: 0.0155, Train R2: 0.1574


                                                                           

Val Loss:   0.6756, Val RMSE:   0.0138, Val R2:   0.0485
No improvement. Patience: 14/50
Epoch 26/500 completed in 50.57s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7909, Train RMSE: 0.0149, Train R2: 0.2186


                                                                           

Val Loss:   0.6763, Val RMSE:   0.0138, Val R2:   0.0478
No improvement. Patience: 15/50
Epoch 27/500 completed in 50.48s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8305, Train RMSE: 0.0154, Train R2: 0.1702


                                                                           

Val Loss:   0.6667, Val RMSE:   0.0137, Val R2:   0.0626
No improvement. Patience: 16/50
Epoch 28/500 completed in 50.66s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7970, Train RMSE: 0.0153, Train R2: 0.1806


                                                                           

Val Loss:   0.6790, Val RMSE:   0.0139, Val R2:   0.0444
No improvement. Patience: 17/50
Epoch 29/500 completed in 50.68s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8242, Train RMSE: 0.0154, Train R2: 0.1713


                                                                           

Val Loss:   0.6710, Val RMSE:   0.0138, Val R2:   0.0571
No improvement. Patience: 18/50
Epoch 30/500 completed in 50.54s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8178, Train RMSE: 0.0154, Train R2: 0.1694


                                                                           

Val Loss:   0.6770, Val RMSE:   0.0139, Val R2:   0.0453
No improvement. Patience: 19/50
Epoch 31/500 completed in 50.74s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8330, Train RMSE: 0.0154, Train R2: 0.1635


                                                                           

Val Loss:   0.6868, Val RMSE:   0.0139, Val R2:   0.0333
No improvement. Patience: 20/50
Epoch 32/500 completed in 50.68s
Current LR: 5.00e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8304, Train RMSE: 0.0153, Train R2: 0.1786


                                                                           

Val Loss:   0.6812, Val RMSE:   0.0139, Val R2:   0.0407
No improvement. Patience: 21/50
Epoch 33/500 completed in 50.63s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7572, Train RMSE: 0.0148, Train R2: 0.2355


                                                                           

Val Loss:   0.6766, Val RMSE:   0.0138, Val R2:   0.0480
No improvement. Patience: 22/50
Epoch 34/500 completed in 50.51s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7891, Train RMSE: 0.0149, Train R2: 0.2230


                                                                           

Val Loss:   0.6842, Val RMSE:   0.0139, Val R2:   0.0398
No improvement. Patience: 23/50
Epoch 35/500 completed in 50.58s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7616, Train RMSE: 0.0148, Train R2: 0.2275


                                                                           

Val Loss:   0.6756, Val RMSE:   0.0138, Val R2:   0.0500
No improvement. Patience: 24/50
Epoch 36/500 completed in 50.68s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8078, Train RMSE: 0.0152, Train R2: 0.1886


                                                                           

Val Loss:   0.6709, Val RMSE:   0.0138, Val R2:   0.0548
No improvement. Patience: 25/50
Epoch 37/500 completed in 50.50s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8220, Train RMSE: 0.0153, Train R2: 0.1782


                                                                           

Val Loss:   0.6625, Val RMSE:   0.0137, Val R2:   0.0679
No improvement. Patience: 26/50
Epoch 38/500 completed in 50.62s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7697, Train RMSE: 0.0149, Train R2: 0.2232


                                                                           

Val Loss:   0.6711, Val RMSE:   0.0138, Val R2:   0.0547
No improvement. Patience: 27/50
Epoch 39/500 completed in 50.54s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7403, Train RMSE: 0.0147, Train R2: 0.2420


                                                                           

Val Loss:   0.6818, Val RMSE:   0.0139, Val R2:   0.0429
No improvement. Patience: 28/50
Epoch 40/500 completed in 50.71s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7708, Train RMSE: 0.0148, Train R2: 0.2335


                                                                           

Val Loss:   0.6715, Val RMSE:   0.0138, Val R2:   0.0565
No improvement. Patience: 29/50
Epoch 41/500 completed in 50.61s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7251, Train RMSE: 0.0146, Train R2: 0.2468


                                                                           

Val Loss:   0.6835, Val RMSE:   0.0139, Val R2:   0.0365
No improvement. Patience: 30/50
Epoch 42/500 completed in 50.59s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7656, Train RMSE: 0.0145, Train R2: 0.2584


                                                                           

Val Loss:   0.6670, Val RMSE:   0.0137, Val R2:   0.0613
No improvement. Patience: 31/50
Epoch 43/500 completed in 50.68s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7674, Train RMSE: 0.0146, Train R2: 0.2491


                                                                           

Val Loss:   0.6675, Val RMSE:   0.0138, Val R2:   0.0597
No improvement. Patience: 32/50
Epoch 44/500 completed in 50.62s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7313, Train RMSE: 0.0144, Train R2: 0.2726


                                                                           

Val Loss:   0.6737, Val RMSE:   0.0138, Val R2:   0.0512
No improvement. Patience: 33/50
Epoch 45/500 completed in 50.57s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7060, Train RMSE: 0.0145, Train R2: 0.2662


                                                                           

Val Loss:   0.6756, Val RMSE:   0.0139, Val R2:   0.0421
No improvement. Patience: 34/50
Epoch 46/500 completed in 50.53s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7343, Train RMSE: 0.0147, Train R2: 0.2461


                                                                           

Val Loss:   0.6609, Val RMSE:   0.0137, Val R2:   0.0671
No improvement. Patience: 35/50
Epoch 47/500 completed in 50.48s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.8020, Train RMSE: 0.0149, Train R2: 0.2254


                                                                           

Val Loss:   0.6873, Val RMSE:   0.0140, Val R2:   0.0310
No improvement. Patience: 36/50
Epoch 48/500 completed in 50.51s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7290, Train RMSE: 0.0143, Train R2: 0.2814


                                                                           

Val Loss:   0.6650, Val RMSE:   0.0137, Val R2:   0.0626
No improvement. Patience: 37/50
Epoch 49/500 completed in 50.50s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7411, Train RMSE: 0.0147, Train R2: 0.2366


                                                                           

Val Loss:   0.6724, Val RMSE:   0.0138, Val R2:   0.0543
No improvement. Patience: 38/50
Epoch 50/500 completed in 50.74s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.6958, Train RMSE: 0.0141, Train R2: 0.3014


                                                                           

Val Loss:   0.6729, Val RMSE:   0.0138, Val R2:   0.0511
No improvement. Patience: 39/50
Epoch 51/500 completed in 50.79s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.6786, Train RMSE: 0.0140, Train R2: 0.3122


                                                                           

Val Loss:   0.6586, Val RMSE:   0.0137, Val R2:   0.0710
No improvement. Patience: 40/50
Epoch 52/500 completed in 50.54s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7029, Train RMSE: 0.0144, Train R2: 0.2675


                                                                           

Val Loss:   0.6783, Val RMSE:   0.0139, Val R2:   0.0448
No improvement. Patience: 41/50
Epoch 53/500 completed in 50.66s
Current LR: 2.50e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7204, Train RMSE: 0.0144, Train R2: 0.2753


                                                                           

Val Loss:   0.6550, Val RMSE:   0.0136, Val R2:   0.0774
No improvement. Patience: 42/50
Epoch 54/500 completed in 50.74s
Current LR: 1.25e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7118, Train RMSE: 0.0143, Train R2: 0.2845


                                                                           

Val Loss:   0.6713, Val RMSE:   0.0138, Val R2:   0.0553
No improvement. Patience: 43/50
Epoch 55/500 completed in 50.63s
Current LR: 1.25e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7024, Train RMSE: 0.0142, Train R2: 0.2879


                                                                           

Val Loss:   0.6732, Val RMSE:   0.0138, Val R2:   0.0512
No improvement. Patience: 44/50
Epoch 56/500 completed in 50.48s
Current LR: 1.25e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7145, Train RMSE: 0.0141, Train R2: 0.2971


                                                                           

Val Loss:   0.6858, Val RMSE:   0.0139, Val R2:   0.0335
No improvement. Patience: 45/50
Epoch 57/500 completed in 50.66s
Current LR: 1.25e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7473, Train RMSE: 0.0145, Train R2: 0.2634


                                                                           

Val Loss:   0.6675, Val RMSE:   0.0137, Val R2:   0.0612
No improvement. Patience: 46/50
Epoch 58/500 completed in 50.59s
Current LR: 1.25e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7183, Train RMSE: 0.0141, Train R2: 0.3007


                                                                           

Val Loss:   0.6574, Val RMSE:   0.0137, Val R2:   0.0709
No improvement. Patience: 47/50
Epoch 59/500 completed in 50.64s
Current LR: 1.25e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.7177, Train RMSE: 0.0142, Train R2: 0.2935


                                                                           

Val Loss:   0.6894, Val RMSE:   0.0140, Val R2:   0.0274
No improvement. Patience: 48/50
Epoch 60/500 completed in 50.76s
Current LR: 1.25e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.6520, Train RMSE: 0.0138, Train R2: 0.3348


                                                                           

Val Loss:   0.6734, Val RMSE:   0.0138, Val R2:   0.0486
No improvement. Patience: 49/50
Epoch 61/500 completed in 50.64s
Current LR: 1.25e-05
--------------------------------------------------------------------------------


                                                                     

Train Loss: 0.6721, Train RMSE: 0.0138, Train R2: 0.3313


                                                                           

Val Loss:   0.6631, Val RMSE:   0.0137, Val R2:   0.0663
No improvement. Patience: 50/50
Early stopping triggered after 62 epochs
Training complete!
Best validation loss: 0.6453, Best R2: 0.0915
Final model saved to /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717/checkpoints/final_model.pt
Training log saved to /home/jupyter/MLC-Project/training_runs/BranchedUHIModel_NYC_unet_20250507_172717/training_log.csv


0,1
epoch,▁▁▁▁▁▂▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▆▆▆▆▆▆▆▇▇▇▇█
lr,███████████████████████▃▃▃▃▃▃▃▃▃▃▃▃▁▁▁▁▁
train_loss,█▇▆▆▆▆▅▅▄▅▅▅▄▄▄▄▄▄▃▄▄▄▂▃▃▂▂▃▂▂▃▂▂▁▂▂▂▂▂▁
train_r2,▁▂▂▃▃▃▄▄▄▅▅▅▄▄▅▅▅▅▆▅▅▅▅▆▆▆▆▇▇▇▆▇▆▇▇▇▇▇▇█
train_rmse,█▇▇▇▆▆▅▅▅▄▄▅▄▅▅▄▄▄▃▄▄▄▄▄▃▃▃▂▃▂▂▃▂▁▂▂▂▂▂▁
val_loss,█▇▅▃▂▃▂▃▃▁▃▂▁▁▃▂▃▄▄▄▅▄▃▂▃▃▄▃▄▂▃▃▂▄▂▄▃▂▅▂
val_r2,▁▂▄▆▇▅▆▇▆▆▆█▇▆▆▅▆▆▆▅▅▅▆▅▆▇▆▅▆▅▆▆▇▇▆▇▅▆▅▆
val_rmse,█▇▃▂▃▃▂▃▃▁▃▂▁▂▃▂▄▃▃▄▄▄▃▄▃▄▃▄▃▃▄▄▂▃▃▃▄▃▂▂

0,1
epoch,62.0
lr,1e-05
train_loss,0.67213
train_r2,0.33128
train_rmse,0.0138
val_loss,0.66309
val_r2,0.06632
val_rmse,0.01371


## Training