# 4. Batch Inference Pipeline

## 4.1. Setup

### 4.1.1. Import Libraries

In [None]:
# Standard imports
import os
from pathlib import Path
import sys
import json
import time
from datetime import date, datetime, timedelta
from dotenv import load_dotenv
import warnings

# warnings.filterwarnings("ignore", module="IPython")

#  Establish project root directory
def find_project_root(start: Path):
    for parent in [start] + list(start.parents):
        if (parent / "pyproject.toml").exists():
            return parent
    return start

root_dir = find_project_root(Path().absolute())
print("Project root dir:", root_dir)

if str(root_dir) not in sys.path:
    sys.path.append(str(root_dir))

# Third-party imports
import requests
import pandas as pd
import numpy as np
import great_expectations as gx
import hopsworks
from urllib3.exceptions import ProtocolError
from requests.exceptions import ConnectionError, Timeout, RequestException
from confluent_kafka import KafkaException
from hsfs.client.exceptions import RestAPIError
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from xgboost import XGBRegressor
from xgboost import plot_importance
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
import joblib
from scipy.spatial.distance import cdist

#  Project imports
from utils import cleaning, config, feature_engineering, fetchers, hopsworks_admin, incremental, metadata, visualization

today = datetime.today().date()

### 4.1.2. Initialize Hopsworks Connection

In [None]:
def detect_environment():
    if (
        "HOPSWORKS_JOB_ID" in os.environ
        or "HOPSWORKS_PROJECT_ID" in os.environ
        or "HOPSWORKS_JOB_NAME" in os.environ
    ):
        return "job"

    cwd = os.getcwd()
    if cwd.startswith("/hopsfs/Jupyter"):
        return "jupyter"

    return "local"

env = detect_environment()
print(f"Detected environment: {env}")

# Load secrets based on environment
if env in ("job", "jupyter"):
    project = hopsworks.login()
    secrets_api = hopsworks.get_secrets_api()

    for key in ["HOPSWORKS_API_KEY", "AQICN_API_KEY", "GH_PAT", "GH_USERNAME"]:
        os.environ[key] = secrets_api.get_secret(key).value

else:
    load_dotenv()

# Load Pydantic settings
settings = config.HopsworksSettings()

HOPSWORKS_API_KEY = settings.HOPSWORKS_API_KEY.get_secret_value()
AQICN_API_KEY = settings.AQICN_API_KEY.get_secret_value()
GITHUB_USERNAME = settings.GH_USERNAME.get_secret_value()

# Login to Hopsworks using the API key
project = hopsworks.login(api_key_value=HOPSWORKS_API_KEY)
fs = project.get_feature_store()

print("Environment initialized and Hopsworks connected!")


### 4.1.3. Repository management

In [None]:
repo_dir = hopsworks_admin.clone_or_update_repo(GITHUB_USERNAME)
os.chdir(repo_dir)

### 4.1.4. Configure API Keys and Secrets

In [None]:
secrets = hopsworks.get_secrets_api()

try:
    secrets.get_secret("AQICN_API_KEY")
except:
    secrets.create_secret("AQICN_API_KEY", settings.AQICN_API_KEY.get_secret_value())

### 4.1.5. Get Model Registry

In [None]:
mr = project.get_model_registry()

## 4.2. Get Feature Groups and Sensor Locations

In [None]:
air_quality_fg, weather_fg = hopsworks_admin.create_feature_groups(fs)

aq_data = air_quality_fg.read()

if len(aq_data) == 0:
    print("‚ö†Ô∏è No air quality data found. Run pipeline 1 (backfill) first.")
    sys.exit(1)

sensor_locations = metadata.get_sensor_locations_dict(air_quality_fg)
print(f"üìç Loaded locations for {len(sensor_locations)} sensors")

## 4.3. Load Data from Feature Store

### 4.3.1. Set Inference Dates

In [None]:
past_date = today - timedelta(days=7)  # Get 7 days of historical data for feature engineering
future_date = today + timedelta(days=7)  # Get 7 days of future weather forecasts
today_short = today.strftime("%Y-%m-%d")

print(f"Inference period: {past_date} to {future_date}")
print(f"Today: {today_short}")

### 4.3.2. Load Weather Features

In [None]:
try:
    batch_weather = weather_fg.filter(
        (weather_fg.date >= past_date) & (weather_fg.date <= future_date)
    ).read()
except Exception:
    batch_weather = weather_fg.read()
    batch_weather = batch_weather[
        (batch_weather["date"] >= past_date) & (batch_weather["date"] <= future_date)
    ]

