In [None]:
# --- Cell 1: Imports & paths ---
import json
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt

from sklearn.metrics import (
    confusion_matrix, accuracy_score, balanced_accuracy_score,
    f1_score, mean_squared_error, r2_score, classification_report
)

from texas_gerrymandering_hb4.config import IMAGES_DIR

ART_DIR = "artifacts"
ART_DIR.mkdir(parents=True, exist_ok=True)

In [None]:
# --- Cell 2: Load test split, model & threshold ---
X_test = pd.read_parquet(ART_DIR / "X_test.parquet")
y_test = pd.read_parquet(ART_DIR / "y_test.parquet")["party"]

pipeline = joblib.load(ART_DIR / "active_model.pkl")
with open(ART_DIR / "train_threshold.json") as f:
    thr_info = json.load(f)
threshold = float(thr_info["threshold"])
variant = thr_info.get("variant", "unknown")
print(f"Evaluating model variant: {variant} with threshold={threshold:.3f}")

In [None]:
# --- Cell 3: Predict & classify ---
y_score = pipeline.predict(X_test).clip(0, 1)
y_pred  = (y_score >= threshold).astype(int)

In [None]:
# --- Cell 4: Metrics & report ---
acc = accuracy_score(y_test, y_pred)
bal = balanced_accuracy_score(y_test, y_pred)
f1  = f1_score(y_test, y_pred, zero_division=0)
mse = mean_squared_error(y_test, y_score)
r2  = r2_score(y_test, y_score)
cm  = confusion_matrix(y_test, y_pred)

print(f"Accuracy          : {acc:.3f}")
print(f"Balanced Accuracy : {bal:.3f}")
print(f"F1 (Dem=1)        : {f1:.3f}")
print(f"MSE               : {mse:.4f}")
print(f"RÂ²                : {r2:.4f}")
print("Confusion Matrix:\n", cm)

with open(ART_DIR / "metrics.json", "w") as f:
    json.dump({
        "variant": variant,
        "threshold": threshold,
        "accuracy":acc, "balanced_accuracy":bal, "f1":f1, "mse":mse, "r2":r2
    }, f, indent=2)

pd.DataFrame(classification_report(
    y_test, y_pred, target_names=["Republican(0)", "Democrat(1)"], output_dict=True, zero_division=0
)).to_csv(ART_DIR / "classification_report.csv")


## Confusion Matrix

In [None]:
def save_confusion_matrix(cm, path, labels=("Rep(0)", "Dem(1)")):
    plt.figure()
    plt.imshow(cm, interpolation="nearest")
    plt.title("Confusion Matrix (Test)")
    plt.xticks([0, 1], labels)
    plt.yticks([0, 1], labels)
    for i in range(2):
        for j in range(2):
            plt.text(j, i, str(cm[i, j]), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(path, dpi=200)
    plt.close()

save_confusion_matrix(cm, ART_DIR / "confusion_matrix.png")

print("Evaluation complete. Saved metrics, report, and confusion matrix.")