In [26]:
import os, numpy as np, pandas as pd, torch, torch.nn as nn
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, balanced_accuracy_score, precision_recall_fscore_support, confusion_matrix

# -------------------
# CONFIG
# -------------------
X_PATH = "/content/X_real_256.csv"
Y_PATH = "/content/y_real_256.csv"
MODEL_PATH = "/content/ann_anisotropy_classifier.pth"
# optional normalization stats (if available)
STATS_CANDIDATES = [
    "/mnt/data/train_stats_256.npz",
    "/mnt/data/train_stats_real_256.npz",
]

VAL_TEST_SPLIT_SEED = 42

# -------------------
# Load data
# -------------------
X = pd.read_csv(X_PATH, header=None).to_numpy(np.float32)
y = pd.read_csv(Y_PATH, header=None).to_numpy(np.int64).ravel()
assert X.shape[1] == 256, f"Expected 256 features, got {X.shape[1]}"

# -------------------
# Apply training-time normalization if available
# -------------------
mean, std = None, None
for stats_path in STATS_CANDIDATES:
    if os.path.exists(stats_path):
        try:
            stats = np.load(stats_path)
            if "mean" in stats and "std" in stats:
                mean = stats["mean"].astype(np.float32).reshape(1, -1)
                std  = stats["std"].astype(np.float32).reshape(1, -1)
                print(f"[STATS] Loaded normalization from {stats_path}")
                break
        except Exception as e:
            print(f"[WARN] Could not read {stats_path}: {e}")

def apply_normalization(Xarr):
    if mean is not None and std is not None:
        safe_std = np.where(std == 0, 1.0, std)
        return (Xarr - mean) / safe_std
    # fallback: scale to zero-mean/unit-var using dataset stats (inference-only)
    m = Xarr.mean(axis=0, keepdims=True)
    s = Xarr.std(axis=0, keepdims=True)
    s[s == 0] = 1.0
    return (Xarr - m) / s

Xn = apply_normalization(X)

# -------------------
# Split
# -------------------
X_train, X_temp, y_train, y_temp = train_test_split(
    Xn, y, test_size=0.30, random_state=VAL_TEST_SPLIT_SEED, stratify=y
)
X_val, X_test, y_val, y_test = train_test_split(
    X_temp, y_temp, test_size=0.50, random_state=VAL_TEST_SPLIT_SEED, stratify=y_temp
)

# -------------------
# Helpers: generic MLP to reconstruct shapes if needed
# -------------------
class GenericMLP(nn.Module):
    def __init__(self, sizes, use_bn=False, p=0.0):
        """sizes: [in, h1, h2, ..., out]"""
        super().__init__()
        layers = []
        for i in range(len(sizes)-1):
            layers += [nn.Linear(sizes[i], sizes[i+1])]
            if i < len(sizes)-2:  # hidden layer
                if use_bn: layers += [nn.BatchNorm1d(sizes[i+1])]
                layers += [nn.ReLU()]
                if p > 0: layers += [nn.Dropout(p)]
        self.net = nn.Sequential(*layers)
    def forward(self, x): return self.net(x)

def reconstruct_from_state_dict(sd, input_dim=256, default_out=2):
    """Infer a chain of Linear layers from weight shapes (out_features, in_features)."""
    lin_weights = {k: v for k, v in sd.items() if k.endswith("weight") and v.ndim == 2}
    # Build mapping in_features -> (key, out_features)
    fmap = {}
    for k, w in lin_weights.items():
        out_f, in_f = w.shape
        fmap.setdefault(in_f, []).append((k, out_f))
    # Greedy chain starting from input_dim
    chain = [input_dim]
    cur_in = input_dim
    visited = set()
    while cur_in in fmap and len(chain) < 6:  # cap depth for safety
        # choose the first candidate whose out isn't visited
        candidates = sorted(fmap[cur_in], key=lambda t: t[1], reverse=True)
        chosen = None
        for k, out_f in candidates:
            if (cur_in, out_f) not in visited:
                chosen = (k, out_f)
                break
        if chosen is None: break
        _, out_f = chosen
        chain.append(out_f)
        visited.add((cur_in, out_f))
        cur_in = out_f
        if out_f == default_out: break
    # ensure last is default_out
    if chain[-1] != default_out:
        chain.append(default_out)
    # create model
    model = GenericMLP(chain, use_bn=False, p=0.0)
    missing = model.load_state_dict(sd, strict=False)
    print(f"[REBUILD] Arch inferred: {chain} | missing keys: {len(missing.missing_keys)} | unexpected: {len(missing.unexpected_keys)}")
    return model

