# 4. Batch Inference Pipeline

## 4.1. Setup

### 4.1.1. Import Libraries

In [None]:
# Standard imports
import os
from pathlib import Path
import sys
import json
import time
from datetime import date, datetime, timedelta
from dotenv import load_dotenv

#  Establish project root directory
def find_project_root(start: Path):
    for parent in [start] + list(start.parents):
        if (parent / "pyproject.toml").exists():
            return parent
    return start

root_dir = find_project_root(Path().absolute())
print("Project root dir:", root_dir)

if str(root_dir) not in sys.path:
    sys.path.append(str(root_dir))

# Third-party imports
import requests
import pandas as pd
import numpy as np
import great_expectations as gx
import hopsworks
from urllib3.exceptions import ProtocolError
from requests.exceptions import ConnectionError, Timeout, RequestException
from confluent_kafka import KafkaException
from hsfs.client.exceptions import RestAPIError
from collections import defaultdict
import matplotlib.pyplot as plt
import matplotlib.colors as mcolors
from xgboost import XGBRegressor
from xgboost import plot_importance
from sklearn.ensemble import RandomForestRegressor
from sklearn.metrics import mean_squared_error, r2_score
import joblib
from scipy.spatial.distance import cdist
import traceback
import subprocess
import glob
import shutil

#  Project imports
from utils import cleaning, config, feature_engineering, fetchers, hopsworks_admin, incremental, metadata, visualization

today = datetime.today().date()

### 4.1.2. Initialize Hopsworks Connection

In [None]:
def detect_environment():
    if (
        "HOPSWORKS_JOB_ID" in os.environ
        or "HOPSWORKS_PROJECT_ID" in os.environ
        or "HOPSWORKS_JOB_NAME" in os.environ
    ):
        return "job"

    cwd = os.getcwd()
    if cwd.startswith("/hopsfs/Jupyter"):
        return "jupyter"

    return "local"

env = detect_environment()
print(f"Detected environment: {env}")

# Load secrets based on environment
if env in ("job", "jupyter"):
    project = hopsworks.login()
    secrets_api = hopsworks.get_secrets_api()

    for key in ["HOPSWORKS_API_KEY", "AQICN_API_KEY", "GH_PAT", "GH_USERNAME"]:
        os.environ[key] = secrets_api.get_secret(key).value

else:
    load_dotenv()

# Load Pydantic settings
settings = config.HopsworksSettings()

HOPSWORKS_API_KEY = settings.HOPSWORKS_API_KEY.get_secret_value()
AQICN_API_KEY = settings.AQICN_API_KEY.get_secret_value()
GITHUB_USERNAME = settings.GH_USERNAME.get_secret_value()

# Login to Hopsworks using the API key
project = hopsworks.login(api_key_value=HOPSWORKS_API_KEY)
fs = project.get_feature_store()

print("Environment initialized and Hopsworks connected!")


### 4.1.3. Repository management

In [None]:
repo_dir = hopsworks_admin.clone_or_update_repo(GITHUB_USERNAME)
os.chdir(repo_dir)

### 4.1.4. Configure API Keys and Secrets

In [None]:
secrets = hopsworks.get_secrets_api()

try:
    secrets.get_secret("AQICN_API_KEY")
except:
    secrets.create_secret("AQICN_API_KEY", settings.AQICN_API_KEY.get_secret_value())

### 4.1.5. Get Model Registry

In [None]:
mr = project.get_model_registry()

## 4.2. Get Feature Groups and Sensor Locations

In [None]:
air_quality_fg, weather_fg = hopsworks_admin.create_feature_groups(fs)

aq_data = air_quality_fg.read()

if len(aq_data) == 0:
    print("‚ö†Ô∏è No air quality data found. Run pipeline 1 (backfill) first.")
    sys.exit(1)

sensor_locations = metadata.get_sensor_locations_dict(air_quality_fg)
print(f"üìç Loaded locations for {len(sensor_locations)} sensors")

## 4.3. Load Data from Feature Store

### 4.3.1. Set Inference Dates

In [None]:
past_date = today - timedelta(days=7)  # Get 7 days of historical data for feature engineering
future_date = today + timedelta(days=7)  # Get 7 days of future weather forecasts
today_short = today.strftime("%Y-%m-%d")

print(f"Inference period: {past_date} to {future_date}")
print(f"Today: {today_short}")

### 4.3.2. Load Weather Features

