In [1]:
# %% Setup & Imports

import sys
import os
import pandas as pd
import numpy as np
import json
import matplotlib.pyplot as plt
from pathlib import Path
from datetime import datetime
import logging
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split, Subset # MODIFIED: Added Subset
from tqdm import tqdm
import math
import shutil # For checkpoint saving

# Add the project root to the Python path
project_root = Path(os.getcwd()).parent # Assumes notebook is in 'notebooks' subdir
sys.path.insert(0, str(project_root))

# --- Import Model Components ---
from src.model import UHINetCNN # Import the CNN model
from src.ingest.dataloader_cnn import CityDataSet # MODIFIED: Import the updated dataloader

# --- Import Training Utilities & Loss --- ### MODIFIED ###
from src.train.loss import masked_mse_loss, masked_mae_loss # Import loss functions
import src.train.train_utils as train_utils # Import the new utility module

# Setup logging
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Optionally import wandb if needed
try:
    import wandb
except ImportError:
    print("wandb not installed, skipping W&B logging.")
    wandb = None

# Import necessary metrics
from sklearn.metrics import mean_squared_error, r2_score

In [6]:
### UNCOMMENT ON FIRST RUN IF USING Clay
#!wget -q https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt

# %% Configuration / Hyperparameters (CNN Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import check_path # For path validation
# -------------------

# --- Paths & Basic Info ---
# project_root is defined in the first cell
project_root_str = str(project_root)
data_dir = project_root / "data"
city_name = "NYC" # Should be defined or loaded
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_UHINetCNN"

# --- Data Loading Config ---
feature_resolution_m = 50
uhi_grid_resolution_m = 50

# --- Define Absolute Input Data Paths Directly ---
uhi_csv_path = data_dir / city_name / "uhi.csv"
bronx_weather_csv_path = data_dir / city_name / "bronx_weather.csv"
manhattan_weather_csv_path = data_dir / city_name / "manhattan_weather.csv" # Ensure this matches actual filename

dem_path = data_dir / city_name / "sat_files" / f"{city_name.lower()}_dem_nasadem_native-resolution_pc.tif"
dsm_path = data_dir / city_name / "sat_files" / f"{city_name.lower()}_dsm_cop-dem-glo-30_native-resolution_pc.tif"
cloudless_mosaic_path = data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy" # Added .npy
# Assuming the LST filename structure from download_data.ipynb if it's used
lst_time_window_str_for_filename = "20210601_to_20210901" # Match download_data if it defines LST filename like this
single_lst_median_path = data_dir / city_name / "sat_files" / f"lst_{city_name}_median_{lst_time_window_str_for_filename}.npy" # Corrected and added .npy

# Nodata values
elevation_nodata = -9999.0 # Or np.nan if that's what your files use
lst_nodata = 0.0 # Or np.nan

# --- Feature Selection Flags --- #
feature_flags = {
    "use_dem": False,
    "use_dsm": True,
    "use_clay": True,
    "use_sentinel_composite": False,
    "use_lst": False, # Set to True if you intend to use LST
    "use_ndvi": False,
    "use_ndbi": False,
    "use_ndwi": False,
}

# --- Bands for Sentinel Composite --- #
sentinel_bands_to_load = ["blue", "green", "red", "nir", "swir16", "swir22"]

# --- Model Config (UHINetCNN) ---
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = True
clay_checkpoint_path = project_root / "notebooks" / "clay-v1.5.ckpt"
clay_metadata_path = project_root / "src" / "Clay" / "configs" / "metadata.yaml"

weather_channels = 6
unet_base_channels = 64 # For UHINetCNN's U-Net like structure
unet_depth = 4         # For UHINetCNN's U-Net like structure

