In [1]:
# %% Setup & Imports

import sys
import os
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset # MODIFIED: Added Subset
from tqdm import tqdm
import math
import shutil # For checkpoint saving

# Add the project root to the Python path
project_root = Path(os.getcwd()).parent # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

# --- Import Model Components ---
from src.model import UHINetCNN # Import the CNN model
from src.ingest.dataloader_cnn import CityDataSet # MODIFIED: Import the updated dataloader

# --- Import Training Utilities & Loss --- ### MODIFIED ###
from src.train.loss import masked_mse_loss, masked_mae_loss # Import loss functions
import src.train.train_utils as train_utils # Import the new utility module

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Optionally import wandb if needed
try:
    import wandb
except ImportError:
    print("wandb not installed, skipping W&B logging.")
    wandb = None

# Import necessary metrics
from sklearn.metrics import mean_squared_error, r2_score

In [2]:
### UNCOMMENT ON FIRST RUN IF USING Clay
#!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt

In [3]:
\
# %% Configuration / Hyperparameters (CNN Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import check_path # For path validation
from src.ingest.data_utils import calculate_actual_weather_channels # For dynamic weather channels
# -------------------

# --- Paths & Basic Info ---
# project_root is defined in the first cell
project_root_str = str(project_root)
data_dir = project_root / "data"
city_name = "NYC" # Should be defined or loaded
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_UHINetCNN"

# --- Data Loading Config ---
feature_resolution_m = 30 #At least 30
uhi_grid_resolution_m = 20
clay_proj_channels = 32 # Number of channels to project Clay features to

# --- Define Absolute Input Data Paths Directly ---
uhi_csv_path = data_dir / city_name / "uhi.csv"
bronx_weather_csv_path = data_dir / city_name / "bronx_weather.csv"
manhattan_weather_csv_path = data_dir / city_name / "manhattan_weather.csv" # Ensure this matches actual filename

dem_path = data_dir / city_name / "sat_files" / f"{city_name.lower()}_dem_nasadem_native-resolution_pc.tif"
dsm_path = data_dir / city_name / "sat_files" / f"{city_name.lower()}_dsm_cop-dem-glo-30_native-resolution_pc.tif"
cloudless_mosaic_path = data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy" # Added .npy
# Assuming the LST filename structure from download_data.ipynb if it's used
lst_time_window_str_for_filename = "20210601_to_20210901" # Match download_data if it defines LST filename like this
single_lst_median_path = data_dir / city_name / "sat_files" / f"lst_{city_name}_median_{lst_time_window_str_for_filename}.npy" # Corrected and added .npy

# Nodata values
elevation_nodata = np.nan # Or np.nan if that's what your files use
lst_nodata = np.nan # Or np.nan

# --- Weather Feature Selection --- NEW ---
enabled_weather_features = [
    "rel_humidity", 
    "avg_windspeed", 
    "wind_direction",       # This will be converted to sin/cos components by data_utils
    "solar_flux",
    "air_temp"
] 

#enabled_weather_features = ["rel_humidity", "avg_windspeed", "wind_direction", "solar_flux", "air_temp"] 

# Calculate the number of actual weather channels that will be produced by the dataloader
actual_dataloader_weather_channels = calculate_actual_weather_channels(enabled_weather_features)
# ------------------------------------

# --- Feature Selection Flags --- #
feature_flags = {
    "use_dem": False,
    "use_dsm": False,
    "use_clay": True,
    "use_sentinel_composite": False,
    "use_lst": False, 
    "use_ndvi": False,
    "use_ndbi": False,
    "use_ndwi": False,
}

# --- Bands for Sentinel Composite --- #
sentinel_bands_to_load = ["blue", "green", "red", "nir", "swir16", "swir22"] #Don't change

# --- Model Config (UHINetCNN) ---
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = False # Set to False to allow finetuning Clay's last layer + projection
clay_checkpoint_path = project_root / "notebooks" / "clay-v1.5.ckpt"
clay_metadata_path = project_root / "src" / "Clay" / "configs" / "metadata.yaml"

# Head Configuration
head_type = "unet"  # Options: "unet" or "simple_cnn"
# U-Net specific
unet_base_channels = 64 
unet_depth = 4         
unet_dropout_rate = 0.1 # Dropout for U-Net blocks
# SimpleCNN specific
simple_cnn_hidden_dims = [64, 32] # Example, adjust as needed
simple_cnn_kernel_size = 3
simple_cnn_dropout_rate = 0.1 # Dropout for SimpleCNN Head


