# Train UHI Prediction Model

In [None]:
# Your code here

# Train UHI Prediction Model

This notebook trains the `UHINet` model using Sentinel-2 mosaics, weather data, LST, and time embeddings.

In [None]:
# Imports
import sys
import os
from pathlib import Path
import yaml
import torch
import lightning as L
from box import Box

# Add src directory to path to import modules
workspace_dir = Path("/Users/arnav/MLC-Project") # Adjust if necessary
sys.path.append(str(workspace_dir))

from src.model import UHINet
# Assume dataloaders are defined in src/dataloader.py (or adjust import)
# from src.dataloader import UHIDataModule

## Configuration

Set up paths and hyperparameters.

In [None]:
# --- Configuration ---

# Paths
CLAY_CHECKPOINT_DIR = workspace_dir / "models" / "clay_checkpoints"
CLAY_CHECKPOINT_NAME = "clay-v1.5.ckpt" # Name of the downloaded checkpoint
CLAY_CHECKPOINT_PATH = CLAY_CHECKPOINT_DIR / CLAY_CHECKPOINT_NAME
CLAY_METADATA_PATH = workspace_dir / "src" / "models" / "Clay" / "configs" / "metadata.yaml"
DATA_DIR = workspace_dir / "data" # Adjust based on your data structure
LOG_DIR = workspace_dir / "lightning_logs"

# Ensure checkpoint directory exists
CLAY_CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)

# Model Hyperparameters (Example values, adjust as needed)
CLAY_MODEL_SIZE = "large"
CLAY_BANDS = ["blue", "green", "red", "nir"] # Must match bands used in src/model.py
CLAY_PLATFORM = "sentinel-2-l2a"
CLAY_GSD = 10

WEATHER_CHANNELS = 3 # Example: Tmax, Tmin, Precip
LST_CHANNELS = 1
USE_LST = True
TIME_EMBED_DIM = 128 # Example

GRU_HIDDEN_DIM = 64
GRU_KERNEL_SIZE = 3
PROJ_CH = 32 # From UHINet init default
OUTPUT_CHANNELS = 1

# Training Hyperparameters (Example values)
BATCH_SIZE = 4
LEARNING_RATE = 1e-4
MAX_EPOCHS = 50
ACCELERATOR = "gpu" if torch.cuda.is_available() else "cpu"
DEVICES = 1

## Download Clay Checkpoint

Download the pre-trained Clay checkpoint if it doesn't exist.

In [None]:
# Download Clay checkpoint
clay_url = "https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt"

if not CLAY_CHECKPOINT_PATH.is_file():
    print(f"Downloading Clay checkpoint to {CLAY_CHECKPOINT_PATH}...")
    # Use os.system for wget within the notebook environment
    os.system(f"wget -q {clay_url} -O {CLAY_CHECKPOINT_PATH}")
    print("Download complete.")
else:
    print(f"Clay checkpoint already exists at {CLAY_CHECKPOINT_PATH}")

## Initialize Model

In [None]:
# Instantiate the UHINet model
model = UHINet(
    clay_checkpoint_path=str(CLAY_CHECKPOINT_PATH),
    clay_metadata_path=str(CLAY_METADATA_PATH),
    weather_channels=WEATHER_CHANNELS,
    time_embed_dim=TIME_EMBED_DIM,
    # --- Args with defaults ---
    proj_ch=PROJ_CH,
    clay_model_size=CLAY_MODEL_SIZE,
    clay_bands=CLAY_BANDS,
    clay_platform=CLAY_PLATFORM,
    clay_gsd=CLAY_GSD,
    lst_channels=LST_CHANNELS,
    use_lst=USE_LST,
    gru_hidden_dim=GRU_HIDDEN_DIM,
    gru_kernel_size=GRU_KERNEL_SIZE,
    output_channels=OUTPUT_CHANNELS
)

print("UHINet model initialized.")
# print(model) # Optional: print model structure

## Prepare Data

Set up the dataloader. This section needs to be implemented based on your specific data structure and `UHIDataModule`.

In [None]:
# --- Data Preparation (Placeholder) ---

# TODO: Implement or import your UHIDataModule
# data_module = UHIDataModule(
#     data_dir=DATA_DIR,
#     batch_size=BATCH_SIZE,
#     num_workers=os.cpu_count(),
#     clay_bands=CLAY_BANDS, # Ensure datamodule provides correct bands
#     use_lst=USE_LST,
#     # ... other datamodule args ...
# )

