In [None]:
import os, glob, json
import numpy as np
import matplotlib.pyplot as plt

from sklearn.metrics import (
    ConfusionMatrixDisplay,
    roc_curve, auc,
    precision_recall_curve,
    average_precision_score,
    f1_score
)

def latest_file(pattern: str):
    files = glob.glob(pattern)
    if not files:
        return None
    return sorted(files, key=os.path.getmtime)[-1]

def show_with_explanation(title: str, explanation: str, save_path: str = None):
    plt.title(title)
    plt.tight_layout()
    if save_path:
        plt.savefig(save_path, dpi=200, bbox_inches="tight")
    plt.show()
    print(explanation)
    print("-" * 90)

def safe_bar_labels(ax):
    for p in ax.patches:
        h = p.get_height()
        ax.annotate(f"{h:.3f}", (p.get_x() + p.get_width()/2, h), ha="center", va="bottom")


RESULTS_JSON = latest_file("results_torch_*.json")
PREDS_NPZ    = latest_file("preds_torch_*.npz")

if RESULTS_JSON is None:
    raise FileNotFoundError("No results_torch_*.json found. Run the training cell first.")
if PREDS_NPZ is None:
    raise FileNotFoundError("No preds_torch_*.npz found. Run training with: python -m src.train_torch --save_pred")

with open(RESULTS_JSON, "r", encoding="utf-8") as f:
    R = json.load(f)

P = np.load(PREDS_NPZ)

y_test = P["y_test"].astype(int)
proba  = P["proba_test"].astype(float)
pred   = P["pred_test"].astype(int)

Y_test = P["Y_test"].astype(float)      # (n,3)
Y_pred = P["Y_pred_test"].astype(float) # (n,3)

cls = R["results"]["classification"]
reg = R["results"]["regression"]

print("Loaded:", RESULTS_JSON, "and", PREDS_NPZ)

# Output folder
OUT_DIR = "figures"
os.makedirs(OUT_DIR, exist_ok=True)


plt.figure()
plt.plot(cls["train_history"]["train_loss"], label="Train loss")
plt.plot(cls["train_history"]["val_loss"],   label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()

show_with_explanation(
    title="Classification — Learning Curves (Train vs Validation Loss)",
    explanation=(
        "Purpose: Diagnose optimization and overfitting/underfitting for the classifier.\n"
        "How to read: If train loss keeps decreasing while validation loss stops improving or increases, "
        "the model is overfitting. If both plateau early at high values, the model may be underfitting.\n"
        "Idea: Early stopping typically triggers when validation stops improving, preventing overfit."
    ),
    save_path=os.path.join(OUT_DIR, "01_classification_learning_curves.png")
)


plt.figure()
plt.plot(reg["train_history"]["train_loss"], label="Train loss")
plt.plot(reg["train_history"]["val_loss"],   label="Validation loss")
plt.xlabel("Epoch")
plt.ylabel("Loss (MSE on standardized targets)")
plt.legend()

show_with_explanation(
    title="Regression — Learning Curves (Train vs Validation Loss)",
    explanation=(
        "Purpose: Diagnose optimization and generalization for the multi-output regressor.\n"
        "How to read: Loss here is MSE computed on standardized targets, so it is comparable across outputs.\n"
        "Idea: A stable gap (train < val) is normal. A growing gap indicates overfitting; "
        "both flat high indicates underfitting or insufficient features."
    ),
    save_path=os.path.join(OUT_DIR, "02_regression_learning_curves.png")
)

cm = np.array(cls["mlp"]["confusion_matrix"])
fig, ax = plt.subplots()
ConfusionMatrixDisplay(cm, display_labels=["Non-viral (0)", "Viral (1)"]).plot(ax=ax, values_format="d")

show_with_explanation(
    title="Classification — Confusion Matrix (Test Set)",
    explanation=(
        "Purpose: Show the types of classification errors.\n"
        "How to read: Rows = true class, Columns = predicted class.\n"
        " - Top-left (TN): correctly predicted non-viral.\n"
        " - Bottom-right (TP): correctly predicted viral.\n"
        " - Top-right (FP): predicted viral but actually non-viral.\n"
        " - Bottom-left (FN): predicted non-viral but actually viral.\n"
        "Idea: For imbalanced data, this is more informative than accuracy alone."
    ),
    save_path=os.path.join(OUT_DIR, "03_confusion_matrix.png")
)

plt.figure()
plt.hist(proba, bins=30)
plt.xlabel("Predicted probability of 'viral'")
plt.ylabel("Number of samples")

show_with_explanation(
    title="Classification — Distribution of Predicted Probabilities (Test)",
    explanation=(
        "Purpose: Check model confidence and calibration behavior.\n"
        "How to read: A strong model often separates probabilities toward 0 and 1. "
        "If most values cluster near 0.5, the classifier is uncertain.\n"
        "Idea: This helps justify threshold tuning (choosing a decision cutoff different from 0.5)."
    ),
    save_path=os.path.join(OUT_DIR, "04_probability_histogram.png")
)

fpr, tpr, _ = roc_curve(y_test, proba)
roc_auc = auc(fpr, tpr)

plt.figure()
plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.3f}")
plt.plot([0, 1], [0, 1], linestyle="--", label="Random baseline")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.legend()

