# Branched UHI Model Training

In [7]:
# Your code here 

In [41]:
# %% Configuration / Hyperparameters for BranchedUHIModel (ConvLSTM + Common Resampling)

# --- Import utils ---
from src.train.train_utils import check_path
# -------------------

# --- Paths & Basic Info ---
project_root_str = str(project_root) # Store as string for config
data_dir_base = project_root / "data"
city_name = "NYC"
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_BranchedUHI" # Modified prefix
run_dir = output_dir / wander_run_name_prefix

# --- Data Loading Config ---
# NEW: Define the common resolution for spatial features entering the model
feature_resolution_m = 30 # Start with 10m (matches Clay/UHI grid)

# Define UHI grid resolution separately (used for final target matching)
# This assumes UHI data corresponds to 10m grid
uhi_grid_resolution_m = 10

weather_seq_length = 60

# Input Data Paths (relative)
relative_data_dir = Path("data")
relative_uhi_csv = relative_data_dir / city_name / "uhi.csv"
relative_bronx_weather_csv = relative_data_dir / city_name / "bronx_weather.csv"
relative_manhattan_weather_csv = relative_data_dir / city_name / "manhattan_weather.csv"
relative_dem_path = relative_data_dir / city_name / "sat_files" / "nyc_dem_10m_pc.tif"
relative_dsm_path = relative_data_dir / city_name / "sat_files" / "nyc_dsm_10m_pc.tif"
relative_cloudless_mosaic_path = relative_data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy"
relative_single_lst_median_path = relative_data_dir / city_name / "sat_files" / f"lst_{city_name}_median_20210601_to_20210901.npy"

# Nodata values
elevation_nodata = -9999.0
lst_nodata = 0.0 # Or appropriate value for LST median file

# --- Feature Selection Flags ---
# IMPORTANT: Only enable the features you want to use
# The Model will expect exactly these features
feature_flags = {
    "use_dem": True,              # Digital Elevation Model
    "use_dsm": True,              # Digital Surface Model
    "use_clay": True,             # Clay feature extractor
    "use_sentinel_composite": False, # Raw Sentinel-2 bands
    "use_lst": False,             # Land Surface Temperature
    "use_ndvi": False,            # Normalized Difference Vegetation Index
    "use_ndbi": False,            # Normalized Difference Built-up Index
    "use_ndwi": False,            # Normalized Difference Water Index
}

# --- Bands for Sentinel Composite (if use_sentinel_composite is True) --- #
sentinel_bands_to_load = ["blue", "green", "red", "nir", "swir16", "swir22"]

# --- Model Config (BranchedUHIModel with ConvLSTM, No separate Elev branches) ---

# Clay Backbone (if feature_flags["use_clay"] is True)
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = True
relative_clay_checkpoint_path = "notebooks/clay-v1.5.ckpt"
relative_clay_metadata_path = Path("src") / "Clay" / "configs" / "metadata.yaml"

# Temporal Weather Processor (ConvLSTM)
weather_input_channels = 6
convlstm_hidden_dims = [16, 8]
convlstm_kernel_sizes = [(3,3), (3,3)]
convlstm_num_layers = len(convlstm_hidden_dims)

# --- REMOVED High-Res Elevation Branch Config ---

# U-Net Head
unet_base_channels = 16
unet_depth = 2 # <<< REDUCED from 3

# Projection Layer Channels
proj_static_ch = 8 # For projecting ALL static feats (Clay, LST, DEM, DSM, Indices)
proj_temporal_ch = 8 # For projecting ConvLSTM output

# --- Training Hyperparameters ---
num_workers = 1
epochs = 500
lr = 5e-5
weight_decay = 0.01
loss_type = 'mse'
patience = 50
cpu = False

# V100 Suggestion: Target batch size 2 => n_train_batches = 47 / 1 = 47
n_train_batches = 47

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- Resolve Paths using check_path from train_utils ---
absolute_uhi_csv = check_path(relative_uhi_csv, project_root, "UHI CSV")
absolute_bronx_weather_csv = check_path(relative_bronx_weather_csv, project_root, "Bronx Weather CSV")
absolute_manhattan_weather_csv = check_path(relative_manhattan_weather_csv, project_root, "Manhattan Weather CSV")