# --- Training Hyperparameters ---
num_workers = 2
epochs = 500
lr = 5e-5 # Initial: 5e-5. Consider 1e-3 or 5e-4 if using mean loss
weight_decay = 0.01
loss_type = 'mse' # Options: 'mse' or 'mae'. Ensure this matches your loss function implementation (sum vs mean)
patience = 50
cpu = False
n_train_batches = 10
max_grad_norm = 1.0 
warmup_epochs = 0

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- Validate Paths (using check_path for files that *must* exist) ---
uhi_csv_path = check_path(uhi_csv_path, project_root, "UHI CSV")
bronx_weather_csv_path = check_path(bronx_weather_csv_path, project_root, "Bronx Weather CSV")
manhattan_weather_csv_path = check_path(manhattan_weather_csv_path, project_root, "Manhattan Weather CSV")

if feature_flags["use_dem"]:
    dem_path = check_path(dem_path, project_root, "DEM TIF")
if feature_flags["use_dsm"]:
    dsm_path = check_path(dsm_path, project_root, "DSM TIF")
if feature_flags["use_clay"]:
    clay_checkpoint_path = check_path(clay_checkpoint_path, project_root, "Clay Checkpoint")
    clay_metadata_path = check_path(clay_metadata_path, project_root, "Clay Metadata")

if (feature_flags["use_clay"] or 
    feature_flags["use_sentinel_composite"] or 
    feature_flags["use_ndvi"] or 
    feature_flags["use_ndbi"] or 
    feature_flags["use_ndwi"]):
    cloudless_mosaic_path = check_path(cloudless_mosaic_path, project_root, "Cloudless Mosaic")

if feature_flags["use_lst"]:
    single_lst_median_path = check_path(single_lst_median_path, project_root, "Single LST Median", should_exist=True)

# --- Calculate Bounds --- #
uhi_df = pd.read_csv(uhi_csv_path) 
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds_list = [ 
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {uhi_csv_path.name}: {bounds_list}")

# --- Central Config Dictionary --- #
config = {
    # Paths & Info
    "model_type": "UHINetCNN", 
    "project_root": project_root_str,
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "feature_resolution_m": feature_resolution_m,
    "uhi_grid_resolution_m": uhi_grid_resolution_m,
    "clay_proj_channels": clay_proj_channels, # Added
    "enabled_weather_features": enabled_weather_features, 
    "uhi_csv": str(uhi_csv_path),
    "bronx_weather_csv": str(bronx_weather_csv_path),
    "manhattan_weather_csv": str(manhattan_weather_csv_path),
    "bounds": bounds_list, 
    "feature_flags": feature_flags,
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path": str(dem_path) if feature_flags["use_dem"] else None,
    "dsm_path": str(dsm_path) if feature_flags["use_dsm"] else None,
    "elevation_nodata": elevation_nodata,
    "cloudless_mosaic_path": str(cloudless_mosaic_path) if (
        feature_flags.get("use_clay") or 
        feature_flags.get("use_sentinel_composite") or 
        feature_flags.get("use_ndvi") or 
        feature_flags.get("use_ndbi") or 
        feature_flags.get("use_ndwi")
    ) else None,
    "single_lst_median_path": str(single_lst_median_path) if feature_flags["use_lst"] else None,
    "lst_nodata": lst_nodata,
    # Model Config 
    "head_type": head_type, # Added
    "unet_base_channels": unet_base_channels,
    "unet_depth": unet_depth,
    "unet_dropout_rate": unet_dropout_rate, # Added
    "simple_cnn_hidden_dims": simple_cnn_hidden_dims, # Added
    "simple_cnn_kernel_size": simple_cnn_kernel_size, # Added
    "simple_cnn_dropout_rate": simple_cnn_dropout_rate, # Added
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(clay_checkpoint_path) if feature_flags["use_clay"] else None,
    "clay_metadata_path": str(clay_metadata_path) if feature_flags["use_clay"] else None,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience,
    "max_grad_norm": max_grad_norm, 
    "warmup_epochs": warmup_epochs, 
    "device": str(device)
}

# --- Create Run Directory & Update Config ---
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name_suffix = f"{config['model_type']}_{city_name}_{config['head_type']}_{run_timestamp}" # Added head_type to name
run_dir = output_dir_base / run_name_suffix
run_dir.mkdir(parents=True, exist_ok=True)
config["run_dir"] = str(run_dir) 

print(f"\\nRun directory: {run_dir}")
print("UHINetCNN Configuration dictionary created:")
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x))


Using device: cuda
Loaded bounds from uhi.csv: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]