# data_module.setup() # Prepare data, download etc.

print("Data preparation placeholder. Needs implementation.")

## Training

Configure the PyTorch Lightning Trainer and start training.

# Train UHI Net Model

This notebook trains the UHI prediction model (`UHINet`) using pre-processed data.

**Steps:**
1. Install necessary packages (if not already installed).
2. Download the Clay foundation model checkpoint.
3. Define configuration parameters (paths, hyperparameters).
4. Load the dataset.
5. Initialize the `UHINet` model, loss function, and optimizer.
6. Run the training and validation loop.
7. Save the best model checkpoint.

In [None]:
# %% Imports
import torch
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
import logging
import argparse
from pathlib import Path
import sys
import numpy as np
import pandas as pd # Needed for loading bounds from csv
from tqdm.notebook import tqdm # Use notebook tqdm
import json
import os
from datetime import datetime
import shutil

# Add project root to path to import custom modules
project_root = Path('.').resolve().parent # Assumes notebook is in 'notebooks' dir
sys.path.insert(0, str(project_root))

from src.ingest.dataloader import CityDataSet
# Import the *refactored* UHINet model
from src.model import UHINet 
from src.train.loss import masked_mae_loss, masked_mse_loss

logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

In [None]:
# %% Download Clay Checkpoint
# Define where to save the checkpoint
clay_checkpoint_dir = project_root / "models" / "clay_checkpoints"
clay_checkpoint_dir.mkdir(parents=True, exist_ok=True)
clay_checkpoint_filename = "clay-v1.5.ckpt"
clay_checkpoint_path_local = clay_checkpoint_dir / clay_checkpoint_filename
clay_checkpoint_url = "https://huggingface.co/made-with-clay/Clay/resolve/main/v1.5/clay-v1.5.ckpt"

# Download only if it doesn't exist
if not clay_checkpoint_path_local.exists():
    print(f"Downloading Clay checkpoint from {clay_checkpoint_url}...")
    # Use wget through the shell
    import subprocess
    try:
        subprocess.run(["wget", "-q", "-O", str(clay_checkpoint_path_local), clay_checkpoint_url], check=True)
        print(f"Clay checkpoint saved to {clay_checkpoint_path_local}")
    except FileNotFoundError:
        print("Error: 'wget' command not found. Please install wget or download the file manually.")
        # Optionally provide manual download instructions
        # print(f"Manual download: {clay_checkpoint_url}")
        # clay_checkpoint_path_local = None # Indicate download failed
    except subprocess.CalledProcessError as e:
        print(f"Error downloading file: {e}")
        # clay_checkpoint_path_local = None # Indicate download failed
else:
    print(f"Clay checkpoint already exists at {clay_checkpoint_path_local}")

# --- Define path to Clay metadata --- 
# Assuming it's within the cloned Clay repo structure inside src/models/Clay
clay_metadata_path_local = project_root / "src" / "models" / "Clay" / "configs" / "metadata.yaml"

if not clay_metadata_path_local.exists():
    print(f"Error: Clay metadata file not found at {clay_metadata_path_local}. Ensure the Clay submodule/repository structure is correct.")
    # clay_metadata_path_local = None # Indicate missing file
else:
     print(f"Using Clay metadata from: {clay_metadata_path_local}")


In [None]:
# %% Configuration / Hyperparameters

# --- Paths ---
data_dir = project_root / "data" # Base directory for processed city data
city_name = "your_city_name" # <<< CHANGE THIS to the city you processed
output_dir_base = project_root / "training_runs"

# --- Data Loading ---
# Bounds: Either specify manually [lon_min, lat_min, lon_max, lat_max] or leave as None to load from bbox.csv
bounds = None 
resolution_m = 30 # Spatial resolution used during data processing
include_lst = True # Set to True if you processed LST and want to include it
# Path to the *single* LST median file (created by create_sat_tensor_files)
# Required if include_lst=True. Example: data/your_city_name/lst_median_2021-01-01_2021-12-31.tif
single_lst_median_path = data_dir / city_name / "lst_median_YYYY-MM-DD_YYYY-MM-DD.tif" # <<< CHANGE THIS if using LST

# --- Model Config ---
# Clay Feature Extractor Args (Must match downloaded checkpoint and data)
clay_model_size = "large" # Should match the checkpoint, e.g., v1.5 is large
# Bands MUST match the order/names used when creating the cloudless mosaic!
clay_bands = ["blue", "green", "red", "nir"] 
clay_platform = "sentinel-2-l2a"
clay_gsd = 10 # GSD of the input Sentinel-2 mosaic