In [None]:
max_retries = 3
for attempt in range(max_retries):
    try:
        batch_weather = weather_fg.filter(
            (weather_fg.date >= past_date) & (weather_fg.date <= future_date)
        ).read()
        break
    except Exception as e:
        if attempt < max_retries - 1:
            print(f"‚ö†Ô∏è Weather read attempt {attempt + 1} failed, retrying...")
            time.sleep(2 ** attempt)
        else:
            print(f"‚ùå Failed to read weather data after {max_retries} attempts")
            batch_weather = weather_fg.read()
            batch_weather = batch_weather[
                (batch_weather["date"] >= past_date) & (batch_weather["date"] <= future_date)
            ]

batch_weather["date"] = pd.to_datetime(batch_weather["date"]).dt.tz_localize(None)

print(f"Retrieved {len(batch_weather)} weather records from {past_date} to {future_date}")

### 4.3.3. Load Air Quality Features

In [None]:
max_retries = 3
for attempt in range(max_retries):
    try:
        batch_airquality = air_quality_fg.filter(
            air_quality_fg.date >= past_date
        ).read()
        break
    except Exception as e:
        if attempt < max_retries - 1:
            print(f"‚ö†Ô∏è Air quality read attempt {attempt + 1} failed, retrying...")
            time.sleep(2 ** attempt)
        else:
            print(f"‚ùå Failed to read air quality data after {max_retries} attempts")
            batch_airquality = air_quality_fg.read()
            batch_airquality = batch_airquality[
                batch_airquality["date"] >= past_date
            ]

batch_airquality["date"] = pd.to_datetime(batch_airquality["date"]).dt.tz_localize(None)

print(f"Retrieved {len(batch_airquality)} air quality records from {past_date} to {today}")

## 4.4. Model Retrieval
Download trained XGBoost models from Hopsworks model registry for each sensor and extract feature names.

In [None]:
MODEL_NAME_TEMPLATE = "air_quality_xgboost_model_{sensor_id}"

retrieved_models = {}

for sensor_id in sensor_locations.keys():
    model_name = MODEL_NAME_TEMPLATE.format(sensor_id=sensor_id)
    
    try:
        available_models = mr.get_models(name=model_name)
        if not available_models:
            print(f"‚ö†Ô∏è No model found for sensor {sensor_id}, skipping...")
            continue
        
        retrieved_model = max(available_models, key=lambda model: model.version)
        saved_model_dir = retrieved_model.download()
        
        xgb_model = XGBRegressor()
        xgb_model.load_model(saved_model_dir + "/model.json")
        booster = xgb_model.get_booster()
        
        retrieved_models[sensor_id] = retrieved_model, xgb_model, booster.feature_names
        
    except Exception as e:
        print(f"‚ùå Error loading model for sensor {sensor_id}: {e}")
        continue

print(f"‚úÖ Retrieved {len(retrieved_models)} models from registry")
print(f"   Total sensors in feature store: {len(sensor_locations)}")

In [None]:
print(f"Retrieved {len(retrieved_models)} models.")

## 4.5. Batch Prediction
Merge weather and air quality data, iteratively predict PM2.5 values for forecast days, update engineered features after each prediction, and store results

### 4.5.1. Batch Prediction Loop

In [None]:
# Merge historical data with weather data
batch_data = pd.merge(batch_weather, batch_airquality, on=["date", "sensor_id"], how="left")
batch_data = batch_data.sort_values(["sensor_id", "date"])

In [None]:
feature_cols = [
    "pm25_rolling_3d",
    "pm25_lag_1d",
    "pm25_lag_2d",
    "pm25_lag_3d",
    "pm25_nearby_avg",
]

In [None]:
# Create a new columns, fill with NaN for now
batch_data["predicted_pm25"] = np.nan
batch_data["days_before_forecast_day"] = np.nan
for col in feature_cols:
    batch_data[f"predicted_{col}"] = np.nan
    

forecast_days = [pd.Timestamp(today) + pd.Timedelta(days=i) for i in range(7)]

# # Select all rows where pm25 is NaN and date is today or later
# # drop any NaN date values, sort the dates in ascending order, get unique dates
# # forecast days will be a list of dates for which pm2.5 predictions are needed
# forecast_days = (
#     batch_data.loc[batch_data["pm25"].isna() & 
#                    (batch_data["date"] >= today.strftime("%Y-%m-%d")), "date"]
#     .dropna()
#     .sort_values()
#     .unique()
# )

# Ensure today is always included for UI display, even if we have some actual data
if today.strftime("%Y-%m-%d") not in forecast_days:
    forecast_days = np.append([pd.Timestamp(today)], forecast_days)
    print(f"‚ÑπÔ∏è  Added today ({today}) to forecast_days for UI completeness")