batch_weather["date"] = pd.to_datetime(batch_weather["date"]).dt.tz_localize(None)

print(f"Retrieved {len(batch_weather)} weather records from {past_date} to {future_date}")

### 4.3.3. Load Air Quality Features

In [None]:
try:
    batch_airquality = air_quality_fg.filter(
        air_quality_fg.date >= past_date
    ).read()
except Exception:
    # batch_airquality = pd.DataFrame()
    batch_airquality = air_quality_fg.read()
    batch_airquality = batch_airquality[
        batch_airquality["date"] >= past_date
    ]

batch_airquality["date"] = pd.to_datetime(batch_airquality["date"]).dt.tz_localize(None)

print(f"Retrieved {len(batch_airquality)} air quality records from {past_date} to {today}")

## 4.4. Model Retrieval
Download trained XGBoost models from Hopsworks model registry for each sensor and extract feature names.

In [None]:
MODEL_NAME_TEMPLATE = "air_quality_xgboost_model_{sensor_id}"

retrieved_models = {}

for sensor_id in sensor_locations.keys():
    model_name = MODEL_NAME_TEMPLATE.format(sensor_id=sensor_id)
    
    try:
        available_models = mr.get_models(name=model_name)
        if not available_models:
            print(f"‚ö†Ô∏è No model found for sensor {sensor_id}, skipping...")
            continue
        
        retrieved_model = max(available_models, key=lambda model: model.version)
        saved_model_dir = retrieved_model.download()
        
        xgb_model = XGBRegressor()
        xgb_model.load_model(saved_model_dir + "/model.json")
        booster = xgb_model.get_booster()
        
        retrieved_models[sensor_id] = retrieved_model, xgb_model, booster.feature_names
        
    except Exception as e:
        print(f"‚ùå Error loading model for sensor {sensor_id}: {e}")
        continue

print(f"‚úÖ Retrieved {len(retrieved_models)} models from registry")
print(f"   Total sensors in feature store: {len(sensor_locations)}")

In [None]:
print(f"Retrieved {len(retrieved_models)} models.")

## 4.5. Batch Prediction
Merge weather and air quality data, iteratively predict PM2.5 values for forecast days, update engineered features after each prediction, and store results

### 4.5.1. Batch Prediction Loop

In [None]:
PREDICTION_CAP_MAX = 150.0  # Maximum reasonable PM2.5 value
PREDICTION_CAP_MIN = 0.0    # Minimum reasonable PM2.5 value

In [None]:
# Merge historical data with weather data
batch_data = pd.merge(batch_weather, batch_airquality, on=["date", "sensor_id"], how="left")
batch_data = batch_data.sort_values(["sensor_id", "date"])

In [None]:
feature_cols = [
    "pm25_rolling_3d",
    "pm25_lag_1d",
    "pm25_lag_2d",
    "pm25_lag_3d",
    "pm25_nearby_avg",
]

In [None]:
# Create a new columns, fill with NaN for now
batch_data["predicted_pm25"] = np.nan
batch_data["days_before_forecast_day"] = np.nan
for col in feature_cols:
    batch_data[f"predicted_{col}"] = np.nan
    
# Select all rows where pm25 is NaN and date is today or later
# drop any NaN date values, sort the dates in ascending order, get unique dates
# forecast days will be a list of dates for which pm2.5 predictions are needed
forecast_days = (
    batch_data.loc[batch_data["pm25"].isna() & 
                   (batch_data["date"] >= today.strftime("%Y-%m-%d")), "date"]
    .dropna()
    .sort_values()
    .unique()
)

