In [None]:
#!/usr/bin/env python3
"""
Breast Cancer (Wisconsin) — 1D CNN + Bar-Chart-Race van Permutation Importances
Author: <your name>

Wat deze script doet:
- Traint een 1D CNN (geen MLP) op de Breast Cancer Wisconsin dataset.
- Na elke paar epochs berekent het *permutation importance* van alle 30 features
  op de validatieset met ROC AUC als metric (custom implementatie voor PyTorch).
- Bouwt per checkpoint een horizontale bar chart (top-10 features) en maakt er een GIF van
  zodat je de verschuivingen in belangrijkheid over de training ziet: "bar-chart-race".

Outputs (in outputs_bc_cnn_barrace/):
  - feature_importance_race.gif
  - metrics.txt (samenvatting voor je LinkedIn-post)

Benodigdheden: numpy, pandas, scikit-learn, matplotlib, torch, imageio
"""
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import imageio.v2 as imageio

from sklearn.datasets import load_breast_cancer
from sklearn.model_selection import train_test_split
from sklearn.preprocessing import StandardScaler
from sklearn.metrics import roc_auc_score, accuracy_score, confusion_matrix, classification_report

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import TensorDataset, DataLoader

# ---------------------------
# Config
# ---------------------------
RANDOM_STATE = 42
TEST_SIZE = 0.2
VAL_SIZE = 0.2
BATCH_SIZE = 64
EPOCHS = 80
LR = 1e-3
WEIGHT_DECAY = 1e-4
FRAME_EVERY = 2           # maak een frame om de N epochs
PERM_REPEATS = 5          # herhalingen per feature voor permutation importance
TOP_K = 10

OUT_DIR = "outputs_bc_cnn_barrace"
os.makedirs(OUT_DIR, exist_ok=True)

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ---------------------------
# Data
# ---------------------------
data = load_breast_cancer()
X = pd.DataFrame(data.data, columns=data.feature_names)
y = pd.Series(data.target, name="target")  # 1=benign, 0=malignant

# Split
X_temp, X_test, y_temp, y_test = train_test_split(
    X, y, test_size=TEST_SIZE, stratify=y, random_state=RANDOM_STATE
)
X_train, X_val, y_train, y_val = train_test_split(
    X_temp, y_temp, test_size=VAL_SIZE, stratify=y_temp, random_state=RANDOM_STATE
)

scaler = StandardScaler()
X_train_s = scaler.fit_transform(X_train)
X_val_s   = scaler.transform(X_val)
X_test_s  = scaler.transform(X_test)

# Torch tensors
def to_tensor(x):
    return torch.tensor(x, dtype=torch.float32).unsqueeze(1)  # (N, 1, 30)

X_train_t = to_tensor(X_train_s)
X_val_t   = to_tensor(X_val_s)
X_test_t  = to_tensor(X_test_s)

y_train_t = torch.tensor(y_train.values, dtype=torch.float32).unsqueeze(1)
y_val_t   = torch.tensor(y_val.values, dtype=torch.float32).unsqueeze(1)
y_test_t  = torch.tensor(y_test.values, dtype=torch.float32).unsqueeze(1)