In [None]:
for target_day in forecast_days:
    # context with all sensors up to current day
    window = batch_data.loc[batch_data["date"] <= target_day].copy()
    day_rows = window[window["date"] == target_day]

    for _, row in day_rows.iterrows():
        sensor_id = row["sensor_id"]
        try:
            _, xgb_model, model_features = retrieved_models[sensor_id]
        except KeyError:
            print(f"No model for sensor {sensor_id}, skipping prediction for {target_day}.")
            continue
        features = (row.reindex(model_features).to_frame().T.apply(pd.to_numeric, errors="coerce"))
        y_hat = xgb_model.predict(features)[0]

        idx = batch_data.index[(batch_data["sensor_id"] == sensor_id) & (batch_data["date"] == target_day)][0]
       
        if pd.isna(row["pm25"]):
            batch_data.at[idx, "pm25"] = y_hat
        batch_data.at[idx, "predicted_pm25"] = y_hat
        batch_data.at[idx, "days_before_forecast_day"] = (target_day - pd.Timestamp(today)).days

    # Recompute features for after filling this day
    temp_df = batch_data.loc[batch_data["date"] <= target_day].copy()
    temp_df = feature_engineering.add_rolling_window_feature(
        temp_df, window_days=3, column="pm25", new_column="pm25_rolling_3d"
    )
    temp_df = feature_engineering.add_lagged_features(temp_df, column="pm25", lags=[1, 2, 3])
    temp_df = feature_engineering.add_nearby_sensor_feature(
        temp_df,
        sensor_locations,
        column="pm25_lag_1d",
        n_closest=3,
        new_column="pm25_nearby_avg",
    )

    current_rows = temp_df[temp_df["date"] == target_day]
    for _, row in current_rows.iterrows():
        sensor_id = row["sensor_id"]
        mask = (batch_data["sensor_id"] == sensor_id) & (batch_data["date"] == target_day)
        if mask.any():
            for col in feature_cols:
                batch_data.loc[mask, f"predicted_{col}"] = row[col]

predictions = batch_data.loc[
    batch_data["predicted_pm25"].notna(),
    ["date", "sensor_id", "predicted_pm25", "days_before_forecast_day"]
    + [f"predicted_{col}" for col in feature_cols],
].reset_index(drop=True)


### 4.5.2. Assemble Prediction Results

In [None]:
predictions_df = predictions.copy()

print(f"‚úÖ Generated {len(predictions_df)} prediction rows")
print(f"   Date range: {predictions_df['date'].min()} to {predictions_df['date'].max()}")
print(f"   Sensors: {predictions_df['sensor_id'].nunique()}")
print(f"   Forecast days: {sorted(predictions_df['days_before_forecast_day'].unique())}")

### 4.5.3. Save Predictions

In [None]:
# Create or get predictions feature group (same as in training pipeline)
predictions_fg = fs.get_or_create_feature_group(
    name="aq_predictions",
    version=1,
    primary_key=["sensor_id", "date", "days_before_forecast_day"],
    description="Air Quality prediction monitoring",
    event_time="date"
)

# Insert predictions
if len(predictions) > 0:

    # Difference in types between feature stores
    if env in ("job", "jupyter"):
        predictions["sensor_id"] = predictions["sensor_id"].astype("int64")
    else:
       predictions["sensor_id"] = predictions["sensor_id"].astype("int32")

    max_retries = 5
    delay = 2  # seconds

    for attempt in range(1, max_retries + 1):
        try:
            predictions_fg.insert(predictions, write_options={"wait_for_job": False})
            print(f"‚úÖ Insert successful on attempt {attempt}")
            break

        except Exception as e:
            print(f"‚ö†Ô∏è Insert failed on attempt {attempt}: {type(e).__name__}: {e}")

            if attempt == max_retries:
                print("‚ùå Max retries reached. Insert failed permanently.")
                raise

            sleep_time = delay * (2 ** (attempt - 1))
            print(f"‚è≥ Retrying in {sleep_time} seconds...")
            time.sleep(sleep_time)

    print(f"‚úÖ Inserted {len(predictions)} predictions to {predictions_fg.name}")
else:
    print("‚ö†Ô∏è No predictions to insert")

## 4.6. Analysis & Visualization

### 4.6.1. Configuration

In [None]:
# Configure visualization for production vs development
SKIP_SENSOR_PLOTS = env == "job"  # Skip individual sensor plots when running as Hopsworks job

if SKIP_SENSOR_PLOTS:
    print("‚è≠Ô∏è Skipping individual sensor plots (running as Hopsworks job)")
    print("   Heatmap interpolations will still be generated for UI")
