In [None]:
# Load library imports
import sys
import torch
import random
import joblib
import logging
import importlib
import numpy as np
import pandas as pd

# Load project Imports
from src.utils.config_loader import load_project_config
from src.data_ingestion.gwl_data_ingestion import process_station_coordinates, \
    fetch_and_process_station_data, download_and_save_station_readings
from src.preprocessing.gwl_preprocessing import load_timeseries_to_dict, outlier_detection, \
    resample_daily_average, remove_spurious_data, interpolate_short_gaps
from src.preprocessing.gap_imputation import handle_large_gaps

In [None]:
# Set up logging config
logging.basicConfig(
    level=logging.INFO,
   format='%(levelname)s - %(message)s',
#    format='%(asctime)s - %(levelname)s - %(name)s - %(message)s',
    handlers=[logging.StreamHandler(sys.stdout)]
)

# Set up logger for file and load config file for paths and params
logger = logging.getLogger(__name__)
config = load_project_config(config_path="config/project_config.yaml")
notebook = True

# Set up seeding to define global states
random_seed = config["global"]["pipeline_settings"]["random_seed"]
random.seed(random_seed)
np.random.seed(random_seed)
torch.manual_seed(random_seed)
torch.cuda.manual_seed_all(random_seed)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

# Define notebook demo catchment
catchments_to_process = config["global"]["pipeline_settings"]["catchments_to_process"]
catchment = catchments_to_process[0]
run_defra_API_calls = config["global"]["pipeline_settings"]["run_defra_api"]  # True to run API calls

logging.info(f"Show Notebook Outputs: {notebook}")
logging.info(f"Notebook Demo Catchment: {catchment.capitalize()}")

### DATA INGESTION ###

Load gwl station list with grid references and convert grid references to easting, northing, longitude and latitude form for plotting and data alignment.

In [None]:
# --- Process Catchment Stations List ----
stations_with_coords_df = process_station_coordinates(
    os_grid_squares=config["global"]["paths"]["gis_os_grid_squares"],
    station_list_input=config[catchment]["paths"]["gwl_station_list"],
    station_list_output=config[catchment]["paths"]["gwl_station_list_with_coords"],
    catchment=catchment
)

logger.info(f"Pipeline step 'Process Station Coordinates for {catchment}' complete.\n")

**API Documentation notes:**

1. The API calls that return readings data have a soft limit of 100,000 rows per-call which can be overridden by setting a _limit parameter. There is a hard limit of 2,000,000 rows, which cannot be overridden.
2. The primary identifier for most stations uses a GUID style identifier called an SUID. These are used in the URL for the station and given as the value of the notation property in the station metadata.  
    a. Wiski identifier (wiskiID) is also available for my subset of stations and data type  
3. All monitoring stations can be filtered by name, location and other parameters. See https://environment.data.gov.uk/hydrology/doc/reference#stations-summary for full metadata details

In [None]:
if run_defra_API_calls:
    # Retrieve gwl monitoring station metadata and measures from DEFRA API
    stations_with_metadata_measures = fetch_and_process_station_data(
        stations_df=stations_with_coords_df,
        base_url=config["global"]["paths"]["defra_station_base_url"],
        output_path=config[catchment]["paths"]["gwl_station_metadata_measures"]
    )

    logger.info(f"Pipeline step 'Pull Hydrological Station Metadata for {catchment}' complete.\n")

    stations_with_metadata_measures.head()

In [None]:
if run_defra_API_calls:
    download_and_save_station_readings(
        stations_df=stations_with_metadata_measures,
        start_date=config["global"]["data_ingestion"]["api_start_date"],
        end_date=config["global"]["data_ingestion"]["api_end_date"],
        gwl_data_output_dir=config[catchment]["paths"]["gwl_data_output_dir"]
    )

    logger.info(f"All timeseries groundwater level data saved for {catchment} catchment.")
    
