# 4. Batch Inference Pipeline

## 4.1. Environment Setup
Detect if running in Google Colab or local environment, handle repository cloning, dependency installation, numpy compatibility fixes, and set up Python path.

In [None]:
import sys
from pathlib import Path
import hopsworks
import warnings

warnings.filterwarnings("ignore", module="IPython")

def clone_repository() -> None:
    repo_dir = Path("pm25-forecast-openmeteo-aqicn")
    if repo_dir.exists():
        print(f"Repository already exists at {repo_dir.absolute()}")
        %cd pm25-forecast-openmeteo-aqicn
    else:
        print("Cloning repository...")
        !git clone https://github.com/KristinaPalmquist/pm25-forecast-openmeteo-aqicn.git
        %cd pm25-forecast-openmeteo-aqicn

def install_dependencies() -> None:
    !pip install --upgrade uv
    !uv pip install --all-extras --system --requirement pyproject.toml


root_dir = Path().absolute()
for folder in ("src", "airquality", "notebooks"):
    if root_dir.parts[-1:] == (folder,):
        root_dir = Path(*root_dir.parts[:-1])
root_dir = str(root_dir)

if root_dir not in sys.path:
    sys.path.append(root_dir)

from utils import config

settings = config.HopsworksSettings(_env_file=f"{root_dir}/.env")
HOPSWORKS_API_KEY = settings.HOPSWORKS_API_KEY.get_secret_value()
project = hopsworks.login(engine="python", api_key_value=HOPSWORKS_API_KEY)
fs = project.get_feature_store()

## 4.2. Imports

In [None]:
import datetime
import pandas as pd
import numpy as np
from xgboost import XGBRegressor
import hopsworks
import json
from utils import airquality
from scipy.spatial.distance import cdist
import matplotlib.colors as mcolors
import matplotlib.pyplot as plt
import os

warnings.filterwarnings("ignore")

## 4.3. Hopsworks Configuration
Establish connection to Hopsworks, retrieve API keys, connect to feature store, and get air quality and weather feature groups.

In [None]:
HOPSWORKS_API_KEY = getattr(settings, 'HOPSWORKS_API_KEY', None)

if HOPSWORKS_API_KEY is not None and hasattr(HOPSWORKS_API_KEY, 'get_secret_value'):
    HOPSWORKS_API_KEY = HOPSWORKS_API_KEY.get_secret_value()

project = hopsworks.login(engine="python", api_key_value=HOPSWORKS_API_KEY)

fs = project.get_feature_store()

secrets = hopsworks.get_secrets_api()
AQICN_API_KEY = secrets.get_secret("AQICN_API_KEY").value


today = datetime.datetime.today().date()
past_date = today - datetime.timedelta(days=4)

# Retrieve feature groups
air_quality_fg = fs.get_feature_group(
    name="air_quality_all",
    version=1,
)

weather_fg = fs.get_feature_group(
    name="weather_all",
    version=1,
)

## 4.4. Sensor Location Loading 
Load sensor location metadata from Hopsworks secrets for all sensors.

In [None]:
all_secrets = secrets.get_secrets()
locations = {}
for secret in all_secrets:
    if secret.name.startswith("SENSOR_LOCATION_JSON_"):
        sensor_id = secret.name.replace("SENSOR_LOCATION_JSON_", "")
        location_str = secrets.get_secret(secret.name).value
        if location_str:
            locations[sensor_id] = json.loads(location_str)

## 4.5. Weather Data Loading
Fetch recent weather data from feature store and convert date formats

In [None]:
try:
    batch_weather = weather_fg.filter(weather_fg.date >= past_date).read()
    # print(batch_weather.date.max())
except Exception:
    batch_weather = weather_fg.read()
    batch_weather = batch_weather[batch_weather["date"] >= past_date]
batch_weather["date"] = pd.to_datetime(batch_weather["date"]).dt.tz_localize(None)

## 4.6. Air Quality Data Loading
Fetch recent air quality with error handling for missing data.