else:
    print("üìä Full visualization enabled (running locally/Jupyter)")


### 4.6.2. Generate Forecast Plots

In [None]:
# Individual sensor plots are skipped in jobs
if SKIP_SENSOR_PLOTS:
    print("‚è≠Ô∏è Skipping forecast plot generation (200+ files)")
else:
    dataset_api = project.get_dataset_api()
    forecast_paths = []

    for sensor_id, location in sensor_locations.items():
        sensor_forecast = predictions[predictions["sensor_id"] == sensor_id].copy()

        city, street = location["city"], location["street"]
        forecast_path = f"{root_dir}/models/{sensor_id}/images/forecast.png"
        Path(forecast_path).parent.mkdir(parents=True, exist_ok=True)

        fig = visualization.plot_air_quality_forecast(
            location["city"],
            location["street"],
            sensor_forecast,
            forecast_path,
            hindcast=False,
        )
        plt.close(fig)
        forecast_paths.append((sensor_id, forecast_path))

    if not dataset_api.exists("Resources/airquality"):
        dataset_api.mkdir("Resources/airquality")

    # Upload with retry logic and error handling
    upload_success = 0
    upload_failed = 0
    
    for i, (sensor_id, forecast_path) in enumerate(forecast_paths):
        max_retries = 3
        retry_delay = 2  # seconds
        
        for attempt in range(max_retries):
            try:
                dataset_api.upload(
                    forecast_path,
                    f"Resources/airquality/{sensor_id}_{today_short}_forecast.png",
                    overwrite=True,
                )
                upload_success += 1
                if (i + 1) % 20 == 0:  # Progress update every 20 uploads
                    print(f"   Uploaded {i + 1}/{len(forecast_paths)} plots...")
                break  # Success, exit retry loop
                
            except (ConnectionError, ProtocolError, Timeout, RequestException) as e:
                if attempt < max_retries - 1:
                    print(f"‚ö†Ô∏è Upload failed for sensor {sensor_id} (attempt {attempt + 1}/{max_retries}), retrying in {retry_delay}s...")
                    time.sleep(retry_delay)
                    retry_delay *= 2  # Exponential backoff
                else:
                    print(f"‚ùå Failed to upload for sensor {sensor_id} after {max_retries} attempts: {e}")
                    upload_failed += 1
            except Exception as e:
                print(f"‚ùå Unexpected error uploading for sensor {sensor_id}: {e}")
                upload_failed += 1
                break
        
        # Small delay between uploads to avoid overwhelming the connection
        if i < len(forecast_paths) - 1:
            time.sleep(0.1)
    
    print(f"‚úÖ Upload complete: {upload_success} successful, {upload_failed} failed")
    if upload_success > 0:
        print(f"   Forecast plots available in Hopsworks under {project.get_url()}/settings/fb/path/Resources/airquality")

### 4.6.3. Hindcast Analysis

In [None]:
# Individual sensor plots are skipped in jobs
if SKIP_SENSOR_PLOTS:
    print("‚è≠Ô∏è Skipping hindcast plot generation (200+ files)")
else:
    # Use predictions_fg (same variable name as in training pipeline)

    try:
        monitoring_df = predictions_fg.filter(predictions_fg.days_before_forecast_day == 1).read()
        monitoring_df["date"] = pd.to_datetime(monitoring_df["date"]).dt.tz_localize(None)
        print(f"‚úÖ Successfully read {len(monitoring_df)} hindcast predictions")
    except Exception as e:
        print(f"‚ö†Ô∏è Could not read monitoring data: {e}")
        print("Skipping hindcast analysis...")
        monitoring_df = pd.DataFrame()  # Empty dataframe to prevent further errors

    if not monitoring_df.empty:
        air_quality_df = air_quality_fg.read()[["date", "sensor_id", "pm25"]]
        air_quality_df["date"] = pd.to_datetime(air_quality_df["date"]).dt.tz_localize(None)

        for sensor_id, location in sensor_locations.items():
            try:
                sensor_preds = monitoring_df[monitoring_df["sensor_id"] == sensor_id][["date", "predicted_pm25"]]
                
                if sensor_preds.empty:
                    continue
                    
                merged = sensor_preds.merge(
                    air_quality_df[air_quality_df["sensor_id"] == sensor_id][["date", "pm25"]],
                    on="date",
                    how="inner",
                ).sort_values("date")

                city, street = location["city"], location["street"]
                hindcast_path = f"{root_dir}/models/{sensor_id}/images/hindcast_prediction.png"
                Path(hindcast_path).parent.mkdir(parents=True, exist_ok=True)

                fig = visualization.plot_air_quality_forecast(
                    city,
                    street,
                    merged if not merged.empty else sensor_preds.assign(pm25=np.nan),
                    hindcast_path,
                    hindcast=True,
                )
                if fig is not None:
                    plt.close(fig)

                dataset_api.upload(
                    hindcast_path,
                    f"Resources/airquality/{sensor_id}_{today:%Y-%m-%d}_hindcast.png",
                    overwrite=True,
                )

            except Exception as e:
                print(f"‚ö†Ô∏è Error processing hindcast for sensor {sensor_id}: {e}")


