## Setup and Imports
Python environment setup and necessary library imports.

In [None]:
from __future__ import annotations

import logging
import time
from pathlib import Path

import matplotlib.pyplot as plt
import rioxarray as rxr
import geopandas as gpd # For type hinting if needed
import xarray as xr # For type hinting if needed
import numpy as np # For potential transformations

# Assuming functions.py is in the same directory or PYTHONPATH
from functions import (
    load_aq,
    load_covariates,
    linear_aq_model,
    predict_linear_model,
    krige_aq_residuals,
    combine_results,
    plot_aq_prediction,
    plot_aq_se,
    find_poll # if needed directly, though usually internal
)

logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
LOG = logging.getLogger() # Use root logger for notebooks or get specific
LOG.setLevel(logging.INFO) # Ensure log level

## Parameters
Define parameters for the analysis, mirroring the R script.

In [None]:
year = 2020
month = 0  # Set to 0 for annual data, >0 for monthly
day_of_year = 0 # Set to 0 if using month, >0 for daily
pollutant = "O3"
stat = "mean"
# station_area_type = "RB"  # Placeholder in Python script, filter_area_type not directly used in AQ_Demo.py

## Read AQ Data
Daily, monthly, or annual aggregates of AQ measurement data are imported using `load_aq()`.
Required inputs are the `pollutant`, the aggregation statistic (`stat`) and a temporal selection.
Specifying only the year (`y`) returns annual data.
Additionally, specifying **either** the month (`m`) **or** the day of the year (`d`) returns monthly or daily data, respectively.

In [None]:
LOG.info(f"Loading AQ measurements for {pollutant}, {stat}, year={year}, month={month}, day={day_of_year}...")
aq = load_aq(pollutant=pollutant, stat=stat, y=year, m=month, d=day_of_year)
LOG.info(f"Loaded {len(aq)} AQ measurements.")
# print(aq.head())
# print(aq.attrs)

## Add Covariates
Covariates are then loaded using `load_covariates()`. The temporal extent is extracted from the AQ measurement data object attributes. Each covariate raster layer is warped to match the spatial grid of the DEM.

The Python version `load_covariates` uses a DEM raster as the target grid.
The list of covariates loaded depends on the pollutant, as defined in `functions.py`.

Example pollutant-specific covariates (as per `functions.py` logic):
-   PM10: log_CAMS_PM10, Elevation, WindSpeed, RelativeHumidity
-   PM2.5: log_CAMS_PM2.5, Elevation, WindSpeed, RelativeHumidity, BoundaryLayerHeight
-   O3: CAMS_O3, Elevation, WindSpeed, SolarRadiation
-   NO2: CAMS_NO2, Elevation, WindSpeed, TROPOMI_NO2

In [None]:
# Define DEM path (ensure this file exists or update the path)
# Original R script used: "supplementary/static/COP-DEM/COP_DEM_Europe_mainland_10km_mask_epsg3035.rds"
# Python script uses a .tif file.
DEM_PATH = Path("Data/aq_cov_o3_2020_dem.tiff") # Path from AQ_Demo.py
# DEM_PATH = Path("supplementary/static/COP-DEM/COP_DEM_Europe_mainland_10km_mask_epsg3035.tif") # Alternative path from QMD, if converted to TIF

if not DEM_PATH.exists():
    LOG.error(f"Cannot find DEM raster: {DEM_PATH}. Update `DEM_PATH` to point to a valid raster file.")
    # raise FileNotFoundError(f"Cannot find DEM raster: {DEM_PATH}.") # Or handle gracefully
else:
    dem = rxr.open_rasterio(DEM_PATH, masked=True).sel(band=1)
    LOG.info(f"DEM loaded from {DEM_PATH} with shape {dem.shape}")

    LOG.info("Loading covariates …")
    t0 = time.time()
    # Ensure 'aq' has the necessary attributes set by load_aq
    # aq.attrs should contain 'pollutant', 'stat', 'y', 'm', 'd'
    aq_cov = load_covariates(aq, dem)
    LOG.info("Covariate load finished in %.1fs", time.time() - t0)

    # Inspection of loaded covariates
    LOG.info(f"--- Covariate Inspection (Pollutant: {aq_cov.attrs.get('pollutant')}, Year: {aq_cov.attrs.get('y')}, Month: {aq_cov.attrs.get('m') if aq_cov.attrs.get('m', 0) > 0 else 'Annual'}) ---")
    LOG.info(f"Loaded covariate layers: {list(aq_cov.data_vars.keys())}")
    LOG.info(f"Covariate dataset attributes: {aq_cov.attrs}")
    for var_name in aq_cov.data_vars:
        LOG.info(f"  Layer '{var_name}': shape {aq_cov[var_name].shape}, coords {list(aq_cov[var_name].coords.keys())}")
    LOG.info("--- End Covariate Inspection ---")
    # print(aq_cov)

## Linear Models
The R script mentions filtering by station area type (e.g., "RB" - rural background) before fitting the linear model. This step is noted as a placeholder in the Python `AQ_Demo.py` script and `filter_area_type` is not directly applied.
The `linear_aq_model` function fits an ordinary least-squares model: pollutant ~ covariates.

