In [None]:
# ==========================================
# LexFaith-HierBERT — One-File Reference Code
# ==========================================
import numpy as np
import pandas as pd
import math, random, warnings
warnings.filterwarnings("ignore")

# -------------------
# Config (edit here)
# -------------------
CFG = dict(
    seed=13,
    # sequence/segments
    max_seq_len=512,
    segment_len=512,
    segment_stride=64,
    max_segments=12,
    # training
    batch_size=4,
    epochs_A=3,            # keep small for demo
    epochs_B=3,
    lr_encoder=2e-5,
    lr_head=5e-5,
    dropout=0.1,
    early_stop_patience=2,
    # multi-label classes (Task B)
    num_labels=10,
    # flags
    use_transformers=True  # set False to run BiLSTM baseline only
)

# ==============
# Reproducibility
# ==============
def set_seed(seed=13):
    random.seed(seed); np.random.seed(seed)
try:
    import torch
    def _torch_seed(s):
        torch.manual_seed(s);
        if torch.cuda.is_available():
            torch.cuda.manual_seed_all(s)
    set_seed(CFG["seed"]); _torch_seed(CFG["seed"])
except Exception:
    set_seed(CFG["seed"])

# ==========================
# Tokenizer (flexible import)
# ==========================
class FallbackWS:
    def __init__(self):
        # fake "ids": small vocab
        self.cls_token_id = 101
        self.sep_token_id = 102
        self.pad_token_id = 0
        self.vocab = {}
        self._next = 200
    def __call__(self, text, add_special_tokens=False, padding=False, truncation=False,
                 max_length=None, return_tensors=None):
        toks = text.split()
        ids = []
        for t in toks:
            if t not in self.vocab:
                self.vocab[t] = self._next; self._next += 1
            ids.append(self.vocab[t])
        if truncation and max_length:
            ids = ids[:max_length]
        return {"input_ids": ids}
    def encode_plus(self, text, **kw): return self(text, **kw)

def build_tokenizer(name="bert-base-uncased", use_transformers=True):
    if use_transformers:
        try:
            from transformers import AutoTokenizer
            return AutoTokenizer.from_pretrained(name, use_fast=True)
        except Exception:
            pass
    return FallbackWS()

TOK = build_tokenizer(use_transformers=CFG["use_transformers"])

# =================
# Segment utilities
# =================
def segment_text_ids(ids, segment_len=512, stride=64, cls_id=101, sep_id=102):
    segs = []
    i = 0
    K = max(1, segment_len-2)
    while i < len(ids):
        chunk = ids[i:i+K]
        seg = [cls_id] + chunk + [sep_id]
        segs.append(seg[:segment_len])
        if i+K >= len(ids): break
        i += (segment_len - stride)
    return segs

def text_to_segments(text, tokenizer, cfg):
    out = tokenizer(text, add_special_tokens=False)
    ids = out["input_ids"]
    return segment_text_ids(ids, cfg["segment_len"], cfg["segment_stride"],
                            getattr(tokenizer, "cls_token_id", 101),
                            getattr(tokenizer, "sep_token_id", 102))

def pad_segments(segs, max_segments, pad_id, seg_len):
    segs = segs[:max_segments]
    segs = [s + [pad_id]*(seg_len-len(s)) for s in segs]
    while len(segs) < max_segments:
        segs.append([pad_id]*seg_len)
    return np.asarray(segs, dtype=np.int64)

# =========
# Metrics
# =========
from sklearn.metrics import (accuracy_score, f1_score, roc_auc_score,
                             precision_recall_fscore_support, hamming_loss)

def metrics_taskA(y_true, y_prob, thr=0.5):
    y_pred = (y_prob >= thr).astype(int)
    acc  = float(accuracy_score(y_true, y_pred))
    P,R,F1,_ = precision_recall_fscore_support(y_true, y_pred, average="binary", zero_division=0)
    try: auc = float(roc_auc_score(y_true, y_prob))
    except Exception: auc = float("nan")
    return dict(accuracy=acc, precision=float(P), recall=float(R), f1=float(F1), roc_auc=auc)

def metrics_taskB(Y_true, Y_prob, thr=0.5):
    Y_pred = (Y_prob >= thr).astype(int)
    micro_f1 = float(f1_score(Y_true, Y_pred, average="micro", zero_division=0))
    macro_f1 = float(f1_score(Y_true, Y_pred, average="macro", zero_division=0))
    h_loss   = float(hamming_loss(Y_true, Y_pred))
    return dict(micro_f1=micro_f1, macro_f1=macro_f1, hamming_loss=h_loss)

