In [None]:
# Load library imports
import os
import sys
import torch
import random
import logging
import datetime
import numpy as np
import pandas as pd
import geopandas as gpd
from collections import Counter

# Load project Imports
from src.utils.config_loader import load_project_config
from src.model.model_building import build_data_loader, instantiate_model_and_associated
from src.training.model_training import run_training_and_validation, save_train_val_losses

In [None]:
# Set up logger config
logging.basicConfig(
    level=logging.INFO,
   format='%(levelname)s - %(message)s',
#    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

# Set up logger for file and load config file for paths and params
logger = logging.getLogger(__name__)
config = load_project_config(config_path="config/project_config.yaml")
notebook = True

# Set up seeding to define global states
random_seed = config["global"]["pipeline_settings"]["random_seed"]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define notebook demo catchment
catchments_to_process = config["global"]["pipeline_settings"]["catchments_to_process"]
catchment = catchments_to_process[0]
run_defra_API_calls = config["global"]["pipeline_settings"]["run_defra_api"]

logger.info(f"Show Notebook Outputs: {notebook}")
logger.info(f"Notebook Demo Catchment: {catchment.capitalize()}")

In [None]:
all_timesteps_list = torch.load(config[catchment]["paths"]["pyg_object_path"])
all_timesteps_list

In [None]:
# Sanity Checks
example_data_obj = all_timesteps_list[0]
logger.info(example_data_obj.x.shape)
logger.info(example_data_obj.train_mask.sum(), example_data_obj.val_mask.sum(), example_data_obj.test_mask.sum())
logger.info(example_data_obj.timestep)

train_counts = sum(d.train_mask.sum().item() for d in all_timesteps_list)
val_counts = sum(d.val_mask.sum().item() for d in all_timesteps_list)
test_counts = sum(d.test_mask.sum().item() for d in all_timesteps_list)

logger.info(f"\nTotal (across all timesteps) — Train: {train_counts}, Val: {val_counts}, Test: {test_counts}\n")

x_tensor = example_data_obj.x
x_tensor

In [None]:
# --- 7a. Build Data Loaders by Timestep ---

full_dataset_loader = build_data_loader(
    all_timesteps_list=all_timesteps_list,
    batch_size = config["global"]["model"]["data_loader_batch_size"],
    shuffle = config["global"]["model"]["data_loader_shuffle"],
    catchment=catchment
)

logger.info(f"Pipeline Step 'Create PyG DataLoaders' complete for {catchment} catchment.")

In [None]:
# --- 7b. Define Graph Neural Network Architecture ---

model, device, optimizer, criterion = instantiate_model_and_associated(
    all_timesteps_list=all_timesteps_list,
    config=config,
    catchment=catchment
)

logger.info(f"Pipeline Step 'Instantiate GAT-LSTM Model' complete for {catchment} catchment.")

In [None]:
# --- 8a. Implement Training Loop ---

train_losses, val_losses = run_training_and_validation(
    num_epochs=config[catchment]["training"]["num_epochs"],
    early_stopping_patience=config[catchment]["training"]["early_stopping_patience"],
    lr_scheduler_factor=config[catchment]["training"]["lr_scheduler_factor"],
    lr_scheduler_patience=config[catchment]["training"]["lr_scheduler_patience"],
    min_lr=config[catchment]["training"]["min_lr"],
    gradient_clip_max_norm=config[catchment]["training"]["gradient_clip_max_norm"],
    model_save_dir=config[catchment]["paths"]["model_dir"],
    loss_delta=config[catchment]["training"]["loss_delta"],
    verbose=config[catchment]["training"]["verbose"],
    catchment=catchment,
    model=model,
    device=device,
    optimizer=optimizer,
    criterion=criterion,
    all_timesteps_list=all_timesteps_list,
    scalers_dir=config[catchment]["paths"]["scalers_dir"],
    config=config
)

logger.info(f"Pipeline Step 'Train and Validate Model' complete for {catchment} catchment.")

save_train_val_losses(
    output_analysis_dir=config[catchment]["paths"]["model_dir"],
    train_losses=train_losses,
    val_losses=val_losses
)

logger.info(f"Pipeline Step 'Save Training and Validation Losses' complete for {catchment} catchment.")

In [None]:
# --- Next Steps: Final Model Evaluation on Test Set (using all_timesteps_list and data.test_mask) ---