def try_load_model(path, input_dim=256):
    device = torch.device("cpu")
    # 1) Try TorchScript
    try:
        m = torch.jit.load(path, map_location=device)
        m.eval()
        print("[LOAD] Loaded TorchScript model.")
        return m
    except Exception as e:
        print(f"[LOAD] TorchScript load failed: {e}")
    # 2) Try full pickled nn.Module
    try:
        obj = torch.load(path, map_location=device)
        if isinstance(obj, nn.Module):
            obj.eval()
            print("[LOAD] Loaded pickled nn.Module.")
            return obj
        # 3) state_dict path
        if isinstance(obj, dict):
            sd = obj.get("state_dict", obj)
            if not isinstance(sd, dict):
                raise ValueError("Unrecognized checkpoint format.")
            # Try to reconstruct a compatible MLP
            model = reconstruct_from_state_dict(sd, input_dim=input_dim, default_out=2)
            model.eval()
            return model
        raise ValueError("Unsupported checkpoint object type.")
    except Exception as e:
        raise RuntimeError(f"Could not load model from {path}: {e}")

model = try_load_model(MODEL_PATH, input_dim=256)

# -------------------
# Inference helpers
# -------------------
@torch.no_grad()
def predict_proba(Xnp):
    X_t = torch.from_numpy(Xnp.astype(np.float32))
    logits = model(X_t)
    if logits.ndim == 1 or logits.shape[1] == 1:
        # single-logit convention -> probability of class 1 via sigmoid
        p1 = torch.sigmoid(logits.view(-1)).numpy()
    else:
        p = torch.softmax(logits, dim=1).numpy()
        p1 = p[:, 1]
    return p1  # probability of class 1

def metrics_from_preds(y_true, y_prob, threshold=None):
    if threshold is None:
        # default argmax equivalent: threshold=0.5 for binary probs
        threshold = 0.5
    y_pred = (y_prob >= threshold).astype(int)
    acc  = accuracy_score(y_true, y_pred)
    bacc = balanced_accuracy_score(y_true, y_pred)
    p, r, f1, _ = precision_recall_fscore_support(y_true, y_pred, labels=[0,1], average=None)
    cm = confusion_matrix(y_true, y_pred, labels=[0,1])
    return acc, bacc, p, r, f1, cm

# -------------------
# Threshold tuning on validation
# -------------------
p_val = predict_proba(X_val)

best_t, best_val_acc = 0.5, -np.inf
candidates = np.linspace(0.05, 0.95, 181)  # fine sweep
val_accs = []
for t in candidates:
    acc, _, _, _, _, _ = metrics_from_preds(y_val, p_val, threshold=t)
    val_accs.append(acc)
    if acc > best_val_acc:
        best_val_acc, best_t = acc, t

print(f"[TUNE] Best val acc {best_val_acc:.3f} at t={best_t:.3f}")
chosen_t = best_t

# -------------------
# Final test evaluation
# -------------------
p_test = predict_proba(X_test)
test_acc, test_bacc, p, r, f1, cm = metrics_from_preds(y_test, p_test, threshold=chosen_t)

print("\n=== TEST RESULTS (no training) ===")
print(f"Threshold used      : {chosen_t:.3f}")
print(f"Accuracy            : {test_acc:.3f}")
print(f"Balanced Accuracy   : {test_bacc:.3f}")
print(f"Precision (c0,c1)   : {p[0]:.3f}, {p[1]:.3f}")
print(f"Recall    (c0,c1)   : {r[0]:.3f}, {r[1]:.3f}")
print(f"F1        (c0,c1)   : {f1[0]:.3f}, {f1[1]:.3f}")
print("Confusion Matrix [rows=true, cols=pred] (labels: 0,1):")
print(cm)

# -------------------
# Save predictions (
# -------------------
p_all = predict_proba(Xn)
y_pred_all = (p_all >= chosen_t).astype(int)
out = pd.DataFrame({
    "y_true": y,
    "p_class1": p_all,
    "y_pred": y_pred_all
})
out.to_csv("preds_balanced.csv", index=False)
print("\n[SAVED] Predictions -> preds_balanced.csv")


[LOAD] TorchScript load failed: PytorchStreamReader failed locating file constants.pkl: file not found
[REBUILD] Arch inferred: [256, 16, 2] | missing keys: 4 | unexpected: 4


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize

[TUNE] Best val acc 0.723 at t=0.485 | Targeted ~80% -> pick t=0.485 (val_acc=0.723)

=== TEST RESULTS (no training) ===
Threshold used      : 0.485
Accuracy            : 0.723
Balanced Accuracy   : 0.723
Precision (c0,c1)   : 0.924, 0.652
Recall    (c0,c1)   : 0.487, 0.960
F1        (c0,c1)   : 0.638, 0.776
Confusion Matrix [rows=true, cols=pred] (labels: 0,1):
[[ 73  77]
 [  6 144]]

[SAVED] Predictions -> preds_balanced.csv


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