# ======================
# Baselines (brief forms)
# ======================
from sklearn.pipeline import Pipeline
from sklearn.feature_extraction.text import TfidfVectorizer
from sklearn.linear_model import LogisticRegression

def run_bow_logreg_taskA(train_df, val_df):
    pipe = Pipeline([
        ("tfidf", TfidfVectorizer(ngram_range=(1,2), min_df=3, max_features=60000)),
        ("clf", LogisticRegression(max_iter=300))
    ])
    pipe.fit(train_df["text"], train_df["label"])
    prob = pipe.predict_proba(val_df["text"])[:,1]
    return metrics_taskA(val_df["label"].values, prob)

# ===========================
# Torch models (if available)
# ===========================
HAS_TORCH = True
try:
    import torch
    import torch.nn as nn
except Exception:
    HAS_TORCH = False

class BiLSTMAttn(nn.Module):
    def __init__(self, vocab_size=40000, emb_dim=128, hidden=128, num_classes=1, pad_idx=0, dropout=0.1):
        super().__init__()
        self.emb = nn.Embedding(vocab_size, emb_dim, padding_idx=pad_idx)
        self.lstm = nn.LSTM(emb_dim, hidden, batch_first=True, bidirectional=True)
        self.attn = nn.Linear(2*hidden, 1)
        self.drop = nn.Dropout(dropout)
        self.fc   = nn.Linear(2*hidden, num_classes)
    def forward(self, x):
        E, _ = self.lstm(self.emb(x))
        a = torch.softmax(self.attn(E).squeeze(-1), dim=-1)            # [B,T]
        ctx = torch.einsum("bt,btd->bd", a, E)                         # [B,2H]
        logits = self.fc(self.drop(ctx))
        return logits, a

# Transformers models (flat) + Hierarchical
TRANS_OK = False
if CFG["use_transformers"] and HAS_TORCH:
    try:
        from transformers import AutoModel
        TRANS_OK = True
    except Exception:
        TRANS_OK = False