In [None]:
try:
    batch_airquality = air_quality_fg.filter(air_quality_fg.date >= past_date).read()
    batch_airquality["date"] = pd.to_datetime(batch_airquality["date"]).dt.tz_localize(None)
except Exception:
    batch_airquality = pd.DataFrame()
print(f"Retrieved {len(batch_airquality)} air quality records from Hopsworks Feature Store.")

## 4.7. Model Retrieval
Download trained XGBoost models from Hopsworks model registry for each sensor and extract feature names.

In [None]:
mr = project.get_model_registry()

MODEL_NAME_TEMPLATE = "air_quality_xgboost_model_{sensor_id}"

# model, model_dir, features
retrieved_models = {}

for sensor_id in locations.keys():
    model_name = MODEL_NAME_TEMPLATE.format(sensor_id=sensor_id)
    retrieved_model = None

    available_models = mr.get_models(name=model_name)
    if available_models:
        retrieved_model = max(available_models, key=lambda model: model.version)

    if retrieved_model is None:
        print(f"No model found for sensor {sensor_id}, skipping...")
        continue
    
    saved_model_dir = retrieved_model.download()
    
    import xgboost as xgb
    booster = xgb.Booster()
    booster.load_model(saved_model_dir + "/model.json")
    xgb_model = XGBRegressor()
    xgb_model._Booster = booster

    retrieved_models[sensor_id] = retrieved_model, xgb_model, booster.feature_names

In [None]:
print(f"Retrieved {len(retrieved_models)} models.")

## 4.8. Batch Prediction Loop
Merge weather and air quality data, iteratively predict PM2.5 values for forecast days, update engineered features after each prediction, and store results

In [None]:
feature_cols = [
    "pm25_rolling_3d",
    "pm25_lag_1d",
    "pm25_lag_2d",
    "pm25_lag_3d",
    "pm25_nearby_avg",
]

In [None]:
# Merge historical data with weather data
batch_data = pd.merge(batch_weather, batch_airquality, on=["date", "sensor_id"], how="left")
batch_data = batch_data.sort_values(["sensor_id", "date"])

In [None]:
batch_data["predicted_pm25"] = np.nan
batch_data["days_before_forecast_day"] = np.nan
for col in feature_cols:
    batch_data[f"predicted_{col}"] = np.nan

forecast_days = (
    batch_data.loc[batch_data["pm25"].isna() & (batch_data["date"] >= today.strftime("%Y-%m-%d")), "date"]
    .dropna()
    .sort_values()
    .unique()
)

In [None]:
for target_day in forecast_days:
    # context with all sensors up to current day
    window = batch_data.loc[batch_data["date"] <= target_day].copy()
    day_rows = window[(window["date"] == target_day) & window["pm25"].isna()]

    for _, row in day_rows.iterrows():
        sensor_id = row["sensor_id"]

        _, xgb_model, model_features = retrieved_models[sensor_id]
        features = (row.reindex(model_features).to_frame().T.apply(pd.to_numeric, errors="coerce"))
        y_hat = xgb_model.predict(features)[0]

        idx = batch_data.index[(batch_data["sensor_id"] == sensor_id) & (batch_data["date"] == target_day)][0]
        batch_data.at[idx, "pm25"] = y_hat
        batch_data.at[idx, "predicted_pm25"] = y_hat
        batch_data.at[idx, "days_before_forecast_day"] = (target_day.date() - today).days + 1

    temp_df = batch_data.loc[batch_data["date"] <= target_day].copy()
    temp_df = airquality.add_rolling_window_feature(
        temp_df, window_days=3, column="pm25", new_column="pm25_rolling_3d"
    )
    temp_df = airquality.add_lagged_features(temp_df, column="pm25", lags=[1, 2, 3])
    temp_df = airquality.add_nearby_sensor_feature(
        temp_df,
        locations,
        column="pm25",
        n_closest=3,
        new_column="pm25_nearby_avg",
    )

    current_rows = temp_df[temp_df["date"] == target_day]
    for _, row in current_rows.iterrows():
        sensor_id = row["sensor_id"]
        mask = (batch_data["sensor_id"] == sensor_id) & (batch_data["date"] == target_day)
        if mask.any():
            for col in feature_cols:
                batch_data.loc[mask, f"predicted_{col}"] = row[col]

