In [None]:
# Load library imports
import os
import sys
import torch
import random
import logging
import datetime
import numpy as np
import pandas as pd
import geopandas as gpd
from collections import Counter

# Load project Imports
from src.utils.config_loader import load_project_config
from src.graph_building.graph_construction import build_mesh, \
    define_catchment_polygon, define_graph_adjacency
from src.preprocessing.data_partitioning import define_station_id_splits, \
    load_graph_tensors, build_pyg_object
from src.preprocessing.model_feature_engineering import preprocess_gwl_features, \
    preprocess_shared_features

In [None]:
# Set up logger config
logging.basicConfig(
    level=logging.INFO,
   format='%(levelname)s - %(message)s',
#    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

# Set up logger for file and load config file for paths and params
logger = logging.getLogger(__name__)
config = load_project_config(config_path="config/project_config.yaml")
notebook = True

# Set up seeding to define global states
random_seed = config["global"]["pipeline_settings"]["random_seed"]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define notebook demo catchment
catchments_to_process = config["global"]["pipeline_settings"]["catchments_to_process"]
catchment = catchments_to_process[0]
run_defra_API_calls = config["global"]["pipeline_settings"]["run_defra_api"]

logger.info(f"Show Notebook Outputs: {notebook}")
logger.info(f"Notebook Demo Catchment: {catchment.capitalize()}")

In [None]:
all_timesteps_list = torch.load(config[catchment]["paths"]["pyg_object_path"])
all_timesteps_list

In [None]:
# Sanity Checks
example_data_obj = all_timesteps_list[0]
print(example_data_obj.x.shape)
print(example_data_obj.train_mask.sum(), example_data_obj.val_mask.sum(), example_data_obj.test_mask.sum())
print(example_data_obj.timestep)

train_counts = sum(d.train_mask.sum().item() for d in all_timesteps_list)
val_counts = sum(d.val_mask.sum().item() for d in all_timesteps_list)
test_counts = sum(d.test_mask.sum().item() for d in all_timesteps_list)

print(f"\nTotal (across all timesteps) — Train: {train_counts}, Val: {val_counts}, Test: {test_counts}\n")

x_tensor = example_data_obj.x
x_tensor