class LegalBERTFlat(nn.Module):
    def __init__(self, name="bert-base-uncased", num_classes=1, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(name)
        d = self.encoder.config.hidden_size
        self.drop = nn.Dropout(dropout)
        self.fc   = nn.Linear(d, num_classes)
    def forward(self, input_ids, attention_mask):
        o = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls = o.last_hidden_state[:,0,:]
        return self.fc(self.drop(cls))

class LongformerFlat(nn.Module):
    def __init__(self, name="allenai/longformer-base-4096", num_classes=1, dropout=0.1):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(name)
        d = self.encoder.config.hidden_size
        self.drop = nn.Dropout(dropout)
        self.fc   = nn.Linear(d, num_classes)
    def forward(self, input_ids, attention_mask):
        o = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        cls = o.last_hidden_state[:,0,:]
        return self.fc(self.drop(cls))

class AdditiveAttention(nn.Module):
    def __init__(self, dim, dropout=0.1):
        super().__init__()
        self.w = nn.Linear(dim, dim)
        self.v = nn.Linear(dim, 1, bias=False)
        self.drop = nn.Dropout(dropout)
    def forward(self, H):                         # [B,S,D]
        s = self.v(torch.tanh(self.w(H))).squeeze(-1)   # [B,S]
        a = torch.softmax(s, dim=-1)
        ctx = torch.einsum("bs,bsd->bd", a, H)
        return ctx, a

class RationaleHead(nn.Module):
    def __init__(self, dim):
        super().__init__()
        self.lin = nn.Linear(dim, 1)
    def forward(self, H):                         # [B*S,T,D]
        return torch.sigmoid(self.lin(H).squeeze(-1))   # [B*S,T]

def faithfulness_loss(attn_seg, rat_seg, margin=0.03):
    pos = (rat_seg > 0.5).float()
    if pos.sum() == 0:
        return attn_seg.sum()*0.0
    attn_pos = (attn_seg*pos).sum() / (pos.sum()+1e-8)
    attn_neg = (attn_seg*(1-pos)).sum() / ((1-pos).sum()+1e-8)
    return torch.relu(margin - (attn_pos - attn_neg))

class LexFaithHierBERT(nn.Module):
    """Legal Faithfulness-Aware Hierarchical BERT"""
    def __init__(self, name="bert-base-uncased", num_classes=1, dropout=0.1, margin=0.03):
        super().__init__()
        self.segment_encoder = AutoModel.from_pretrained(name)
        d = self.segment_encoder.config.hidden_size
        self.seg_attn   = AdditiveAttention(d, dropout)
        self.rationale  = RationaleHead(d)
        self.classifier = nn.Linear(d, num_classes)
        self.drop = nn.Dropout(dropout)
        self.margin = margin
    def forward(self, input_ids, attention_mask):    # [B,S,T]
        B,S,T = input_ids.shape
        x = input_ids.reshape(B*S, T)
        m = attention_mask.reshape(B*S, T)
        out = self.segment_encoder(input_ids=x, attention_mask=m)
        H = out.last_hidden_state                          # [B*S,T,D]
        rat_tok = self.rationale(H)                        # [B*S,T]
        cls = H[:,0,:].reshape(B, S, -1)                   # [B,S,D]
        seg_ctx, a_seg = self.seg_attn(self.drop(cls))     # [B,D], [B,S]
        logits = self.classifier(self.drop(seg_ctx))       # [B,C]
        rat_seg = rat_tok.reshape(B, S, T).mean(dim=-1)    # [B,S]
        f_loss = faithfulness_loss(a_seg, rat_seg, self.margin)
        return logits, a_seg, rat_tok, f_loss

# ===========================
# Collate functions (hier/flat)
# ===========================
def collate_hier_texts(batch_texts, labels=None, cfg=CFG, tokenizer=TOK):
    segs = []; masks=[]
    pad_id = getattr(tokenizer, "pad_token_id", 0)
    for t in batch_texts:
        s = text_to_segments(str(t), tokenizer, cfg)
        a = pad_segments(s, cfg["max_segments"], pad_id, cfg["segment_len"])
        segs.append(a)
        masks.append((a != pad_id).astype(int))
    X = torch.tensor(np.asarray(segs))
    M = torch.tensor(np.asarray(masks))
    Y = None
    if labels is not None:
        lab = np.asarray(labels)
        if lab.ndim == 1: Y = torch.tensor(lab).float()
        else: Y = torch.tensor(lab).float()
    return X, M, Y

def collate_flat_texts(batch_texts, labels=None, max_len=512, tokenizer=TOK):
    # transformers-like encoding if available
    try:
        enc = tokenizer(batch_texts, padding=True, truncation=True, max_length=max_len, return_tensors='pt')
        X = enc["input_ids"]; M = enc["attention_mask"]
    except Exception:
        # fallback: pad to same length
        ids = [tokenizer(t)["input_ids"] for t in batch_texts]
        L = min(max_len, max(len(x) for x in ids))
        arr = []
        for x in ids:
            xx = x[:L] + [getattr(tokenizer, "pad_token_id", 0)]*(L-len(x))
            arr.append(xx)
        X = torch.tensor(np.asarray(arr)); M = (X != getattr(tokenizer, "pad_token_id", 0)).long()
    Y = None
    if labels is not None:
        lab = np.asarray(labels)
        if lab.ndim == 1: Y = torch.tensor(lab).float()
        else: Y = torch.tensor(lab).float()
    return X, M, Y

# ==========================
# Training loops (concise)
# ==========================
def train_taskA_proposed(train_df, val_df, cfg=CFG):
    device = "cuda" if (HAS_TORCH and torch.cuda.is_available()) else "cpu"
    if not (HAS_TORCH and TRANS_OK):
        raise RuntimeError("Transformers not available; cannot run proposed model.")
    model = LexFaithHierBERT(num_classes=1, dropout=cfg["dropout"]).to(device)
    opt = torch.optim.AdamW([
        {"params": model.segment_encoder.parameters(), "lr": cfg["lr_encoder"]},
        {"params": list(model.seg_attn.parameters()) + list(model.rationale.parameters()) + list(model.classifier.parameters()), "lr": cfg["lr_head"]}
    ], weight_decay=0.01)
    crit = nn.BCEWithLogitsLoss()
    best, wait = -1.0, 0
    for ep in range(cfg["epochs_A"]):
        model.train()
        # mini-batches
        idx = np.arange(len(train_df)); np.random.shuffle(idx)
        for i in range(0, len(idx), cfg["batch_size"]):
            j = idx[i:i+cfg["batch_size"]]
            bt = train_df.iloc[j]
            X, M, Y = collate_hier_texts(bt["text"].tolist(), bt["label"].tolist(), cfg, TOK)
            X,M,Y = X.to(device), M.to(device), Y.to(device)
            logits, a_seg, rat_tok, f_loss = model(X,M)
            loss = crit(logits.squeeze(-1), Y) + 0.1*f_loss
            opt.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
        # validate
        model.eval(); all_prob=[]; all_y=[]
        with torch.no_grad():
            for i in range(0, len(val_df), cfg["batch_size"]):
                bt = val_df.iloc[i:i+cfg["batch_size"]]
                X,M,_ = collate_hier_texts(bt["text"].tolist(), None, cfg, TOK)
                X,M = X.to(device), M.to(device)
                logits,_,_,_ = model(X,M)
                p = torch.sigmoid(logits).squeeze(-1).cpu().numpy().tolist()
                all_prob += p; all_y += bt["label"].tolist()
        res = metrics_taskA(np.array(all_y), np.array(all_prob))
        print(f"[TaskA][Ep {ep+1}] ACC={res['accuracy']:.3f} F1={res['f1']:.3f} AUC={res['roc_auc']:.3f}")
        if res["f1"] > best:
            best = res["f1"]; wait = 0
            best_metrics = res
        else:
            wait += 1
            if wait >= cfg["early_stop_patience"]: break
    return best_metrics

def train_taskB_proposed(train_df, val_df, cfg=CFG):
    device = "cuda" if (HAS_TORCH and torch.cuda.is_available()) else "cpu"
    if not (HAS_TORCH and TRANS_OK):
        raise RuntimeError("Transformers not available; cannot run proposed model.")
    model = LexFaithHierBERT(num_classes=cfg["num_labels"], dropout=cfg["dropout"]).to(device)
    opt = torch.optim.AdamW([
        {"params": model.segment_encoder.parameters(), "lr": cfg["lr_encoder"]},
        {"params": list(model.seg_attn.parameters()) + list(model.rationale.parameters()) + list(model.classifier.parameters()), "lr": cfg["lr_head"]}
    ], weight_decay=0.01)
    crit = nn.BCEWithLogitsLoss()
    best, wait = -1.0, 0
    label_cols = [f"y{i}" for i in range(cfg["num_labels"])]
    for ep in range(cfg["epochs_B"]):
        model.train()
        idx = np.arange(len(train_df)); np.random.shuffle(idx)
        for i in range(0, len(idx), cfg["batch_size"]):
            j = idx[i:i+cfg["batch_size"]]
            bt = train_df.iloc[j]
            Y = bt[label_cols].values
            X,M,Y = collate_hier_texts(bt["text"].tolist(), Y, cfg, TOK)
            X,M,Y = X.to(device), M.to(device), Y.to(device)
            logits, a_seg, rat_tok, f_loss = model(X,M)
            loss = crit(logits, Y) + 0.1*f_loss
            opt.zero_grad(); loss.backward(); torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0); opt.step()
        # validate
        model.eval(); probs=[]; Yt=[]
        with torch.no_grad():
            for i in range(0, len(val_df), cfg["batch_size"]):
                bt = val_df.iloc[i:i+cfg["batch_size"]]
                X,M,_ = collate_hier_texts(bt["text"].tolist(), None, cfg, TOK)
                X,M = X.to(device), M.to(device)
                logits,_,_,_ = model(X,M)
                p = torch.sigmoid(logits).cpu().numpy()
                probs.append(p); Yt.append(bt[label_cols].values)
        P = np.vstack(probs); YT = np.vstack(Yt)
        res = metrics_taskB(YT, P, thr=0.5)
        print(f"[TaskB][Ep {ep+1}] microF1={res['micro_f1']:.3f} macroF1={res['macro_f1']:.3f} H={res['hamming_loss']:.3f}")
        if res["micro_f1"] > best:
            best = res["micro_f1"]; wait = 0; best_metrics = res
        else:
            wait += 1
            if wait >= cfg["early_stop_patience"]: break
    return best_metrics

