# Imports
import sys
import os

from pathlib import Path
import yaml
from box import Box

import torch
import torch.optim as optim
import torch.nn.functional as F # Added for interpolation
from torch.utils.data import DataLoader, random_split

import logging
import argparse
from pathlib import Path
import sys

import numpy as np
import pandas as pd # Needed for loading bounds from csv
from tqdm.notebook import tqdm # Use notebook tqdm

import json
import os
from datetime import datetime
import shutil

# --- Metrics --- 
from sklearn.metrics import mean_squared_error, r2_score

# Add src directory to path to import modules
project_root = Path(os.getcwd()).parent  # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

from src.model import UHINet

In [1]:
# Imports
import sys
import os

from pathlib import Path
import yaml
from box import Box

import torch
import torch.optim as optim
import torch.nn.functional as F # Added for interpolation
from torch.utils.data import DataLoader, random_split

import logging
import argparse
from pathlib import Path
import sys

import numpy as np
import pandas as pd # Needed for loading bounds from csv
from tqdm.notebook import tqdm # Use notebook tqdm

import json
import os
from datetime import datetime
import shutil

# Add src directory to path to import modules
project_root = Path(os.getcwd()).parent  # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

from src.model import UHINet

## Configuration

Set up paths and hyperparameters.

In [2]:

from src.ingest.dataloader import CityDataSet
from src.model import UHINet 
from src.train.loss import masked_mae_loss, masked_mse_loss

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [3]:
# %% Configuration / Hyperparameters

# --- Paths ---
data_dir = project_root / "data" # Base directory for processed city data
city_name = "NYC" 
output_dir_base = project_root / "training_runs"

# --- Data Loading ---
resolution_m = 10 # Spatial resolution used for the mosaic/grids (10m for Sentinel)
include_lst = False # Set to True if you processed LST and want to include it

# Path to the LST median file (created by download_data notebook)
single_lst_median_path = data_dir / city_name / "sat_files" / f"lst_{city_name}_median_20210601_to_20210901.npy"

# Path to cloudless mosaic (created by download_data notebook)
cloudless_mosaic_path = data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy"

# --- Model Config ---
clay_model_size = "large" # Should match the checkpoint, e.g., v1.5 is large
# Bands MUST match the order used when creating the cloudless mosaic!
clay_bands = ["blue", "green", "red", "nir"] # ["B02", "B03", "B04", "B08"] from Sentinel data
clay_platform = "sentinel-2-l2a"
clay_gsd = 10 # GSD of the input Sentinel-2 mosaic

# UHINet Args (ConvGRU part)
weather_channels = 6 # Updated: air_temp, rel_humidity, avg_windspeed, wind_direction_cos, wind_direction_sin solar_flux
time_embed_dim = 2 # sin/cos minute_of_day
lst_channels = 1 if include_lst else 0 # Number of channels in the LST input
proj_ch = 15 # Channels after projecting Clay features
gru_hidden_dim = 64
gru_kernel_size = 3

# --- Training Hyperparameters ---
batch_size = 1 # Small dataset,
num_workers = 4
epochs = 50
lr = 1e-4
weight_decay = 0.01
loss_type = 'mse' # 'mae' or 'mse'
patience = 10 # Early stopping patience (will be removed from loop)
cpu = False 

# --- Derived Paths & Sanity Checks ---
data_dir_path = Path(data_dir)
city_data_dir = data_dir_path / city_name
uhi_csv = city_data_dir / "uhi.csv"
uhi_df = pd.read_csv(uhi_csv)
# Paths to station weather data
bronx_weather_csv = city_data_dir / "bronx_weather.csv"
manhattan_weather_csv = city_data_dir / "mahattan_weather.csv"

# ---- Grid Bound ----
bounds = [
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {uhi_csv}: {bounds}")

Loaded bounds from /home/jupyter/UHI/MLC-Project/data/NYC/uhi.csv: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]


## Setup DataLoader

In [5]:
from torch.utils.data import Subset

## IGNORE TIMEZONE WARNING: Timezone is first incorrectly loaded by pandas but then fixed in our dataloader.

