# Branched UHI Model Training

In [3]:
# %% Imports and Setup

# --- Standard Libraries ---
import os
import sys
from pathlib import Path
import json
import logging
import warnings
import time
from datetime import datetime

# --- Data Handling ---
import numpy as np
import pandas as pd
import xarray as xr
import rioxarray # For geospatial data handling with xarray

# --- PyTorch ---
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Subset # Added Subset


# --- Visualization & Progress ---
# Optional: Import matplotlib or other plotting libs if needed for checks
from tqdm.notebook import tqdm # Use notebook version if running interactively
import wandb

# --- Custom Modules ---
# Project root is the parent directory of the current working directory
project_root = Path(os.getcwd()).parent
src = project_root / "src"

# Add src directory to Python path
if str(project_root) not in sys.path:
    sys.path.append(str(project_root))
    print(sys.path)

# Import custom modules
from src.ingest.dataloader_branched import CityDataSetBranched
from src.branched_uhi_model import BranchedUHIModel
from src.train.loss import masked_mse_loss, masked_mae_loss # Import loss functions
import src.train.train_utils as train_utils # Import the utility module

# --- Environment Setup ---
# Optional: Configure logging level
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s')

# Optional: Ignore specific warnings if needed
# warnings.filterwarnings('ignore', category=UserWarning, message='.*TypedStorage is deprecated.*')

print(f"Project Root: {project_root}")
print(f"PyTorch Version: {torch.__version__}")
print(f"CUDA Available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA Device Name: {torch.cuda.get_device_name(0)}")

['/opt/conda/lib/python310.zip', '/opt/conda/lib/python3.10', '/opt/conda/lib/python3.10/lib-dynload', '', '/opt/conda/lib/python3.10/site-packages', '/home/jupyter/MLC-Project/src', '/home/jupyter/MLC-Project']
Project Root: /home/jupyter/MLC-Project
PyTorch Version: 2.6.0+cu124
CUDA Available: True
CUDA Device Name: Tesla V100-SXM2-16GB


In [None]:
# %% Configuration / Hyperparameters for BranchedUHIModel (ConvLSTM + Common Resampling)

# --- Import utils ---
from src.train.train_utils import check_path # For path validation
from src.ingest.data_utils import calculate_actual_weather_channels # For dynamic weather channels
# -------------------

# --- Paths & Basic Info ---
# project_root is defined in the first cell
project_root_str = str(project_root)
data_dir = project_root / "data"
city_name = "NYC" # Should be defined or loaded
output_dir_base = project_root / "training_runs"

# --- WANDB Config ---
wandb_project_name = "MLC_UHI_Proj"
wander_run_name_prefix = f"{city_name}_BranchedUHI"

# --- Data Loading Config ---
feature_resolution_m = 30
uhi_grid_resolution_m = 10 # UHI target grid
weather_seq_length = 60

# --- Weather Feature Selection --- NEW ---
enabled_weather_features = [
    "air_temp", 
    "rel_humidity", 
    "avg_windspeed", 
    "wind_dir",       # This will be converted to sin/cos components by data_utils
    "solar_flux"
]
# Calculate the number of actual weather channels that will be produced by the dataloader
actual_dataloader_weather_channels = calculate_actual_weather_channels(enabled_weather_features)
# ------------------------------------

# --- Define Absolute Input Data Paths Directly ---
uhi_csv_path = data_dir / city_name / "uhi.csv"
bronx_weather_csv_path = data_dir / city_name / "bronx_weather.csv"
manhattan_weather_csv_path = data_dir / city_name / "manhattan_weather.csv"

dem_path = data_dir / city_name / "sat_files" / f"{city_name.lower()}_dem_nasadem_native-resolution_pc.tif"
dsm_path = data_dir / city_name / "sat_files" / f"{city_name.lower()}_dsm_cop-dem-glo-30_native-resolution_pc.tif"
cloudless_mosaic_path = data_dir / city_name / "sat_files" / f"sentinel_{city_name}_20210601_to_20210901_cloudless_mosaic.npy" # Added .npy
# Assuming the LST filename structure from download_data.ipynb if it's used
lst_time_window_str_for_filename = "20210601_to_20210901" # Match download_data
single_lst_median_path = data_dir / city_name / "sat_files" / f"lst_{city_name}_median_{lst_time_window_str_for_filename}.npy" # Corrected and added .npy