# --- Training Hyperparameters ---
num_workers = 4
epochs = 500
lr = 5e-5
weight_decay = 0.01
loss_type = 'mse'
patience = 50
cpu = False
n_train_batches = 9
max_grad_norm = 1.0 # Added for consistency with branched
warmup_epochs = 5 # Added for consistency with branched

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- Validate Paths (using check_path for files that *must* exist) ---
# Fixed: Removed is_absolute parameter to match updated function signature
uhi_csv_path = check_path(uhi_csv_path, project_root, "UHI CSV")
bronx_weather_csv_path = check_path(bronx_weather_csv_path, project_root, "Bronx Weather CSV")
manhattan_weather_csv_path = check_path(manhattan_weather_csv_path, project_root, "Manhattan Weather CSV")

if feature_flags["use_dem"]:
    dem_path = check_path(dem_path, project_root, "DEM TIF")
if feature_flags["use_dsm"]:
    dsm_path = check_path(dsm_path, project_root, "DSM TIF")
if feature_flags["use_clay"]:
    clay_checkpoint_path = check_path(clay_checkpoint_path, project_root, "Clay Checkpoint")
    clay_metadata_path = check_path(clay_metadata_path, project_root, "Clay Metadata")
# Check cloudless mosaic if Clay or direct Sentinel composite is used
if feature_flags["use_clay"] or feature_flags["use_sentinel_composite"]:
    cloudless_mosaic_path = check_path(cloudless_mosaic_path, project_root, "Cloudless Mosaic")
if feature_flags["use_lst"]:
    single_lst_median_path = check_path(single_lst_median_path, project_root, "Single LST Median", should_exist=True)


# --- Calculate Bounds --- #
uhi_df = pd.read_csv(uhi_csv_path) # Use the validated path
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds_list = [ # Renamed to avoid conflict if 'bounds' is a global from another notebook
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {uhi_csv_path.name}: {bounds_list}")

# --- Central Config Dictionary --- #
config = {
    # Paths & Info
    "model_type": "UHINetCNN", # Specific to this notebook
    "project_root": project_root_str,
    # "run_dir" will be added below
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "feature_resolution_m": feature_resolution_m,
    "uhi_grid_resolution_m": uhi_grid_resolution_m,
    "uhi_csv": str(uhi_csv_path),
    "bronx_weather_csv": str(bronx_weather_csv_path),
    "manhattan_weather_csv": str(manhattan_weather_csv_path),
    "bounds": bounds_list, # Use the locally defined bounds_list
    "feature_flags": feature_flags,
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path": str(dem_path) if feature_flags["use_dem"] else None,
    "dsm_path": str(dsm_path) if feature_flags["use_dsm"] else None,
    "elevation_nodata": elevation_nodata,
    "cloudless_mosaic_path": str(cloudless_mosaic_path) if feature_flags.get("use_clay") or feature_flags.get("use_sentinel_composite") else None,
    "single_lst_median_path": str(single_lst_median_path) if feature_flags["use_lst"] else None,
    "lst_nodata": lst_nodata,
    # Model Config (UHINetCNN specific)
    "weather_channels": weather_channels, # For single weather grid input
    "unet_base_channels": unet_base_channels,
    "unet_depth": unet_depth,
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(clay_checkpoint_path) if feature_flags["use_clay"] else None,
    "clay_metadata_path": str(clay_metadata_path) if feature_flags["use_clay"] else None,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience,
    "max_grad_norm": max_grad_norm, # Added
    "warmup_epochs": warmup_epochs, # Added
    "device": str(device)
}

# --- Create Run Directory & Update Config ---
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name_suffix = f"{config['model_type']}_{city_name}_{run_timestamp}"
run_dir = output_dir_base / run_name_suffix
run_dir.mkdir(parents=True, exist_ok=True)
config["run_dir"] = str(run_dir) # Add to config for later use

print(f"\nRun directory: {run_dir}")
print("UHINetCNN Configuration dictionary created:")
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x))