# UHINet Args (ConvGRU part)
weather_channels = 3 # Assuming 3 weather variables (e.g., tmin, tmax, precip)
time_embed_dim = 4 # Assuming 4D time embedding (sin/cos week, sin/cos hour)
lst_channels = 1 if include_lst else 0 # Number of channels in the LST input
gru_hidden_dim = 64
gru_kernel_size = 3
output_channels = 1 # Predicting single UHI value

# --- Training Hyperparameters ---
batch_size = 16
num_workers = 4
epochs = 50
lr = 1e-4
weight_decay = 0.01
loss_type = 'mae' # 'mae' or 'mse'
patience = 10 # Early stopping patience
cpu = False # Set to True to force CPU use

# --- Derived Paths & Sanity Checks ---
data_dir_path = Path(data_dir)
city_data_dir = data_dir_path / city_name
uhi_csv = city_data_dir / "uhi_data.csv"
bbox_csv = city_data_dir / "bbox.csv"
weather_csv = city_data_dir / "weather_grid.csv"
cloudless_mosaic_path = city_data_dir / "cloudless_mosaic.tif"

if include_lst and (not single_lst_median_path or not Path(single_lst_median_path).exists()):
    raise FileNotFoundError(f"Error: LST is included (include_lst=True) but single_lst_median_path '{single_lst_median_path}' is invalid or not found.")

# Check essential data files exist
required_files = [uhi_csv, bbox_csv, weather_csv, cloudless_mosaic_path]
for f in required_files:
    if not f.exists():
        raise FileNotFoundError(f"Required data file not found: {f}")

# Check Clay files exist (paths defined in previous cell)
if not clay_checkpoint_path_local or not clay_checkpoint_path_local.exists():
     raise FileNotFoundError("Clay checkpoint path is not valid. Check download step.")
if not clay_metadata_path_local or not clay_metadata_path_local.exists():
     raise FileNotFoundError("Clay metadata path is not valid. Check path definition.")

# Load bounds from CSV if not specified
if not bounds:
    print("Bounds not provided, loading from bbox.csv")
    try:
        bbox_df = pd.read_csv(bbox_csv)
        bounds = [
            bbox_df['longitudes'].min(), bbox_df['latitudes'].min(),
            bbox_df['longitudes'].max(), bbox_df['latitudes'].max()
        ]
        print(f"Loaded bounds from {bbox_csv}: {bounds}")
    except Exception as e:
            raise ValueError(f"Failed to load bounds from {bbox_csv}: {e}. Provide manually in config.")


In [None]:
# %% Setup Device
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

In [None]:
# %% Initialize Dataset and Dataloaders

print("Initializing dataset...")
try:
    dataset = CityDataSet(
        bounds=bounds,
        resolution_m=resolution_m,
        uhi_csv=str(uhi_csv),
        bbox_csv=str(bbox_csv),
        weather_csv=str(weather_csv),
        cloudless_mosaic_path=str(cloudless_mosaic_path),
        data_dir=str(data_dir_path),
        city_name=city_name,
        include_lst=include_lst,
        single_lst_median_path=str(single_lst_median_path) if include_lst else None
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    # Stop execution or handle error
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Train/Val Split ---
val_percent = 0.15
n_samples = len(dataset)
if n_samples < 10: # Handle very small datasets
    print(f"Warning: Dataset size ({n_samples}) is very small. Validation split disabled.")
    n_val = 0
    n_train = n_samples
    train_ds = dataset
    val_ds = None # No validation set
else:
    n_val = int(n_samples * val_percent)
    n_train = n_samples - n_val
    train_ds, val_ds = random_split(dataset, [n_train, n_val], generator=torch.Generator().manual_seed(42))

print(f"Dataset split: {n_train} training, {n_val or 0} validation samples.")

print("Creating dataloaders...")
train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, drop_last=True)
val_loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False, num_workers=num_workers, pin_memory=True) if val_ds else None

print("Data loading setup complete.")

In [None]:
# %% Initialize Model
print("Initializing UHINet model...")

# Ensure checkpoint path is correct
if not clay_checkpoint_path_local or not Path(clay_checkpoint_path_local).exists():
    raise FileNotFoundError(f"Clay checkpoint not found at expected path: {clay_checkpoint_path_local}")