Run directory: /home/jupyter/MLC-Project/training_runs/UHINetCNN_NYC_20250507_061047
UHINetCNN Configuration dictionary created:
{
  "model_type": "UHINetCNN",
  "project_root": "/home/jupyter/MLC-Project",
  "city_name": "NYC",
  "wandb_project_name": "MLC_UHI_Proj",
  "wander_run_name_prefix": "NYC_UHINetCNN",
  "feature_resolution_m": 30,
  "uhi_grid_resolution_m": 20,
  "enabled_weather_features": [
    "rel_humidity",
    "avg_windspeed",
    "wind_direction",
    "solar_flux",
    "air_temp"
  ],
  "uhi_csv": "/home/jupyter/MLC-Project/data/NYC/uhi.csv",
  "bronx_weather_csv": "/home/jupyter/MLC-Project/data/NYC/bronx_weather.csv",
  "manhattan_weather_csv": "/home/jupyter/MLC-Project/data/NYC/manhattan_weather.csv",
  "bounds": [
    -73.99445667,
    40.75879167,
    -73.87945833,
    40.85949667
  ],
  "feature_flags": {
    "us

In [4]:
# %% Data Loading and Preprocessing (CNN Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import (
    calculate_uhi_stats, 
    create_dataloaders
)
from torch.utils.data import random_split, Subset # MODIFIED: Added random_split
# -------------------

print("Initializing CityDataSet (for CNN model)...")
try:
    # Ensure all necessary parameters from the config are passed to CityDataSet
    dataset = CityDataSet(
        bounds=config["bounds"],
        feature_resolution_m=config["feature_resolution_m"],
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"],
        uhi_csv=config["uhi_csv"],
        bronx_weather_csv=config["bronx_weather_csv"],
        manhattan_weather_csv=config["manhattan_weather_csv"],
        data_dir=project_root_str, # data_dir should be project_root for this loader
        city_name=config["city_name"],
        feature_flags=config["feature_flags"],
        enabled_weather_features=config["enabled_weather_features"], # NEW: Pass from config
        sentinel_bands_to_load=config["sentinel_bands_to_load"],
        dem_path=config["dem_path"],
        dsm_path=config["dsm_path"],
        elevation_nodata=config["elevation_nodata"],
        cloudless_mosaic_path=config["cloudless_mosaic_path"],
        single_lst_median_path=config["single_lst_median_path"],
        lst_nodata=config["lst_nodata"],
        target_crs_str=config.get("target_crs_str", "EPSG:4326") # Added optional param
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Train/Val Split (Random, similar to old setup) --- # MODIFIED
val_percent = 0.40 # MODIFIED: Changed to 40% validation
num_samples = len(dataset)

if num_samples < 2: 
    raise ValueError(f"Dataset has only {num_samples} samples, cannot perform train/val split.")

n_val = int(num_samples * val_percent)
n_train = num_samples - n_val

if n_train == 0 or n_val == 0: 
    raise ValueError(f"Split resulted in zero samples for train ({n_train}) or validation ({n_val}). Adjust val_percent or check dataset size.")

# Use random_split for a random split, with a generator for reproducibility
train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42)) # MODIFIED

print(f"Random dataset split: {len(train_ds)} training, {len(val_ds)} validation samples (using seed 42).") # MODIFIED

# --- Calculate UHI Mean and Std from Training Data ONLY --- #
uhi_mean, uhi_std = calculate_uhi_stats(train_ds)
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
train_loader, val_loader = create_dataloaders(
    train_ds,
    val_ds,
    n_train_batches=config['n_train_batches'],
    num_workers=config['num_workers'],
    device=device # Pass device from config cell
)
print("Data loading and preprocessing for CNN model complete.")


2025-05-07 06:10:47,593 - INFO - Dataloader will produce 6 weather channels based on enabled features: ['rel_humidity', 'avg_windspeed', 'wind_direction', 'solar_flux', 'air_temp']
2025-05-07 06:10:47,594 - INFO - Target FEATURE grid size (H, W): (373, 323) @ 30m, CRS: EPSG:4326
2025-05-07 06:10:47,595 - INFO - Target UHI grid size (H, W): (559, 485) @ 20m


Initializing CityDataSet (for CNN model)...


Precomputing UHI grids: 100%|██████████| 59/59 [00:00<00:00, 1175.56it/s]
2025-05-07 06:10:47,682 - INFO - Loading cloudless mosaic from /home/jupyter/MLC-Project/data/NYC/sat_files/sentinel_NYC_20210601_to_20210901_cloudless_mosaic.npy with memory mapping
2025-05-07 06:10:47,683 - INFO - Loaded mosaic shape (native res): (5, 1119, 1278)
2025-05-07 06:10:47,692 - INFO - Loaded Bronx weather data: 169 records
2025-05-07 06:10:47,692 - INFO - Loaded Manhattan weather data: 169 records
2025-05-07 06:10:47,694 - INFO - Computed grid cell center coordinates for CRS: EPSG:4326.
2025-05-07 06:10:47,696 - INFO - Computed grid cell center coordinates for weather grid at feature resolution.
2025-05-07 06:10:47,697 - INFO - Dataset initialized for NYC with 59 unique timestamps.
2025-05-07 06:10:47,697 - INFO - Enabled features (flags): {"use_dem": false, "use_dsm": false, "use_clay": true, "use_sentinel_composite": false, "use_lst": false, "use_ndvi": false, "use_ndbi": false, "use_ndwi": false}


Sequential dataset split: 47 training (indices 0-46), 12 validation (indices 47-58) samples.