train_loader = DataLoader(TensorDataset(X_train_t, y_train_t), batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(TensorDataset(X_val_t, y_val_t), batch_size=BATCH_SIZE, shuffle=False)
test_loader  = DataLoader(TensorDataset(X_test_t, y_test_t), batch_size=BATCH_SIZE, shuffle=False)

# ---------------------------
# Model (1D CNN)
# ---------------------------
class CNN1D(nn.Module):
    def __init__(self):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv1d(1, 16, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.Conv1d(16, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.AdaptiveAvgPool1d(1),
        )
        self.classifier = nn.Sequential(
            nn.Flatten(),
            nn.Linear(32, 1),
        )
    def forward(self, x):
        x = self.features(x)
        x = self.classifier(x)
        return x

model = CNN1D().to(DEVICE)
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.Adam(model.parameters(), lr=LR, weight_decay=WEIGHT_DECAY)

# ---------------------------
# Helpers
# ---------------------------
@torch.no_grad()
def predict_proba_tensor(X_t):
    model.eval()
    probs = []
    for i in range(0, X_t.size(0), BATCH_SIZE):
        xb = X_t[i:i+BATCH_SIZE].to(DEVICE)
        logits = model(xb)
        p = torch.sigmoid(logits).cpu().numpy().ravel()
        probs.append(p)
    return np.concatenate(probs)

@torch.no_grad()
def predict_proba(loader):
    model.eval()
    probs, labels = [], []
    for xb, yb in loader:
        xb = xb.to(DEVICE)
        logits = model(xb)
        p = torch.sigmoid(logits).cpu().numpy().ravel()
        probs.append(p)
        labels.append(yb.numpy().ravel())
    return np.concatenate(probs), np.concatenate(labels)

def train_one_epoch():
    model.train()
    total = 0.0
    for xb, yb in train_loader:
        xb, yb = xb.to(DEVICE), yb.to(DEVICE)
        optimizer.zero_grad()
        logits = model(xb)
        loss = criterion(logits, yb)
        loss.backward()
        optimizer.step()
        total += loss.item() * xb.size(0)
    return total / len(train_loader.dataset)

def permutation_importance_nn(X_val_s_np, y_val_np, repeats=5):
    """
    Custom permutation importance voor PyTorch model.
    - Metric: ROC AUC (val). Return: gemiddelde AUC-drop per feature (positieve waarde = belangrijk).
    """
    # baseline
    X_val_t_local = to_tensor(X_val_s_np)
    base_probs = predict_proba_tensor(X_val_t_local)
    try:
        base_auc = roc_auc_score(y_val_np, base_probs)
    except ValueError:
        base_auc = float("nan")

    n_samples, n_features = X_val_s_np.shape
    drops = np.zeros(n_features, dtype=float)

    rng = np.random.default_rng(RANDOM_STATE)
    for j in range(n_features):
        aucs = []
        for r in range(repeats):
            X_perm = X_val_s_np.copy()
            X_perm[:, j] = rng.permutation(X_perm[:, j])
            probs = predict_proba_tensor(to_tensor(X_perm))
            try:
                aucs.append(roc_auc_score(y_val_np, probs))
            except ValueError:
                aucs.append(float("nan"))
        mean_auc = np.nanmean(aucs)
        drops[j] = base_auc - mean_auc  # positieve drop = belangrijker
    return drops, base_auc

# ---------------------------
# Training + frames
# ---------------------------
frames = []
all_epochs_importances = []
max_importance_seen = 0.0

for epoch in range(1, EPOCHS+1):
    train_one_epoch()

    if epoch % FRAME_EVERY == 0 or epoch in (1, EPOCHS):
        # Compute permutation importance (val set)
        drops, base_auc = permutation_importance_nn(X_val_s, y_val.values, repeats=PERM_REPEATS)
        all_epochs_importances.append((epoch, drops, base_auc))
        # Track for consistent xlim across frames
        finite_vals = drops[np.isfinite(drops)]
        if finite_vals.size > 0:
            max_importance_seen = max(max_importance_seen, float(np.max(finite_vals)))

# Build frames for bar-chart-race
feature_names = np.array(data.feature_names)

for (epoch, drops, base_auc) in all_epochs_importances:
    # Select top-k
    order = np.argsort(drops)[::-1]
    top_idx = order[:TOP_K]
    top_vals = drops[top_idx]
    top_names = feature_names[top_idx]

    # Sort within top-k for a nice barh (ascending so largest at bottom)
    sort_in_top = np.argsort(top_vals)
    top_vals = top_vals[sort_in_top]
    top_names = top_names[sort_in_top]

    fig, ax = plt.subplots(figsize=(8, 5.5))
    y_pos = np.arange(len(top_vals))
    ax.barh(y_pos, top_vals)
    ax.set_yticks(y_pos)
    ax.set_yticklabels(top_names)
    ax.set_xlabel("AUC drop bij permutatie (belangrijkheid)")
    ax.set_title(f"Permutation Importances — Top {TOP_K} | Epoch {epoch} | Val AUC={base_auc:.3f}")
    # Gebruik consistente x-lim voor rustige animatie
    right_lim = max_importance_seen * 1.05 if max_importance_seen > 0 else None
    if right_lim is not None and np.isfinite(right_lim):
        ax.set_xlim(0, right_lim)
    fig.tight_layout()

    frame_path = os.path.join(OUT_DIR, f"_fi_{epoch:03d}.png")
    fig.savefig(frame_path, dpi=160)
    plt.close(fig)
    frames.append(imageio.imread(frame_path))

# Save GIF
gif_path = os.path.join(OUT_DIR, "feature_importance_race.gif")
imageio.mimsave(gif_path, frames, duration=0.18)

# ---------------------------
# Final test metrics
# ---------------------------
test_probs, test_labels = predict_proba(test_loader)
test_pred = (test_probs >= 0.5).astype(int)
test_acc = accuracy_score(test_labels, test_pred)
try:
    test_auc = roc_auc_score(test_labels, test_probs)
except ValueError:
    test_auc = float("nan")
cm_test = confusion_matrix(test_labels, test_pred)
report = classification_report(test_labels, test_pred, target_names=["malignant (0)", "benign (1)"])

caption = f"""
Borstkankermaand 🎗️ — 1D CNN + Bar-Chart-Race Feature Importances
- Dataset: Breast Cancer Wisconsin (569 cases, 30 features)
- Model: 1D CNN (geen MLP), permutation importance op validatieset met ROC AUC
- In de GIF springen de TOP-{TOP_K} features in rangorde naarmate training vordert.

Setup: epochs={EPOCHS}, batch_size={BATCH_SIZE}, lr={LR}, weight_decay={WEIGHT_DECAY}, device={DEVICE}
Test: accuracy={test_acc:.3f}, ROC AUC={test_auc:.3f}
Confusion matrix (test):\n{cm_test}

Reflectie:
Elke verschuiving vertelt welke signalen het model oppikt. Modellen ondersteunen; mensen beslissen.
"""
with open(os.path.join(OUT_DIR, "metrics.txt"), "w", encoding="utf-8") as f:
    f.write(caption + "\n\nVolledig classification report:\n" + report)

print("Done. Files saved to:", os.path.abspath(OUT_DIR))
print("GIF:", gif_path)