### 4.6.4. IDW Heatmap

#### 4.6.4.2. Generate Heatmap Images 

In [None]:
print("\nüó∫Ô∏è Generating heatmap interpolation images (required for UI)")

try:
    with open(f"{root_dir}/frontend/coordinates.json") as f:
        coords_data = json.load(f)
        # CRITICAL: Skip REGION_NAME and get only the numeric coordinate values
        grid_bounds = tuple([
            float(coords_data["MIN_LONGITUDE"]),
            float(coords_data["MIN_LATITUDE"]),
            float(coords_data["MAX_LONGITUDE"]),
            float(coords_data["MAX_LATITUDE"])
        ])
    print(f"   Grid bounds from coordinates.json: {grid_bounds}")
    print(f"   (min_lon={grid_bounds[0]}, min_lat={grid_bounds[1]}, max_lon={grid_bounds[2]}, max_lat={grid_bounds[3]})")
except Exception as e:
    print(f"‚ö†Ô∏è Could not load coordinates.json: {e}")
    print("   Calculating bounds from sensor locations...")
    
    # Calculate bounds from sensor locations with padding
    lons = [loc['longitude'] for loc in sensor_locations.values()]
    lats = [loc['latitude'] for loc in sensor_locations.values()]
    
    min_lon, max_lon = min(lons), max(lons)
    min_lat, max_lat = min(lats), max(lats)
    
    # Add 10% padding
    lon_padding = (max_lon - min_lon) * 0.1
    lat_padding = (max_lat - min_lat) * 0.1
    
    grid_bounds = (
        float(min_lon - lon_padding),
        float(min_lat - lat_padding),
        float(max_lon + lon_padding),
        float(max_lat + lat_padding)
    )
    print(f"   Calculated grid bounds: {grid_bounds}")


In [None]:
interpolation_dir = f"{root_dir}/models/interpolation"
os.makedirs(interpolation_dir, exist_ok=True)

# Start with predictions and ensure numeric types immediately
interpolation_df = predictions.copy()
interpolation_df["predicted_pm25"] = pd.to_numeric(interpolation_df["predicted_pm25"], errors='coerce')


In [None]:
# For day 0 heatmap: Use actual sensor measurements if available, otherwise fall back to predictions
# For days 1-6 heatmaps: Always use predictions
interpolation_df = predictions.copy()
interpolation_df["predicted_pm25"] = pd.to_numeric(interpolation_df["predicted_pm25"], errors='coerce')

# Get today's actual measurements
today_actual = batch_data[batch_data["date"] == today_short].copy()

if not today_actual.empty:
    # Keep only actual measurements with valid pm25 values
    today_actual = today_actual[[col for col in ["date", "sensor_id", "pm25"] if col in today_actual.columns]]
    today_actual["pm25"] = pd.to_numeric(today_actual["pm25"], errors='coerce')
    
    # Remove any rows with NaN pm25 (we only want real measurements)
    today_actual = today_actual[today_actual["pm25"].notna()]
    
    if len(today_actual) > 0:
        # IMPORTANT: Remove any prediction rows for today (we want ONLY actual measurements for day 0)
        interpolation_df = interpolation_df[interpolation_df["date"] != pd.Timestamp(today)]
        
        # Add actual measurements for today
        interpolation_df = pd.concat([today_actual, interpolation_df], ignore_index=True)
        print(f"   Using {len(today_actual)} actual sensor measurements for day 0 heatmap")
    else:
        # No valid actual readings - keep predictions for today
        print(f"   ‚ö†Ô∏è No valid actual sensor data for today - using predictions for day 0 heatmap")
        # Rename predicted_pm25 to pm25 for today's data only
        today_preds = interpolation_df[interpolation_df["date"] == pd.Timestamp(today)].copy()
        if not today_preds.empty:
            today_preds["pm25"] = today_preds["predicted_pm25"]
            interpolation_df = interpolation_df[interpolation_df["date"] != pd.Timestamp(today)]
            interpolation_df = pd.concat([today_preds, interpolation_df], ignore_index=True)