Calculating stats: 100%|██████████| 47/47 [00:00<00:00, 4765.56it/s]
2025-05-07 06:10:47,714 - INFO - Training UHI Mean: 1.0002, Std Dev: 0.0168
2025-05-07 06:10:47,715 - INFO - Creating dataloaders...
2025-05-07 06:10:47,716 - INFO - Using Train Batch Size: 4
2025-05-07 06:10:47,717 - INFO - Using Validation Batch Size: 1
2025-05-07 06:10:47,718 - INFO - Data loading setup complete.


Data loading and preprocessing for CNN model complete.


In [5]:
\
# %% Model Initialization (CNN Model + Common Resampling)

# --- Import necessary components ---
from src.model import UHINetCNN # Ensure this is the correct model for the CNN pipeline
from src.train.loss import masked_mse_loss, masked_mae_loss
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau # Added scheduler import
import logging # Ensure logging is imported if not already

# Instantiate the UHINetCNN (or your chosen CNN model)
print(f"Initializing {config['model_type']}...")
try:
    # Ensure parameters match the UHINetCNN constructor
    model = UHINetCNN(
        feature_flags=config["feature_flags"],
        enabled_weather_features=config["enabled_weather_features"], 
        bounds=config["bounds"],
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"],
        clay_proj_channels=config["clay_proj_channels"], # Added
        sentinel_bands_to_load=config.get("sentinel_bands_to_load"),
        clay_model_size=config.get("clay_model_size"),
        clay_bands=config.get("clay_bands"),
        clay_platform=config.get("clay_platform"),
        clay_gsd=config.get("clay_gsd"),
        freeze_backbone=config.get("freeze_backbone", True),
        clay_checkpoint_path=config.get("clay_checkpoint_path"),
        clay_metadata_path=config.get("clay_metadata_path"),
        # Head specific parameters
        head_type=config["head_type"],
        unet_base_channels=config["unet_base_channels"],
        unet_depth=config["unet_depth"],
        unet_dropout_rate=config["unet_dropout_rate"], # Added
        simple_cnn_hidden_dims=config["simple_cnn_hidden_dims"], # Added
        simple_cnn_kernel_size=config["simple_cnn_kernel_size"], # Added
        simple_cnn_dropout_rate=config["simple_cnn_dropout_rate"] # Added
    )
    model.to(config["device"])
    print(f"{config['model_type']} initialized successfully.")

except Exception as e:
    logging.error(f"Error initializing UHINetCNN: {e}", exc_info=True)
    raise # Re-raise the exception after logging

# --- Optimizer --- #
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
print("Optimizer (AdamW) initialized.")

# --- Loss Function --- #
if config["loss_type"] == 'mse':
    loss_fn = masked_mse_loss
    print("Loss function set to masked_mse_loss.")
elif config["loss_type"] == 'mae':
    loss_fn = masked_mae_loss
    print("Loss function set to masked_mae_loss.")
else:
    raise ValueError(f"Unsupported loss type: {config['loss_type']}")

# --- LR Scheduler --- #
scheduler = None 
"""
if config.get("patience"): 
    try:
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=config.get("scheduler_patience", 10), factor=0.5)
        print("Initialized ReduceLROnPlateau scheduler.")
    except Exception as e:
        logging.error(f"Error initializing scheduler: {e}", exc_info=True)
        print("Proceeding without LR scheduler due to initialization error.")
else:
    print("Patience not set in config, proceeding without LR scheduler.")
"""
print("\\nModel, optimizer, loss function, and scheduler setup complete.")


Initializing UHINetCNN...
Manually loading checkpoint: /home/jupyter/MLC-Project/notebooks/clay-v1.5.ckpt
Instantiating ClayMAEModule manually...


2025-05-07 06:10:54,369 - INFO - Loading pretrained weights from Hugging Face hub (timm/vit_large_patch14_reg4_dinov2.lvd142m)
2025-05-07 06:10:54,497 - INFO - [timm/vit_large_patch14_reg4_dinov2.lvd142m] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Loading state_dict manually into self.model.model...


2025-05-07 06:10:58,065 - INFO - Identified final encoder layer as self.model.model.proj
2025-05-07 06:10:58,066 - INFO - Unfreezing the final encoder layer (self.model.model.proj) of the Clay backbone.
2025-05-07 06:10:58,071 - INFO - ClayFeatureExtractor output channels set to: 1024
2025-05-07 06:10:58,113 - INFO - Initialized Clay model (large), output channels: 1024
2025-05-07 06:10:58,114 - INFO - Total input channels for feature head: 1030
2025-05-07 06:10:58,243 - INFO - Initialized UNetDecoder. In channels: 1030, Base channels: 64, Depth: 4
2025-05-07 06:10:58,244 - INFO - UHINetCNN using UNetDecoder head. Output channels: 64
2025-05-07 06:10:58,245 - INFO - UHINetCNN final processor target UHI grid: (559, 485)
2025-05-07 06:10:58,246 - INFO - Initialized FinalUpsamplerAndProjection: Bicubic upsampling. InCh=64, Target=(559,485).
2025-05-07 06:10:58,247 - INFO - UHINetCNN initialized completely with unet head.