if not clay_metadata_path_local or not Path(clay_metadata_path_local).exists():
    raise FileNotFoundError(f"Clay metadata not found at expected path: {clay_metadata_path_local}")


try:
    model = UHINet(
        # Clay args
        clay_checkpoint_path=str(clay_checkpoint_path_local),
        clay_metadata_path=str(clay_metadata_path_local),
        clay_model_size=clay_model_size,
        clay_bands=clay_bands,
        clay_platform=clay_platform,
        clay_gsd=clay_gsd,
        # Weather args
        weather_channels=weather_channels,
        # LST args
        lst_channels=lst_channels,
        use_lst=include_lst,
        # Time embedding args
        time_embed_dim=time_embed_dim,
        # ConvGRU args
        gru_hidden_dim=gru_hidden_dim,
        gru_kernel_size=gru_kernel_size,
        # Output args
        output_channels=output_channels
    ).to(device)

    print(f"Model initialized: {model.__class__.__name__}")
    num_params_total = sum(p.numel() for p in model.parameters())
    num_params_trainable = sum(p.numel() for p in model.parameters() if p.requires_grad)
    print(f"Total parameters: {num_params_total / 1e6:.2f} M")
    print(f"Trainable parameters: {num_params_trainable / 1e6:.2f} M")

except RuntimeError as e:
        print(f"Failed to initialize model: {e}")
        raise
except Exception as e:
    print(f"Unexpected error initializing model: {e}")
    raise


In [None]:
# %% Initialize Loss and Optimizer

loss_fn = masked_mae_loss if loss_type == 'mae' else masked_mse_loss
print(f"Using loss function: {loss_type}")

optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=lr, weight_decay=weight_decay)
print(f"Using optimizer: AdamW (lr={lr}, wd={weight_decay})")

In [None]:
# %% Helper Functions (Checkpointing, Train/Val Loops)