In [None]:
if 'aq_cov' in locals(): # Check if covariates were loaded
    # station_area_type = "RB" # As in AQ_Demo.py
    # aq_filtered = filter_area_type(aq, area_type = station_area_type) # Not implemented/used in AQ_Demo.py
    # For now, use the unfiltered 'aq' data as in AQ_Demo.py
    LOG.info("Fitting linear model …")
    linmod = linear_aq_model(aq, aq_cov) # Uses full aq dataset
    # The R² score is logged by linear_aq_model
    # To print coefficients: print(f"Model coefficients: {linmod.coef_}")
    # print(f"Model intercept: {linmod.intercept_}")
    # print(f"Feature names: {linmod.feature_names_in_}")
else:
    LOG.error("Covariates (aq_cov) not loaded. Skipping linear model fitting.")

## Predict Linear Model
The fitted linear model is predicted over the covariate raster grid.
The R script includes steps for masking based on CLC data and back-transformation for log-transformed variables.
The Python `predict_linear_model` function handles the prediction. Back-transformation, if needed, would typically be part of `combine_results` or a subsequent step.

In [None]:
if 'linmod' in locals() and 'aq_cov' in locals():
    LOG.info("Predicting linear model over covariate grid …")
    aq_cov["lm_pred"] = predict_linear_model(linmod, aq_cov)
    LOG.info("Linear model prediction added to aq_cov as 'lm_pred'.")
    # print(aq_cov["lm_pred"])

    # Note: The R script's explicit back-transformation (exp()) for lm_pred and se
    # is not directly mirrored here. If CAMS was log-transformed, the prediction
    # is in log-space. `combine_results` might handle this, or it needs explicit handling.
    # For PM10/PM2.5, CAMS is log_transformed.
    if pollutant in ["PM10", "PM2.5"] and f"log_CAMS_{pollutant.upper()}" in linmod.feature_names_in_:
        LOG.info(f"Predictions for {pollutant} are in log-space due to log_CAMS input. Consider back-transformation (exp).")
        # Example: aq_cov["lm_pred"] = np.exp(aq_cov["lm_pred"]) # If direct back-transformation is desired here
else:
    LOG.error("Linear model or covariates not available. Skipping prediction.")

## Residual Kriging
The residuals of the linear model are interpolated using `krige_aq_residuals`.
This function requires the AQ measurements, covariates (including the linear model prediction), and the linear model itself.
The R script mentions parallelization for kriging; the Python version in `functions.py` does not currently expose parallel options in `AQ_Demo.py`.

In [None]:
if 'linmod' in locals() and 'aq_cov' in locals() and "lm_pred" in aq_cov:
    LOG.info("Kriging residuals … this may take a while …")
    t0 = time.time()
    # n_max corresponds to n.max in R. CV and show.vario are not directly in AQ_Demo.py call
    krige_res = krige_aq_residuals(aq, aq_cov, linmod, n_max=10)
    LOG.info("Kriging finished in %.1fs", time.time() - t0)
    # print(krige_res) # Contains 'pred' and 'pred_se' for residuals
else:
    LOG.error("Linear model prediction or covariates not available. Skipping residual kriging.")

## Combine Results & Plot
Model prediction and Kriging output are merged using `combine_results()`.
The final prediction and its standard error are then plotted.

In [None]:
# Ensure 'result' is the output from the combine_results function call
LOG.info("Combining LM prediction and kriged residuals ...")
try:
    result = combine_results(aq_cov, krige_res)
    LOG.info("Results combined.")
    # Optionally, print result here to inspect it immediately after creation:
    # print(result) 
    # print(type(result))
except Exception as e:
    LOG.error(f"Error during combine_results: {e}", exc_info=True)


# Import necessary plotting libraries and functions if not already done
import matplotlib.pyplot as plt
from functions import plot_aq_prediction, plot_aq_se, LOG 

# Check if 'result' exists and is an xarray.Dataset
if 'result' in locals() and isinstance(result, xr.Dataset):
    LOG.info("Plotting final AQ prediction map and Standard Error map …")
    
    fig, axes = plt.subplots(1, 2, figsize=(15, 6)) # Adjust figsize as needed
    
    # Plot AQ Prediction
    try:
        plot_aq_prediction(result, ax=axes[0])
        # Title and colorbar label are handled within plot_aq_prediction
    except Exception as e:
        LOG.error(f"Error plotting AQ prediction: {e}")
        axes[0].text(0.5, 0.5, "Error plotting prediction", ha='center', va='center', transform=axes[0].transAxes)
        axes[0].set_title("Predicted AQ (Error)")
        axes[0].set_xticks([])
        axes[0].set_yticks([])

    # Plot Prediction Standard Error
    try:
        plot_aq_se(result, ax=axes[1])
        # Title and colorbar label are handled within plot_aq_se
    except Exception as e: # Catch any other unexpected error during plot_aq_se call
        LOG.error(f"Error plotting prediction standard error: {e}")
        axes[1].text(0.5, 0.5, "Error plotting SE", ha='center', va='center', transform=axes[1].transAxes)
        axes[1].set_title("Prediction Std. Error (Error)")
        axes[1].set_xticks([])
        axes[1].set_yticks([])
        
    plt.tight_layout()
    plt.show()
    
    LOG.info("Plotting finished.")
else:
    LOG.error("'result' dataset not found or not an xarray.Dataset. Skipping plotting.")