# Nodata values
elevation_nodata = -9999.0 # Or np.nan
lst_nodata = 0.0 # Or np.nan

# --- Feature Selection Flags ---
feature_flags = {
    "use_dem": True,
    "use_dsm": True,
    "use_clay": True,
    "use_sentinel_composite": False,
    "use_lst": False, # Set to True if LST is intended to be used
    "use_ndvi": False,
    "use_ndbi": False,
    "use_ndwi": False,
}

# --- Bands for Sentinel Composite (if use_sentinel_composite is True) ---
sentinel_bands_to_load = []

# --- Model Config (BranchedUHIModel with ConvLSTM, No separate Elev branches) ---
# Clay Backbone
clay_model_size = "large"
clay_bands = ["blue", "green", "red", "nir"]
clay_platform = "sentinel-2-l2a"
clay_gsd = 10
freeze_backbone = True
clay_checkpoint_path = project_root / "notebooks" / "clay-v1.5.ckpt"
clay_metadata_path = project_root / "src" / "Clay" / "configs" / "metadata.yaml"

# Temporal Weather Processor (ConvLSTM)
# weather_input_channels = 6 # REMOVED - will use actual_dataloader_weather_channels or model will derive
convlstm_hidden_dims = [32, 16]
convlstm_kernel_sizes = [(3,3), (3,3)]
convlstm_num_layers = len(convlstm_hidden_dims)

# Projection Layer Channels
proj_static_ch = 8
proj_temporal_ch = 8

# --- Head Configuration ---
head_type = "unet"
unet_base_channels = 16
unet_depth = 2
simple_cnn_head_channels = [32, 16]
simple_cnn_head_kernels = [3, 3]
simple_cnn_dropout_rate = 0.1

# --- Training Hyperparameters ---
num_workers = 1
epochs = 500
lr = 1e-4
weight_decay = 1e-4
loss_type = 'mse'
patience = 50
cpu = False
max_grad_norm = 1.0
warmup_epochs = 5
n_train_batches = 47

# --- Device Setup ---
device = torch.device("cuda" if torch.cuda.is_available() and not cpu else "cpu")
print(f"Using device: {device}")

# --- Validate Paths (using check_path for files that *must* exist) ---
# Fixed: Removed is_absolute parameter to match updated function signature
uhi_csv_path = check_path(uhi_csv_path, project_root, "UHI CSV")
bronx_weather_csv_path = check_path(bronx_weather_csv_path, project_root, "Bronx Weather CSV")
manhattan_weather_csv_path = check_path(manhattan_weather_csv_path, project_root, "Manhattan Weather CSV")

if feature_flags["use_dem"]:
    dem_path = check_path(dem_path, project_root, "DEM TIF")
if feature_flags["use_dsm"]:
    dsm_path = check_path(dsm_path, project_root, "DSM TIF")
if feature_flags["use_clay"]:
    clay_checkpoint_path = check_path(clay_checkpoint_path, project_root, "Clay Checkpoint")
    clay_metadata_path = check_path(clay_metadata_path, project_root, "Clay Metadata")
# Check cloudless mosaic if Clay or direct Sentinel composite is used
if feature_flags["use_clay"] or feature_flags["use_sentinel_composite"]:
    cloudless_mosaic_path = check_path(cloudless_mosaic_path, project_root, "Cloudless Mosaic")
if feature_flags["use_lst"]:
    single_lst_median_path = check_path(single_lst_median_path, project_root, "Single LST Median", should_exist=True)


# --- Calculate Bounds ---
uhi_df = pd.read_csv(uhi_csv_path) # Use the validated path
required_cols = ['Longitude', 'Latitude']
if not all(col in uhi_df.columns for col in required_cols):
    raise ValueError(f"UHI CSV must contain columns: {required_cols}")
bounds = [
    uhi_df['Longitude'].min(),
    uhi_df['Latitude'].min(),
    uhi_df['Longitude'].max(),
    uhi_df['Latitude'].max()
]
print(f"Loaded bounds from {uhi_csv_path.name}: {bounds}") # Use .name for just the filename