def save_checkpoint(state, is_best, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
    """Saves model checkpoint.
    Args:
        state (dict): Contains model's state_dict, optimizer state, epoch, etc.
        is_best (bool): True if this is the best model seen so far.
        filename (str): Path to save the latest checkpoint.
        best_filename (str): Path to save the best checkpoint.
    """
    Path(filename).parent.mkdir(parents=True, exist_ok=True) # Ensure dir exists
    torch.save(state, filename)
    if is_best:
        shutil.copyfile(filename, best_filename)
        print(f"Saved new best model to {best_filename}")

def train_epoch(model, dataloader, optimizer, loss_fn, device):
    model.train()
    total_loss = 0.0
    num_batches = 0
    # Use tqdm.notebook for better integration
    progress_bar = tqdm(dataloader, desc='Training', leave=False)
    for batch in progress_bar:
        # Ensure all required keys are present
        required_keys = ['sentinel_mosaic', 'weather_seq', 'time_emb_seq', 'target', 'mask']
        if include_lst:
            required_keys.append('lst_seq')
        if not all(key in batch for key in required_keys):
             missing = [key for key in required_keys if key not in batch]
             print(f"Warning: Skipping batch due to missing keys: {missing}")
             continue
        
        # Move batch to device
        try:
            batch_device = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
        except Exception as e:
            print(f"Error moving batch to device: {e}")
            continue # Skip batch if moving fails

        optimizer.zero_grad()
        try:
            predictions = model(batch_device)
            loss = loss_fn(predictions, batch_device['target'], batch_device['mask'])

            # Check for NaN loss
            if torch.isnan(loss):
                 print("Warning: NaN loss detected, skipping batch.")
                 continue

            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
            progress_bar.set_postfix(loss=loss.item())
            
        except RuntimeError as e:
             # Catch CUDA OOM or other runtime errors
             print(f"Runtime error during training: {e}")
             if "out of memory" in str(e):
                 print("CUDA out of memory. Try reducing batch size.")
             # Potentially skip batch or stop training depending on error
             continue # Skip this batch
        except Exception as e:
             print(f"Unexpected error during training step: {e}")
             continue # Skip this batch

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

def validate_epoch(model, dataloader, loss_fn, device):
    model.eval()
    total_loss = 0.0
    num_batches = 0
    progress_bar = tqdm(dataloader, desc='Validation', leave=False)
    with torch.no_grad():
        for batch in progress_bar:
            # Ensure all required keys are present
            required_keys = ['sentinel_mosaic', 'weather_seq', 'time_emb_seq', 'target', 'mask']
            if include_lst:
                required_keys.append('lst_seq')
            if not all(key in batch for key in required_keys):
                missing = [key for key in required_keys if key not in batch]
                print(f"Warning: Skipping validation batch due to missing keys: {missing}")
                continue
            
            try:
                batch_device = {k: v.to(device) for k, v in batch.items() if isinstance(v, torch.Tensor)}
                predictions = model(batch_device)
                loss = loss_fn(predictions, batch_device['target'], batch_device['mask'])
                
                if torch.isnan(loss):
                    print("Warning: NaN validation loss detected, skipping batch.")
                    continue

                total_loss += loss.item()
                num_batches += 1
                progress_bar.set_postfix(loss=loss.item())
                
            except Exception as e:
                 print(f"Error during validation step: {e}")
                 continue # Skip batch on error

    avg_loss = total_loss / num_batches if num_batches > 0 else 0.0
    return avg_loss

In [None]:
# %% Training Loop

best_val_loss = float('inf')
epochs_no_improve = 0

# Create output directory for this run
run_name = f"{city_name}_run_{datetime.now().strftime('%Y%m%d_%H%M%S')}"
output_dir = Path(output_dir_base) / run_name
output_dir.mkdir(parents=True, exist_ok=True)
print(f"Checkpoints and logs will be saved to: {output_dir}")

# Save configuration used for this run
try:
    config_dict = {
        "data_dir": str(data_dir),
        "city_name": city_name,
        "output_dir": str(output_dir),
        "bounds": bounds,
        "resolution_m": resolution_m,
        "include_lst": include_lst,
        "single_lst_median_path": str(single_lst_median_path) if include_lst else None,
        "clay_model_size": clay_model_size,
        "clay_bands": clay_bands,
        "clay_platform": clay_platform,
        "clay_gsd": clay_gsd,
        "clay_checkpoint_path_local": str(clay_checkpoint_path_local),
        "clay_metadata_path_local": str(clay_metadata_path_local),
        "weather_channels": weather_channels,
        "time_embed_dim": time_embed_dim,
        "lst_channels": lst_channels,
        "gru_hidden_dim": gru_hidden_dim,
        "gru_kernel_size": gru_kernel_size,
        "output_channels": output_channels,
        "batch_size": batch_size,
        "num_workers": num_workers,
        "epochs": epochs,
        "lr": lr,
        "weight_decay": weight_decay,
        "loss_type": loss_type,
        "patience": patience,
        "device": str(device)
    }
    with open(output_dir / "config.json", 'w') as f:
        json.dump(config_dict, f, indent=2)
    print("Saved configuration to config.json")
except Exception as e:
    print(f"Warning: Failed to save configuration: {e}")


print("Starting training...")
for epoch in range(epochs):
    print(f"--- Epoch {epoch+1}/{epochs} ---")
    train_loss = train_epoch(model, train_loader, optimizer, loss_fn, device)

    if val_loader:
        val_loss = validate_epoch(model, val_loader, loss_fn, device)
        print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f}, Val Loss={val_loss:.4f}")
        current_loss = val_loss
        if np.isnan(current_loss):
             print("Warning: Validation loss is NaN. Stopping training.")
             break
    else:
            print(f"Epoch {epoch+1}: Train Loss={train_loss:.4f} (No validation set)")
            current_loss = train_loss # Use train loss for checkpointing if no val set
            if np.isnan(current_loss):
                 print("Warning: Training loss is NaN. Stopping training.")
                 break

    is_best = current_loss < best_val_loss
    if is_best:
        best_val_loss = current_loss
        epochs_no_improve = 0
        print(f"New best loss: {best_val_loss:.4f}")
    else:
        epochs_no_improve += 1
        print(f"No improvement in validation loss for {epochs_no_improve} epochs.")

    # Save checkpoint
    save_checkpoint(
        {'epoch': epoch + 1,
         'state_dict': model.state_dict(),
         'best_val_loss': best_val_loss,
         'optimizer' : optimizer.state_dict(),
         'config': config_dict # Save config with checkpoint
         },
        is_best,
        filename=output_dir / 'checkpoint_last.pth.tar',
        best_filename=output_dir / 'model_best.pth.tar'
    )

    # Early stopping check
    if epochs_no_improve >= patience:
        print(f"Early stopping triggered after {patience} epochs with no improvement.")
        break

print("Training finished.")
print(f"Best loss recorded: {best_val_loss:.4f}")
print(f"Checkpoints saved in: {output_dir}")