In [None]:
import json
import joblib
import numpy as np
import pandas as pd
from pathlib import Path
import matplotlib.pyplot as plt

from sklearn.metrics import (
    confusion_matrix, accuracy_score, balanced_accuracy_score,
    f1_score, mean_squared_error, r2_score, classification_report
)

from texas_gerrymandering_hb4.config import LINEAR_REGRESSION_ARTIFACTS, IMAGES_DIR


In [None]:
X_test = pd.read_parquet(LINEAR_REGRESSION_ARTIFACTS / "X_test.parquet")
y_test = pd.read_parquet(LINEAR_REGRESSION_ARTIFACTS / "y_test.parquet")["party"]

pipeline = joblib.load(LINEAR_REGRESSION_ARTIFACTS / "active_model.pkl")
with open(LINEAR_REGRESSION_ARTIFACTS / "train_threshold.json") as f:
    thr_info = json.load(f)
threshold = thr_info["threshold"]
threshold


In [None]:
y_score = pipeline.predict(X_test).clip(0, 1)
y_pred  = (y_score >= threshold).astype(int)


In [None]:
acc = accuracy_score(y_test, y_pred)
bal = balanced_accuracy_score(y_test, y_pred)
f1  = f1_score(y_test, y_pred, zero_division=0)
mse = mean_squared_error(y_test, y_score)
r2  = r2_score(y_test, y_score)
cm  = confusion_matrix(y_test, y_pred)

print(f"Accuracy          : {acc:.3f}")
print(f"Balanced Accuracy : {bal:.3f}")
print(f"F1 (Dem=1)        : {f1:.3f}")
print(f"MSE               : {mse:.4f}")
print(f"RÂ²                : {r2:.4f}")
print("Confusion Matrix:\n", cm)

# Save metrics
with open(LINEAR_REGRESSION_ARTIFACTS / "metrics.json", "w") as f:
    json.dump({"accuracy":acc, "balanced_accuracy":bal, "f1":f1, "mse":mse, "r2":r2}, f, indent=2)

# Classification report CSV
pd.DataFrame(classification_report(
    y_test, y_pred, target_names=["Republican(0)", "Democrat(1)"], output_dict=True, zero_division=0
)).to_csv(LINEAR_REGRESSION_ARTIFACTS / "classification_report.csv")


In [None]:
def save_confusion_matrix(cm, path, labels=("Rep(0)", "Dem(1)")):
    plt.figure()
    plt.imshow(cm, interpolation="nearest")
    plt.title("Confusion Matrix (Test)")
    plt.xticks([0, 1], labels)
    plt.yticks([0, 1], labels)
    for i in range(2):
        for j in range(2):
            plt.text(j, i, str(cm[i, j]), ha="center", va="center")
    plt.tight_layout()
    plt.savefig(path, dpi=200)
    plt.close()

save_confusion_matrix(cm, IMAGES_DIR / "confusion_matrix.png")