Clay model properties: model_size=large, embed_dim=1024, patch_size=16 (patch_size OVERRIDDEN)
Normalization prepared for bands: ['blue', 'green', 'red', 'nir']
UHINetCNN initialized successfully.
Optimizer (AdamW) initialized.
Loss function set to masked_mse_loss.

Model, optimizer, loss function, and scheduler setup complete.


In [None]:
# %% Training Loop (CNN Model)

# --- Imports ---
import time
from datetime import datetime
import json
from pathlib import Path
import src.train.train_utils as train_utils # Import the full module
import numpy as np # Added for isnan check
import pandas as pd # For saving log

# --- Setup ---
print(f"Model {config['model_type']} training starting on {device}")

# --- Tracking Variables --- #
best_val_r2 = -float('inf') 
patience_counter = 0
training_log_list = [] # Local log for metrics

# --- Output Directory & Run Name --- #
run_dir = Path(config["run_dir"]) # Get from config
model_save_dir = run_dir / "checkpoints"
model_save_dir.mkdir(parents=True, exist_ok=True)
print(f"Checkpoints and logs will be saved to: {run_dir}")

# --- Save Config --- #
config_path = run_dir / "config.json"
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x)
print(f"Saved configuration to {config_path}")

# --- Retrieve UHI Stats from Config --- #
uhi_mean = config.get('uhi_mean')
uhi_std = config.get('uhi_std')
if uhi_mean is None or uhi_std is None:
    raise ValueError("uhi_mean/uhi_std not in config. Run data loading cell.")
print(f"Using Training UHI Mean: {uhi_mean:.4f}, Std Dev: {uhi_std:.4f}")

# --- WANDB Init --- #
if wandb:
    try:
        if wandb.run is not None: wandb.finish() # Finish previous run if any
        wandb.init(
            project=config["wandb_project_name"],
            name=f"{config['wander_run_name_prefix']}_{datetime.now().strftime('%Y%m%d_%H%M')}",
            config=config
        )
        
        # Optional: watch model parameters
        wandb.watch(model)
        print(f"Wandb initialized for run: {run_dir.name}")
    except Exception as e:
        print(f"Wandb initialization failed: {e}")
        wandb = None
else:
    print("Wandb not available, skipping logging.")

print(f"Starting training for {config['epochs']} epochs with patience {config['patience']}")

# Get warmup_epochs from config, default to 0 if not present
warmup_epochs = config.get("warmup_epochs", 0)

try:
    for epoch in range(config['epochs']):
        epoch_start = time.time()
        print(f"--- Epoch {epoch+1}/{config['epochs']} ---")
        
        # --- Train --- #
        if train_loader:
            # Use generic train function from train_utils
            train_loss, train_rmse, train_r2 = train_utils.train_epoch_generic(
                model, train_loader, optimizer, loss_fn, device, uhi_mean, uhi_std,
                feature_flags=config["feature_flags"],
                max_grad_norm=config.get("max_grad_norm", 1.0)
            )
            print(f"Train Loss: {train_loss:.4f}, Train RMSE: {train_rmse:.4f}, Train R2: {train_r2:.4f}")
            if np.isnan(train_loss):
                print("Warning: Training loss is NaN. Stopping training.")
                break
            current_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        else:
            print("Skipping training: train_loader is None.")
            train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')
            current_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        
        # Log train metrics AFTER checking for NaN
        if wandb:
            wandb.log(current_metrics)
        training_log_list.append(current_metrics) # Append to local log regardless of W&B


        # --- Validate --- #
        if val_loader:
            # Use generic validate function from train_utils
            val_loss, val_rmse, val_r2 = train_utils.validate_epoch_generic(
                model, val_loader, loss_fn, device, uhi_mean, uhi_std,
                feature_flags=config["feature_flags"]
            )
            print(f"Val Loss:   {val_loss:.4f}, Val RMSE:   {val_rmse:.4f}, Val R2:   {val_r2:.4f}")
            if np.isnan(val_loss):
                print("Warning: Validation Loss is NaN. Stopping training.")
                break
            val_metrics = {"val_loss": val_loss, "val_rmse": val_rmse, "val_r2": val_r2}
            current_metrics.update(val_metrics)
            
            # Log validation metrics
            if wandb:
                wandb.log(val_metrics)
            
            # Warmup period: don't save or check early stopping until after 50 epochs
            if epoch >= 100:
                # Check for improvement (using validation R2 now)
                if val_r2 > best_val_r2:
                    best_val_r2 = val_r2
                    patience_counter = 0
                    
                    # Save best model
                    train_utils.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict() if scheduler else None,
                        'best_val_r2': best_val_r2,
                        'config': config
                    }, is_best=True, output_dir=model_save_dir)
                    print(f"New best model saved at epoch {epoch+1} with val_r2 {val_r2:.4f}")
                else:
                    patience_counter += 1
                    print(f"No improvement. Patience: {patience_counter}/{config['patience']}")
                    
                    # Early stopping check
                    if patience_counter >= config['patience']:
                        print(f"Early stopping triggered after {epoch+1} epochs")
                        break
            else:
                print(f"Early epoch {epoch+1}/100. Skipping checkpointing and early stopping.")
        else:
            print("Skipping validation: val_loader is None.")
            
        # Step the scheduler after validation (if it exists)
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
            if wandb:
                wandb.log({"lr": optimizer.param_groups[0]['lr']})
                
        # Print epoch summary
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{config['epochs']} completed in {epoch_time:.2f}s")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")
        print("-" * 80)
            
    print("Training complete!")
    print(f"Best validation R2: {best_val_r2:.4f}")
    
    # Save the final model
    final_model_path = model_save_dir / "final_model.pt"
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'config': config
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    # Save training log
    log_df = pd.DataFrame(training_log_list)
    log_path = Path(config['run_dir']) / "training_log.csv"
    log_df.to_csv(log_path, index=False)
    print(f"Training log saved to {log_path}")
    