print("Initializing dataset...")
try:
    # Note: averaging_window is needed by constructor but might not be used internally if single_lst_median_path is set
    # Use a placeholder value if needed, or ensure the dataloader handles its optional usage.
    placeholder_avg_window = 30 # Example placeholder

    dataset = CityDataSet(
        bounds=bounds,
        averaging_window=placeholder_avg_window, # Pass placeholder
        resolution_m=resolution_m,
        uhi_csv=str(uhi_csv),
        # Use station weather CSVs
        bronx_weather_csv=str(bronx_weather_csv),
        manhattan_weather_csv=str(manhattan_weather_csv),
        cloudless_mosaic_path=str(cloudless_mosaic_path),
        data_dir=str(data_dir_path),
        city_name=city_name,
        include_lst=include_lst,
        single_lst_median_path=str(single_lst_median_path) if include_lst else None
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    # Stop execution or handle error
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Train/Val Split (Sequential) ---
val_percent = 0.40
n_samples = len(dataset)

if n_samples < 10: # Handle very small datasets
    print(f"Warning: Dataset size ({n_samples}) is very small. Using all data for training.")
    n_val = 0
    n_train = n_samples
else:
    n_val = int(n_samples * val_percent)
    n_train = n_samples - n_val

# Create sequential split using Subset
train_indices = list(range(n_train))
val_indices = list(range(n_train, n_samples))

train_ds = Subset(dataset, train_indices)
val_ds = Subset(dataset, val_indices) if n_val > 0 else None # Create val_ds only if n_val > 0

print(f"Sequential dataset split: {len(train_ds)} training, {len(val_ds) if val_ds else 0} validation samples.")

print("Creating dataloaders...")
# Shuffle training data loader, but not validation
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) if val_ds else None

print("Data loading setup complete.")

2025-04-28 00:43:59,146 - INFO - Loading cloudless mosaic from /home/jupyter/UHI/MLC-Project/data/NYC/sat_files/sentinel_NYC_20210601_to_20210901_cloudless_mosaic.npy
2025-04-28 00:43:59,159 - INFO - Loaded mosaic with 4 bands and shape (1122, 1281)
  dt_naive_or_aware = pd.to_datetime(self.bronx_weather['datetime'], errors='raise')
  dt_naive_or_aware = pd.to_datetime(self.manhattan_weather['datetime'], errors='raise')
2025-04-28 00:43:59,185 - INFO - Loaded Bronx weather data: 169 records
2025-04-28 00:43:59,186 - INFO - Loaded Manhattan weather data: 169 records


Initializing dataset...


2025-04-28 00:44:01,072 - INFO - Computed grid cell coordinates and closest station map
2025-04-28 00:44:01,075 - INFO - Grid cells assigned to Bronx: 370359
2025-04-28 00:44:01,077 - INFO - Grid cells assigned to Manhattan: 712983
Precomputing UHI grids: 100%|██████████| 59/59 [00:00<00:00, 333.49it/s]
2025-04-28 00:44:01,264 - INFO - Dataset initialized for NYC with 59 unique timestamps. LST included: False
2025-04-28 00:44:01,265 - INFO - Target grid size (H, W): (1118, 969)


Sequential dataset split: 36 training, 23 validation samples.
Creating dataloaders...
Data loading setup complete.


## Initialize Model

In [5]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
    """Saves model checkpoint.
    Args:
        state (dict): Contains model's state_dict, optimizer state, epoch, etc.
        is_best (bool): True if this is the best model seen so far.
        filename (str): Path to save the latest checkpoint.
        best_filename (str): Path to save the best checkpoint.
    """
    Path(filename).parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_filename)
        print(f"Saved new best model to {best_filename}")