In [None]:
for target_day in forecast_days:
    # context with all sensors up to current day
    window = batch_data.loc[batch_data["date"] <= target_day].copy()
    day_rows = window[(window["date"] == target_day) & window["pm25"].isna()]

    for _, row in day_rows.iterrows():
        sensor_id = row["sensor_id"]
        try:
            _, xgb_model, model_features = retrieved_models[sensor_id]
        except KeyError:
            print(f"No model for sensor {sensor_id}, skipping prediction for {target_day}.")
            continue
        features = (row.reindex(model_features).to_frame().T.apply(pd.to_numeric, errors="coerce"))
        y_hat = xgb_model.predict(features)[0]

        idx = batch_data.index[(batch_data["sensor_id"] == sensor_id) & (batch_data["date"] == target_day)][0]
        batch_data.at[idx, "pm25"] = y_hat
        batch_data.at[idx, "predicted_pm25"] = y_hat
        batch_data.at[idx, "days_before_forecast_day"] = (target_day - pd.Timestamp(today)).days + 1

    # Recompute features for after filling this day
    temp_df = batch_data.loc[batch_data["date"] <= target_day].copy()
    temp_df = feature_engineering.add_rolling_window_feature(
        temp_df, window_days=3, column="pm25", new_column="pm25_rolling_3d"
    )
    temp_df = feature_engineering.add_lagged_features(temp_df, column="pm25", lags=[1, 2, 3])
    temp_df = feature_engineering.add_nearby_sensor_feature(
        temp_df,
        sensor_locations,
        column="pm25_lag_1d",
        n_closest=3,
        new_column="pm25_nearby_avg",
    )

    current_rows = temp_df[temp_df["date"] == target_day]
    for _, row in current_rows.iterrows():
        sensor_id = row["sensor_id"]
        mask = (batch_data["sensor_id"] == sensor_id) & (batch_data["date"] == target_day)
        if mask.any():
            for col in feature_cols:
                batch_data.loc[mask, f"predicted_{col}"] = row[col]

predictions = batch_data.loc[
    batch_data["predicted_pm25"].notna(),
    ["date", "sensor_id", "predicted_pm25", "days_before_forecast_day"]
    + [f"predicted_{col}" for col in feature_cols],
].reset_index(drop=True)
batch_data.loc[batch_data["date"] > pd.Timestamp(today), "pm25"] = np.nan

### 4.5.2. Assemble Prediction Results

In [None]:
predictions_df = predictions.copy()

print(f"‚úÖ Generated {len(predictions_df)} prediction rows")
print(f"   Date range: {predictions_df['date'].min()} to {predictions_df['date'].max()}")
print(f"   Sensors: {predictions_df['sensor_id'].nunique()}")
print(f"   Forecast days: {sorted(predictions_df['days_before_forecast_day'].unique())}")

## 4.6. Save Predictions

### 4.6.1. Save Predictions to Feature Store / Model Registry / Dataset

In [None]:


# print("üîç Diagnostic Info:")
# print(f"\nPredictions DataFrame columns: {predictions.columns.tolist()}")
# print(f"\nPredictions DataFrame shape: {predictions.shape}")
# print(f"\nSample data:")
# print(predictions.head())

# # Check if the feature group already exists and what schema it has
# try:
#     existing_fg = fs.get_feature_group("air_quality_predictions", version=1)
#     print(f"\nüìã Existing feature group schema:")
#     for feat in existing_fg.features:
#         print(f"  - {feat.name} ({feat.type})")
# except:
#     print("\n‚úÖ Feature group doesn't exist yet - will be created fresh")

In [None]:
## TEMPORARY fix for type mismatch issue


# Create or get predictions feature group (same as in training pipeline)
predictions_fg = fs.get_or_create_feature_group(
    name="aq_predictions",
    version=1,
    primary_key=["sensor_id", "date", "days_before_forecast_day"],
    description="Air Quality prediction monitoring",
    event_time="date"
)

# Insert predictions
print(f"üìä Inserting {len(predictions)} prediction rows to {predictions_fg.name}")
print(f"   Primary keys: {predictions_fg.primary_key}")
print(f"   Columns: {list(predictions.columns)}")
print(f"   Date range: {predictions['date'].min()} to {predictions['date'].max()}")

if len(predictions) > 0:

    ## DIFFERENCE IN TYPE BETWEEN FEATURE STORES???
    # Ensure sensor_id is int64 to match feature group schema (bigint)
    if env in ("job", "jupyter"):
        predictions["sensor_id"] = predictions["sensor_id"].astype("int64")
    else:
       predictions["sensor_id"] = predictions["sensor_id"].astype("int32")

    max_retries = 5
    delay = 2  # seconds

    for attempt in range(1, max_retries + 1):
        try:
            print(f"üü¶ Insert attempt {attempt}/{max_retries}...")
            predictions_fg.insert(predictions, write_options={"wait_for_job": False})
            print("‚úÖ Insert successful!")
            print(f"   Total predictions: {len(predictions)}")
            break

        except Exception as e:
            print(f"‚ö†Ô∏è Insert failed on attempt {attempt}: {e}")

            if attempt == max_retries:
                print("‚ùå Max retries reached. Insert failed permanently.")
                raise

            sleep_time = delay * (2 ** (attempt - 1))
            print(f"‚è≥ Retrying in {sleep_time} seconds...")
            time.sleep(sleep_time)



    # predictions_fg.insert(predictions, write_options={"wait_for_job": False})
    print("‚úÖ Insert successful!")
    print(f"   Total predictions: {len(predictions)}")