# --- Central Config Dictionary --- #
config = {
    # Paths & Info
    "model_type": "BranchedUHIModel", # More specific model type
    "project_root": project_root_str,
    # "run_dir" will be added below
    "city_name": city_name,
    "wandb_project_name": wandb_project_name,
    "wander_run_name_prefix": wander_run_name_prefix,
    # Data Loading
    "feature_resolution_m": feature_resolution_m,
    "uhi_grid_resolution_m": uhi_grid_resolution_m,
    "weather_seq_length": weather_seq_length,
    "enabled_weather_features": enabled_weather_features, # NEW: Add to config
    "uhi_csv": str(uhi_csv_path),
    "bronx_weather_csv": str(bronx_weather_csv_path),
    "manhattan_weather_csv": str(manhattan_weather_csv_path),
    "bounds": bounds,
    "feature_flags": feature_flags,
    "sentinel_bands_to_load": sentinel_bands_to_load,
    "dem_path": str(dem_path) if feature_flags["use_dem"] else None,
    "dsm_path": str(dsm_path) if feature_flags["use_dsm"] else None,
    "elevation_nodata": elevation_nodata,
    "cloudless_mosaic_path": str(cloudless_mosaic_path) if feature_flags["use_clay"] or feature_flags["use_sentinel_composite"] else None,
    "single_lst_median_path": str(single_lst_median_path) if feature_flags["use_lst"] else None,
    "lst_nodata": lst_nodata,
    # Model Config
    # "weather_input_channels": actual_dataloader_weather_channels, # REMOVED - Model will use its enabled_weather_features list
    "convlstm_hidden_dims": convlstm_hidden_dims,
    "convlstm_kernel_sizes": convlstm_kernel_sizes,
    "convlstm_num_layers": convlstm_num_layers,
    "proj_static_ch": proj_static_ch,
    "proj_temporal_ch": proj_temporal_ch,
    "head_type": head_type,
    "unet_base_channels": unet_base_channels if head_type == "unet" else None,
    "unet_depth": unet_depth if head_type == "unet" else None,
    "simple_cnn_head_channels": simple_cnn_head_channels if head_type == "simple_cnn" else None,
    "simple_cnn_head_kernels": simple_cnn_head_kernels if head_type == "simple_cnn" else None,
    "simple_cnn_dropout_rate": simple_cnn_dropout_rate if head_type == "simple_cnn" else None,
    # Clay specific
    "clay_model_size": clay_model_size,
    "clay_bands": clay_bands,
    "clay_platform": clay_platform,
    "clay_gsd": clay_gsd,
    "freeze_backbone": freeze_backbone,
    "clay_checkpoint_path": str(clay_checkpoint_path) if feature_flags["use_clay"] else None,
    "clay_metadata_path": str(clay_metadata_path) if feature_flags["use_clay"] else None,
    # Training Hyperparameters
    "n_train_batches": n_train_batches,
    "num_workers": num_workers,
    "epochs": epochs,
    "lr": lr,
    "weight_decay": weight_decay,
    "loss_type": loss_type,
    "patience": patience,
    "max_grad_norm": max_grad_norm,
    "warmup_epochs": warmup_epochs,
    "device": str(device)
}

# --- Create Run Directory & Update Config ---
run_timestamp = datetime.now().strftime('%Y%m%d_%H%M%S')
run_name_suffix = f"{config['model_type']}_{city_name}_{run_timestamp}"
run_dir = output_dir_base / run_name_suffix
run_dir.mkdir(parents=True, exist_ok=True)
config["run_dir"] = str(run_dir) # Add to config for later use

print(f"Run directory: {run_dir}")
print("\nBranched Model Configuration dictionary created:")
print(json.dumps(config, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x))


In [None]:
# %% Data Loading and Preprocessing (Branched Model + Common Resampling)

# --- Import utils ---
from src.train.train_utils import (
    calculate_uhi_stats, # Removed split_data
    create_dataloaders
)
from torch.utils.data import Subset # Import Subset
# -------------------