show_with_explanation(
    title="Classification — ROC Curve (Test)",
    explanation=(
        "Purpose: Evaluate ranking quality across all thresholds.\n"
        "How to read: The curve shows TPR vs FPR for every possible threshold. "
        "AUC summarizes the curve: 0.5 = random, 1.0 = perfect ranking.\n"
        "Idea: ROC is useful, but with class imbalance, Precision–Recall can be more informative."
    ),
    save_path=os.path.join(OUT_DIR, "05_ROC_curve.png")
)

prec, rec, _ = precision_recall_curve(y_test, proba)
ap = average_precision_score(y_test, proba)

plt.figure()
plt.plot(rec, prec, label=f"Average Precision (AP) = {ap:.3f}")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.legend()

show_with_explanation(
    title="Classification — Precision–Recall Curve (Test)",
    explanation=(
        "Purpose: Evaluate performance on the positive (viral) class under imbalance.\n"
        "How to read: Precision measures how many predicted virals are correct; Recall measures "
        "how many true virals are found.\n"
        "Idea: PR curves are often preferred for imbalanced datasets because they focus on the positive class."
    ),
    save_path=os.path.join(OUT_DIR, "06_PR_curve.png")
)

thresholds = np.linspace(0.05, 0.95, 19)
f1s = [f1_score(y_test, (proba >= thr).astype(int), average="macro") for thr in thresholds]

plt.figure()
plt.plot(thresholds, f1s, marker="o")
plt.xlabel("Decision threshold")
plt.ylabel("Macro F1")

show_with_explanation(
    title="Classification — Macro-F1 as a Function of the Threshold (Test)",
    explanation=(
        "Purpose: Visualize why the chosen threshold may not be 0.5.\n"
        "How to read: The peak indicates the threshold that best balances classes under macro-F1.\n"
        "Idea: This supports your project decision to tune the threshold on validation (and then apply it to test)."
    ),
    save_path=os.path.join(OUT_DIR, "07_F1_vs_threshold.png")
)

target_names = ["y1 (log1p(shares))", "y2 (log1p(capped shares))", "y3 (percentile score)"]

for j, name in enumerate(target_names):
    yt = Y_test[:, j]
    yp = Y_pred[:, j]

    # True vs Pred
    plt.figure()
    plt.scatter(yt, yp, s=10)
    mn = min(yt.min(), yp.min())
    mx = max(yt.max(), yp.max())
    plt.plot([mn, mx], [mn, mx], linestyle="--", label="Ideal: y = x")
    plt.xlabel("True value")
    plt.ylabel("Predicted value")
    plt.legend()
    show_with_explanation(
        title=f"Regression — {name} — True vs Predicted (Test)",
        explanation=(
            "Purpose: Check how close predictions are to ground truth.\n"
            "How to read: Points close to the diagonal mean good predictions. Systematic deviations "
            "indicate bias (e.g., underpredicting high values).\n"
            "Idea: For heavy-tailed original shares, using log targets makes this plot much more learnable."
        ),
        save_path=os.path.join(OUT_DIR, f"08_reg_true_vs_pred_{j+1}.png")
    )
    # Residuals
    res = yt - yp
    plt.figure()
    plt.hist(res, bins=40)
    plt.xlabel("Residual (true − predicted)")
    plt.ylabel("Frequency")

    show_with_explanation(
        title=f"Regression — {name} — Residual Distribution (Test)",
        explanation=(
            "Purpose: Diagnose error structure.\n"
            "How to read: Ideally residuals center around 0 and are roughly symmetric. "
            "Skew or long tails indicate that some ranges are harder.\n"
            "Idea: Residual analysis helps explain limitations: popularity depends on external factors not in features."
        ),
        save_path=os.path.join(OUT_DIR, f"09_reg_residuals_{j+1}.png")
    )