else:
    
    loaded_csv_path = config[catchment]["paths"]["gwl_station_metadata_measures"]
    stations_with_metadata_measures = pd.read_csv(loaded_csv_path)

### PREPROCESSING ###

Remove stations with insufficient data and clean ts data from outliers and incorrect measurements. Interpolate between small data gaps using rational spline.

1. Load station df's into dict, dropping catchments with insufficient data

In [None]:
# Load timeseries CSVs from API into reference dict
gwl_time_series_dict = load_timeseries_to_dict(
    stations_df=stations_with_metadata_measures,
    col_order=config["global"]["data_ingestion"]["col_order"],
    data_dir=config[catchment]["paths"]["gwl_data_output_dir"],
    inclusion_threshold=config[catchment]["preprocessing"]["inclusion_threshold"]
)

logger.info(f"All timeseries data converted to dict for {catchment} catchment.\n")

2. Remove outlying and incorrect data points

In [None]:
for station_name, df in gwl_time_series_dict.items():
    gwl_time_series_dict[station_name] = remove_spurious_data(
        target_df=df,
        station_name=station_name,
        path=config[catchment]["visualisations"]["ts_plots"]["time_series_gwl_output"],
        notebook=notebook
    )

In [None]:
run_outlier_processing = config["global"]["pipeline_settings"]["run_outlier_detection"]

if run_outlier_processing:
    # run outlier detection and processing
    processed_gwl_time_series_dict = outlier_detection(
        gwl_time_series_dict=gwl_time_series_dict,
        output_path=config[catchment]["visualisations"]["ts_plots"]["time_series_gwl_output"],
        dpi=config[catchment]["visualisations"]["ts_plots"]["dpi_save"],
        dict_output=config[catchment]["paths"]["gwl_outlier_dict"],
        notebook=notebook
    )

3. Aggregate to daily time steps

In [None]:
if not run_outlier_processing:
    input_dict = config[catchment]["paths"]["gwl_outlier_dict"]
    processed_gwl_time_series_dict = joblib.load(input_dict)

daily_data = resample_daily_average(
    dict=processed_gwl_time_series_dict,
    start_date=config["global"]["data_ingestion"]["api_start_date"],
    end_date=config["global"]["data_ingestion"]["api_end_date"],
    path=config[catchment]["visualisations"]["ts_plots"]["time_series_gwl_output"],
    notebook=notebook
)

4. Interpolate across small gaps in the ts data using rational spline or PCHIP - try both (& define threshold n/o missing time steps for interpolation eligibility) + Add binary interpolation flag column

In [None]:
for station_name, df_data in daily_data.items():
    if 'dateTime' in df_data.columns:
        df_data['dateTime'] = pd.to_datetime(df_data['dateTime'], errors='coerce')
        df_data = df_data.set_index('dateTime').sort_index()
        daily_data[station_name] = df_data # Update the dict with the indexed DataFrame

gaps_list = []
station_max_gap_lengths_calculated = {}

for station_name, df in daily_data.items():
    gap_status_for_large_interp, updated_df, max_gap_len_for_this_station = interpolate_short_gaps(
        df=df,
        station_name=station_name,
        path=config[catchment]["visualisations"]["ts_plots"]["time_series_gwl_output"],
        max_steps=config["global"]["data_ingestion"]["max_interp_length"],
        notebook=notebook
    )
    
    # Update daily_data with the processed (interpolated) DataFrame
    daily_data[station_name] = updated_df

    if gap_status_for_large_interp: # If the station still needs large gap interp
        gaps_list.append(station_name)
        if max_gap_len_for_this_station > 0: # Only store if there was an actual large gap
            station_max_gap_lengths_calculated[station_name] = max_gap_len_for_this_station
        
logging.info(f"Stations still needing interpolation: {gaps_list}\n")
logging.info(f"Max uninterpolated gap lengths per station:\n{station_max_gap_lengths_calculated}\n")

