In [None]:

!pip -q install -U tensorflow-datasets scikit-learn opencv-python-headless tqdm

# Imports
import random, numpy as np, matplotlib.pyplot as plt, torch, torch.nn as nn, cv2, tensorflow as tf, tensorflow_datasets as tfds
from torchvision.models import resnet18, ResNet18_Weights
from torchvision.transforms import ColorJitter
from sklearn.metrics import roc_auc_score, roc_curve
from sklearn.linear_model import LogisticRegression
from tqdm.auto import tqdm
try: tf.config.set_visible_devices([], "GPU")
except Exception: pass
CFG = dict(seed=42, epochs=2, steps_per_epoch=350, lr=1e-4, weight_decay=1e-4,
           val_batches=120, test_batches=120, tsi_N=2000, tsi_m=3,
           tile_sizes=(8,12,16), jitter=dict(brightness=0.4, contrast=0.4, saturation=0.2, hue=0.02),
           viz_k=5)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu"); print("Device:", device)
bs = 256 if device.type == "cuda" else 64; print("Batch size:", bs)
# Seeds
def seed_everything(s):
    random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
    torch.backends.cudnn.benchmark = True
seed_everything(CFG["seed"])
# ----------------------------
# Load PCam from TFDS
# ----------------------------
print("Loading PCam from TFDS...")
train_tf = tfds.load("patch_camelyon", split="train", as_supervised=True, shuffle_files=True)
val_tf   = tfds.load("patch_camelyon", split="validation", as_supervised=True, shuffle_files=False)
test_tf  = tfds.load("patch_camelyon", split="test", as_supervised=True, shuffle_files=False)
AUTOTUNE = tf.data.AUTOTUNE
train_tf = train_tf.shuffle(4096, seed=CFG["seed"], reshuffle_each_iteration=True).batch(bs).prefetch(AUTOTUNE)
val_tf   = val_tf.batch(bs).prefetch(AUTOTUNE)
test_tf  = test_tf.batch(bs).prefetch(AUTOTUNE)
# ----------------------------
# Preprocess (ResNet with ImageNet normalization)
# ----------------------------
IMNET_MEAN = torch.tensor([0.485,0.456,0.406], device=device).view(1,3,1,1)
IMNET_STD  = torch.tensor([0.229,0.224,0.225], device=device).view(1,3,1,1)
def preprocess_uint8_batch(b):
    x = torch.from_numpy(b).to(device=device, dtype=torch.float32)/255.0
    x = x.permute(0,3,1,2).contiguous()
    return (x-IMNET_MEAN)/IMNET_STD
@torch.no_grad()
def logits_from_uint8_batch(b, sub_bs=1024):
    """Same helper: uint8 NHWC -> logits (N,)"""
    model.eval(); out=[]
    for i in range(0, b.shape[0], sub_bs):
        out.append(model(preprocess_uint8_batch(b[i:i+sub_bs])).detach().float().cpu().numpy().reshape(-1))
    return np.concatenate(out, axis=0)
def prob_from_logit(z): return 1.0/(1.0+np.exp(-z))
# ----------------------------
# Model: ResNet18 pretrained -> binary head
# ----------------------------
print("Building model...")
model = resnet18(weights=ResNet18_Weights.IMAGENET1K_V1)
model.fc = nn.Linear(model.fc.in_features, 1)
model = model.to(device)
criterion = nn.BCEWithLogitsLoss()
opt = torch.optim.AdamW(model.parameters(), lr=CFG["lr"], weight_decay=CFG["weight_decay"])
# ----------------------------
#  eval helper
# ----------------------------
@torch.no_grad()
def evaluate(tf_dataset, max_batches):
    model.eval(); logits_all, y_all = [], []
    for b, (img, y) in enumerate(tfds.as_numpy(tf_dataset)):
        if b >= max_batches: break
        x = preprocess_uint8_batch(img)
        y_t = torch.from_numpy(y).to(device=device, dtype=torch.float32).view(-1, 1)  # unused, but kept as-is
        logits = model(x).detach().cpu().numpy().reshape(-1)
        logits_all.append(logits); y_all.append(y.astype(np.int32).reshape(-1))
    logits = np.concatenate(logits_all) if logits_all else np.array([])
    y_true = np.concatenate(y_all) if y_all else np.array([])
    if len(y_true) == 0: return dict(acc=np.nan, auroc=np.nan, n=0)
    p = prob_from_logit(logits); pred = (p >= 0.5).astype(np.int32)
    acc = (pred == y_true).mean()
    try: auroc = roc_auc_score(y_true, p)
    except Exception: auroc = np.nan
    return dict(acc=float(acc), auroc=float(auroc), n=int(len(y_true)))