print("Initializing BranchedCityDataSet...")
try:
    dataset = CityDataSetBranched(
        bounds=config["bounds"],
        feature_resolution_m=config["feature_resolution_m"], # Corrected param name
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"], # Corrected param name
        uhi_csv=config["uhi_csv"], # Use paths from config
        bronx_weather_csv=config["bronx_weather_csv"],
        manhattan_weather_csv=config["manhattan_weather_csv"],
        data_dir=project_root_str,
        city_name=config["city_name"],
        feature_flags=config["feature_flags"],
        enabled_weather_features=config["enabled_weather_features"], # NEW: Pass from config
        sentinel_bands_to_load=config["sentinel_bands_to_load"],
        dem_path=config["dem_path"], # Corrected param name
        dsm_path=config["dsm_path"], # Corrected param name
        elevation_nodata=config["elevation_nodata"], # Corrected param name
        cloudless_mosaic_path=config["cloudless_mosaic_path"],
        single_lst_median_path=config["single_lst_median_path"],
        lst_nodata=config["lst_nodata"], # Added missing param
        weather_seq_length=config["weather_seq_length"],
        target_crs_str=config.get("target_crs_str", "EPSG:4326") # Added optional param
    )
except FileNotFoundError as e:
    print(f"Dataset initialization failed: {e}")
    print("Ensure required data files (DEM, DSM, weather, UHI, potentially mosaic/LST) exist.")
    print("Run `notebooks/download_data.ipynb` first.")
    raise
except Exception as e:
    print(f"Unexpected error during dataset initialization: {e}")
    raise

# --- Sequential Train/Val Split --- #
val_percent = 0.20 # Keep the percentage definition
num_samples = len(dataset)
if num_samples < 2: # Need at least one for train and one for val
    raise ValueError(f"Dataset has only {num_samples} samples, cannot perform train/val split.")

n_train = int(num_samples * (1 - val_percent))
n_val = num_samples - n_train

if n_train == 0 or n_val == 0:
    raise ValueError(f"Split resulted in zero samples for train ({n_train}) or validation ({n_val}). Adjust val_percent or check dataset size.")

train_indices = list(range(n_train))
val_indices = list(range(n_train, num_samples))

train_ds = Subset(dataset, train_indices)
val_ds = Subset(dataset, val_indices)

print(f"Sequential dataset split: {len(train_ds)} training (indices 0-{n_train-1}), {len(val_ds)} validation (indices {n_train}-{num_samples-1}) samples.")

# --- Calculate UHI Mean and Std from Training Data ONLY --- #
uhi_mean, uhi_std = calculate_uhi_stats(train_ds)
config['uhi_mean'] = uhi_mean
config['uhi_std'] = uhi_std

# --- Create DataLoaders --- #
train_loader, val_loader = create_dataloaders(
    train_ds,
    val_ds,
    n_train_batches=config['n_train_batches'],
    num_workers=config['num_workers'],
    device=device # Pass device from config cell
)


In [None]:
# %% Model Initialization (Branched Model + Common Resampling)

# --- Import necessary components ---
from src.branched_uhi_model import BranchedUHIModel
from src.train.loss import masked_mse_loss, masked_mae_loss
import torch.optim as optim
from torch.optim.lr_scheduler import ReduceLROnPlateau
import logging # Ensure logging is imported if not already

# Instantiate the BranchedUHIModel
print(f"Initializing {config['model_type']}...")
try:
    model = BranchedUHIModel(
        # --- Weather Branch Config --- #
        # weather_input_channels=config["weather_input_channels"], # REMOVED: Model uses enabled_weather_features
        enabled_weather_features=config["enabled_weather_features"], # NEW: Pass from config
        convlstm_hidden_dims=config["convlstm_hidden_dims"],
        convlstm_kernel_sizes=config["convlstm_kernel_sizes"],
        convlstm_num_layers=config["convlstm_num_layers"],
        weather_seq_length=config["weather_seq_length"], 
        # --- Static Feature Config --- #
        feature_flags=config["feature_flags"],
        sentinel_bands_to_load=config.get("sentinel_bands_to_load"), 
        # Clay Specific
        clay_model_size=config.get("clay_model_size"),
        clay_bands=config.get("clay_bands"),
        clay_platform=config.get("clay_platform"),
        clay_gsd=config.get("clay_gsd"),
        freeze_backbone=config.get("freeze_backbone", True),
        clay_checkpoint_path=config.get("clay_checkpoint_path"),
        clay_metadata_path=config.get("clay_metadata_path"),
        # --- Head Config --- #
        proj_static_ch=config["proj_static_ch"],
        proj_temporal_ch=config["proj_temporal_ch"],
        head_type=config["head_type"],
        unet_base_channels=config.get("unet_base_channels"), 
        unet_depth=config.get("unet_depth"),                 
        simple_cnn_head_channels=config.get("simple_cnn_head_channels"),
        simple_cnn_head_kernels=config.get("simple_cnn_head_kernels"),
        simple_cnn_dropout_rate=config.get("simple_cnn_dropout_rate", 0.1), 
        # --- Target Grid Info --- #
        uhi_grid_resolution_m=config["uhi_grid_resolution_m"],
        bounds=config["bounds"]
    )
    model.to(config["device"])
    print(f"{config['model_type']} initialized successfully.")