else:
    print("‚ö†Ô∏è No predictions to insert")


In [None]:
# # Create or get predictions feature group (same as in training pipeline)
# predictions_fg = fs.get_or_create_feature_group(
#     name="aq_predictions",
#     version=1,
#     primary_key=["sensor_id", "date", "days_before_forecast_day"],
#     description="Air Quality prediction monitoring",
#     event_time="date"
# )

# # Insert predictions
# print(f"üìä Inserting {len(predictions)} prediction rows to {predictions_fg.name}")
# print(f"   Primary keys: {predictions_fg.primary_key}")
# print(f"   Columns: {list(predictions.columns)}")
# print(f"   Date range: {predictions['date'].min()} to {predictions['date'].max()}")

# if len(predictions) > 0:
#     # Ensure sensor_id is int64 to match feature group schema (bigint)
#     predictions["sensor_id"] = predictions["sensor_id"].astype("int64")
    

#     ## DIFFERENCE IN TYPE BETWEEN FEATURE STORES???


#     predictions_fg.insert(predictions, write_options={"wait_for_job": False})
#     print("‚úÖ Insert successful!")
#     print(f"   Total predictions: {len(predictions)}")
# else:
#     print("‚ö†Ô∏è No predictions to insert")

## 4.7. Configuration

In [None]:
# Configure visualization for production vs development
SKIP_SENSOR_PLOTS = env == "job"  # Skip individual sensor plots when running as Hopsworks job

if SKIP_SENSOR_PLOTS:
    print("‚è≠Ô∏è Skipping individual sensor plots (running as Hopsworks job)")
    print("   Heatmap interpolations will still be generated for UI")
else:
    print("üìä Full visualization enabled (running locally/Jupyter)")


## 4.8. Analysis & Visualization

### 4.8.1. Generate Forecast Plots

In [None]:
# Individual sensor plots are skipped in jobs
if SKIP_SENSOR_PLOTS:
    print("‚è≠Ô∏è Skipping forecast plot generation (200+ files)")
else:
    dataset_api = project.get_dataset_api()
    forecast_paths = []

    for sensor_id, location in sensor_locations.items():
        sensor_forecast = predictions[predictions["sensor_id"] == sensor_id].copy()

        city, street = location["city"], location["street"]
        forecast_path = f"{root_dir}/models/{sensor_id}/images/forecast.png"
        Path(forecast_path).parent.mkdir(parents=True, exist_ok=True)

        fig = visualization.plot_air_quality_forecast(
            location["city"],
            location["street"],
            sensor_forecast,
            forecast_path,
            hindcast=False,
        )
        plt.close(fig)
        forecast_paths.append((sensor_id, forecast_path))

    if not dataset_api.exists("Resources/airquality"):
        dataset_api.mkdir("Resources/airquality")

    # Upload with retry logic and error handling
    upload_success = 0
    upload_failed = 0
    
    for i, (sensor_id, forecast_path) in enumerate(forecast_paths):
        max_retries = 3
        retry_delay = 2  # seconds
        
        for attempt in range(max_retries):
            try:
                dataset_api.upload(
                    forecast_path,
                    f"Resources/airquality/{sensor_id}_{today_short}_forecast.png",
                    overwrite=True,
                )
                upload_success += 1
                if (i + 1) % 20 == 0:  # Progress update every 20 uploads
                    print(f"   Uploaded {i + 1}/{len(forecast_paths)} plots...")
                break  # Success, exit retry loop
                
            except (ConnectionError, ProtocolError, Timeout, RequestException) as e:
                if attempt < max_retries - 1:
                    print(f"‚ö†Ô∏è Upload failed for sensor {sensor_id} (attempt {attempt + 1}/{max_retries}), retrying in {retry_delay}s...")
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    print(f"‚ùå Failed to upload for sensor {sensor_id} after {max_retries} attempts: {e}")
                    upload_failed += 1
            except Exception as e:
                print(f"‚ùå Unexpected error uploading for sensor {sensor_id}: {e}")
                upload_failed += 1
                break
        
        # Small delay between uploads to avoid overwhelming the connection
        if i < len(forecast_paths) - 1:
            time.sleep(0.1)
    
    print(f"‚úÖ Upload complete: {upload_success} successful, {upload_failed} failed")
    if upload_success > 0:
        print(f"   Forecast plots available in Hopsworks under {project.get_url()}/settings/fb/path/Resources/airquality")