# Always check paths needed by dataloader, flags control usage inside
absolute_dem_path = check_path(relative_dem_path, project_root, "DEM TIF")
absolute_dsm_path = check_path(relative_dsm_path, project_root, "DSM TIF")
absolute_clay_checkpoint_path = check_path(relative_clay_checkpoint_path, project_root, "Clay Checkpoint")
absolute_clay_metadata_path = check_path(relative_clay_metadata_path, project_root, "Clay Metadata")
absolute_cloudless_mosaic_path = check_path(relative_cloudless_mosaic_path, project_root, "Cloudless Mosaic")
absolute_single_lst_median_path = check_path(relative_single_lst_median_path, project_root, "Single LST Median", should_exist=feature_flags["use_lst"]) # Check only if flag is true

# --- Calculate Bounds ---
uhi_df = pd.read_csv(absolute_uhi_csv)
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds = [
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {absolute_uhi_csv.name}: {bounds}")

# --- Central Config Dictionary --- #
config = {
    # Paths & Info
    "model_type": "BranchedUHIModel_CommonRes",
    "project_root": project_root_str,
    "run_dir": run_dir,
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "feature_resolution_m": feature_resolution_m, # NEW
    "uhi_grid_resolution_m": uhi_grid_resolution_m, # For reference
    "weather_seq_length": weather_seq_length,
    "uhi_csv": str(relative_uhi_csv),
    "bronx_weather_csv": str(absolute_bronx_weather_csv),
    "manhattan_weather_csv": str(absolute_manhattan_weather_csv),
    "bounds": bounds,
    "feature_flags": feature_flags,
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path": str(absolute_dem_path) if absolute_dem_path else None,
    "dsm_path": str(absolute_dsm_path) if absolute_dsm_path else None,
    "elevation_nodata": elevation_nodata,
    "cloudless_mosaic_path": str(absolute_cloudless_mosaic_path) if absolute_cloudless_mosaic_path else None,
    "single_lst_median_path": str(absolute_single_lst_median_path) if absolute_single_lst_median_path else None,
    "lst_nodata": lst_nodata,
    # Model Config
    "weather_input_channels": weather_input_channels,
    "convlstm_hidden_dims": convlstm_hidden_dims,
    "convlstm_kernel_sizes": convlstm_kernel_sizes,
    "convlstm_num_layers": convlstm_num_layers,
    "proj_static_ch": proj_static_ch,
    "proj_temporal_ch": proj_temporal_ch,
    "unet_base_channels": unet_base_channels,
    "unet_depth": unet_depth, # Updated
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(absolute_clay_checkpoint_path) if feature_flags["use_clay"] else None,
    "clay_metadata_path": str(absolute_clay_metadata_path) if feature_flags["use_clay"] else None,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience,
    "device": str(device)
}

print("\nBranched Model Configuration dictionary created:")
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x))


Using device: cuda
Loaded bounds from uhi.csv: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]