In [10]:
# %% Data Loading and Preprocessing (CNN Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import (
    calculate_uhi_stats, # MODIFIED: Removed split_data as we do sequential split here
    create_dataloaders
)
from torch.utils.data import Subset # Ensure Subset is imported
# -------------------

print("Initializing CityDataSet (for CNN model)...")
try:
    # Ensure all necessary parameters from the config are passed to CityDataSet
    dataset = CityDataSet(
        bounds=config["bounds"],
        feature_resolution_m=config["feature_resolution_m"],
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"],
        uhi_csv=config["uhi_csv"],
        bronx_weather_csv=config["bronx_weather_csv"],
        manhattan_weather_csv=config["manhattan_weather_csv"],
        data_dir=project_root_str, # data_dir should be project_root for this loader
        city_name=config["city_name"],
        feature_flags=config["feature_flags"],
        sentinel_bands_to_load=config["sentinel_bands_to_load"],
        dem_path=config["dem_path"],
        dsm_path=config["dsm_path"],
        elevation_nodata=config["elevation_nodata"],
        cloudless_mosaic_path=config["cloudless_mosaic_path"],
        single_lst_median_path=config["single_lst_median_path"],
        lst_nodata=config["lst_nodata"],
        target_crs_str=config.get("target_crs_str", "EPSG:4326") # Added optional param
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Sequential Train/Val Split (More common for time-series like data if order matters) ---
val_percent = 0.20 # Example: 20% for validation
num_samples = len(dataset)

if num_samples < 2: # Check if there are enough samples for a split
    raise ValueError(f"Dataset has only {num_samples} samples, cannot perform train/val split.")

n_train = int(num_samples * (1 - val_percent))
n_val = num_samples - n_train

if n_train == 0 or n_val == 0: # Check if split results in empty sets
    raise ValueError(f"Split resulted in zero samples for train ({n_train}) or validation ({n_val}). Adjust val_percent or check dataset size.")

# Create sequential indices for train and validation
# This assumes your UHI data is chronologically ordered if that's desired for the split
train_indices = list(range(n_train))
val_indices = list(range(n_train, num_samples))

train_ds = Subset(dataset, train_indices)
val_ds = Subset(dataset, val_indices)

print(f"Sequential dataset split: {len(train_ds)} training (indices 0-{n_train-1}), {len(val_ds)} validation (indices {n_train}-{num_samples-1}) samples.")

# --- Calculate UHI Mean and Std from Training Data ONLY --- #
uhi_mean, uhi_std = calculate_uhi_stats(train_ds)
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
train_loader, val_loader = create_dataloaders(
    train_ds,
    val_ds,
    n_train_batches=config['n_train_batches'],
    num_workers=config['num_workers'],
    device=device # Pass device from config cell
)
print("Data loading and preprocessing for CNN model complete.")

Using device: cuda


TypeError: check_path() got an unexpected keyword argument 'is_absolute'

In [8]:
# %% Model Initialization (CNN Model + Common Resampling)

# --- Import necessary components ---
from src.model import UHINetCNN # Ensure this is the correct model for the CNN pipeline
from src.train.loss import masked_mse_loss, masked_mae_loss
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau # Added scheduler import
import logging # Ensure logging is imported if not already

# Instantiate the UHINetCNN (or your chosen CNN model)
print(f"Initializing {config['model_type']}...")
try:
    # Ensure parameters match the UHINetCNN constructor
    model = UHINetCNN(
        feature_flags=config["feature_flags"],
        weather_channels=config["weather_channels"], # Used for single weather grid
        sentinel_bands_to_load=config.get("sentinel_bands_to_load"),
        clay_model_size=config.get("clay_model_size"),
        clay_bands=config.get("clay_bands"),
        clay_platform=config.get("clay_platform"),
        clay_gsd=config.get("clay_gsd"),
        freeze_backbone=config.get("freeze_backbone", True),
        clay_checkpoint_path=config.get("clay_checkpoint_path"),
        clay_metadata_path=config.get("clay_metadata_path"),
        base_channels=config["unet_base_channels"], # Typically CNNs use a U-Net like backbone
        depth=config["unet_depth"],
    )
    model.to(config["device"])
    print(f"{config['model_type']} initialized successfully.")

except Exception as e:
    logging.error(f"Error initializing UHINetCNN: {e}", exc_info=True)
    raise # Re-raise the exception after logging

# --- Optimizer --- #
optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
print("Optimizer (AdamW) initialized.")

# --- Loss Function --- #
if config["loss_type"] == 'mse':
    loss_fn = masked_mse_loss
    print("Loss function set to masked_mse_loss.")
elif config["loss_type"] == 'mae':
    loss_fn = masked_mae_loss
    print("Loss function set to masked_mae_loss.")
else:
    raise ValueError(f"Unsupported loss type: {config['loss_type']}")

# --- LR Scheduler --- #
# Ensure scheduler is defined, even if None, for the training loop
scheduler = None 
if config.get("patience"): # Check if patience is set for scheduler
    try:
        scheduler = ReduceLROnPlateau(optimizer, 'min', patience=config.get("scheduler_patience", 10), factor=0.5)
        print("Initialized ReduceLROnPlateau scheduler.")
    except Exception as e:
        logging.error(f"Error initializing scheduler: {e}", exc_info=True)
        print("Proceeding without LR scheduler due to initialization error.")
else:
    print("Patience not set in config, proceeding without LR scheduler.")

print("\nModel, optimizer, loss function, and scheduler setup complete.")

2025-05-06 23:48:14,353 - INFO - Target FEATURE grid size (H, W): (224, 194) @ 50m, CRS: EPSG:4326
2025-05-06 23:48:14,354 - INFO - Target UHI grid size (H, W): (224, 194) @ 50m


Initializing CityDataSet (for CNN model)...


Precomputing UHI grids: 100%|██████████| 59/59 [00:00<00:00, 3740.39it/s]
2025-05-06 23:48:14,422 - INFO - Loading DSM from: /home/jupyter/MLC-Project/data/NYC/sat_files/nyc_dsm_cop-dem-glo-30_native-resolution_pc.tif
2025-05-06 23:48:14,445 - INFO - Clipping DSM to bounds: [np.float64(-73.99445667), np.float64(40.75879167), np.float64(-73.87945833), np.float64(40.85949667)]
2025-05-06 23:48:14,446 - INFO - Opened DSM (lazy load). Native shape (approx): (1, 364, 415)
2025-05-06 23:48:14,447 - INFO - Calculating global DSM min/max...
2025-05-06 23:48:14,448 - INFO - Global DSM Min: -14.068702697753906, Max: 186.3260040283203
2025-05-06 23:48:14,448 - INFO - Loading cloudless mosaic from /home/jupyter/MLC-Project/data/NYC/sat_files/sentinel_NYC_20210601_to_20210901_cloudless_mosaic.npy with memory mapping
2025-05-06 23:48:14,449 - INFO - Loaded mosaic shape (native res): (4, 1119, 1278)
2025-05-06 23:48:14,458 - INFO - Loaded Bronx weather data: 169 records
2025-05-06 23:48:14,458 - INFO

Sequential dataset split: 47 training (indices 0-46), 12 validation (indices 47-58) samples.


Calculating stats: 100%|██████████| 47/47 [00:00<00:00, 8484.28it/s]
2025-05-06 23:48:14,473 - INFO - Training UHI Mean: 1.0004, Std Dev: 0.0169
2025-05-06 23:48:14,475 - INFO - Creating dataloaders...
2025-05-06 23:48:14,475 - INFO - Using Train Batch Size: 5
2025-05-06 23:48:14,476 - INFO - Using Validation Batch Size: 1
2025-05-06 23:48:14,477 - INFO - Data loading setup complete.


Data loading and preprocessing for CNN model complete.


In [11]:
# %% Training Loop (CNN Model)

# --- Imports ---
import time
from datetime import datetime
import json
from pathlib import Path
import src.train.train_utils as train_utils # Import the full module
import numpy as np # Added for isnan check
import pandas as pd # For saving log

# --- Setup ---
print(f"Model {config['model_type']} training starting on {device}")

# --- Tracking Variables --- #
best_val_r2 = -float('inf') 
patience_counter = 0
training_log_list = [] # Local log for metrics

# --- Output Directory & Run Name --- #
run_dir = Path(config["run_dir"]) # Get from config
model_save_dir = run_dir / "checkpoints"
model_save_dir.mkdir(parents=True, exist_ok=True)
print(f"Checkpoints and logs will be saved to: {run_dir}")

# --- Save Config --- #
config_path = run_dir / "config.json"
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x)
print(f"Saved configuration to {config_path}")