In [None]:
# Define the full date range based on your config
start_date = pd.to_datetime(config["global"]["data_ingestion"]["api_start_date"])
end_date = pd.to_datetime(config["global"]["data_ingestion"]["api_end_date"])
full_date_range = pd.date_range(start=start_date, end=end_date, freq='D')

for station_name, df_data in daily_data.items():
    if 'dateTime' in df_data.columns:
        df_data['dateTime'] = pd.to_datetime(df_data['dateTime'], errors='coerce')
        df_data = df_data.set_index('dateTime').sort_index()
        df_data = df_data.reindex(full_date_range)
        daily_data[station_name] = df_data # Update the dict with the reindexed DataFrame

Handle large gaps

In [None]:
synthetic_imputation_performace, trimmed_df_dict = handle_large_gaps(
    df_dict=daily_data,
    gaps_list=gaps_list,
    catchment=catchment,
    spatial_path=config[catchment]["paths"]["gwl_station_list_with_coords"],
    path=config[catchment]["visualisations"]["ts_plots"]["time_series_gwl_output"],
    threshold_m=config[catchment]["preprocessing"]["large_catchment_threshold_m"],
    radius=config["global"]["preprocessing"]["radius"],
    output_path=config[catchment]["visualisations"]["corr_dist_score_scatters"],
    threshold=config[catchment]["preprocessing"]["dist_corr_score_threshold"],
    predefined_large_gap_lengths=config["global"]["preprocessing"]["gap_lengths_days"] ,
    max_imputation_length_threshold=config["global"]["preprocessing"]["max_imputation_threshold"],
    min_around=config["global"]["preprocessing"]["min_data_points_around_gap"],
    station_max_gap_lengths=station_max_gap_lengths_calculated,
    model_start_date=config['global']['data_ingestion']['model_start_date'],
    model_end_date=config['global']['data_ingestion']['model_end_date'],
    k_decay=config[catchment]["preprocessing"]["dist_corr_score_k_decay"],
    notebook=notebook,
    random_seed=config["global"]["pipeline_settings"]["random_seed"]
)

In [None]:
for station, dict in trimmed_df_dict.items():
    salary_nan_count = dict['value'].isna().sum()
    print(f"{station}: {salary_nan_count} NaNs")

**NEXT: ADD IMPUTATION FLAGS**

5. Lagged: Add lagged features (by timestep across 7 days?) + potentially rolling averages (3-day/7-day?)

6. Temporal Encoding: Define sinasoidal features for seasonality (both sine and cosine for performance)

To zoom in on an area of a graph (for checking imputation etc during code dev):

In [None]:
# import matplotlib.pyplot as plt
# import matplotlib.dates as mdates
# import pandas as pd

# start_date = '2022-12-01'
# end_date = '2023-12-31'

# # Filter the DataFrame
# df_renwick = daily_data['Ainstable']
# df_renwick['dateTime'] = pd.to_datetime(df_renwick['dateTime'], errors='coerce')
# filtered_df_renwick = df_renwick[(df_renwick['dateTime'] >= start_date) & (df_renwick['dateTime'] <= end_date)]

# # print(filtered_df_renwick.head(20))

# fig, ax = plt.subplots(figsize=(12, 6)) # Use fig, ax for more control

# ax.plot(filtered_df_renwick['dateTime'], filtered_df_renwick['value'])
# ax.set_title(f"Renwick Groundwater Level: {start_date} to {end_date}")
# ax.set_xlabel("Date")
# ax.set_ylabel("Groundwater Level (mAOD)")

# ax.xaxis.set_minor_locator(mdates.WeekdayLocator(interval=1))
# ax.yaxis.set_minor_locator(plt.MultipleLocator(0.1))
# ax.grid(True, which='major', linestyle='-', linewidth=0.5)
# ax.grid(True, which='minor', linestyle=':', linewidth=0.2) # Finer, dashed minor grid


# fig.autofmt_xdate() # Auto-formats date labels for readability
# plt.show()

# # print(filtered_df.head())