except Exception as e:
    print(f"Error during training: {str(e)}")
    raise
finally:
    # Finish wandb run
    if wandb:
        wandb.finish()

Model UHINetCNN training starting on cuda
Checkpoints and logs will be saved to: /home/jupyter/MLC-Project/training_runs/UHINetCNN_NYC_20250507_061047
Saved configuration to /home/jupyter/MLC-Project/training_runs/UHINetCNN_NYC_20250507_061047/config.json
Using Training UHI Mean: 1.0002, Std Dev: 0.0168


[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33marnava1304[0m ([33marnava1304-columbia-university[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Wandb initialized for run: UHINetCNN_NYC_20250507_061047
Starting training for 500 epochs with patience 50
--- Epoch 1/500 ---


                                                                    

Train Loss: 1.1681, Train RMSE: 0.0181, Train R2: -0.1678


                                                                           

Val Loss:   0.7209, Val RMSE:   0.0142, Val R2:   0.0045
Early epoch 1/100. Skipping checkpointing and early stopping.
Epoch 1/500 completed in 15.57s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 2/500 ---


                                                                    

Train Loss: 1.1438, Train RMSE: 0.0180, Train R2: -0.1544


                                                                           

Val Loss:   0.7358, Val RMSE:   0.0144, Val R2:   -0.0152
Early epoch 2/100. Skipping checkpointing and early stopping.
Epoch 2/500 completed in 14.56s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 3/500 ---


                                                                    

Train Loss: 0.9999, Train RMSE: 0.0168, Train R2: -0.0002


                                                                           

Val Loss:   0.7129, Val RMSE:   0.0141, Val R2:   0.0196
Early epoch 3/100. Skipping checkpointing and early stopping.
Epoch 3/500 completed in 14.60s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 4/500 ---


                                                                     

Train Loss: 0.9272, Train RMSE: 0.0163, Train R2: 0.0573


                                                                           

Val Loss:   0.6983, Val RMSE:   0.0140, Val R2:   0.0396
Early epoch 4/100. Skipping checkpointing and early stopping.
Epoch 4/500 completed in 14.70s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 5/500 ---


                                                                    

Train Loss: 1.0197, Train RMSE: 0.0170, Train R2: -0.0249


                                                                           

Val Loss:   0.6947, Val RMSE:   0.0139, Val R2:   0.0487
Early epoch 5/100. Skipping checkpointing and early stopping.
Epoch 5/500 completed in 14.78s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 6/500 ---


                                                                     

Train Loss: 1.0006, Train RMSE: 0.0167, Train R2: 0.0030


                                                                           

Val Loss:   0.7794, Val RMSE:   0.0148, Val R2:   -0.0677
Early epoch 6/100. Skipping checkpointing and early stopping.
Epoch 6/500 completed in 14.67s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 7/500 ---


                                                                     

Train Loss: 0.9700, Train RMSE: 0.0166, Train R2: 0.0181


                                                                           

Val Loss:   0.7284, Val RMSE:   0.0142, Val R2:   0.0079
Early epoch 7/100. Skipping checkpointing and early stopping.
Epoch 7/500 completed in 14.66s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 8/500 ---


                                                                     

Train Loss: 0.9721, Train RMSE: 0.0167, Train R2: 0.0136


                                                                           

Val Loss:   0.8013, Val RMSE:   0.0150, Val R2:   -0.0962
Early epoch 8/100. Skipping checkpointing and early stopping.
Epoch 8/500 completed in 14.55s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 9/500 ---


                                                                     

Train Loss: 0.9816, Train RMSE: 0.0166, Train R2: 0.0245


                                                                           

Val Loss:   0.6828, Val RMSE:   0.0138, Val R2:   0.0633
Early epoch 9/100. Skipping checkpointing and early stopping.
Epoch 9/500 completed in 14.60s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 10/500 ---


                                                                     

Train Loss: 0.9326, Train RMSE: 0.0162, Train R2: 0.0610


                                                                           

Val Loss:   0.6723, Val RMSE:   0.0137, Val R2:   0.0810
Early epoch 10/100. Skipping checkpointing and early stopping.
Epoch 10/500 completed in 14.59s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 11/500 ---


                                                                     

Train Loss: 0.9540, Train RMSE: 0.0164, Train R2: 0.0423


                                                                           

Val Loss:   0.7185, Val RMSE:   0.0142, Val R2:   0.0173
Early epoch 11/100. Skipping checkpointing and early stopping.
Epoch 11/500 completed in 14.57s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 12/500 ---


                                                                     

Train Loss: 0.9455, Train RMSE: 0.0162, Train R2: 0.0644


                                                                           

Val Loss:   0.6983, Val RMSE:   0.0140, Val R2:   0.0406
Early epoch 12/100. Skipping checkpointing and early stopping.
Epoch 12/500 completed in 14.63s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 13/500 ---


                                                                     

Train Loss: 0.8979, Train RMSE: 0.0161, Train R2: 0.0789


                                                                           

Val Loss:   0.7284, Val RMSE:   0.0143, Val R2:   -0.0021
Early epoch 13/100. Skipping checkpointing and early stopping.
Epoch 13/500 completed in 14.66s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 14/500 ---


                                                                     

Train Loss: 0.8682, Train RMSE: 0.0157, Train R2: 0.1216


                                                                           

Val Loss:   0.6973, Val RMSE:   0.0139, Val R2:   0.0489
Early epoch 14/100. Skipping checkpointing and early stopping.
Epoch 14/500 completed in 14.62s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 15/500 ---


                                                                     

Train Loss: 0.9056, Train RMSE: 0.0157, Train R2: 0.1218


                                                                           

Val Loss:   0.6729, Val RMSE:   0.0137, Val R2:   0.0769
Early epoch 15/100. Skipping checkpointing and early stopping.
Epoch 15/500 completed in 14.62s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 16/500 ---


                                                                     

Train Loss: 0.8276, Train RMSE: 0.0152, Train R2: 0.1740


                                                                           

Val Loss:   0.6930, Val RMSE:   0.0139, Val R2:   0.0568
Early epoch 16/100. Skipping checkpointing and early stopping.
Epoch 16/500 completed in 14.67s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 17/500 ---


                                                                     

Train Loss: 0.8463, Train RMSE: 0.0154, Train R2: 0.1575


                                                                           

Val Loss:   0.7577, Val RMSE:   0.0145, Val R2:   -0.0358
Early epoch 17/100. Skipping checkpointing and early stopping.
Epoch 17/500 completed in 14.63s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 18/500 ---


                                                                     

Train Loss: 0.8258, Train RMSE: 0.0153, Train R2: 0.1681


                                                                           

Val Loss:   0.6837, Val RMSE:   0.0138, Val R2:   0.0705
Early epoch 18/100. Skipping checkpointing and early stopping.
Epoch 18/500 completed in 14.61s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 19/500 ---


                                                                     

Train Loss: 0.7871, Train RMSE: 0.0149, Train R2: 0.2138


                                                                           

Val Loss:   0.7003, Val RMSE:   0.0140, Val R2:   0.0412
Early epoch 19/100. Skipping checkpointing and early stopping.
Epoch 19/500 completed in 14.61s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 20/500 ---


                                                                     

Train Loss: 0.8284, Train RMSE: 0.0152, Train R2: 0.1793


                                                                           

Val Loss:   0.7451, Val RMSE:   0.0144, Val R2:   -0.0113
Early epoch 20/100. Skipping checkpointing and early stopping.
Epoch 20/500 completed in 14.62s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 21/500 ---


                                                                     

Train Loss: 0.7494, Train RMSE: 0.0146, Train R2: 0.2377


                                                                           

Val Loss:   0.7214, Val RMSE:   0.0142, Val R2:   0.0165
Early epoch 21/100. Skipping checkpointing and early stopping.
Epoch 21/500 completed in 14.70s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 22/500 ---


                                                                     

Train Loss: 0.7717, Train RMSE: 0.0148, Train R2: 0.2246


                                                                           

Val Loss:   0.7876, Val RMSE:   0.0147, Val R2:   -0.0648
Early epoch 22/100. Skipping checkpointing and early stopping.
Epoch 22/500 completed in 14.65s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 23/500 ---


                                                                     

Train Loss: 0.7962, Train RMSE: 0.0149, Train R2: 0.2147


                                                                           

Val Loss:   0.7539, Val RMSE:   0.0145, Val R2:   -0.0285
Early epoch 23/100. Skipping checkpointing and early stopping.
Epoch 23/500 completed in 14.69s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 24/500 ---


                                                                     

Train Loss: 0.7197, Train RMSE: 0.0142, Train R2: 0.2832


                                                                           

Val Loss:   0.7422, Val RMSE:   0.0144, Val R2:   -0.0157
Early epoch 24/100. Skipping checkpointing and early stopping.
Epoch 24/500 completed in 14.66s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 25/500 ---


                                                                     

Train Loss: 0.6941, Train RMSE: 0.0141, Train R2: 0.2972


                                                                           

Val Loss:   0.6937, Val RMSE:   0.0139, Val R2:   0.0573
Early epoch 25/100. Skipping checkpointing and early stopping.
Epoch 25/500 completed in 14.62s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 26/500 ---


                                                                     

Train Loss: 0.7325, Train RMSE: 0.0142, Train R2: 0.2779


                                                                           

Val Loss:   0.7320, Val RMSE:   0.0143, Val R2:   0.0037
Early epoch 26/100. Skipping checkpointing and early stopping.
Epoch 26/500 completed in 14.63s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 27/500 ---


                                                                     

Train Loss: 0.6819, Train RMSE: 0.0138, Train R2: 0.3192


                                                                           

Val Loss:   0.8284, Val RMSE:   0.0151, Val R2:   -0.1220
Early epoch 27/100. Skipping checkpointing and early stopping.
Epoch 27/500 completed in 14.66s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 28/500 ---


                                                                     

Train Loss: 0.6475, Train RMSE: 0.0135, Train R2: 0.3491


                                                                           

Val Loss:   0.8891, Val RMSE:   0.0157, Val R2:   -0.2151
Early epoch 28/100. Skipping checkpointing and early stopping.
Epoch 28/500 completed in 14.70s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 29/500 ---


                                                                     

Train Loss: 0.6540, Train RMSE: 0.0135, Train R2: 0.3497


                                                                           

Val Loss:   0.8553, Val RMSE:   0.0155, Val R2:   -0.1739
Early epoch 29/100. Skipping checkpointing and early stopping.
Epoch 29/500 completed in 14.69s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 30/500 ---


                                                                     

Train Loss: 0.6446, Train RMSE: 0.0135, Train R2: 0.3548


                                                                           

Val Loss:   0.8066, Val RMSE:   0.0150, Val R2:   -0.0993
Early epoch 30/100. Skipping checkpointing and early stopping.
Epoch 30/500 completed in 14.62s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 31/500 ---


                                                                     

Train Loss: 0.6484, Train RMSE: 0.0134, Train R2: 0.3588


                                                                           

Val Loss:   0.7604, Val RMSE:   0.0145, Val R2:   -0.0321
Early epoch 31/100. Skipping checkpointing and early stopping.
Epoch 31/500 completed in 14.62s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 32/500 ---


                                                                     

Train Loss: 0.5851, Train RMSE: 0.0130, Train R2: 0.4013


                                                                           

Val Loss:   0.8326, Val RMSE:   0.0152, Val R2:   -0.1345
Early epoch 32/100. Skipping checkpointing and early stopping.
Epoch 32/500 completed in 14.60s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 33/500 ---


                                                                     

Train Loss: 0.6359, Train RMSE: 0.0133, Train R2: 0.3695


                                                                           

Val Loss:   0.8184, Val RMSE:   0.0151, Val R2:   -0.1145
Early epoch 33/100. Skipping checkpointing and early stopping.
Epoch 33/500 completed in 14.65s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 34/500 ---


                                                                     

Train Loss: 0.6100, Train RMSE: 0.0132, Train R2: 0.3821


                                                                           

Val Loss:   0.8012, Val RMSE:   0.0150, Val R2:   -0.0983
Early epoch 34/100. Skipping checkpointing and early stopping.
Epoch 34/500 completed in 14.60s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 35/500 ---


                                                                     

Train Loss: 0.6325, Train RMSE: 0.0132, Train R2: 0.3808


                                                                           

Val Loss:   0.8967, Val RMSE:   0.0158, Val R2:   -0.2297
Early epoch 35/100. Skipping checkpointing and early stopping.
Epoch 35/500 completed in 14.70s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 36/500 ---


                                                                     

Train Loss: 0.5838, Train RMSE: 0.0129, Train R2: 0.4094


                                                                           

Val Loss:   0.8821, Val RMSE:   0.0157, Val R2:   -0.2021
Early epoch 36/100. Skipping checkpointing and early stopping.
Epoch 36/500 completed in 14.65s
Current LR: 5.00e-05
--------------------------------------------------------------------------------
--- Epoch 37/500 ---


Training:   0%|          | 0/12 [00:00<?, ?it/s]