# --- Retrieve UHI Stats from Config --- #
uhi_mean = config.get('uhi_mean')
uhi_std = config.get('uhi_std')
if uhi_mean is None or uhi_std is None:
    raise ValueError("uhi_mean/uhi_std not in config. Run data loading cell.")
print(f"Using Training UHI Mean: {uhi_mean:.4f}, Std Dev: {uhi_std:.4f}")

# --- WANDB Init --- #
if wandb:
    try:
        if wandb.run is not None: wandb.finish() # Finish previous run if any
        wandb.init(
            project=config["wandb_project_name"],
            name=f"{config['wander_run_name_prefix']}_{datetime.now().strftime('%Y%m%d_%H%M')}",
            config=config
        )
        
        # Optional: watch model parameters
        wandb.watch(model)
        print(f"Wandb initialized for run: {run_dir.name}")
    except Exception as e:
        print(f"Wandb initialization failed: {e}")
        wandb = None
else:
    print("Wandb not available, skipping logging.")

print(f"Starting training for {config['epochs']} epochs with patience {config['patience']}")

# Get warmup_epochs from config, default to 0 if not present
warmup_epochs = config.get("warmup_epochs", 0)

try:
    for epoch in range(config['epochs']):
        epoch_start = time.time()
        print(f"--- Epoch {epoch+1}/{config['epochs']} ---")
        
        # --- Train --- #
        if train_loader:
            # Use generic train function from train_utils
            train_loss, train_rmse, train_r2 = train_utils.train_epoch_generic(
                model, train_loader, optimizer, loss_fn, device, uhi_mean, uhi_std,
                feature_flags=config["feature_flags"],
                max_grad_norm=config.get("max_grad_norm", 1.0)
            )
            print(f"Train Loss: {train_loss:.4f}, Train RMSE: {train_rmse:.4f}, Train R2: {train_r2:.4f}")
            if np.isnan(train_loss):
                print("Warning: Training loss is NaN. Stopping training.")
                break
            current_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        else:
            print("Skipping training: train_loader is None.")
            train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')
            current_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        
        # Log train metrics AFTER checking for NaN
        if wandb:
            wandb.log(current_metrics)
        training_log_list.append(current_metrics) # Append to local log regardless of W&B


        # --- Validate --- #
        if val_loader:
            # Use generic validate function from train_utils
            val_loss, val_rmse, val_r2 = train_utils.validate_epoch_generic(
                model, val_loader, loss_fn, device, uhi_mean, uhi_std,
                feature_flags=config["feature_flags"]
            )
            print(f"Val Loss:   {val_loss:.4f}, Val RMSE:   {val_rmse:.4f}, Val R2:   {val_r2:.4f}")
            if np.isnan(val_loss):
                print("Warning: Validation Loss is NaN. Stopping training.")
                break
            val_metrics = {"val_loss": val_loss, "val_rmse": val_rmse, "val_r2": val_r2}
            current_metrics.update(val_metrics)
            
            # Log validation metrics
            if wandb:
                wandb.log(val_metrics)
            
            # Warmup period: don't save or check early stopping until after warmup_epochs
            if epoch >= warmup_epochs:
                # Check for improvement (using validation R2 now)
                if val_r2 > best_val_r2:
                    best_val_r2 = val_r2
                    patience_counter = 0
                    
                    # Save best model
                    train_utils.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict() if scheduler else None,
                        'best_val_r2': best_val_r2,
                        'config': config
                    }, is_best=True, output_dir=model_save_dir)
                    print(f"New best model saved at epoch {epoch+1} with val_r2 {val_r2:.4f}")
                else:
                    patience_counter += 1
                    print(f"No improvement. Patience: {patience_counter}/{config['patience']}")
                    
                    # Early stopping check
                    if patience_counter >= config['patience']:
                        print(f"Early stopping triggered after {epoch+1} epochs")
                        break
            else:
                print(f"Warmup epoch {epoch+1}/{warmup_epochs}. Skipping checkpointing and early stopping.")
        else:
            print("Skipping validation: val_loader is None.")
            
        # Step the scheduler after validation (if it exists)
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
            if wandb:
                wandb.log({"lr": optimizer.param_groups[0]['lr']})
                
        # Print epoch summary
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{config['epochs']} completed in {epoch_time:.2f}s")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")
        print("-" * 80)
            
    print("Training complete!")
    print(f"Best validation R2: {best_val_r2:.4f}")
    
    # Save the final model
    final_model_path = model_save_dir / "final_model.pt"
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'config': config
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    # Save training log
    log_df = pd.DataFrame(training_log_list)
    log_path = Path(config['run_dir']) / "training_log.csv"
    log_df.to_csv(log_path, index=False)
    print(f"Training log saved to {log_path}")
    