r2_y1 = reg["mlp"]["r2_y1"]
r2_y2 = reg["mlp"]["r2_y2"]
r2_y3 = reg["mlp"]["r2_y3"]

plt.figure()
ax = plt.gca()
ax.bar(["y1", "y2", "y3"], [r2_y1, r2_y2, r2_y3])
ax.set_ylabel("R² (Test)")
safe_bar_labels(ax)

show_with_explanation(
    title="Regression — R² per Output (Test)",
    explanation=(
        "Purpose: Summarize how much variance is explained for each regression target.\n"
        "How to read: R² = 1 is perfect; 0 means no better than predicting the mean; negative means worse than mean.\n"
        "Idea: With social popularity data, modest R² is common; showing improvements from target engineering is key."
    ),
    save_path=os.path.join(OUT_DIR, "10_reg_R2_per_output.png")
)

print(f"\nAll figures saved to: {OUT_DIR}/")


In [2]:
import os, glob, json
from pathlib import Path
import numpy as np
import matplotlib.pyplot as plt

OUT = Path("figures"); OUT.mkdir(exist_ok=True)

def latest(pattern):
    files = glob.glob(pattern)
    return sorted(files, key=os.path.getmtime)[-1] if files else None

paths = {
    "sklearn": latest("results_sklearn_42.json"),
    "tf":      latest("results_tf_42.json"),
    "torch":   latest("results_torch_42.json"),
}

# Load available results
R = {}
for k, p in paths.items():
    if p is None:
        continue
    with open(p, "r", encoding="utf-8") as f:
        R[k] = json.load(f)

if not R:
    raise FileNotFoundError("No results_*.json found. Run training scripts first (torch/tf/sklearn).")

print("Loaded files:")
for k, p in paths.items():
    if p: print(f" - {k}: {p}")
print("Saving figures to:", OUT.resolve())

def get(d, *keys, default=None):
    cur = d
    for k in keys:
        if not isinstance(cur, dict) or k not in cur:
            return default
        cur = cur[k]
    return cur

def save_show(filename, title, xlabel, ylabel, explanation):
    plt.title(title)
    if xlabel: plt.xlabel(xlabel)
    if ylabel: plt.ylabel(ylabel)
    plt.tight_layout()
    fp = OUT / filename
    plt.savefig(fp, dpi=200, bbox_inches="tight")
    plt.show()
    print(explanation)
    print("Saved:", fp)
    print("-" * 90)

# -------- Collect metrics --------
frameworks = list(R.keys())

acc = []
f1m = []
cms = []
r2 = {"y1": [], "y2": [], "y3": [], "mean": []}
rmse = {"y1": [], "y2": [], "y3": []}

for fw in frameworks:
    cls = get(R[fw], "results", "classification", default={})
    reg = get(R[fw], "results", "regression", default={})

    acc.append(get(cls, "mlp", "accuracy", default=np.nan))
    f1m.append(get(cls, "mlp", "f1_macro", default=np.nan))
    cms.append(get(cls, "mlp", "confusion_matrix", default=None))

    r2["y1"].append(get(reg, "mlp", "r2_y1", default=np.nan))
    r2["y2"].append(get(reg, "mlp", "r2_y2", default=np.nan))
    r2["y3"].append(get(reg, "mlp", "r2_y3", default=np.nan))
    r2["mean"].append(get(reg, "mlp", "r2_mean", default=np.nan))

    rmse["y1"].append(get(reg, "mlp", "rmse_y1", default=np.nan))
    rmse["y2"].append(get(reg, "mlp", "rmse_y2", default=np.nan))
    rmse["y3"].append(get(reg, "mlp", "rmse_y3", default=np.nan))

# -------- 1) Accuracy comparison --------
plt.figure()
plt.bar(frameworks, acc)
plt.ylim(0, 1)
save_show(
    "CMP_01_accuracy.png",
    "Classification Accuracy (Test) — Framework Comparison",
    "Framework", "Accuracy",
    "Purpose: overall % correct on the test set.\n"
    "Idea: accuracy can look good even when the dataset is imbalanced, so we also report Macro-F1."
)