# ----------------------------
# Train
# ----------------------------
history = {"train_loss": [], "val_acc": [], "val_auroc": []}
print("\nTraining...")
train_iter = iter(tfds.as_numpy(train_tf.repeat()))
for ep in range(1, CFG["epochs"]+1):
    model.train(); losses=[]
    for _ in tqdm(range(CFG["steps_per_epoch"]), desc=f"epoch {ep}/{CFG['epochs']}", leave=False):
        img, y = next(train_iter)
        x = preprocess_uint8_batch(img)
        y_t = torch.from_numpy(y).to(device=device, dtype=torch.float32).view(-1, 1)
        opt.zero_grad(set_to_none=True)
        logits = model(x)
        loss = criterion(logits, y_t)
        loss.backward(); opt.step()
        losses.append(loss.item())
    val = evaluate(val_tf, CFG["val_batches"])
    history["train_loss"].append(float(np.mean(losses)))
    history["val_acc"].append(val["acc"])
    history["val_auroc"].append(val["auroc"])
    print(f"Epoch {ep:02d} | train_loss={history['train_loss'][-1]:.4f} | val_acc={val['acc']:.4f} | val_AUROC={val['auroc']:.4f} (n={val['n']})")
test = evaluate(test_tf, CFG["test_batches"])
print(f"\nTest (limited to {CFG['test_batches']} batches) | acc={test['acc']:.4f} | AUROC={test['auroc']:.4f} (n={test['n']})")
# ----------------------------
# Plots: training curves
# ----------------------------
plt.figure(); plt.plot(range(1, CFG["epochs"]+1), history["train_loss"], marker="o")
plt.title("Training loss"); plt.xlabel("Epoch"); plt.ylabel("Train loss (BCEWithLogits)"); plt.grid(True); plt.show()
plt.figure(); plt.plot(range(1, CFG["epochs"]+1), history["val_acc"], marker="o", label="Val Acc")
plt.plot(range(1, CFG["epochs"]+1), history["val_auroc"], marker="o", label="Val AUROC")
plt.title("Validation metrics"); plt.xlabel("Epoch"); plt.ylabel("Metric"); plt.grid(True); plt.legend(); plt.show()
# ============================================================
# TSI bits: S(x), T(x), score, and TSI = score(S)-score(T)
# ============================================================
def shape_view_uint8(img, rng):
    # "Shape-ish" view: smooth texture but keep edges (kinda like sketching the patch)
    d = int(rng.integers(5, 10))
    sm = cv2.bilateralFilter(img, d=d, sigmaColor=float(rng.uniform(40.0, 90.0)), sigmaSpace=float(rng.uniform(40.0, 90.0)))
    gray = cv2.cvtColor(sm, cv2.COLOR_RGB2GRAY)
    mag = cv2.magnitude(cv2.Sobel(gray, cv2.CV_32F, 1, 0, ksize=3), cv2.Sobel(gray, cv2.CV_32F, 0, 1, ksize=3))
    mag = mag/(mag.max()+1e-6)
    alpha = float(rng.uniform(0.15, 0.35))
    out = sm.astype(np.float32) + alpha * (mag[..., None] * 255.0).astype(np.float32)
    return np.clip(out, 0, 255).astype(np.uint8)