def train_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    num_batches = 0
    epoch_preds = []
    epoch_targets = []
    progress_bar = tqdm(dataloader, desc='Training', leave=False)

    for batch in progress_bar:
        # --- MODIFIED: Use 'cloudless_mosaic' key ---
        required_keys = ['cloudless_mosaic', 'weather_seq', 'time_emb_seq', 'target', 'mask']
        if include_lst:
            required_keys.append('lst_seq')
        if not all(key in batch for key in required_keys):
            missing = [key for key in required_keys if key not in batch]
            logging.warning(f"Skipping batch due to missing keys: {missing}")
            continue

        # Move batch to device
        try:
            cloudless_mosaic = batch["cloudless_mosaic"].to(device)
            weather_seq = batch["weather_seq"].to(device)       # (B, T, C_weather, H, W)
            lst_seq = batch["lst_seq"].to(device) if include_lst else None # (B, T, C_lst, H, W) - T=1
            time_emb_seq = batch["time_emb_seq"].to(device)     # (B, T, C_time, H, W)
            target = batch["target"].to(device)               # (B, H, W)
            mask = batch["mask"].to(device, dtype=torch.bool) # Ensure mask is boolean
        except Exception as e:
            logging.error(f"Error moving batch to device: {e}")
            continue # Skip batch if moving fails

        optimizer.zero_grad()
        try:
            B, T, C_weather, H_in, W_in = weather_seq.shape
            _, _, C_time, _, _ = time_emb_seq.shape

            # 1. Encode static features ONCE
            static_lst_map = lst_seq[:, 0, :, :, :] if include_lst and lst_seq is not None else None # Get T=0 slice
            with torch.no_grad(): # Ensure Clay backbone remains frozen
                 static_features = model.encode_and_project_static(cloudless_mosaic, static_lst_map)
            _, C_static, H_feat, W_feat = static_features.shape

            # 2. Initialize hidden state
            h = torch.zeros(B, model.gru_hidden_dim, H_feat, W_feat, device=device)

            # 3. Resize dynamic features if needed
            if weather_seq.shape[3:] != (H_feat, W_feat):
                weather_seq_resized = F.interpolate(weather_seq.view(B*T, C_weather, H_in, W_in), size=(H_feat, W_feat), mode='bilinear', align_corners=False).view(B, T, C_weather, H_feat, W_feat)
            else:
                weather_seq_resized = weather_seq
            if time_emb_seq.shape[3:] != (H_feat, W_feat):
                time_emb_seq_resized = F.interpolate(time_emb_seq.view(B*T, C_time, H_in, W_in), size=(H_feat, W_feat), mode='bilinear', align_corners=False).view(B, T, C_time, H_feat, W_feat)
            else:
                time_emb_seq_resized = time_emb_seq

            # 4. Loop through time steps
            for t in range(T):
                weather_t = weather_seq_resized[:, t, :, :, :]      # (B, C_weather, H', W')
                time_emb_t = time_emb_seq_resized[:, t, :, :, :]    # (B, C_time, H', W')
                x_t_combined = torch.cat([static_features, weather_t, time_emb_t], dim=1)
                h = model.step(x_t_combined, h)

            # 5. Predict from final hidden state
            prediction = model.predict(h) # (B, 1, H', W')

            # Resize prediction to target size if needed
            if prediction.shape[2:] != target.shape[1:]:
                 prediction_resized = F.interpolate(prediction, size=target.shape[1:], mode='bilinear', align_corners=False)
            else:
                 prediction_resized = prediction
            prediction_final = prediction_resized.squeeze(1) # Shape (B, H, W)

            # --- Calculate Loss --- 
            loss = loss_fn(prediction_final, target, mask) # Use boolean mask directly if loss supports it

            # Check for NaN loss
            if torch.isnan(loss):
                 logging.warning("NaN loss detected, skipping backward pass.")
                 continue

            loss.backward()
            optimizer.step()

            # --- Store valid predictions/targets for metrics ---
            with torch.no_grad():
                valid_preds = prediction_final[mask].cpu().numpy()
                valid_targets = target[mask].cpu().numpy()
                if valid_preds.size > 0: # Only append if there are valid points in the batch
                    epoch_preds.append(valid_preds)
                    epoch_targets.append(valid_targets)
            # --------------------------------------------------

            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix(loss=loss.item())

        except RuntimeError as e:
             logging.error(f"Runtime error during training: {e}")
             if "out of memory" in str(e):
                 logging.error("CUDA out of memory. Try reducing batch size.")
             continue # Skip this batch
        except Exception as e:
             logging.error(f"Unexpected error during training step: {e}", exc_info=True)
             continue # Skip this batch

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    
    # --- Calculate Epoch Metrics ---
    rmse = np.nan
    r2 = np.nan
    if epoch_preds:
        all_preds = np.concatenate(epoch_preds)
        all_targets = np.concatenate(epoch_targets)
        if all_preds.size > 0 and all_targets.size > 0:
             try:
                 rmse = np.sqrt(mean_squared_error(all_targets, all_preds))
                 r2 = r2_score(all_targets, all_preds)
             except Exception as metric_e:
                  logging.error(f"Error calculating epoch metrics: {metric_e}")
    # -------------------------------
    
    return avg_loss, rmse, r2

