In [None]:
# %% [setup]
!pip -q install torch torchvision torchaudio transformers datasets sentence-transformers opacus scikit-learn matplotlib numpy pandas umap-learn shap

import os, math, random, json, numpy as np, pandas as pd, torch, torch.nn as nn, torch.nn.functional as F
import matplotlib.pyplot as plt
from datasets import load_dataset
from sentence_transformers import SentenceTransformer
from sklearn.cluster import KMeans
from sklearn.metrics import pairwise_distances, roc_auc_score, precision_recall_curve, average_precision_score, confusion_matrix
from sklearn.calibration import calibration_curve
from sklearn.model_selection import train_test_split
from umap import UMAP
from opacus import PrivacyEngine
import shap

plt.rcParams["figure.figsize"] = (6,4)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
rng = np.random.default_rng(42); torch.manual_seed(42); np.random.seed(42); random.seed(42)

# %% [helpers]
def set_seed(s=42):
    rng = np.random.default_rng(s); torch.manual_seed(s); np.random.seed(s); random.seed(s)
set_seed(42)

def text_from_dataset(ds, prefer=("text","title","headline","description","content")):
    cols = set(ds.column_names)
    for c in prefer:
        if c in cols:
            vals = ds[c]
            return [v for v in vals if isinstance(v,str) and v.strip()]
    sample = ds[0]
    cand = [k for k,v in sample.items() if isinstance(v,str)]
    if cand:
        vals = ds[cand[0]]
        return [v for v in vals if isinstance(v,str) and v.strip()]
    raise ValueError(f"No text-like columns. Columns: {ds.column_names}")

def plot_tradeoff(x, y, xlab, ylab, title):
    plt.figure(); plt.plot(x, y, marker='o'); plt.xlabel(xlab); plt.ylabel(ylab); plt.title(title); plt.show()

# %% [module A: robust news load + embeddings + echo auditor / reranker + visuals]
try:
    news = load_dataset("mteb/mind_small")["train"]
    titles = text_from_dataset(news)
    SRC = "mteb/mind_small"
except Exception as e:
    print("mteb/mind_small unavailable → using AG News:", e)
    news = load_dataset("wangrongsheng/ag_news")["train"]
    titles = text_from_dataset(news, prefer=("text","title","description","content","headline"))
    SRC = "wangrongsheng/ag_news"
print("News source:", SRC, "| items:", len(titles))

N = min(len(titles), 20000)
titles = titles[:N]
embed_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2", device=device)
X = embed_model.encode(titles, batch_size=256, convert_to_numpy=True, show_progress_bar=True, normalize_embeddings=True)

K = 12
km = KMeans(n_clusters=K, n_init=10, random_state=42).fit(X)
clusters = km.labels_
global_dist = np.bincount(clusters, minlength=K)/len(clusters)

def simulate_users(U=300, hist_len=50, alpha=0.7):
    prefs = rng.dirichlet(alpha*np.ones(K), size=U)
    histories = []
    for u in range(U):
        probs = prefs[u]
        p_idx = np.maximum(1e-12, probs[clusters]); p_idx /= p_idx.sum()
        idxs = rng.choice(np.arange(N), size=hist_len, replace=False, p=p_idx)
        histories.append(idxs)
    return prefs, histories

prefs, histories = simulate_users(U=300, hist_len=50, alpha=0.7)

def user_vec(hist_idx): return np.mean(X[hist_idx], axis=0)

def candidates(u_hist, C=120):
    hot = rng.choice(u_hist, size=min(len(u_hist), 20), replace=False)
    hot_cl = clusters[hot]
    mask = np.isin(clusters, rng.choice(hot_cl, size=min(3,len(np.unique(hot_cl))), replace=False))
    pool = np.where(mask)[0]
    rest = np.setdiff1d(np.arange(N), pool)
    c = np.unique(np.concatenate([
        rng.choice(pool, size=min(len(pool), int(0.6*C)), replace=False) if len(pool)>0 else np.array([],dtype=int),
        rng.choice(rest, size=min(len(rest), int(0.4*C)), replace=False)]))
    return c[:C]

def relevance(u_v, idx): return (X[idx] @ u_v)