Branched Model Configuration dictionary created:
{
  "model_type": "BranchedUHIModel_CommonRes",
  "project_root": "/home/jupyter/MLC-Project",
  "run_dir": "/home/jupyter/MLC-Project/training_runs/train_20250506_054808/NYC_BranchedUHI",
  "city_name": "NYC",
  "wandb_project_name": "MLC_UHI_Proj",
  "wander_run_name_prefix": "NYC_BranchedUHI",
  "feature_resolution_m": 30,
  "uhi_grid_resolution_m": 10,
  "weather_seq_length": 60,
  "uhi_csv": "data/NYC/uhi.csv",
  "bronx_weather_csv": "/home/jupyter/MLC-Project/data/NYC/bronx_weather.csv",
  "manhattan_weather_csv": "/home/jupyter/MLC-Project/data/NYC/manhattan_weather.csv",
  "bounds": [
    -73.99445667,
    40.75879167,
    -73.87945833,
    40.85949667
  ],
  "feature_flags": {
    "use_dem": true,
    "use_dsm": true,
    "use_clay": true,
    "use_sentinel_composite": false,
    

In [42]:
# %% Data Loading and Preprocessing (Branched Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import (
    calculate_uhi_stats, # Removed split_data
    create_dataloaders
)
from torch.utils.data import Subset # Import Subset
# -------------------

print("Initializing BranchedCityDataSet...")
try:
    dataset = CityDataSetBranched(
        bounds=config["bounds"],
        feature_resolution_m=config["feature_resolution_m"], # Corrected param name
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"], # Corrected param name
        uhi_csv=absolute_uhi_csv, # Use absolute path resolved earlier
        bronx_weather_csv=absolute_bronx_weather_csv,
        manhattan_weather_csv=absolute_manhattan_weather_csv,
        data_dir=project_root_str,
        city_name=config["city_name"],
        feature_flags=config["feature_flags"],
        sentinel_bands_to_load=config["sentinel_bands_to_load"],
        dem_path=config["dem_path"], # Corrected param name
        dsm_path=config["dsm_path"], # Corrected param name
        elevation_nodata=config["elevation_nodata"], # Corrected param name
        cloudless_mosaic_path=config["cloudless_mosaic_path"],
        single_lst_median_path=config["single_lst_median_path"],
        lst_nodata=config["lst_nodata"], # Added missing param
        weather_seq_length=config["weather_seq_length"],
        target_crs_str=config.get("target_crs_str", "EPSG:4326") # Added optional param
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Sequential Train/Val Split --- #
val_percent = 0.20 # Keep the percentage definition
num_samples = len(dataset)
if num_samples < 2: # Need at least one for train and one for val
    raise ValueError(f"Dataset has only {num_samples} samples, cannot perform train/val split.")

n_train = int(num_samples * (1 - val_percent))
n_val = num_samples - n_train

if n_train == 0 or n_val == 0:
    raise ValueError(f"Split resulted in zero samples for train ({n_train}) or validation ({n_val}). Adjust val_percent or check dataset size.")

train_indices = list(range(n_train))
val_indices = list(range(n_train, num_samples))

train_ds = Subset(dataset, train_indices)
val_ds = Subset(dataset, val_indices)

print(f"Sequential dataset split: {len(train_ds)} training (indices 0-{n_train-1}), {len(val_ds)} validation (indices {n_train}-{num_samples-1}) samples.")

# --- Calculate UHI Mean and Std from Training Data ONLY --- #
uhi_mean, uhi_std = calculate_uhi_stats(train_ds)
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
train_loader, val_loader = create_dataloaders(
    train_ds,
    val_ds,
    n_train_batches=config['n_train_batches'],
    num_workers=config['num_workers'],
    device=device # Pass device from config cell
)


2025-05-06 05:57:08,724 - INFO - Target FEATURE grid size (H, W): (373, 323) @ 30m, CRS: EPSG:4326
2025-05-06 05:57:08,725 - INFO - Target UHI grid size (H, W): (1118, 969) @ 10m


Initializing BranchedCityDataSet...


Precomputing UHI grids: 100%|██████████| 59/59 [00:00<00:00, 1141.13it/s]
2025-05-06 05:57:08,798 - INFO - Loading DEM from: /home/jupyter/MLC-Project/data/NYC/sat_files/nyc_dem_10m_pc.tif
2025-05-06 05:57:08,848 - INFO - DEM loaded raw shape: (1, 1120, 1279)
2025-05-06 05:57:08,909 - INFO - Clipping DEM to bounds: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]
2025-05-06 05:57:08,911 - INFO - Opened DEM (lazy load). Native shape (approx): (1, 1120, 1279)
2025-05-06 05:57:08,911 - INFO - Loading DSM from: /home/jupyter/MLC-Project/data/NYC/sat_files/nyc_dsm_10m_pc.tif
2025-05-06 05:57:08,971 - INFO - DSM loaded raw shape: (1, 1120, 1279)
2025-05-06 05:57:09,044 - INFO - Clipping DSM to bounds: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]
2025-05-06 05:57:09,045 - INFO - Opened DSM (lazy load). Native shape (approx): (1, 1120, 1279)
2025-05-06 05:57:09,046 - INFO - Calculati

Sequential dataset split: 47 training (indices 0-46), 12 validation (indices 47-58) samples.


Calculating stats: 100%|██████████| 47/47 [00:00<00:00, 789.31it/s]
2025-05-06 05:57:09,134 - INFO - Training UHI Mean: 1.0006, Std Dev: 0.0168
2025-05-06 05:57:09,135 - INFO - Creating dataloaders...
2025-05-06 05:57:09,136 - INFO - Using Train Batch Size: 1
2025-05-06 05:57:09,137 - INFO - Using Validation Batch Size: 1
2025-05-06 05:57:09,137 - INFO - Data loading setup complete.


In [45]:
# --- Helper functions ---
# Get the warmup epochs from config or default to 5
warmup_epochs = config.get("warmup_epochs", 5)

# --- Training Loop ---
# Initialize metrics tracking
best_val_loss = float('inf')
best_val_rmse = float('inf')
best_val_r2 = -float('inf')
patience_counter = 0
training_log = []

# Create run directory
model_save_dir = Path(config['run_dir']) / "checkpoints"

# Save config to run directory
config_path = Path(config['run_dir']) / "config.json"
with open(config_path, 'w') as f:
    json.dump(config, f)