def texture_view_uint8(img, rng, tile_sizes=(8,12,16)):
    # "Texture-ish" view: shuffle tiles so local stats stay but global layout gets wrecked
    H, W, C = img.shape
    s = int(rng.choice(np.array(tile_sizes)))
    h, w = (H//s)*s, (W//s)*s
    core = img[:h, :w]
    nH, nW = h//s, w//s
    tiles = core.reshape(nH, s, nW, s, C).transpose(0,2,1,3,4).reshape(nH*nW, s, s, C)
    tiles = tiles[rng.permutation(nH*nW)]
    shuffled = tiles.reshape(nH, nW, s, s, C).transpose(0,2,1,3,4).reshape(h, w, C)
    out = img.copy(); out[:h, :w] = shuffled
    return out
def true_class_score(logit, y01): return np.where(y01 == 1, logit, -logit)
# ----------------------------
# Pull a validation subset for per-image analysis
# ----------------------------
print("\nSampling validation subset for TSI...")
raw_imgs, raw_y = [], []
for img, y in tfds.as_numpy(val_tf.unbatch().take(CFG["tsi_N"])): raw_imgs.append(img); raw_y.append(int(y))
raw_imgs = np.stack(raw_imgs, axis=0); raw_y = np.array(raw_y, dtype=np.int32)
# Base logits/preds on the subset (same math)
base_logit = logits_from_uint8_batch(raw_imgs)
base_prob  = prob_from_logit(base_logit)
base_pred  = (base_prob >= 0.5).astype(np.int32)
base_corr  = (base_pred == raw_y).astype(np.int32)
err        = 1 - base_corr
conf       = np.maximum(base_prob, 1.0 - base_prob)
print(f"Subset base accuracy: {base_corr.mean():.4f} (N={len(raw_y)})")
# ----------------------------
# Tiny sanity checks
# ----------------------------
rng_test = np.random.default_rng(CFG["seed"])
img0 = raw_imgs[0]
s0 = shape_view_uint8(img0, rng_test)
t0 = texture_view_uint8(img0, rng_test, CFG["tile_sizes"])
assert s0.shape == img0.shape == t0.shape, "Transform shape mismatch"
assert s0.dtype == np.uint8 and t0.dtype == np.uint8, "Transforms must return uint8"
assert 0 <= s0.min() and s0.max() <= 255 and 0 <= t0.min() and t0.max() <= 255, "Transforms out of [0,255]"
def hf_energy(img):
    g = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    lap = cv2.Laplacian(g, cv2.CV_32F)
    return float(np.mean(lap*lap))
E_x, E_s, E_t = [], [], []
for i in range(32):
    rng = np.random.default_rng(int(rng_test.integers(0, 2**32-1)))
    xi = raw_imgs[i]
    E_x.append(hf_energy(xi)); E_s.append(hf_energy(shape_view_uint8(xi, rng))); E_t.append(hf_energy(texture_view_uint8(xi, rng, CFG["tile_sizes"])))
print(f"[Diagnostic] mean HF energy: orig={np.mean(E_x):.2f}, S(x)={np.mean(E_s):.2f}, T(x)={np.mean(E_t):.2f}")
print(f"[Diagnostic] per-channel mean diff (orig vs T): {np.mean(np.abs(raw_imgs[:32].mean((1,2)) - np.stack([texture_view_uint8(raw_imgs[i], np.random.default_rng(i), CFG['tile_sizes']).mean((0,1)) for i in range(32)])), axis=0)}")
# ----------------------------
# Compute TSI
# ----------------------------
print("\nComputing TSI...")
N = len(raw_imgs)
tsi = np.zeros(N, dtype=np.float32)
scoreS = np.zeros(N, dtype=np.float32)
scoreT = np.zeros(N, dtype=np.float32)
rng_master = np.random.default_rng(CFG["seed"])
chunk = 64
for i0 in tqdm(range(0, N, chunk), desc="TSI", leave=False):
    i1 = min(N, i0 + chunk)
    imgs, ys = raw_imgs[i0:i1], raw_y[i0:i1]
    S_views, T_views = [], []
    for j in range(i1 - i0):
        for _ in range(CFG["tsi_m"]):
            rng = np.random.default_rng(int(rng_master.integers(0, 2**32-1)))
            S_views.append(shape_view_uint8(imgs[j], rng))
            T_views.append(texture_view_uint8(imgs[j], rng, CFG["tile_sizes"]))
    S_views, T_views = np.stack(S_views, axis=0), np.stack(T_views, axis=0)
    logS, logT = logits_from_uint8_batch(S_views), logits_from_uint8_batch(T_views)
    ys_rep = np.repeat(ys, CFG["tsi_m"])
    scS = true_class_score(logS, ys_rep).reshape(-1, CFG["tsi_m"]).mean(axis=1)
    scT = true_class_score(logT, ys_rep).reshape(-1, CFG["tsi_m"]).mean(axis=1)
    scoreS[i0:i1], scoreT[i0:i1], tsi[i0:i1] = scS.astype(np.float32), scT.astype(np.float32), (scS-scT).astype(np.float32)
print("\nTSI summary:"); print(f"  N={N}")
print(f"  mean={tsi.mean():.4f} | std={tsi.std():.4f} | min={tsi.min():.4f} | max={tsi.max():.4f}")
plt.figure(); plt.hist(tsi, bins=50)
plt.title("TSI distribution (validation subset)")
plt.xlabel("TSI = score(S(x)) - score(T(x))"); plt.ylabel("Count")
plt.grid(True); plt.show()
# ----------------------------
# Deciles (TSI and |TSI|)
# ----------------------------
def decile_bins(values, n_bins=10):
    idx = np.argsort(values); bins = np.zeros(len(values), dtype=np.int32)
    for b in range(n_bins):
        lo = int(round(b * len(values) / n_bins)); hi = int(round((b + 1) * len(values) / n_bins))
        bins[idx[lo:hi]] = b
    return bins
bins_tsi = decile_bins(tsi, 10)
bins_abs = decile_bins(np.abs(tsi), 10)
acc_by = [float(base_corr[bins_tsi == b].mean()) for b in range(10)]
print("\nAccuracy by TSI decile (0=lowest → 9=highest):")
for b, a in enumerate(acc_by): print(f"  decile {b}: acc={a:.4f}")
plt.figure(); plt.plot(range(10), acc_by, marker="o")
plt.title("Base accuracy by TSI decile (can be non-monotone)")
plt.xlabel("TSI decile (0 low → 9 high)"); plt.ylabel("Accuracy")
plt.ylim(0, 1.0); plt.grid(True); plt.show()
acc_abs = [float(base_corr[bins_abs == b].mean()) for b in range(10)]
plt.figure(); plt.plot(range(10), acc_abs, marker="o")
plt.title("Base accuracy by |TSI| decile (diagnostic for U-shape)")
plt.xlabel("|TSI| decile (0 small → 9 large)"); plt.ylabel("Accuracy")
plt.ylim(0, 1.0); plt.grid(True); plt.show()
print(f"\nPearson corr(TSI, correctness): {np.corrcoef(tsi, base_corr.astype(np.float32))[0,1]:.4f}")
# ----------------------------
# Error detection ROC: TSI vs -|TSI| vs confidence (+ tiny LR baseline)
# ----------------------------
score_tsi  = tsi.astype(np.float64)
score_abs  = (-np.abs(tsi)).astype(np.float64)     # small |TSI| => "meh/ambiguous" => riskier
score_conf = (-conf).astype(np.float64)            # low confidence => risky
au_tsi  = roc_auc_score(err, score_tsi)
au_abs  = roc_auc_score(err, score_abs)
au_conf = roc_auc_score(err, score_conf)
X = np.stack([tsi, conf], axis=1)
perm = np.random.default_rng(CFG["seed"]).permutation(N)
split = int(0.7 * N); tr, te = perm[:split], perm[split:]
lr = LogisticRegression(max_iter=2000); lr.fit(X[tr], err[tr])
score_lr = lr.predict_proba(X[te])[:, 1]
au_lr = roc_auc_score(err[te], score_lr)
print("\nError-detection AUROC (higher is better):")
print(f"  TSI:        {au_tsi:.4f}")
print(f"  -|TSI|:     {au_abs:.4f}")
print(f"  Confidence: {au_conf:.4f}")
print(f"  LR(TSI,conf) on held-out split: {au_lr:.4f}")
fpr1, tpr1, _ = roc_curve(err, score_tsi)
fpr2, tpr2, _ = roc_curve(err, score_abs)
fpr3, tpr3, _ = roc_curve(err, score_conf)
fpr4, tpr4, _ = roc_curve(err[te], score_lr)
plt.figure()
plt.plot(fpr1, tpr1, label=f"TSI (AUROC={au_tsi:.3f})")
plt.plot(fpr2, tpr2, label=f"-|TSI| (AUROC={au_abs:.3f})")
plt.plot(fpr3, tpr3, label=f"Confidence (AUROC={au_conf:.3f})")
plt.plot(fpr4, tpr4, label=f"LR(TSI,conf) held-out (AUROC={au_lr:.3f})")
plt.plot([0,1],[0,1], linestyle="--")
plt.title("Error detection ROC (validation subset)")
plt.xlabel("False positive rate"); plt.ylabel("True positive rate")
plt.grid(True); plt.legend(); plt.show()
# ----------------------------
# Risk–coverage curves (base errors)
# ----------------------------
def risk_coverage(error01, keep_order):
    n = len(error01)
    cum_err = np.cumsum(error01[keep_order])
    k = np.arange(1, n+1)
    return k/n, cum_err/k
order_conf = np.argsort(-conf)            # keep safe first
order_abs  = np.argsort(-np.abs(tsi))     # keep extreme |TSI| first
cov_c, risk_c = risk_coverage(err, order_conf)
cov_a, risk_a = risk_coverage(err, order_abs)
plt.figure()
plt.plot(cov_c, risk_c, label="Keep-high-confidence first")
plt.plot(cov_a, risk_a, label="Keep-high-|TSI| first")
plt.title("Risk–Coverage (base errors)")
plt.xlabel("Coverage (fraction kept)"); plt.ylabel("Risk (error rate on kept)")
plt.grid(True); plt.legend(); plt.show()
# ============================================================
# Robustness probe: ColorJitter + stratify drop by TSI decile
# ============================================================
print("\nRobustness probe: ColorJitter on the same subset...")
jitter = ColorJitter(**CFG["jitter"])
def apply_jitter_uint8(img):
    # torchvision jitter wants torch CHW float in [0,1]
    t = torch.from_numpy(img).float()/255.0
    t = jitter(t.permute(2,0,1))
    t = torch.clamp(t, 0.0, 1.0).permute(1,2,0).numpy()
    return (t*255.0 + 0.5).astype(np.uint8)
j_imgs = np.stack([apply_jitter_uint8(raw_imgs[i]) for i in tqdm(range(N), desc="jitter", leave=False)], axis=0)
j_logit = logits_from_uint8_batch(j_imgs)
j_prob  = prob_from_logit(j_logit)
j_pred  = (j_prob >= 0.5).astype(np.int32)
j_corr  = (j_pred == raw_y).astype(np.int32)
acc_orig_by, acc_jit_by, drop_by = [], [], []
for b in range(10):
    m = (bins_tsi == b)
    ao = float(base_corr[m].mean()); aj = float(j_corr[m].mean())
    acc_orig_by.append(ao); acc_jit_by.append(aj); drop_by.append(ao-aj)
print("\nAccuracy by TSI decile under style-only perturbation:")
for b in range(10):
    print(f"  decile {b}: orig={acc_orig_by[b]:.4f} | jitter={acc_jit_by[b]:.4f} | drop={drop_by[b]:+.4f}")
plt.figure()
plt.plot(range(10), acc_orig_by, marker="o", label="Original")
plt.plot(range(10), acc_jit_by, marker="o", label="ColorJitter")
plt.title("Accuracy by TSI decile under style-only perturbation")
plt.xlabel("TSI decile (0 low → 9 high)")
plt.ylabel("Accuracy")
plt.ylim(0, 1.0)
plt.grid(True); plt.legend(); plt.show()
plt.figure()
plt.plot(range(10), drop_by, marker="o")
plt.title("Accuracy drop under ColorJitter vs TSI decile")
plt.xlabel("TSI decile (0 low → 9 high)")
plt.ylabel("Accuracy drop (orig - jitter)")
plt.grid(True); plt.show()
fragile = ((base_corr == 1) & (j_corr == 0)).astype(np.int32)  # correct -> wrong under jitter
if fragile.mean() > 0:
    au_frag_tsi  = roc_auc_score(fragile, tsi)
    au_frag_conf = roc_auc_score(fragile, -conf)
    print(f"\nRobustness-failure AUROC (predict correct→wrong under jitter):")
    print(f"  TSI:        {au_frag_tsi:.4f}")
    print(f"  Confidence: {au_frag_conf:.4f}")
    fprf1, tprf1, _ = roc_curve(fragile, tsi)
    fprf2, tprf2, _ = roc_curve(fragile, -conf)
    plt.figure()
    plt.plot(fprf1, tprf1, label=f"TSI (AUROC={au_frag_tsi:.3f})")
    plt.plot(fprf2, tprf2, label=f"Confidence (AUROC={au_frag_conf:.3f})")
    plt.plot([0,1],[0,1], linestyle="--")
    plt.title("Robustness-failure ROC (correct→wrong under ColorJitter)")
    plt.xlabel("False positive rate"); plt.ylabel("True positive rate")
    plt.grid(True); plt.legend(); plt.show()
else:
    print("\nRobustness-failure label has zero positives in this subset (increase N or jitter strength).")
# ----------------------------
# Visual check: show some examples (low/mid/high TSI)
# ----------------------------
def show_examples(idxs, title):
    rows = len(idxs)
    plt.figure(figsize=(14, 3.0*rows))
    for r, idx in enumerate(idxs):
        x, y = raw_imgs[idx], raw_y[idx]
        rng = np.random.default_rng(1234 + idx)  # deterministic so screenshots are stable
        xs = shape_view_uint8(x, rng)
        xt = texture_view_uint8(x, rng, CFG["tile_sizes"])
        xj = j_imgs[idx]
        p0, pj = base_prob[idx], j_prob[idx]
        info = f"idx={idx} y={y} TSI={tsi[idx]:+.2f}\norig p={p0:.2f} corr={base_corr[idx]} | jitter p={pj:.2f} corr={j_corr[idx]}"
        for c, (im, name) in enumerate([(x,"x"), (xs,"S(x)"), (xt,"T(x)"), (xj,"jitter(x)")]):
            ax = plt.subplot(rows, 4, r*4 + c + 1)
            ax.imshow(im); ax.axis("off")
            ax.set_title(name if c else (name + "\n" + info), fontsize=9)
    plt.suptitle(title, y=1.02); plt.tight_layout(); plt.show()
ord_tsi = np.argsort(tsi)
k = min(CFG["viz_k"], N//5)
low  = ord_tsi[:k]
high = ord_tsi[-k:][::-1]
mid  = ord_tsi[N//2 - k//2 : N//2 - k//2 + k]
show_examples(low,  "Lowest TSI (more texture-favored by this index)")
show_examples(mid,  "Middle TSI (often where base accuracy can dip)")
show_examples(high, "Highest TSI (often where style fragility can be high)")
print("\nDone.")