except Exception as e:
    logging.error(f"Error initializing BranchedUHIModel: {e}", exc_info=True)
    raise # Re-raise the exception after logging

# --- Optimizer --- #
try:
    optimizer = optim.AdamW(model.parameters(), lr=config["lr"], weight_decay=config["weight_decay"])
    print("Optimizer (AdamW) initialized.")
except Exception as e:
    logging.error(f"Error initializing optimizer: {e}", exc_info=True)
    raise

# --- Loss Function --- #
if config["loss_type"] == 'mse':
    loss_fn = masked_mse_loss
    print("Loss function set to masked_mse_loss.")
elif config["loss_type"] == 'mae':
    loss_fn = masked_mae_loss
    print("Loss function set to masked_mae_loss.")
else:
    raise ValueError(f"Unsupported loss type: {config['loss_type']}")

# --- LR Scheduler --- #
try:
    scheduler = ReduceLROnPlateau(optimizer, 'min', patience=10, factor=0.5, verbose=True)
    print("Initialized ReduceLROnPlateau scheduler.")
except Exception as e:
    logging.error(f"Error initializing scheduler: {e}", exc_info=True)
    scheduler = None # Allow training without scheduler if init fails
    print("Proceeding without LR scheduler due to initialization error.")


print("\nModel, optimizer, loss function, and scheduler setup complete.")


In [None]:
# %% Training Loop (Branched Model)

# --- Helper functions ---
# Get the warmup epochs from config or default to 5
warmup_epochs = config.get("warmup_epochs", 5)

# --- Training Loop ---
# Initialize metrics tracking
best_val_loss = float('inf')
best_val_r2 = -float('inf')
patience_counter = 0
training_log = []

# Create run directory
model_save_dir = Path(config['run_dir']) / "checkpoints"
model_save_dir.mkdir(parents=True, exist_ok=True)

# Save config to run directory
config_path = Path(config['run_dir']) / "config.json"
with open(config_path, 'w') as f:
    json.dump(config, f, indent=2, default=lambda x: str(x) if isinstance(x, (Path, torch.device)) else x)

# Initialize wandb
try:
    import wandb
    wandb_available = True
    print("Weights & Biases (wandb) available for logging.")
except ImportError:
    wandb_available = False
    wandb = None
    print("Weights & Biases (wandb) not available. Skipping wandb logging.")

if wandb_available:
    # Configure wandb
    wandb.init(
        project=config['wandb_project_name'],
        name=f"{config['wander_run_name_prefix']}_{datetime.now().strftime('%Y%m%d_%H%M')}",
        config=config
    )
    
    # Optional: watch model parameters
    wandb.watch(model)

print(f"Starting training for {config['epochs']} epochs with patience {config['patience']}")