def validate_epoch(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    num_batches = 0
    epoch_preds = []
    epoch_targets = []
    progress_bar = tqdm(dataloader, desc='Validation', leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            # Ensure all required keys are present
            required_keys = ['cloudless_mosaic', 'weather_seq', 'time_emb_seq', 'target', 'mask']
            if include_lst:
                required_keys.append('lst_seq')
            if not all(key in batch for key in required_keys):
                missing = [key for key in required_keys if key not in batch]
                logging.warning(f"Skipping validation batch due to missing keys: {missing}")
                continue

            try:
                # Move batch to device
                cloudless_mosaic = batch["cloudless_mosaic"].to(device)
                weather_seq = batch["weather_seq"].to(device)
                lst_seq = batch["lst_seq"].to(device) if include_lst else None
                time_emb_seq = batch["time_emb_seq"].to(device)
                target = batch["target"].to(device)
                mask = batch["mask"].to(device, dtype=torch.bool) # Ensure mask is boolean

                # --- Model Forward Pass --- 
                B, T, C_weather, H_in, W_in = weather_seq.shape
                _, _, C_time, _, _ = time_emb_seq.shape

                static_lst_map = lst_seq[:, 0, :, :, :] if include_lst and lst_seq is not None else None
                static_features = model.encode_and_project_static(cloudless_mosaic, static_lst_map)
                _, C_static, H_feat, W_feat = static_features.shape

                h = torch.zeros(B, model.gru_hidden_dim, H_feat, W_feat, device=device)

                if weather_seq.shape[3:] != (H_feat, W_feat):
                    weather_seq_resized = F.interpolate(weather_seq.view(B*T, C_weather, H_in, W_in), size=(H_feat, W_feat), mode='bilinear', align_corners=False).view(B, T, C_weather, H_feat, W_feat)
                else:
                    weather_seq_resized = weather_seq
                if time_emb_seq.shape[3:] != (H_feat, W_feat):
                    time_emb_seq_resized = F.interpolate(time_emb_seq.view(B*T, C_time, H_in, W_in), size=(H_feat, W_feat), mode='bilinear', align_corners=False).view(B, T, C_time, H_feat, W_feat)
                else:
                    time_emb_seq_resized = time_emb_seq

                for t in range(T):
                    weather_t = weather_seq_resized[:, t, :, :, :]
                    time_emb_t = time_emb_seq_resized[:, t, :, :, :]
                    x_t_combined = torch.cat([static_features, weather_t, time_emb_t], dim=1)
                    h = model.step(x_t_combined, h)

                prediction = model.predict(h)

                # --- Resize and Process Prediction --- 
                if prediction.shape[2:] != target.shape[1:]:
                    prediction_resized = F.interpolate(prediction, size=target.shape[1:], mode='bilinear', align_corners=False)
                else:
                    prediction_resized = prediction
                prediction_final = prediction_resized.squeeze(1) # Shape (B, H, W)

                # Calculate loss
                loss = loss_fn(prediction_final, target, mask)

                if torch.isnan(loss):
                    logging.warning("NaN validation loss detected, skipping batch.")
                    continue
                
                # --- Store valid predictions/targets for metrics ---
                valid_preds = prediction_final[mask].cpu().numpy()
                valid_targets = target[mask].cpu().numpy()
                if valid_preds.size > 0:
                    epoch_preds.append(valid_preds)
                    epoch_targets.append(valid_targets)
                # --------------------------------------------------

                total_loss += loss.item()
                num_batches += 1
                progress_bar.set_postfix(loss=loss.item())

            except Exception as e:
                 logging.error(f"Error during validation step: {e}", exc_info=True)
                 continue # Skip batch on error

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    
    # --- Calculate Epoch Metrics ---
    rmse = np.nan
    r2 = np.nan
    if epoch_preds:
        all_preds = np.concatenate(epoch_preds)
        all_targets = np.concatenate(epoch_targets)
        if all_preds.size > 0 and all_targets.size > 0:
            try:
                rmse = np.sqrt(mean_squared_error(all_targets, all_preds))
                r2 = r2_score(all_targets, all_preds)
            except Exception as metric_e:
                 logging.error(f"Error calculating validation epoch metrics: {metric_e}")
    # -------------------------------
    
    return avg_loss, rmse, r2

## Helper Functions for training

In [6]:
def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
    """Saves model checkpoint.
    Args:
        state (dict): Contains model's state_dict, optimizer state, epoch, etc.
        is_best (bool): True if this is the best model seen so far.
        filename (str): Path to save the latest checkpoint.
        best_filename (str): Path to save the best checkpoint.
    """
    Path(filename).parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_filename)
        print(f"Saved new best model to {best_filename}")

def train_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Training', leave=False)

    for batch in progress_bar:
        # --- MODIFIED: Use 'cloudless_mosaic' key ---
        required_keys = ['cloudless_mosaic', 'weather_seq', 'time_emb_seq', 'target', 'mask']
        if include_lst:
            required_keys.append('lst_seq')
        if not all(key in batch for key in required_keys):
            missing = [key for key in required_keys if key not in batch]
            logging.warning(f"Skipping batch due to missing keys: {missing}")
            continue

        # Move batch to device
        try:
            cloudless_mosaic = batch["cloudless_mosaic"].to(device)
            weather_seq = batch["weather_seq"].to(device)       # (B, T, C_weather, H, W)
            lst_seq = batch["lst_seq"].to(device) if include_lst else None # (B, T, C_lst, H, W) - T=1
            time_emb_seq = batch["time_emb_seq"].to(device)     # (B, T, C_time, H, W)
            target = batch["target"].to(device)               # (B, H, W)
            mask = batch["mask"].to(device)                   # (B, H, W)
        except Exception as e:
            logging.error(f"Error moving batch to device: {e}")
            continue # Skip batch if moving fails

        optimizer.zero_grad()
        try:
            B, T, C_weather, H_in, W_in = weather_seq.shape
            _, _, C_time, _, _ = time_emb_seq.shape

            # 1. Encode static features ONCE
            static_lst_map = lst_seq[:, 0, :, :, :] if include_lst and lst_seq is not None else None # Get T=0 slice
            # static_features shape: (B, proj_ch [+ C_lst], H', W')
            with torch.no_grad(): # Ensure Clay backbone remains frozen
                 # --- MODIFIED: Use 'cloudless_mosaic' key ---
                 static_features = model.encode_and_project_static(cloudless_mosaic, static_lst_map)
            _, C_static, H_feat, W_feat = static_features.shape

            # 2. Initialize hidden state
            h = torch.zeros(B, model.gru_hidden_dim, H_feat, W_feat, device=device)

            # 3. Resize dynamic features if needed
            if weather_seq.shape[3:] != (H_feat, W_feat):
                weather_seq_resized = F.interpolate(weather_seq.view(B*T, C_weather, H_in, W_in), size=(H_feat, W_feat), mode='bilinear', align_corners=False).view(B, T, C_weather, H_feat, W_feat)
            else:
                weather_seq_resized = weather_seq
            if time_emb_seq.shape[3:] != (H_feat, W_feat):
                time_emb_seq_resized = F.interpolate(time_emb_seq.view(B*T, C_time, H_in, W_in), size=(H_feat, W_feat), mode='bilinear', align_corners=False).view(B, T, C_time, H_feat, W_feat)
            else:
                time_emb_seq_resized = time_emb_seq

            # 4. Loop through time steps
            for t in range(T):
                weather_t = weather_seq_resized[:, t, :, :, :]      # (B, C_weather, H', W')
                time_emb_t = time_emb_seq_resized[:, t, :, :, :]    # (B, C_time, H', W')
                # Concatenate static + dynamic features
                x_t_combined = torch.cat([static_features, weather_t, time_emb_t], dim=1)
                # Update hidden state
                h = model.step(x_t_combined, h)

            # 5. Predict from final hidden state
            prediction = model.predict(h) # (B, 1, H', W')
            # --------------------------

            # Resize prediction to target size if needed
            if prediction.shape[2:] != target.shape[1:]:
                 prediction_resized = F.interpolate(prediction, size=target.shape[1:], mode='bilinear', align_corners=False)
            else:
                 prediction_resized = prediction

            # Calculate loss
            loss = loss_fn(prediction_resized.squeeze(1), target, mask)

            # Check for NaN loss
            if torch.isnan(loss):
                 logging.warning("NaN loss detected, skipping backward pass.")
                 continue

            loss.backward()
            optimizer.step()

            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix(loss=loss.item())

        except RuntimeError as e:
             logging.error(f"Runtime error during training: {e}")
             if "out of memory" in str(e):
                 logging.error("CUDA out of memory. Try reducing batch size.")
                 # Consider breaking the loop or stopping training
                 # break # Or raise e
             continue # Skip this batch
        except Exception as e:
             logging.error(f"Unexpected error during training step: {e}", exc_info=True)
             continue # Skip this batch

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

def validate_epoch(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Validation', leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            # Ensure all required keys are present
            # --- MODIFIED: Use 'cloudless_mosaic' key ---
            required_keys = ['cloudless_mosaic', 'weather_seq', 'time_emb_seq', 'target', 'mask']
            if include_lst:
                required_keys.append('lst_seq')
            if not all(key in batch for key in required_keys):
                missing = [key for key in required_keys if key not in batch]
                logging.warning(f"Skipping validation batch due to missing keys: {missing}")
                continue

            try:
                # Move batch to device
                # --- MODIFIED: Use 'cloudless_mosaic' key ---
                cloudless_mosaic = batch["cloudless_mosaic"].to(device)
                weather_seq = batch["weather_seq"].to(device)
                lst_seq = batch["lst_seq"].to(device) if include_lst else None
                time_emb_seq = batch["time_emb_seq"].to(device)
                target = batch["target"].to(device)
                mask = batch["mask"].to(device)

                # --- New Validation Logic ---
                B, T, C_weather, H_in, W_in = weather_seq.shape
                _, _, C_time, _, _ = time_emb_seq.shape

                # 1. Encode static features ONCE
                static_lst_map = lst_seq[:, 0, :, :, :] if include_lst and lst_seq is not None else None
                # --- MODIFIED: Use 'cloudless_mosaic' key ---
                static_features = model.encode_and_project_static(cloudless_mosaic, static_lst_map)
                _, C_static, H_feat, W_feat = static_features.shape

                # 2. Initialize hidden state
                h = torch.zeros(B, model.gru_hidden_dim, H_feat, W_feat, device=device)

                # 3. Resize dynamic features if needed
                if weather_seq.shape[3:] != (H_feat, W_feat):
                    weather_seq_resized = F.interpolate(weather_seq.view(B*T, C_weather, H_in, W_in), size=(H_feat, W_feat), mode='bilinear', align_corners=False).view(B, T, C_weather, H_feat, W_feat)
                else:
                    weather_seq_resized = weather_seq
                if time_emb_seq.shape[3:] != (H_feat, W_feat):
                    time_emb_seq_resized = F.interpolate(time_emb_seq.view(B*T, C_time, H_in, W_in), size=(H_feat, W_feat), mode='bilinear', align_corners=False).view(B, T, C_time, H_feat, W_feat)
                else:
                    time_emb_seq_resized = time_emb_seq

                # 4. Loop through time steps
                for t in range(T):
                    weather_t = weather_seq_resized[:, t, :, :, :]
                    time_emb_t = time_emb_seq_resized[:, t, :, :, :]
                    x_t_combined = torch.cat([static_features, weather_t, time_emb_t], dim=1)
                    h = model.step(x_t_combined, h)

                # 5. Predict from final hidden state
                prediction = model.predict(h)
                # --------------------------

                # Resize prediction to target size if needed
                if prediction.shape[2:] != target.shape[1:]:
                    prediction_resized = F.interpolate(prediction, size=target.shape[1:], mode='bilinear', align_corners=False)
                else:
                    prediction_resized = prediction

                # Calculate loss
                loss = loss_fn(prediction_resized.squeeze(1), target, mask)

                if torch.isnan(loss):
                    logging.warning("NaN validation loss detected, skipping batch.")
                    continue

                total_loss += loss.item()
                num_batches += 1
                progress_bar.set_postfix(loss=loss.item())

            except Exception as e:
                 logging.error(f"Error during validation step: {e}", exc_info=True)
                 continue # Skip batch on error

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

## Training

In [7]:
best_val_loss = float('inf')
epochs_no_improve = 0
optimizer = optim.AdamW(model.parameters(), lr=lr,weight_decay=0.01)
loss_fn= masked_mae_loss if loss_type=="mae" else masked_mse_loss

# Create output directory for this run
run_name = f"{city_name}_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
output_dir = Path(output_dir_base) / run_name
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Checkpoints and logs will be saved to: {output_dir}")

# Save configuration used for this run
try:
    config_dict = {
        "data_dir": str(data_dir),
        "city_name": city_name,
        "output_dir": str(output_dir),
        "bounds": bounds,
        "resolution_m": resolution_m,
        "include_lst": include_lst,
        "single_lst_median_path": str(single_lst_median_path) if include_lst else None,
        "uhi_csv": str(uhi_csv),
        "bronx_weather_csv": str(bronx_weather_csv),
        "manhattan_weather_csv": str(manhattan_weather_csv),
        "cloudless_mosaic_path": str(cloudless_mosaic_path),
        "clay_model_size": clay_model_size,
        "clay_bands": clay_bands,
        "clay_platform": clay_platform,
        "clay_gsd": clay_gsd,
        "clay_checkpoint_path_local": clay_checkpoint_path,
        "clay_metadata_path_local": clay_metadata_path,
        "weather_channels": weather_channels,
        "time_embed_dim": time_embed_dim,
        "lst_channels": lst_channels,
        "proj_ch": proj_ch,
        "gru_hidden_dim": gru_hidden_dim,
        "gru_kernel_size": gru_kernel_size,
        "batch_size": batch_size,
        "num_workers": num_workers,
        "epochs": epochs,
        "lr": lr,
        "weight_decay": weight_decay,
        "loss_type": loss_type,
        "patience": patience,
        "device": str(device)
    }
    with open(output_dir / "config.json", 'w') as f:
        json.dump(config_dict, f, indent=2)
    print("Saved configuration to config.json")
except Exception as e:
    print(f"Warning: Failed to save configuration: {e}")


print("Starting training...")
for epoch in range(epochs):
    print(f"--- Epoch {epoch+1}/{epochs} ---")
    train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device)

    if val_loader:
        val_loss = validate_epoch(model, val_loader, loss_fn, device)
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
        current_loss = val_loss
        if np.isnan(current_loss):
             print("Warning: Validation loss is NaN. Stopping training.")
             break
    else:
            print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} (No validation set)")
            current_loss = train_loss # Use train loss for checkpointing if no val set
            if np.isnan(current_loss):
                 print("Warning: Training loss is NaN. Stopping training.")
                 break

    is_best = current_loss < best_val_loss
    if is_best:
        best_val_loss = current_loss
        epochs_no_improve = 0
        print(f"New best loss: {best_val_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"No improvement in validation loss for {epochs_no_improve} epochs.")

    # Save checkpoint
    save_checkpoint(
        {'epoch': epoch + 1,
         'state_dict': model.state_dict(),
         'best_val_loss': best_val_loss,
         'optimizer' : optimizer.state_dict(),
         'config': config_dict # Save config with checkpoint
         },
        is_best,
        filename=output_dir / 'checkpoint_last.pth.tar',
        best_filename=output_dir / 'model_best.pth.tar'
    )

    # Early stopping check
    if epochs_no_improve >= patience:
        print(f"Early stopping triggered after {patience} epochs with no improvement.")
        break

print("Training finished.")
print(f"Best loss recorded: {best_val_loss:.4f}")
print(f"Checkpoints saved in: {output_dir}")

Checkpoints and logs will be saved to: /home/jupyter/UHI/MLC-Project/training_runs/NYC_run_20250428_003814
Saved configuration to config.json
Starting training...
--- Epoch 1/50 ---


Training:   0%|          | 0/54 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 1: Train Loss=0.5585, Val Loss=0.2308
New best loss: 0.2308
Saved new best model to /home/jupyter/UHI/MLC-Project/training_runs/NYC_run_20250428_003814/model_best.pth.tar
--- Epoch 2/50 ---


Training:   0%|          | 0/54 [00:00<?, ?it/s]

Validation:   0%|          | 0/5 [00:00<?, ?it/s]

Epoch 2: Train Loss=0.0951, Val Loss=0.0117
New best loss: 0.0117


KeyboardInterrupt: 