### 4.8.2. Hindcast Analysis

In [None]:
# Individual sensor plots are skipped in jobs
if SKIP_SENSOR_PLOTS:
    print("‚è≠Ô∏è Skipping hindcast plot generation (200+ files)")
else:
    # Use predictions_fg (same variable name as in training pipeline)

    try:
        monitoring_df = predictions_fg.filter(predictions_fg.days_before_forecast_day == 1).read()
        monitoring_df["date"] = pd.to_datetime(monitoring_df["date"]).dt.tz_localize(None)
        print(f"‚úÖ Successfully read {len(monitoring_df)} hindcast predictions")
    except Exception as e:
        print(f"‚ö†Ô∏è Could not read monitoring data: {e}")
        print("Skipping hindcast analysis...")
        monitoring_df = pd.DataFrame()  # Empty dataframe to prevent further errors

    if not monitoring_df.empty:
        air_quality_df = air_quality_fg.read()[["date", "sensor_id", "pm25"]]
        air_quality_df["date"] = pd.to_datetime(air_quality_df["date"]).dt.tz_localize(None)

        for sensor_id, location in sensor_locations.items():
            try:
                sensor_preds = monitoring_df[monitoring_df["sensor_id"] == sensor_id][["date", "predicted_pm25"]]
                
                if sensor_preds.empty:
                    continue
                    
                merged = sensor_preds.merge(
                    air_quality_df[air_quality_df["sensor_id"] == sensor_id][["date", "pm25"]],
                    on="date",
                    how="inner",
                ).sort_values("date")

                city, street = location["city"], location["street"]
                hindcast_path = f"{root_dir}/models/{sensor_id}/images/hindcast_prediction.png"
                Path(hindcast_path).parent.mkdir(parents=True, exist_ok=True)

                fig = visualization.plot_air_quality_forecast(
                    city,
                    street,
                    merged if not merged.empty else sensor_preds.assign(pm25=np.nan),
                    hindcast_path,
                    hindcast=True,
                )
                if fig is not None:
                    plt.close(fig)

                dataset_api.upload(
                    hindcast_path,
                    f"Resources/airquality/{sensor_id}_{today:%Y-%m-%d}_hindcast.png",
                    overwrite=True,
                )

            except Exception as e:
                print(f"‚ö†Ô∏è Error processing hindcast for sensor {sensor_id}: {e}")


### 4.8.3. IDW Heatmap

#### 4.8.3.1. IDW Interpolation Function

In [None]:
def idw_interpolation(points, values, grid_points, lon_mesh, power=2):
    # compute distances between grid points and known data points 
    distances = cdist(grid_points, points)
    # replace 0 with a small value to avoid division by zero
    distances = np.where(distances == 0, 1e-10, distances)
    # compute weights based on inverse distance
    weights = 1.0 / (distances ** power)
    # sum of weights for normalization
    weights_sum = np.sum(weights, axis=1)
    # compute interpolated values - weighted average of known values for each grid point
    interpolated = np.sum(weights * values, axis=1) / weights_sum
    # reshape to the match grid shape
    return interpolated.reshape(lon_mesh.shape)

#### 4.8.3.2. Generate Heatmap Images 

In [None]:
print("üó∫Ô∏è Generating heatmap interpolation images (required for UI)")
grid_bounds = tuple(list(json.load(open(f"{root_dir}/frontend/coordinates.json")).values())[:4])
# grid_bounds = map_bounds[1], map_bounds[0], map_bounds[3], map_bounds[2]  # lat_min, lat_max, lon_min, lon_max
print(grid_bounds)


In [None]:
interpolation_dir = f"{root_dir}/models/interpolation"
os.makedirs(interpolation_dir, exist_ok=True)

interpolation_df = predictions.copy()


