# ALS Grid Search Visualization

This notebook loads the latest `als_grid_*.json` from `artifacts/metrics/` and visualizes RMSE and Precision@10 across the grid.



In [None]:
from pathlib import Path
import json
import re

import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt

metrics_dir = Path("../artifacts/metrics")

# Find latest grid json
candidates = sorted(metrics_dir.glob("als_grid_*.json"))
if not candidates:
    raise FileNotFoundError("No als_grid_*.json found in ../artifacts/metrics. Run train_als first.")

latest = candidates[-1]
with open(latest) as f:
    grid = json.load(f)

grid_results = grid["grid_results"]
rows = []
for entry in grid_results:
    params = entry["params"]
    mets = entry["metrics"]
    rows.append({
        "rank": params.get("rank"),
        "regParam": params.get("regParam"),
        "rmse": mets.get("rmse"),
        "precision_at_10": mets.get("precision@k"),
        "ndcg_at_10": mets.get("ndcg@k"),
    })

df = pd.DataFrame(rows)
print(f"Loaded {len(df)} candidates from {latest.name}")
df


In [None]:
sns.set_theme(style="whitegrid")

# Heatmap: RMSE by (rank, regParam)
piv_rmse = df.pivot_table(index="rank", columns="regParam", values="rmse")
sns.heatmap(piv_rmse, annot=True, fmt=".3f", cmap="viridis")
plt.title("RMSE across grid")
plt.ylabel("rank")
plt.xlabel("regParam")
plt.show()

# Heatmap: Precision@10 by (rank, regParam)
piv_prec = df.pivot_table(index="rank", columns="regParam", values="precision_at_10")
sns.heatmap(piv_prec, annot=True, fmt=".3f", cmap="magma")
plt.title("Precision@10 across grid")
plt.ylabel("rank")
plt.xlabel("regParam")
plt.show()



In [None]:
# Scatter: RMSE vs Precision@10
ax = sns.scatterplot(data=df, x="rmse", y="precision_at_10", hue="rank", style="regParam", s=120)
plt.title("ALS grid: RMSE vs Precision@10")
plt.show()

# Show best from JSON
best = grid.get("best", {})
best