# Initialize wandb
try:
    import wandb
    wandb_available = True
    print("Weights & Biases (wandb) available for logging.")
except ImportError:
    wandb_available = False
    wandb = None
    print("Weights & Biases (wandb) not available. Skipping wandb logging.")

if wandb_available:
    # Configure wandb
    wandb.init(
        project=config['wandb_project_name'],
        name=f"{config['run_name_prefix']}_{datetime.now().strftime('%Y%m%d_%H%M')}",
        config=config
    )
    
    # Optional: watch model parameters
    wandb.watch(model)

print(f"Starting training for {config['epochs']} epochs with patience {config['patience']}")

try:
    for epoch in range(config['epochs']):
        epoch_start = time.time()
        
        # --- Train --- #
        if train_loader:
            # Use generic train function from train_utils
            train_loss, train_rmse, train_r2 = train_utils.train_epoch_generic(
                model, train_loader, optimizer, loss_fn, device, uhi_mean, uhi_std,
                max_grad_norm=config.get("max_grad_norm", 1.0)  # Use max_grad_norm parameter
            )
            print(f"Train Loss: {train_loss:.4f}, Train RMSE: {train_rmse:.4f}, Train R2: {train_r2:.4f}")
            if np.isnan(train_loss):
                print("Warning: Training loss is NaN. Stopping training.")
                break
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        else:
            print("Skipping training: train_loader is None.")
            train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        
        # Log train metrics AFTER checking for NaN
        if wandb:
            wandb.log(log_metrics)
        training_log.append(log_metrics) # Append to local log regardless of W&B


        # --- Validate --- #
        if val_loader:
            # Use generic validate function from train_utils
            val_loss, val_rmse, val_r2 = train_utils.validate_epoch_generic(
                model, val_loader, loss_fn, device, uhi_mean, uhi_std
            )
            print(f"Val Loss:   {val_loss:.4f}, Val RMSE:   {val_rmse:.4f}, Val R2:   {val_r2:.4f}")
            if np.isnan(val_loss):
                print("Warning: Validation Loss is NaN. Stopping training.")
                break
            val_metrics = {"val_loss": val_loss, "val_rmse": val_rmse, "val_r2": val_r2}
            log_metrics.update(val_metrics)
            
            # Log validation metrics
            if wandb:
                wandb.log(val_metrics)
            
            # Warmup period: don't save or check early stopping until after warmup_epochs
            if epoch >= warmup_epochs:
                # Check for improvement (using validation loss now)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_val_rmse = val_rmse
                    best_val_r2 = val_r2
                    patience_counter = 0
                    
                    # Save best model
                    model_save_path = model_save_dir / f"best_model_epoch{epoch+1}.pt"
                    torch.save({
                        'epoch': epoch + 1,
                        'model_state_dict': model.state_dict(),
                        'optimizer_state_dict': optimizer.state_dict(),
                        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
                        'loss': val_loss,
                        'val_rmse': val_rmse,
                        'val_r2': val_r2,
                        'config': config
                    }, model_save_path)
                    print(f"New best model saved at epoch {epoch+1} with val_loss {val_loss:.4f}")
                else:
                    patience_counter += 1
                    print(f"No improvement. Patience: {patience_counter}/{config['patience']}")
                    
                    # Early stopping check
                    if patience_counter >= config['patience']:
                        print(f"Early stopping triggered after {epoch+1} epochs")
                        break
        else:
            print("Skipping validation: val_loader is None.")
            
        # Step the scheduler after validation (if it exists)
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
            if wandb:
                wandb.log({"lr": optimizer.param_groups[0]['lr']})
                
        # Print epoch summary
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{config['epochs']} completed in {epoch_time:.2f}s")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")
        print("-" * 80)
            
    print("Training complete!")
    print(f"Best validation loss: {best_val_loss:.4f}, RMSE: {best_val_rmse:.4f}, R2: {best_val_r2:.4f}")
    
    # Save the final model
    final_model_path = model_save_dir / "final_model.pt"
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'config': config
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    # Save training log
    log_df = pd.DataFrame(training_log)
    log_path = Path(config['run_dir']) / "training_log.csv"
    log_df.to_csv(log_path, index=False)
    print(f"Training log saved to {log_path}")
    
except Exception as e:
    print(f"Error during training: {str(e)}")
    raise
finally:
    # Finish wandb run
    if wandb:
        wandb.finish()

FileNotFoundError: [Errno 2] No such file or directory: '/home/jupyter/MLC-Project/training_runs/train_20250506_054808/NYC_BranchedUHI/config.json'