else:
    # No data at all for today in batch_data - use predictions  
    print(f"   ‚ö†Ô∏è No batch data for today - using predictions for day 0 heatmap")
    today_preds = interpolation_df[interpolation_df["date"] == pd.Timestamp(today)].copy()
    if not today_preds.empty:
        today_preds["pm25"] = today_preds["predicted_pm25"]
        interpolation_df = interpolation_df[interpolation_df["date"] != pd.Timestamp(today)]
        interpolation_df = pd.concat([today_preds, interpolation_df], ignore_index=True)

# DIAGNOSTIC: Print dtypes to debug type conversion issues
print(f"\nüîç Interpolation DataFrame dtypes:")
print(f"   predicted_pm25: {interpolation_df['predicted_pm25'].dtype if 'predicted_pm25' in interpolation_df.columns else 'N/A'}")
if "pm25" in interpolation_df.columns:
    print(f"   pm25: {interpolation_df['pm25'].dtype}")
    print(f"   Sample pm25 values: {interpolation_df[interpolation_df['pm25'].notna()]['pm25'].head(3).tolist()}")

dataset_api = project.get_dataset_api()

unique_dates = sorted(interpolation_df["date"].unique())
print(f"\nüìÖ Generating {len(unique_dates)} heatmap images...")
print(f"   Today (reference date): {today}")
print(f"   Unique forecast dates: {[str(d)[:10] for d in unique_dates]}")

# Create frontend interpolation directory
frontend_interpolation_dir = f"{root_dir}/frontend/interpolation"
os.makedirs(frontend_interpolation_dir, exist_ok=True)

successful_images = 0
failed_images = 0

for i, forecast_date in enumerate(unique_dates):
    forecast_date_short = forecast_date.strftime("%Y-%m-%d")
    days_ahead = (forecast_date - pd.Timestamp(today)).days
    output_png = f"{interpolation_dir}/forecast_interpolation_{days_ahead}d.png"
    frontend_png = f"{frontend_interpolation_dir}/forecast_interpolation_{days_ahead}d.png"

    # Diagnostic logging for day calculation
    if days_ahead <= 1:
        print(f"\n   üîç Day {days_ahead} diagnostics:")
        print(f"      forecast_date: {forecast_date} (type: {type(forecast_date).__name__})")
        print(f"      pd.Timestamp(today): {pd.Timestamp(today)}")
        print(f"      days_ahead = ({forecast_date} - {pd.Timestamp(today)}).days = {days_ahead}")
        print(f"      Filename: {os.path.basename(frontend_png)}")

    print(f"   [{i+1}/{len(unique_dates)}] {forecast_date_short} (+{days_ahead}d)...", end=" ")
    
    try:
        visualization.plot_pm25_idw_heatmap(
            interpolation_df,
            sensor_locations,
            forecast_date,
            output_png,
            grid_bounds=grid_bounds,
            today=today,
        )
        
        # Copy to frontend directory for Netlify deployment
        shutil.copy2(output_png, frontend_png)
        
        if days_ahead == 0:
                    from IPython.display import Image, display
                    print("\nüñºÔ∏è DEBUG: Showing Day 0 heatmap image")
                    display(Image(filename=output_png))

        print("‚úÖ")


        # Verify the copy succeeded
        if not os.path.exists(frontend_png):
            raise FileNotFoundError(f"Failed to copy to {frontend_png}")
        
        print("‚úÖ")
        successful_images += 1
        
    except Exception as e:
        print(f"‚ùå {type(e).__name__}: {str(e)[:100]}")
        failed_images += 1

print(f"\nüìä Heatmap generation complete: {successful_images} successful, {failed_images} failed")

if failed_images > 0:
    print(f"‚ö†Ô∏è Warning: {failed_images} images failed to generate")
if successful_images == 0:
    raise Exception("‚ùå All heatmap images failed - pipeline cannot complete without UI data")

## 4.7. Export Predictions for Frontend

### 4.7.1. Prepare Predictions

In [None]:
print("\nüì¶ Preparing predictions for frontend export...")

# Debug: show available forecast days and prediction count before merging with metadata
print(f"\nüîç Debug - Available days_before_forecast_day values: {sorted(predictions['days_before_forecast_day'].unique())}")
print(f"   Total predictions: {len(predictions)}")

frontend_predictions = predictions.copy()


# Build metadata dataframe safely
sensor_metadata_rows = []
for sensor_id, location_data in sensor_locations.items():
    sensor_metadata_rows.append({
        'sensor_id': int(sensor_id),  # force int
        'latitude': float(location_data['latitude']),
        'longitude': float(location_data['longitude']),
        'city': location_data.get('city', ''),
        'street': location_data.get('street', '')
    })