# =====================
# Simple XAI interfaces
# =====================
def attention_tokens_preview(text):
    # demo: uniform random weights (replace with model attentions)
    toks = text.split()
    w = np.random.rand(len(toks))
    return list(zip(toks, w/w.max()))

def lime_stub(text):
    # put your LIME pipeline here (predict_fn wrapper around model)
    return {"explanation": "LIME scores (stub for demo)"}

def shap_stub(texts):
    # put your SHAP pipeline here (Explainer over text)
    return {"explanation": "SHAP values (stub for demo)"}

# ========================
# Statistical test helpers
# ========================
from scipy.stats import ttest_rel, f_oneway, chi2, norm

def paired_t(values_a, values_b):
    t, p = ttest_rel(values_a, values_b, alternative="greater")
    return float(t), float(p)

def anova_all(*args):
    f, p = f_oneway(*args)
    return float(f), float(p)

def z_test_acc(acc_a, acc_b, n_a, n_b):
    p = (acc_a*n_a + acc_b*n_b) / (n_a+n_b)
    se = math.sqrt(p*(1-p)*(1/n_a + 1/n_b))
    z  = (acc_a-acc_b)/(se+1e-12)
    pval = 2*(1-norm.cdf(abs(z)))
    return float(z), float(pval)

def chi_square_from_counts(tp, fp, fn, tn):
    obs = np.array([tp, fp, fn, tn], dtype=float)
    exp = np.ones_like(obs)*obs.mean()
    stat = ((obs-exp)**2/(exp+1e-12)).sum()
    p = 1-chi2.cdf(stat, df=len(obs)-1)
    return float(stat), float(p)