def xquad(cand_idx, rel, target_dist, lam=0.75, topk=50):
    sel, cov = [], np.zeros(K)
    R = set(cand_idx.tolist())
    while R and len(sel)<topk:
        best_i, best_s = None, -1e18
        for i in R:
            k = clusters[i]
            cover = 1.0 - (cov[k] / max(1, target_dist[k]))
            s = lam*rel[i] + (1-lam)*cover
            if s>best_s: best_s, best_i = s, i
        sel.append(best_i); cov[clusters[best_i]] += 1; R.remove(best_i)
    return np.array(sel)

def list_metrics(lst, u_v):
    emb = X[lst]; d = pairwise_distances(emb, metric="cosine")
    ild = float(np.mean(d[np.triu_indices_from(d,1)]))
    exp = np.bincount(clusters[lst], minlength=K); p = exp/np.sum(exp)
    kl_g = float(np.sum(np.where(p>0, p*np.log(p/np.maximum(1e-12,global_dist)), 0.0)))
    s = (X[lst] @ u_v); pr = 1/(1+np.exp(-4*s))
    w = 1/np.log2(np.arange(1,len(lst)+1)+1)
    clicks = float(np.sum(pr*w))
    return ild, kl_g, clicks, exp

def eval_user(u, lam=0.7):
    h = histories[u]; u_v = user_vec(h)
    cand_idx = candidates(h)
    rel = np.zeros(N); rel[cand_idx] = relevance(u_v, cand_idx)
    base = cand_idx[np.argsort(rel[cand_idx])[::-1]][:50]
    hist_dist = np.bincount(clusters[h], minlength=K); hist_dist = hist_dist/np.sum(hist_dist)
    target = 0.5*hist_dist + 0.5*global_dist
    rer = xquad(cand_idx, rel, target, lam=lam, topk=50)
    m_b = list_metrics(base, u_v); m_r = list_metrics(rer, u_v)
    return m_b, m_r, base, rer

U = len(histories)
results = [eval_user(u, lam=0.7) for u in range(U)]
ILD_b = np.array([r[0][0] for r in results]); KL_b = np.array([r[0][1] for r in results]); C_b = np.array([r[0][2] for r in results])
ILD_r = np.array([r[1][0] for r in results]); KL_r = np.array([r[1][1] for r in results]); C_r = np.array([r[1][2] for r in results])

print("Echo Auditor")
print("Baseline  : ILD=%.3f KL=%.3f Clicks=%.3f"%(ILD_b.mean(), KL_b.mean(), C_b.mean()))
print("Reranked  : ILD=%.3f KL=%.3f Clicks=%.3f"%(ILD_r.mean(), KL_r.mean(), C_r.mean()))
print("ΔILD=%.3f  ΔKL=%.3f  ΔClicks=%.3f"%( (ILD_r-ILD_b).mean(), (KL_r-KL_b).mean(), (C_r-C_b).mean()))

plt.figure(); plt.scatter(KL_b, KL_r, s=6); mx = float(max(KL_b.max(), KL_r.max())); plt.plot([0,mx],[0,mx])
plt.xlabel("Baseline KL"); plt.ylabel("Reranked KL"); plt.title("KL to Global Dist per User"); plt.show()

umap = UMAP(n_neighbors=15, min_dist=0.1, random_state=42)
emb2d = umap.fit_transform(X)
plt.figure(); plt.scatter(emb2d[:,0], emb2d[:,1], c=clusters, s=3)
plt.title("UMAP of News Embeddings (clusters)"); plt.show()

sample_u = 0
_, _, base_s, rer_s = results[sample_u]
exp_b = np.bincount(clusters[base_s], minlength=K); exp_r = np.bincount(clusters[rer_s], minlength=K)
plt.figure(); plt.plot(exp_b/exp_b.sum(), marker='o', label="Baseline"); plt.plot(exp_r/exp_r.sum(), marker='o', label="Reranked"); plt.legend(); plt.title("Cluster Exposure (sample user)"); plt.xlabel("Cluster"); plt.ylabel("Share"); plt.show()