sensor_metadata_df = pd.DataFrame(sensor_metadata_rows)

# Ensure consistent dtypes
frontend_predictions['sensor_id'] = frontend_predictions['sensor_id'].astype(int)
sensor_metadata_df['sensor_id'] = sensor_metadata_df['sensor_id'].astype(int)

print(f"   Metadata sensors: {sensor_metadata_df['sensor_id'].nunique()}")
print(f"   Prediction sensors: {frontend_predictions['sensor_id'].nunique()}")

# Merge predictions with metadata
merged = frontend_predictions.merge(sensor_metadata_df, on='sensor_id', how='left')

# Validate coordinates
missing_coords = merged[merged['latitude'].isna() | merged['longitude'].isna()]
if not missing_coords.empty:
    print("\n‚ùå ERROR: Missing coordinates for these sensor IDs:")
    print(missing_coords['sensor_id'].unique())
    raise ValueError("Some prediction rows have no matching sensor metadata. Fix required.")

frontend_predictions = merged
print(f"   ‚úÖ Coordinates added for {len(frontend_predictions)} rows")
print(frontend_predictions.head())

# Convert datetime to string for JSON
frontend_predictions["date"] = frontend_predictions["date"].astype(str)

### 4.7.2. Export Predictions JSON

In [None]:
# Export predictions as JSON for frontend
print("\nüì¶ Exporting predictions for frontend...")

# Export to frontend directory
predictions_json_path = f"{root_dir}/frontend/predictions.json"
frontend_predictions.to_json(predictions_json_path, orient="records", indent=2)

print(f"‚úÖ Exported {len(frontend_predictions)} predictions to frontend/predictions.json")
print(f"   Sensors: {frontend_predictions['sensor_id'].nunique()}")
if len(frontend_predictions) > 0:
    print(f"   Date range: {frontend_predictions['date'].min()} to {frontend_predictions['date'].max()}" if len(frontend_predictions) > 0 else "   No predictions")

### 4.7.3. Commit Frontend Artifacts to GitHub

In [None]:
print("\nüì§ Committing predictions and interpolation images to git...")

try:
    # List PNGs for debugging
    frontend_pngs = glob.glob(f"{root_dir}/frontend/interpolation/*.png")
    print(f"   Found {len(frontend_pngs)} PNGs in frontend/interpolation/")
    if frontend_pngs:
        print(f"   Example: {os.path.basename(frontend_pngs[0])}")

    # Configure git identity
    subprocess.run(["git", "config", "user.name", "Hopsworks Bot"], cwd=root_dir, check=True)
    subprocess.run(["git", "config", "user.email", "bot@hopsworks.ai"], cwd=root_dir, check=True)

    # Stage only the generated srtifacts
    print("   Staging frontend artifacts...")
    subprocess.run(["git", "add", "-f", "frontend/predictions.json"], cwd=root_dir, check=True)
    subprocess.run(["git", "add", "-f", "frontend/interpolation/"], cwd=root_dir, check=True)

    # Check if anything changed
    status = subprocess.run(
        ["git", "status", "--porcelain"],
        cwd=root_dir,
        capture_output=True,
        text=True
    )

    if not status.stdout.strip():
        print("‚ÑπÔ∏è  No changes to commit ‚Äî predictions and heatmaps already up to date")
    else:
        print("   Changes detected:")
        print(status.stdout)
        commit = subprocess.run(
            ["git", "commit", "-m", f"Update predictions and heatmaps - {today_short}"],
            cwd=root_dir,
            capture_output=True,
            text=True
        )

        if commit.returncode == 0:
            print("‚úÖ Commit successful")

            push = subprocess.run(
                ["git", "push"],
                cwd=root_dir,
                capture_output=True,
                text=True
            )

            if push.returncode == 0:
                print("‚úÖ Successfully pushed updates to GitHub")
                print("   Netlify will rebuild automatically")

            else:
                print("‚ùå Git push failed")
                print(push.stderr)
        else:
            print("‚ö†Ô∏è Commit failed")
            print(commit.stderr)

except Exception as e:
    print(f"‚ùå Git operation error: {e}")
    print("   Files exported locally but not pushed to git")

In [None]:


# # Commit and push predictions and interpolation images to git
# print("\nüì§ Committing predictions and interpolation images to git...")
# try:    
#     # Verify files exist before committing
#     frontend_pngs = glob.glob(f"{root_dir}/frontend/interpolation/*.png")
#     print(f"   Found {len(frontend_pngs)} PNG files in frontend/interpolation/")
#     if len(frontend_pngs) > 0:
#         print(f"   Example: {os.path.basename(frontend_pngs[0])}")
    
#     # Configure git
#     subprocess.run(["git", "config", "user.name", "Hopsworks Bot"], cwd=root_dir, check=True)
#     subprocess.run(["git", "config", "user.email", "bot@hopsworks.ai"], cwd=root_dir, check=True)
    
#     # Add frontend files
#     print("   Staging frontend files...")
#     subprocess.run(["git", "add", "-f", "frontend/predictions.json"], cwd=root_dir, check=True)
#     subprocess.run(["git", "add", "-f", "frontend/interpolation/"], cwd=root_dir, check=True)
#     print(f"   Staged predictions.json and {len(frontend_pngs)} interpolation images")
    
#     # Check if there are actually changes to commit
#     status_result = subprocess.run(
#         ["git", "status", "--porcelain"],
#         cwd=root_dir,
#         capture_output=True,
#         text=True
#     )
    
#     if not status_result.stdout.strip():
#         print("‚ÑπÔ∏è  No changes detected - predictions and heatmaps are already up to date")
#     else:
#         print(f"   Changes detected:\n{status_result.stdout}")
        
#         commit_result = subprocess.run(
#             ["git", "commit", "-m", f"Update predictions and heatmaps - {today_short}"],
#             cwd=root_dir,
#             capture_output=True,
#             text=True
#         )
        
#         if commit_result.returncode == 0:
#             print(f"‚úÖ Commit successful!")
            
#             # Push with error handling
#             push_result = subprocess.run(
#                 ["git", "push"],
#                 cwd=root_dir,
#                 capture_output=True,
#                 text=True
#             )
            
#             if push_result.returncode == 0:
#                 print("‚úÖ Pushed predictions + interpolation images to repository")
#                 print(f"   This will trigger a Netlify rebuild with updated data")
#             else:
#                 print(f"‚ùå Git push failed (exit code {push_result.returncode})")
#                 print(f"   stdout: {push_result.stdout}")
#                 print(f"   stderr: {push_result.stderr}")
#                 print("   Files are committed locally but not pushed to GitHub")
#         else:
#             print(f"‚ö†Ô∏è  Commit failed: {commit_result.stderr}")
            
# except Exception as e:
#     print(f"‚ùå Error in git operations: {e}")
#     print(traceback.format_exc())
#     print("   Files exported locally but not pushed to repository")


In [None]:

# print("   Ensuring we're on main/master branch...")

# # Always fetch latest refs
# subprocess.run(["git", "fetch", "origin"], cwd=root_dir, check=True)

# # Detect the default branch name (main or master)
# branch_result = subprocess.run(
#     ["git", "branch", "-r"],
#     cwd=root_dir,
#     capture_output=True,
#     text=True,
#     check=True
# )

# # Determine which default branch exists
# if "origin/main" in branch_result.stdout:
#     default_branch = "main"
# elif "origin/master" in branch_result.stdout:
#     default_branch = "master"
# else:
#     raise RuntimeError("Could not find origin/main or origin/master branch")

# print(f"   Detected default branch: {default_branch}")

# # Get current branch
# current_branch_result = subprocess.run(
#     ["git", "rev-parse", "--abbrev-ref", "HEAD"],
#     cwd=root_dir,
#     capture_output=True,
#     text=True,
#     check=True
# )
# current_branch = current_branch_result.stdout.strip()

# # If we're in detached HEAD or on wrong branch, checkout the default branch
# if current_branch != default_branch:
#     print(f"   Switching from {current_branch} to {default_branch}")
#     subprocess.run(["git", "checkout", default_branch], cwd=root_dir, check=True)

# # Reset to remote state (discarding local notebook changes)
# print(f"   Resetting to origin/{default_branch}")
# subprocess.run(
#     ["git", "reset", "--hard", f"origin/{default_branch}"],
#     cwd=root_dir,
#     check=True
# )


### 4.7.4. Inspection

### 4.10. Pipeline Completion

In [None]:
print("=" * 80)
print("‚úÖ BATCH INFERENCE PIPELINE COMPLETED SUCCESSFULLY")
print("=" * 80)
print(f"\nüìä Summary:")
print(f"   - Predictions generated: {len(predictions)}")
print(f"   - Heatmap images created: {len(interpolation_df['date'].unique())}")
print(f"   - Feature group: {predictions_fg.name} (v{predictions_fg.version})")
print("\n" + "=" * 80)