In [None]:
def plot_pm25_idw_heatmap(
    predictions: pd.DataFrame,
    sensor_locations: dict,
    forecast_date: datetime,
    path: str,
    grid_bounds=(11.4, 57.15, 12.5, 58.25),
    grid_resolution=800,
    power=2,
):

    df_day = predictions[predictions["date"] == forecast_date].copy()

    # Build sensor coordinates and PM2.5 values
    sensor_coords_list = []
    pm25_values_list = []
    
    pm25_column = "predicted_pm25" if not df_day["predicted_pm25"].isna().all() else "pm25"
    
    for sid in df_day["sensor_id"].unique():
        if sid in sensor_locations:
            sensor_coords_list.append([sensor_locations[sid]["longitude"], sensor_locations[sid]["latitude"]])
            pm25_val = df_day[df_day["sensor_id"] == sid][pm25_column].iloc[0]
            pm25_values_list.append(pm25_val)
    
    # Convert to numpy arrays
    sensor_coords = np.array(sensor_coords_list)
    pm25_values = np.array(pm25_values_list)
    
    # Safety check: need at least 1 sensor with data
    if len(sensor_coords) == 0 or len(pm25_values) == 0:
        print(f"‚ö†Ô∏è No sensor data available for {forecast_date}, skipping heatmap generation")
        return

    
    # Ensure sensor_coords is 2D (required by cdist)
    if sensor_coords.ndim == 1:
        sensor_coords = sensor_coords.reshape(1, -1)

    min_lon, min_lat, max_lon, max_lat = grid_bounds

    lon_grid = np.linspace(min_lon, max_lon, grid_resolution)
    lat_grid = np.linspace(min_lat, max_lat, grid_resolution)
    lon_mesh, lat_mesh = np.meshgrid(lon_grid, lat_grid)
    grid_points = np.column_stack([lon_mesh.ravel(), lat_mesh.ravel()])

    idw_result = idw_interpolation(sensor_coords, pm25_values, grid_points, lon_mesh, power=power)

    default_levels = np.array([0, 12, 35, 55, 150, 250, 500])
    category_colors = ["#00e400", "#7de400", "#ffff00", "#ffb000", "#ff7e00", "#ff4000", "#ff0000", "#c0007f", "#8f3f97", "#7e0023"]
    vmin, vmax = default_levels[0], 150
    
    clipped = np.clip(idw_result, vmin, vmax)
    fig, ax = plt.subplots(figsize=(10, 10))
    im = ax.imshow(
        clipped,
        extent=(min_lon, max_lon, min_lat, max_lat),
        origin="lower",
        cmap=mcolors.LinearSegmentedColormap.from_list("aqi", category_colors, N=512),
        vmin=vmin,
        vmax=vmax,
        alpha=0.5,
    )
    ax.set_xlim(min_lon, max_lon)
    ax.set_ylim(min_lat, max_lat)
    ax.axis("off")

    fig.savefig(path, dpi=300, bbox_inches="tight", pad_inches=0, transparent=True)
    plt.close(fig)


In [None]:
# Add any actual PM2.5 data from today if available
today_actual = batch_data[batch_data["date"] == today_short].copy()

if not today_actual.empty:
    # Ensure both columns exist for the plotting function
    today_actual = today_actual[[col for col in ["date", "sensor_id", "pm25", "predicted_pm25"] if col in today_actual.columns]]
    interpolation_df = pd.concat([today_actual, interpolation_df], ignore_index=True)

dataset_api = project.get_dataset_api()

for i, forecast_date in enumerate(sorted(interpolation_df["date"].unique())):
    forecast_date_short = forecast_date.strftime("%Y-%m-%d")
    days_ahead = (forecast_date - pd.Timestamp(today)).days
    output_png = f"{interpolation_dir}/forecast_interpolation_{days_ahead}d.png"

    plot_pm25_idw_heatmap(
        interpolation_df,
        sensor_locations,
        forecast_date,
        output_png,
    )

    dataset_api.upload(
        output_png,
        f"Resources/airquality/interpolation_{today_short}_{forecast_date_short}.png",
        overwrite=True,
    )

print(f"‚úÖ Generated {len(interpolation_df['date'].unique())} heatmap interpolation images")


## 4.9. Pipeline Completion

In [None]:
print("=" * 80)
print("‚úÖ BATCH INFERENCE PIPELINE COMPLETED SUCCESSFULLY")
print("=" * 80)
print(f"\nüìä Summary:")
print(f"   - Predictions generated: {len(predictions)}")
print(f"   - Sensors processed: {predictions['sensor_id'].nunique()}")
print(f"   - Forecast days: {sorted(predictions['days_before_forecast_day'].unique())}")
print(f"   - Date range: {predictions['date'].min()} to {predictions['date'].max()}")
print(f"\nüíæ Data saved to:")
print(f"   - Feature Group: {predictions_fg.name} (version {predictions_fg.version})")
if not SKIP_SENSOR_PLOTS:
    print(f"   - Sensor plots uploaded to Hopsworks Resources/airquality/")
print(f"   - Heatmap images uploaded to Hopsworks Resources/airquality/")
print("\n" + "=" * 80)