# PM2.5 Forecast – Prediction Analysis

Notebook dùng để đọc file CSV kết quả dự báo, trực quan hóa sai số giữa `label` (thực tế) và `prediction`, đồng thời đối chiếu lại với dữ liệu gốc trên Iceberg.

In [None]:
from pathlib import Path

import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

plt.style.use("seaborn-v0_8")
%matplotlib inline


In [None]:
PREDICTIONS_CSV = Path("/home/dlhnhom2/dlh-aqi/data/ml_outputs/pm25_predictions/full_predictions.csv")
assert PREDICTIONS_CSV.exists(), f"Không tìm thấy file: {PREDICTIONS_CSV}"

pred_df = pd.read_csv(PREDICTIONS_CSV)
pred_df["ts_utc"] = pd.to_datetime(pred_df["ts_utc"])
pred_df.sort_values(["location_key", "ts_utc"], inplace=True)

pred_df.head()


## Tổng quan sai số

In [None]:
from sklearn.metrics import mean_absolute_error, mean_squared_error, r2_score

y_true = pred_df["label"].to_numpy()
y_pred = pred_df["prediction"].to_numpy()

metrics = {
    "MAE": mean_absolute_error(y_true, y_pred),
    "RMSE": mean_squared_error(y_true, y_pred, squared=False),
    "R2": r2_score(y_true, y_pred),
    "MAPE (%)": np.mean(np.abs((y_true - y_pred) / np.clip(np.abs(y_true), 1e-6, None))) * 100.0,
}
metrics


In [None]:
fig, ax = plt.subplots(figsize=(6, 6))
sns.scatterplot(data=pred_df, x="label", y="prediction", hue="location_key", s=20, ax=ax, legend=False)
lims = [min(ax.get_xlim()[0], ax.get_ylim()[0]), max(ax.get_xlim()[1], ax.get_ylim()[1])]
ax.plot(lims, lims, "k--", linewidth=1)
ax.set_title("Actual vs Prediction")
ax.set_xlabel("Actual PM2.5")
ax.set_ylabel("Predicted PM2.5")
plt.show()


In [None]:
pred_df["residual"] = pred_df["prediction"] - pred_df["label"]
fig, ax = plt.subplots(figsize=(8, 4))
sns.histplot(pred_df["residual"], bins=40, kde=True, ax=ax)
ax.set_title("Residual Distribution (Prediction - Actual)")
ax.set_xlabel("Residual")
plt.show()


## Đối chiếu thêm với dữ liệu Iceberg

Tạo SparkSession, đọc bảng silver trên Iceberg để kiểm tra chéo dữ liệu/dặc trưng.

In [None]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import col, to_timestamp

spark = (
    SparkSession.builder
    .appName("pm25_predictions_analysis")
    .config("spark.sql.extensions", "org.apache.iceberg.spark.extensions.IcebergSparkSessionExtensions")
    .config("spark.sql.catalog.hadoop_catalog", "org.apache.iceberg.spark.SparkCatalog")
    .config("spark.sql.catalog.hadoop_catalog.type", "hadoop")
    .config("spark.sql.catalog.hadoop_catalog.warehouse", "hdfs://khoa-master:9000/lakehouse/iceberg-warehouse")
    .getOrCreate()
)
spark


In [None]:
pred_spark_df = spark.createDataFrame(
    pred_df[["location_key", "ts_utc", "label", "prediction", "residual"]]
)
silver_df = spark.table("hadoop_catalog.lh.silver.air_quality_hourly_clean")

joined_df = (
    silver_df
    .join(pred_spark_df.select("location_key", "ts_utc", "prediction", "residual"), ["location_key", "ts_utc"], "inner")
    .select("location_key", "ts_utc", "pm25", "prediction", "residual", "pm10", "no2", "o3")
)

joined_df.limit(5).toPandas()


### Diễn biến theo thời gian cho từng location

In [None]:
sample_location = pred_df["location_key"].iloc[0]
sample_pdf = pred_df[pred_df["location_key"] == sample_location].sort_values("ts_utc")

fig, ax = plt.subplots(figsize=(12, 4))
ax.plot(sample_pdf["ts_utc"], sample_pdf["label"], label="Actual", marker="o", linestyle="-")
ax.plot(sample_pdf["ts_utc"], sample_pdf["prediction"], label="Prediction", marker="o", linestyle="--")
ax.set_title(f"PM2.5 – Actual vs Prediction ({sample_location})")
ax.set_xlabel("UTC time")
ax.set_ylabel("PM2.5")
ax.legend()
plt.xticks(rotation=45)
plt.tight_layout()
plt.show()


### Sai số theo từng khu vực

In [None]:
location_summary = (
    pred_df.groupby("location_key")
    .agg(
        count=("prediction", "size"),
        mae=("residual", lambda x: np.mean(np.abs(x))),
        rmse=("residual", lambda x: np.sqrt(np.mean(np.square(x)))),
    )
    .reset_index()
)
location_summary


In [None]:
sns.barplot(data=location_summary, x="location_key", y="mae")
plt.title("MAE theo location")
plt.ylabel("MAE")
plt.show()


In [None]:
spark.stop()