# =========================
# Demo runner (edit or drop)
# =========================
if __name__ == "__main__":
    # -----------------------------
    # FAKE DATA (replace with real)
    # -----------------------------
    # Task A
    n_tr, n_va = 80, 20
    dfA_tr = pd.DataFrame({
        "text": ["The applicant alleges unlawful detention and lack of counsel." if i%2==0 else
                 "The state argues due process was followed with timely review." for i in range(n_tr)],
        "label": [1 if i%2==0 else 0 for i in range(n_tr)]
    })
    dfA_va = pd.DataFrame({
        "text": ["Prolonged detention without cause; judicial oversight absent." if i%2==0 else
                 "Evidence indicates no breach and adequate procedural safeguards." for i in range(n_va)],
        "label": [1 if i%2==0 else 0 for i in range(n_va)]
    })

    # Task B (10 labels y0..y9)
    n_trB, n_vaB = 60, 20
    def rand_multi_lab(n, L=10):
        Y = np.zeros((n,L), dtype=int)
        for i in range(n):
            k = np.random.randint(1,4)             # 1-3 labels active
            idx = np.random.choice(L, k, replace=False)
            Y[i, idx] = 1
        return Y
    Ytr = rand_multi_lab(n_trB, CFG["num_labels"])
    Yva = rand_multi_lab(n_vaB, CFG["num_labels"])
    dfB_tr = pd.DataFrame({"text":[f"Case {i}: facts and legal analysis on multiple articles." for i in range(n_trB)]})
    for j in range(CFG["num_labels"]): dfB_tr[f"y{j}"]=Ytr[:,j]
    dfB_va = pd.DataFrame({"text":[f"Validation case {i}: complex facts for articles." for i in range(n_vaB)]})
    for j in range(CFG["num_labels"]): dfB_va[f"y{j}"]=Yva[:,j]

    # -----------------------------
    # Baseline example (Task A)
    # -----------------------------
    bowA = run_bow_logreg_taskA(dfA_tr, dfA_va)
    print("[BoW+LogReg][TaskA]", bowA)

    # -----------------------------
    # Proposed (Task A & Task B)
    # -----------------------------
    if HAS_TORCH and TRANS_OK:
        bestA = train_taskA_proposed(dfA_tr, dfA_va, CFG)
        print("[Proposed LexFaith-HierBERT][TaskA]", bestA)

        bestB = train_taskB_proposed(dfB_tr, dfB_va, CFG)
        print("[Proposed LexFaith-HierBERT][TaskB]", bestB)
    else:
        print("Transformers not available; run with BiLSTM or enable transformers to train proposed model.")

    # -----------------------------
    # XAI quick preview (token-level)
    # -----------------------------
    xai_demo = attention_tokens_preview("The applicant was unlawfully detained and denied access to counsel.")
    print("Attention preview:", xai_demo[:8], "...")

    # -----------------------------
    # Statistical tests demo
    # (use arrays from repeated runs/folds; here random placeholders)
    # -----------------------------
    rng = np.random.default_rng(7)
    prop_scores = rng.normal(0.88, 0.01, 5)  # pretend 5-fold accuracy for Task A
    bow_scores  = rng.normal(0.75, 0.02, 5)

    t, p = paired_t(prop_scores, bow_scores)
    f, pa = anova_all(prop_scores, bow_scores)
    z, pz = z_test_acc(prop_scores.mean(), bow_scores.mean(), 2200, 2200)  # example test sizes
    chi, pc = chi_square_from_counts(1500, 200, 250, 250)                   # dummy counts

    print(f"[Stats][TaskA] t={t:.2f}, p={p:.4f}; ANOVA p={pa:.4f}; z={z:.2f}, p={pz:.4f}; chi2 p={pc:.4f}")
