# MLflow Multi-Label Classification Validation Report

This notebook automatically
1. connects to MLflow;
2. grabs the **latest** run in the experiment *MultilabelPhotoTagPipeline* (fallback = *Default* experiment);
3. downloads metrics & artifacts; and
4. visualises them 👉 precision / recall / F1 per label, confusion matrix, ROC curve (micro-average), loss curves.

> **Tip for CI:** add a pipeline step `jupyter nbconvert --execute validation_report.ipynb` so the rendered notebook is saved as an artifact automatically.

In [None]:
import mlflow
from mlflow.tracking import MlflowClient
from mlflow.entities import ViewType
import pandas as pd, numpy as np, json, os, matplotlib.pyplot as plt, seaborn as sns

EXPERIMENT_NAME = "MultilabelPhotoTagPipeline"  # edit if needed
client = MlflowClient()
exp = client.get_experiment_by_name(EXPERIMENT_NAME) or client.get_experiment_by_name("Default")
runs = client.search_runs([exp.experiment_id], run_view_type=ViewType.ACTIVE_ONLY,
                          order_by=["attributes.start_time DESC"], max_results=1)
if not runs:
    raise RuntimeError(f"No runs in experiment {exp.name}.")
run = runs[0]
run_id = run.info.run_id
print("Using run_id:", run_id)

In [None]:
# ---------- 1  Classification metrics ----------
report_path = client.download_artifacts(run_id, "classification_report.json", ".")
report = json.load(open(report_path))
metrics_df = pd.DataFrame(report).T
metrics_df.index.name = "Label"
display(metrics_df)

In [None]:
# ---------- 2  Confusion matrix heat-map ----------
try:
    cm_path = client.download_artifacts(run_id, "confusion_matrix.png", ".")
    img = plt.imread(cm_path)
    plt.figure(figsize=(6,6))
    plt.imshow(img)
    plt.axis("off")
    plt.title("Confusion Matrix")
    plt.show()
except Exception as e:
    print("Confusion matrix not logged:", e)

In [None]:
# ---------- 3  ROC curve (micro-average) ----------
try:
    roc_path = client.download_artifacts(run_id, "roc_curve.json", ".")
    roc_data = json.load(open(roc_path))
    fpr, tpr = np.array(roc_data["fpr"]), np.array(roc_data["tpr"])
    auc_score = np.trapz(tpr, fpr)
    plt.figure(figsize=(6,5))
    plt.plot(fpr, tpr, label=f"Micro ROC (AUC={auc_score:.2f})")
    plt.plot([0,1],[0,1],'k--', label="Chance")
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title("ROC Curve (micro-average)")
    plt.legend(); plt.grid(True); plt.show()
except Exception as e:
    print("ROC curve not logged:", e)

In [None]:
# ---------- 4  Loss curves ----------
try:
    hist_path = client.download_artifacts(run_id, "history.json", ".")
    history = json.load(open(hist_path))
    epochs = range(1, len(history["loss"]) + 1)
    plt.figure(figsize=(6,4))
    plt.plot(epochs, history["loss"], label="Train Loss")
    if "val_loss" in history:
        plt.plot(epochs, history["val_loss"], label="Val Loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title("Training vs Validation Loss")
    plt.legend(); plt.grid(True); plt.show()
except Exception as e:
    print("History not logged:", e)

In [None]:
# ---------- 5  Micro / Macro / Weighted averages ----------
for avg in ("micro avg", "macro avg", "weighted avg"):
    if avg in metrics_df.index:
        display(metrics_df.loc[[avg]])