predictions = batch_data.loc[
    batch_data["predicted_pm25"].notna(),
    ["date", "sensor_id", "predicted_pm25", "days_before_forecast_day"]
    + [f"predicted_{col}" for col in feature_cols],
].reset_index(drop=True)
batch_data.loc[batch_data["date"].dt.date > today, "pm25"] = np.nan

## 4.9. Save Predictions
Export prediction results to CSV file in models directory.

In [None]:
batch_data.to_csv(f"{root_dir}/models/predictions.csv", columns=batch_data.columns, index=False)

## 4.10. Generate Forecast Plots
Create forecast visualization plots for each sensor and upload them to Hopsworks dataset storage.

In [None]:
forecast_paths = []

for sensor_id, location in locations.items():
    sensor_forecast = predictions[predictions["sensor_id"] == sensor_id].copy()

    city, street = location["city"], location["street"]
    forecast_path = f"{root_dir}/models/{sensor_id}/images/forecast.png"
    Path(forecast_path).parent.mkdir(parents=True, exist_ok=True)

    plt = airquality.plot_air_quality_forecast(
        location["city"],
        location["street"],
        sensor_forecast,
        forecast_path,
        hindcast=False,
    )
    plt.close()
    forecast_paths.append((sensor_id, forecast_path))

dataset_api = project.get_dataset_api()
today_short = today.strftime("%Y-%m-%d")
if not dataset_api.exists("Resources/airquality"):
    dataset_api.mkdir("Resources/airquality")

for sensor_id, forecast_path in forecast_paths:
    dataset_api.upload(
        forecast_path,
        f"Resources/airquality/{sensor_id}_{today_short}_forecast.png",
        overwrite=True,
    )
print(f"Forecast plots available in Hopsworks under {project.get_url()}/settings/fb/path/Resources/airquality")

## 4.11. Insert Monitoring Data
Save predictions to monitoring feature group in Hopsworks for tracking.

In [None]:
monitor_fg = fs.get_or_create_feature_group(
    name="aq_predictions",
    description="Air Quality prediction monitoring",
    version=1,
    primary_key=["sensor_id", "date", "days_before_forecast_day"],
    event_time="date",
)
monitor_fg.insert(predictions, wait=True)


## 4.12. Hindcast Analysis
Compare predicted with forecasted values (1-day prior forecast)

In [None]:
monitoring_df = monitor_fg.filter(monitor_fg.days_before_forecast_day == 1).read()
monitoring_df["date"] = pd.to_datetime(monitoring_df["date"]).dt.tz_localize(None)

air_quality_df = air_quality_fg.read()[["date", "sensor_id", "pm25"]]
air_quality_df["date"] = pd.to_datetime(air_quality_df["date"]).dt.tz_localize(None)

for sensor_id, location in locations.items():
    try:
        sensor_preds = monitoring_df[monitoring_df["sensor_id"] == sensor_id][["date", "predicted_pm25"]]
        merged = sensor_preds.merge(
            air_quality_df[air_quality_df["sensor_id"] == sensor_id][["date", "pm25"]],
            on="date",
            how="inner",
        ).sort_values("date")

        city, street = location["city"], location["street"]
        hindcast_path = f"{root_dir}/models/{sensor_id}/images/hindcast_prediction.png"
        Path(hindcast_path).parent.mkdir(parents=True, exist_ok=True)

        plt = airquality.plot_air_quality_forecast(
            city,
            street,
            merged if not merged.empty else sensor_preds.assign(pm25=np.nan),
            hindcast_path,
            hindcast=True,
        )
        plt.close()

        dataset_api.upload(
            hindcast_path,
            f"Resources/airquality/{sensor_id}_{today:%Y-%m-%d}_hindcast.png",
            overwrite=True,
        )
            
    except Exception as e:
        print(f"⚠️  Error processing hindcast for sensor {sensor_id}: {e}")