try:
    for epoch in range(config['epochs']):
        epoch_start = time.time()
        
        # --- Train --- #
        if train_loader:
            # Use generic train function from train_utils
            train_loss, train_rmse, train_r2 = train_utils.train_epoch_generic(
                model, train_loader, optimizer, loss_fn, device, 
                uhi_mean=config['uhi_mean'], 
                uhi_std=config['uhi_std'],
                feature_flags=config['feature_flags'],
                max_grad_norm=config.get("max_grad_norm", 1.0)
            )
            print(f"Train Loss: {train_loss:.4f}, Train RMSE: {train_rmse:.4f}, Train R2: {train_r2:.4f}")
            if np.isnan(train_loss):
                print("Warning: Training loss is NaN. Stopping training.")
                break
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        else:
            print("Skipping training: train_loader is None.")
            train_loss, train_rmse, train_r2 = float('nan'), float('nan'), float('nan')
            log_metrics = {"epoch": epoch + 1, "train_loss": train_loss, "train_rmse": train_rmse, "train_r2": train_r2}
        
        # Log train metrics AFTER checking for NaN
        if wandb:
            wandb.log(log_metrics)
        training_log.append(log_metrics) # Append to local log regardless of W&B


        # --- Validate --- #
        if val_loader:
            # Use generic validate function from train_utils
            val_loss, val_rmse, val_r2 = train_utils.validate_epoch_generic(
                model, val_loader, loss_fn, device, 
                uhi_mean=config['uhi_mean'], 
                uhi_std=config['uhi_std'],
                feature_flags=config['feature_flags']
            )
            print(f"Val Loss:   {val_loss:.4f}, Val RMSE:   {val_rmse:.4f}, Val R2:   {val_r2:.4f}")
            if np.isnan(val_loss):
                print("Warning: Validation Loss is NaN. Stopping training.")
                break
            val_metrics = {"val_loss": val_loss, "val_rmse": val_rmse, "val_r2": val_r2}
            log_metrics.update(val_metrics)
            
            # Log validation metrics
            if wandb:
                wandb.log(val_metrics)
            
            # Warmup period: don't save or check early stopping until after warmup_epochs
            if epoch >= warmup_epochs:
                # Check for improvement (using validation loss now)
                if val_loss < best_val_loss:
                    best_val_loss = val_loss
                    best_val_r2 = val_r2
                    patience_counter = 0
                    
                    # Save best model
                    train_utils.save_checkpoint({
                        'epoch': epoch + 1,
                        'state_dict': model.state_dict(),
                        'optimizer': optimizer.state_dict(),
                        'scheduler': scheduler.state_dict() if scheduler else None,
                        'loss': val_loss,
                        'val_rmse': val_rmse,
                        'val_r2': val_r2,
                        'config': config
                    }, is_best=True, output_dir=model_save_dir)
                    print(f"New best model saved at epoch {epoch+1} with val_loss {val_loss:.4f}")
                else:
                    patience_counter += 1
                    print(f"No improvement. Patience: {patience_counter}/{config['patience']}")
                    
                    # Early stopping check
                    if patience_counter >= config['patience']:
                        print(f"Early stopping triggered after {epoch+1} epochs")
                        break
            else:
                print(f"Warmup epoch {epoch+1}/{warmup_epochs}. Skipping checkpointing and early stopping.")
        else:
            print("Skipping validation: val_loader is None.")
            
        # Step the scheduler after validation (if it exists)
        if scheduler:
            if isinstance(scheduler, torch.optim.lr_scheduler.ReduceLROnPlateau):
                scheduler.step(val_loss)
            else:
                scheduler.step()
            if wandb:
                wandb.log({"lr": optimizer.param_groups[0]['lr']})
                
        # Print epoch summary
        epoch_time = time.time() - epoch_start
        print(f"Epoch {epoch+1}/{config['epochs']} completed in {epoch_time:.2f}s")
        print(f"Current LR: {optimizer.param_groups[0]['lr']:.2e}")
        print("-" * 80)
            
    print("Training complete!")
    print(f"Best validation loss: {best_val_loss:.4f}, Best R2: {best_val_r2:.4f}")
    
    # Save the final model
    final_model_path = model_save_dir / "final_model.pt"
    torch.save({
        'epoch': epoch + 1,
        'model_state_dict': model.state_dict(),
        'optimizer_state_dict': optimizer.state_dict(),
        'scheduler_state_dict': scheduler.state_dict() if scheduler else None,
        'config': config
    }, final_model_path)
    print(f"Final model saved to {final_model_path}")
    
    # Save training log
    log_df = pd.DataFrame(training_log)
    log_path = Path(config['run_dir']) / "training_log.csv"
    log_df.to_csv(log_path, index=False)
    print(f"Training log saved to {log_path}")
    
except Exception as e:
    print(f"Error during training: {str(e)}")
    raise
finally:
    # Finish wandb run
    if wandb:
        wandb.finish()

## Training