except Exception as e:
    print(f"Error during training: {str(e)}")
    raise
finally:
    # Finish wandb run
    if wandb:
        wandb.finish()

Initializing UHINetCNN_CommonRes (CNN variant)...
Manually loading checkpoint: /home/jupyter/MLC-Project/notebooks/clay-v1.5.ckpt
Instantiating ClayMAEModule manually...


2025-05-06 23:52:14,411 - INFO - Loading pretrained weights from Hugging Face hub (timm/vit_large_patch14_reg4_dinov2.lvd142m)
2025-05-06 23:52:14,520 - INFO - [timm/vit_large_patch14_reg4_dinov2.lvd142m] Safe alternative available for 'pytorch_model.bin' (as 'model.safetensors'). Loading weights using safetensors.


Loading state_dict manually into self.model.model...


2025-05-06 23:52:17,655 - INFO - Identified final encoder layer as self.model.model.proj
2025-05-06 23:52:17,656 - INFO - Keeping Clay backbone frozen.
2025-05-06 23:52:17,660 - INFO - ClayFeatureExtractor output channels set to: 1024
2025-05-06 23:52:17,664 - INFO - Initialized Clay model (large), output channels: 1024
2025-05-06 23:52:17,664 - INFO - Total calculated input channels for U-Net: 1031


Clay model properties: model_size=large, embed_dim=1024, patch_size=16 (patch_size OVERRIDDEN)
Normalization prepared for bands: ['blue', 'green', 'red', 'nir']


2025-05-06 23:52:17,919 - INFO - Initialized UNetDecoder. In channels: 1031, Base channels: 64, Depth: 4
2025-05-06 23:52:17,921 - INFO - UHINetCNN initialized. U-Net Input Ch: 1031, Base Ch: 64, Depth: 4


UHINetCNN_CommonRes (CNN variant) initialized successfully.
Optimizer (AdamW) initialized.
Loss function set to masked_mse_loss.
Initialized ReduceLROnPlateau scheduler.

CNN Model, optimizer, loss function, and scheduler setup complete.