## 4.13 IDW Heatmap
IDW - Inverse Distance Weighting

### 4.13.1 IDW interpolation function

In [None]:
def idw_interpolation(points, values, grid_points, lon_mesh, power=2):
    distances = cdist(grid_points, points)
    distances = np.where(distances == 0, 1e-10, distances)
    weights = 1.0 / (distances ** power)
    weights_sum = np.sum(weights, axis=1)
    interpolated = np.sum(weights * values, axis=1) / weights_sum
    return interpolated.reshape(lon_mesh.shape)

In [None]:
def plot_pm25_idw_heatmap(
    predictions: pd.DataFrame,
    locations: dict,
    forecast_date: datetime.datetime,
    path: str,
    grid_bounds=(-7.602536,50.862218,36.738284,69.923179),
    grid_resolution=800,
    power=2,
):

    df_day = predictions[predictions["date"] == forecast_date].copy()

    sensor_coords = np.array([[locations[sid]["longitude"], locations[sid]["latitude"]]
                              for sid in df_day["sensor_id"].unique() if sid in locations])

    pm25_column = "predicted_pm25"
    if df_day["predicted_pm25"].isna().any():
        pm25_column = "pm25"

    pm25_values = np.array([df_day[df_day["sensor_id"] == sid][pm25_column].iloc[0]
                            for sid in df_day["sensor_id"].unique() if sid in locations])
    
    # Cap extreme values to prevent unrealistic interpolation
    pm25_values = np.clip(pm25_values, 0, 150)

    min_lon, min_lat, max_lon, max_lat = grid_bounds

    lon_grid = np.linspace(min_lon, max_lon, grid_resolution)
    lat_grid = np.linspace(min_lat, max_lat, grid_resolution)
    lon_mesh, lat_mesh = np.meshgrid(lon_grid, lat_grid)
    grid_points = np.column_stack([lon_mesh.ravel(), lat_mesh.ravel()])

    idw_result = idw_interpolation(sensor_coords, pm25_values, grid_points, lon_mesh, power=power)

    default_levels = np.array([0, 12, 35, 55, 150, 250, 500])
    category_colors = [
            "#00e400", "#7de400", "#ffff00", "#ffb000",
            "#ff7e00", "#ff4000", "#ff0000", "#d00050",
            "#8f3f97", "#7e0023"
        ]
    vmin, vmax = default_levels[0], 150
    
    clipped = np.clip(idw_result, vmin, vmax)
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(
        clipped,
        extent=(min_lon, max_lon, min_lat, max_lat),
        origin="lower",
        cmap=mcolors.LinearSegmentedColormap.from_list("aqi", category_colors, N=512),
        vmin=vmin,
        vmax=vmax,
        alpha=0.5,
    )
    ax.set_xlim(min_lon, max_lon)
    ax.set_ylim(min_lat, max_lat)
    ax.axis("off")

    fig.savefig(path, dpi=300, bbox_inches="tight", pad_inches=0, transparent=True)
    plt.close(fig)

In [None]:
interpolation_dir = f"{root_dir}/models/interpolation"
if not os.path.exists(interpolation_dir):
    os.mkdir(interpolation_dir)

today_short = today.strftime("%Y-%m-%d")
interpolation_df = predictions.copy()
for i, forecast_date in enumerate(sorted(interpolation_df["date"].unique())):
    forecast_date_short = forecast_date.strftime("%Y-%m-%d")
    output_png = f"{interpolation_dir}/forecast_interpolation_{i}d.png"
    
    plot_pm25_idw_heatmap(
        interpolation_df,
        locations,
        forecast_date,
        output_png,
    )
    dataset_api.upload(
        output_png,
        f"Resources/airquality/interpolation_{today_short}_{forecast_date_short}.png",
        overwrite=True,
    )