lams = [0.3,0.5,0.7,0.9]
trade_ild, trade_clicks = [], []
for L in lams:
    vals = [eval_user(u, lam=L) for u in range(U)]
    ild = np.mean([v[1][0] for v in vals]); clicks = np.mean([v[1][2] for v in vals])
    trade_ild.append(ild); trade_clicks.append(clicks)
plot_tradeoff(lams, trade_ild, "λ", "ILD", "Diversity vs λ")
plot_tradeoff(lams, trade_clicks, "λ", "Clicks (proxy)", "Relevance vs λ")

# %% [module B: DP phishing detector + conformal abstention + explainability]
sms = load_dataset("ucirvine/sms_spam")
df = pd.DataFrame({"text": sms["train"]["sms"], "label": [1 if t=="spam" else 0 for t in sms["train"]["label"]]})
tr, te = train_test_split(df, test_size=0.2, random_state=42, stratify=df.label)
tr, ca = train_test_split(tr, test_size=0.1, random_state=42, stratify=tr.label)

def embed_texts(txts, bs=256):
    return embed_model.encode(txts, batch_size=bs, convert_to_numpy=True, normalize_embeddings=True, show_progress_bar=False)

X_tr, y_tr = embed_texts(tr.text.tolist()), tr.label.values
X_ca, y_ca = embed_texts(ca.text.tolist()), ca.label.values
X_te, y_te = embed_texts(te.text.tolist()), te.label.values

class MLP(nn.Module):
    def __init__(self, d=384, h=256): super().__init__(); self.f=nn.Sequential(nn.Linear(d,h),nn.ReLU(),nn.Linear(h,2))
    def forward(self,x): return self.f(x)

model = MLP().to(device)
opt = torch.optim.Adam(model.parameters(), lr=3e-4)
privacy_engine = PrivacyEngine()
train_loader = torch.utils.data.DataLoader(list(zip(torch.tensor(X_tr,dtype=torch.float32), torch.tensor(y_tr))), batch_size=256, shuffle=True, drop_last=True)
model, opt, train_loader = privacy_engine.make_private_with_epsilon(
    module=model, optimizer=opt, data_loader=train_loader,
    target_epsilon=8.0, target_delta=1e-5, epochs=6, max_grad_norm=1.0,
)
for _ in range(6):
    model.train()
    for xb,yb in train_loader:
        xb, yb = xb.to(device), yb.to(device)
        opt.zero_grad(); loss = F.cross_entropy(model(xb), yb); loss.backward(); opt.step()

@torch.no_grad()
def probs(X):
    t = torch.tensor(X, dtype=torch.float32, device=device)
    return F.softmax(model(t), dim=-1).cpu().numpy()

p_te = probs(X_te); y_pred = p_te.argmax(1)
auc = roc_auc_score(y_te, p_te[:,1]); ap = average_precision_score(y_te, p_te[:,1])
print("DP Phishing: AUC=%.3f  AP=%.3f"%(auc, ap))
cm = confusion_matrix(y_te, (p_te[:,1]>=0.5).astype(int))
print("Confusion matrix @0.5:\n", cm)

prec, rec, thr = precision_recall_curve(y_te, p_te[:,1])
plt.figure(); plt.plot(rec, prec); plt.xlabel("Recall"); plt.ylabel("Precision"); plt.title("Precision-Recall"); plt.show()

prob_true, prob_pred = calibration_curve(y_te, p_te[:,1], n_bins=10)
plt.figure(); plt.plot(prob_pred, prob_true, marker='o'); plt.plot([0,1],[0,1],'--'); plt.xlabel("Predicted"); plt.ylabel("True"); plt.title("Reliability Diagram"); plt.show()

def conformal_q(Xc, yc, alpha=0.05):
    pc = probs(Xc); sc = 1 - pc[np.arange(len(yc)), yc]
    return float(np.quantile(sc, 1-alpha, method="higher"))