# -------- 2) Macro-F1 comparison --------
plt.figure()
plt.bar(frameworks, f1m)
plt.ylim(0, 1)
save_show(
    "CMP_02_macro_f1.png",
    "Classification Macro-F1 (Test) — Framework Comparison",
    "Framework", "Macro-F1",
    "Purpose: balanced classification quality across both classes (viral and non-viral).\n"
    "Idea: Macro-F1 penalizes models that ignore the minority class; it’s more meaningful than accuracy here."
)

# -------- 3) Confusion matrix per framework --------
for fw, cm in zip(frameworks, cms):
    if cm is None:
        continue
    cm = np.array(cm, dtype=int)
    plt.figure()
    plt.imshow(cm, interpolation="nearest")
    plt.xticks([0, 1], ["Pred 0", "Pred 1"])
    plt.yticks([0, 1], ["True 0", "True 1"])
    for i in range(2):
        for j in range(2):
            plt.text(j, i, str(cm[i, j]), ha="center", va="center")
    plt.colorbar()
    save_show(
        f"CMP_03_confusion_{fw}.png",
        f"Confusion Matrix (Test) — {fw}",
        "", "",
        "Purpose: shows exactly where the classifier makes mistakes.\n"
        "Idea: Top-right = false positives (false viral alerts); bottom-left = false negatives (missed viral)."
    )

# -------- 4) Regression R² per output --------
x = np.arange(len(frameworks))
w = 0.25
plt.figure()
plt.bar(x - w, r2["y1"], width=w, label="R² y1")
plt.bar(x,     r2["y2"], width=w, label="R² y2")
plt.bar(x + w, r2["y3"], width=w, label="R² y3")
plt.xticks(x, frameworks)
plt.legend()
save_show(
    "CMP_04_r2_per_output.png",
    "Regression R² per Output (Test) — Framework Comparison",
    "Framework", "R²",
    "Purpose: measures how much variance the model explains (higher is better).\n"
    "Idea: popularity is noisy, so modest R² is expected; compare frameworks mainly for consistency."
)

# -------- 5) Regression RMSE per output --------
plt.figure()
plt.bar(x - w, rmse["y1"], width=w, label="RMSE y1")
plt.bar(x,     rmse["y2"], width=w, label="RMSE y2")
plt.bar(x + w, rmse["y3"], width=w, label="RMSE y3")
plt.xticks(x, frameworks)
plt.legend()
save_show(
    "CMP_05_rmse_per_output.png",
    "Regression RMSE per Output (Test) — Framework Comparison",
    "Framework", "RMSE",
    "Purpose: shows error magnitude in the target units.\n"
    "Idea: after your y2 fix (log+cap), RMSE across outputs becomes comparable and training is more stable."
)

# -------- 6) Learning curves when present (TF/Torch) --------
for fw in frameworks:
    cls_hist = get(R[fw], "results", "classification", "train_history", default=None)
    if cls_hist and "train_loss" in cls_hist and "val_loss" in cls_hist:
        plt.figure()
        plt.plot(cls_hist["train_loss"], label="Train loss")
        plt.plot(cls_hist["val_loss"], label="Val loss")
        plt.legend()
        save_show(
            f"LC_01_class_{fw}.png",
            f"Classification Learning Curves — {fw}",
            "Epoch", "Loss",
            "Purpose: verifies learning dynamics and overfitting.\n"
            "Idea: if validation loss stops improving, early stopping prevents over-training."
        )

    reg_hist = get(R[fw], "results", "regression", "train_history", default=None)
    if reg_hist and "train_loss" in reg_hist and "val_loss" in reg_hist:
        plt.figure()
        plt.plot(reg_hist["train_loss"], label="Train loss")
        plt.plot(reg_hist["val_loss"], label="Val loss")
        plt.legend()
        save_show(
            f"LC_02_reg_{fw}.png",
            f"Regression Learning Curves — {fw}",
            "Epoch", "Loss",
            "Purpose: checks stability of training for the 3-output regressor.\n"
            "Idea: a growing gap (train much lower than val) indicates overfitting."
        )

print("Done. All figures saved in:", OUT.resolve())


FileNotFoundError: No results_*.json found. Run training scripts first (torch/tf/sklearn).