alphas = [0.20,0.10,0.05,0.02]
abst, sel_err, covg = [], [], []
for a in alphas:
    q = conformal_q(X_ca, y_ca, alpha=a)
    sets = []
    for p in p_te:
        s = 1 - p.max()
        sets.append([int(np.argmax(p))] if s<=q else [0,1])
    covg.append(np.mean([y_te[i] in sets[i] for i in range(len(y_te))]))
    idx = [i for i,s in enumerate(sets) if len(s)==1]
    sel_err.append(1 - np.mean([int(y_pred[i]==y_te[i]) for i in idx]) if idx else np.nan)
    abst.append(np.mean([len(s)==2 for s in sets]))
plt.figure(); plt.plot(alphas, abst, marker='o'); plt.gca().invert_xaxis(); plt.xlabel("α"); plt.ylabel("Abstention rate"); plt.title("Abstention vs α"); plt.show()
plt.figure(); plt.plot(alphas, sel_err, marker='o'); plt.gca().invert_xaxis(); plt.xlabel("α"); plt.ylabel("Error | non-abstain"); plt.title("Selective Risk vs α"); plt.show()

# SHAP (Permutation) — fix: set max_evals >= 2*D+1 and add batch_size
subset = min(100, len(X_te))
D = X_te.shape[1]
explainer = shap.Explainer(
    lambda x: probs(x)[:,1],
    X_te[:subset],
    algorithm="permutation",
    max_evals=int(2*D + 1),
    batch_size=50
)
shap_values = explainer(X_te[:subset])
shap.summary_plot(shap_values, X_te[:subset], feature_names=[f"f{i}" for i in range(D)], show=False)
plt.title("SHAP Summary (embedding features)"); plt.show()

# %% [module C: local DP telemetry for exposure audits + visuals]
def randomized_response(cat, K, eps=2.0):
    p = math.exp(eps)/(math.exp(eps)+K-1)
    return cat if rng.random()<p else rng.choice([i for i in range(K) if i!=cat])

user_exposures = []
for h in histories:
    dist = np.bincount(clusters[h], minlength=K).astype(float); dist/=dist.sum()
    user_exposures.append(dist)
user_exposures = np.array(user_exposures)
true_pop = user_exposures.mean(0); true_pop/=true_pop.sum()

eps_list = [0.5,1.0,2.0,4.0]
l1_err = []
for eps in eps_list:
    noisy = [randomized_response(int(np.argmax(d)), K, eps) for d in user_exposures]
    est = np.bincount(noisy, minlength=K)/len(noisy)
    p = math.exp(eps)/(math.exp(eps)+K-1)
    est = (est - (1-p)/K)/p; est = np.clip(est,0,1); est/=est.sum()
    l1_err.append(float(np.abs(true_pop-est).sum()))
plot_tradeoff(eps_list, l1_err, "ε (privacy budget)", "L1 error", "LDP: Privacy vs Accuracy")

eps_show = 2.0
noisy = [randomized_response(int(np.argmax(d)), K, eps_show) for d in user_exposures]
est = np.bincount(noisy, minlength=K)/len(noisy)
p = math.exp(eps_show)/(math.exp(eps_show)+K-1)
est = (est - (1-p)/K)/p; est = np.clip(est,0,1); est/=est.sum()
plt.figure(); plt.plot(true_pop, marker='o', label="True"); plt.plot(est, marker='o', label=f"LDP est (ε={eps_show})"); plt.legend(); plt.xlabel("Cluster"); plt.ylabel("Population share"); plt.title("Population Exposure (Private Telemetry)"); plt.show()

# %% [summary]
print("\n=== SUMMARY ===")
print(f"News source: {SRC}, Items: {N}, Clusters: {K}, Users: {U}")
print("Echo Auditor: ILD (base→rerank) = %.3f → %.3f | KL = %.3f → %.3f | Clicks = %.3f → %.3f" %
      (ILD_b.mean(), ILD_r.mean(), KL_b.mean(), KL_r.mean(), C_b.mean(), C_r.mean()))
print("Phishing (DP-SGD): AUC=%.3f, AP=%.3f" % (auc, ap))
print("Conformal: α", alphas, " | Abstention", [round(x,3) for x in abst], " | SelErr", [None if np.isnan(x) else round(x,3) for x in sel_err], " | Coverage", [round(x,3) for x in covg])
print("LDP telemetry: ε list", eps_list, " | L1 errors", [round(x,3) for x in l1_err])
