# Setup

In [None]:
!pip install -U "fsspec<2024"

In [None]:
!pip uninstall -y dataset banal alembic
!pip install -U "SQLAlchemy>=2.0"
!pip install -U datasets "fsspec<2024"


**Models and Dataset**

In [None]:
from datasets import load_dataset
from transformers import (
    AutoTokenizer,
    AutoModelForSequenceClassification,
    DataCollatorWithPadding,
)
from torch.utils.data import DataLoader
import torch
from datasets import load_dataset, concatenate_datasets

MODEL_ID = "philschmid/DistilBERT-tweet-eval-emotion"
DEVICE = "cuda"
BATCH_SIZE = 64
MAX_LEN = 512

# dataset
ds_train = load_dataset("tweet_eval", "emotion", split="train")

ds_train = load_dataset("tweet_eval", "emotion", split="train")
ds_val = load_dataset("tweet_eval", "emotion", split="validation")

ds_train = concatenate_datasets([ds_train, ds_val])


# tokenize
tok = AutoTokenizer.from_pretrained(MODEL_ID)

def collate_fn(batch):
    enc = tok(
        [x["text"] for x in batch],
        padding="max_length",
        truncation=True,
        max_length=MAX_LEN,
        return_tensors="pt",
    )
    enc["labels"] = torch.tensor([x["label"] for x in batch])
    return enc

# dataloder
train_loader = DataLoader(
    ds_train,
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

test_loader = DataLoader(
    load_dataset("tweet_eval", "emotion", split="test"),
    batch_size=BATCH_SIZE,
    shuffle=False,
    collate_fn=collate_fn,
)

# model
model = AutoModelForSequenceClassification.from_pretrained(MODEL_ID).to(DEVICE).eval()


In [None]:
!pip install alibi

In [None]:
!pip install "spacy<3.6"

In [None]:
!pip install -U "spacy>=3.8,<3.9" "alibi==0.9.6"

# Step 1 (Purity)

**Activation Visualization**

In [None]:
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader

@torch.no_grad()
def plot_neuron_hist(
    model,
    dataloader: DataLoader,
    hook_module: torch.nn.Module,
    neuron_idx: int,
    class_id: int,
    threshold: float,
    *,
    apply_gelu: bool = False,
    device: str = "cuda",
    bins: int = 50,
    layer: int = 5,
):
    activ = []

    def save_hook(_, __, out):
        out = F.gelu(out) if apply_gelu else out
        activ.append(out[:, 0, neuron_idx].cpu())

    h = hook_module.register_forward_hook(save_hook)
    model.eval()

    y_all = []
    for batch in dataloader:
        # for every batch get labels and run inference
        labels = batch.pop("labels")
        y_all.append(labels.cpu())
        batch = {k: v.to(device) for k, v in batch.items()}
        model(**batch)

    h.remove()
    y_all = torch.cat(y_all)
    activ = torch.cat(activ)

    pos = activ[y_all == class_id].numpy()
    neg = activ[y_all != class_id].numpy()

    plt.figure(figsize=(6, 4))
    plt.axvline(threshold, color="k", linestyle="--", linewidth=1.5,
              label=f"Threshold = {threshold:.3f}")
    plt.hist(pos, bins=bins, alpha=0.9, label=f"Class {class_id}", density=True)
    plt.hist(neg, bins=bins, alpha=0.9, label="Others", density=True)
    plt.title(f"Activation values for Neuron({neuron_idx}, {layer})")
    plt.xlabel("Activation Value")
    plt.ylabel("Frequency")
    plt.legend()
    plt.tight_layout()
    plt.show()

**Purity (Step 1)**

In [None]:
import torch
from torch.utils.data import DataLoader
from typing import Dict, List, Tuple, Callable


def extract_predicates_rules_cls_purity_any(
    model,
    dataloader: DataLoader,
    target_layer: torch.nn.Module,
    apply_act: bool = False,
    k: int = 15,
    search_steps: int = 30,
    device: str = "cuda",
) -> Tuple[Dict[int, List[Tuple[int, float]]], Tuple[int, int, float]]:
    """
    Picks k neurons with highest purity per class and returns
    (neuron‑index, threshold, support) lists plus the single best neuron.
    """
    z_list, y_list = [], []

    # collect activations
    def hook_fn(_, __, out):
        z_list.append(out.detach().cpu())

    h = target_layer.register_forward_hook(hook_fn)

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            labels = batch.pop("labels")
            y_list.append(labels.cpu())
            batch = {k: v.to(device) for k, v in batch.items()}
            model(**batch)

    h.remove()

    z_all = torch.cat(z_list)
    if apply_act:
        gelu: Callable = getattr(model.config, "activation_fn", torch.nn.functional.gelu)
        z_all = gelu(z_all)

    # CLS token is first position
    z_cls = z_all[:, 0, :] if z_all.dim() == 3 else z_all
    y_all = torch.cat(y_list)

    num_classes, hidden_size = model.config.num_labels, z_cls.shape[1]
    purity   = torch.empty(num_classes, hidden_size)
    thr_mat  = torch.empty(num_classes, hidden_size)
    supp_mat = torch.empty(num_classes, hidden_size, dtype=torch.long)

    # per class counts (avoid the recomputation)
    class_counts = torch.bincount(y_all, minlength=num_classes)

    for j in range(hidden_size):
        a = z_cls[:, j]
        idx = torch.argsort(a, descending=True)
        a_sorted = a[idx]
        y_sorted = y_all[idx]

        # pre compute cumulative sums for every class
        one_hot = torch.nn.functional.one_hot(y_sorted, num_classes=num_classes).cumsum(0)
        total_seen = torch.arange(1, len(a_sorted) + 1, dtype=torch.long)

        for c in range(num_classes):
            tp = one_hot[:, c]  # positives >= threshold
            fp = total_seen - tp  # negatives >= threshold
            tn = (z_cls.size(0) - class_counts[c]) - fp

            tp_rate = tp.float() / class_counts[c].clamp_min(1)
            tn_rate = tn.float() / (class_counts.sum() - class_counts[c]).clamp_min(1)
            p_scores = tp_rate + tn_rate # get purity score
            best_idx = torch.argmax(p_scores)
            purity[c, j]  = p_scores[best_idx]
            thr_mat[c, j] = a_sorted[best_idx].item()
            supp_mat[c, j] = total_seen[best_idx].item()

    rules: Dict[int, List[Tuple[int, float]]] = {}
    best_c, best_j = divmod(purity.argmax().item(), hidden_size)
    best_val = purity[best_c, best_j].item()
    best_neuron = (
        best_c,
        best_j,
        thr_mat[best_c, best_j].item(),
        best_val,
        supp_mat[best_c, best_j].item(),
    )
    # sort and get best
    for c in range(num_classes):
        topk = torch.topk(purity[c], k=min(k, hidden_size))
        indices = topk.indices.tolist()
        rules[c] = [(j, thr_mat[c, j].item(), supp_mat[c, j].item()) for j in indices]

    return rules, best_neuron


# Step 2 (Distillation)

**Distillation from tree**

In [None]:
# helpers
import itertools
def _dedup_preserve_order(seq):
    seen = set()
    out  = []
    for x in seq:
        if x not in seen:
            seen.add(x)
            out.append(x)
    return out


def _tree_to_dnf_multiclass(tree: DecisionTreeClassifier,
                            preds: List[Predicate]) -> DNFRuleSet:
    """
    Convert a multiclass decision‑tree into per‑class DNF rule‑sets.
    Each clause is a list[Literal] (neuron‑idx, threshold, sign).
    """
    T = tree.tree_
    dnfs: DNFRuleSet = {c: [] for c in range(tree.n_classes_)}

    def walk(node: int, path: List[Literal]) -> None:
        if T.feature[node] != _tree.TREE_UNDEFINED:
            f_idx   = T.feature[node]
            neuron, thr, *_ = preds[f_idx]
            walk(T.children_left[node],  path + [(neuron, thr, 0)]) # <= thr
            walk(T.children_right[node], path + [(neuron, thr, 1)]) # > thr
        else: # leaf => add clause
            cls = np.argmax(T.value[node])
            dnfs[cls].append(path)

    walk(0, [])
    return dnfs


def _eval_predicates_matrix(X: np.ndarray, preds: List[Predicate]) -> np.ndarray:
    if not preds:
        return np.empty((X.shape[0], 0), dtype=bool)
    idxs = np.fromiter((p[0] for p in preds), int)
    thrs = np.fromiter((p[1] for p in preds), float)
    return X[:, idxs] > thrs


# main functions

def distil_to_dnfs_from_loader(
    model: torch.nn.Module,
    dataloader,
    predicates: Dict[int, List[Predicate]],
    target_layer: torch.nn.Module,
    *,
    apply_act: bool = False,
    device: str = "cuda",
    max_depth: int | None = None,
    min_samples_leaf: int = 1,
) -> DNFRuleSet:
    """
    Distil a single multiclass decision tree into per‑class DNF clauses.
    """
    # collect hidden layer activations + labels
    buf, acts_list, labels_list = [], [], []

    def _hook(_, __, out): buf.append(out.detach())

    h = target_layer.register_forward_hook(_hook)
    act_fn = torch.nn.functional.gelu if apply_act else (lambda x: x)

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            labels = batch.pop("labels").to("cpu")
            batch  = {k: v.to(device) for k, v in batch.items()}

            buf.clear()
            model(**batch)
            acts = act_fn(buf.pop().to("cpu"))

            if acts.dim() == 3: # take CLS token
                acts = acts[:, 0, :]

            acts_list.append(acts.numpy())
            labels_list.append(labels.numpy())

    h.remove()

    X = np.concatenate(acts_list, axis=0)
    y = np.concatenate(labels_list, axis=0)

    # build global predicate set
    all_preds = _dedup_preserve_order(
        list(itertools.chain.from_iterable(predicates.values()))
    )

    X_bin = _eval_predicates_matrix(X, all_preds).astype(int)

    # fit multiclass decision tree
    tree = DecisionTreeClassifier(
        max_depth=max_depth,
        min_samples_leaf=min_samples_leaf,
        random_state=0,
    ).fit(X_bin, y)

    # convert tree => per‑class DNF
    dnfs = _tree_to_dnf_multiclass(tree, all_preds)
    return dnfs

**Simple positive destillation**

In [None]:
from collections import defaultdict, Counter
from typing import Dict, List, Tuple, FrozenSet, Callable, Set
import torch

def build_pruned_dnf_rules_cls(
    model: torch.nn.Module,
    dataloader,
    rules: Dict[int, List[Tuple[int, float]]],
    target_layer: torch.nn.Module,
    *,
    apply_act: bool = False,
    min_predicates: int = 3,          # paper’s threshold
    min_support: int = 0,
    device: str = "cuda",
) -> Dict[int, List[List[Tuple[int, float]]]]:
    """
    Returns a pruned DNFs rule set using simple positive pruning:
        {class_id: [[(idx,thr), …], …]}
    """
    # capture layer‑(l‑1) activations
    buf: List[torch.Tensor] = []
    def _hook(_, __, out):
        buf.append(out.detach().cpu())

    h = target_layer.register_forward_hook(_hook)
    act_fn: Callable = torch.nn.functional.gelu if apply_act else (lambda x: x)

    pattern_counts: Dict[int, Counter] = defaultdict(Counter)

    model.eval()
    with torch.no_grad():
        for batch in dataloader:
            labels = batch.pop("labels").cpu()
            batch  = {k: v.to(device) for k, v in batch.items()}

            buf.clear()
            model(**batch)
            acts = act_fn(buf.pop())
            if acts.dim() == 3: # CLS position
                acts = acts[:, 0, :]

            for i, c in enumerate(labels.tolist()):
                active = tuple(
                    (idx, thr)
                    for idx, thr, *_ in rules.get(c, [])
                    if acts[i, idx] >= thr
                )
                if len(active) >= min_predicates:
                    pattern_counts[c][frozenset(active)] += 1

    h.remove()
    # prune by support
    kept: Dict[int, Set[FrozenSet[Tuple[int, float]]]] = defaultdict(set)
    for c, counter in pattern_counts.items():
        for clause, cnt in counter.items():
            if cnt >= min_support:
                kept[c].add(clause)

    # drop supersets (prune)
    pruned_rules: Dict[int, List[List[Tuple[int, float]]]] = {}
    for c, clauses in kept.items():
        minimal = []
        for cl in sorted(clauses, key=len):          # shortest first
            if not any(cl > m for m in minimal):
                minimal.append(cl)
        pruned_rules[c] = [sorted(list(cl)) for cl in minimal]

    return pruned_rules


# Grounding (Step 3)

In [None]:
layers = {
    "layer1" : model.distilbert.transformer.layer[0].output_layer_norm,
    "layer2" : model.distilbert.transformer.layer[1].output_layer_norm,
    "layer3" : model.distilbert.transformer.layer[2].output_layer_norm,
    "layer4" : model.distilbert.transformer.layer[3].output_layer_norm,
    "layer5" : model.distilbert.transformer.layer[4].output_layer_norm,
    "layer6" : model.distilbert.transformer.layer[5].output_layer_norm,
}

In [None]:

from scipy import stats
import pandas as pd

def _compute_support_table(
    *,
    flip_counter: dict,
    total_counter: dict,
    df_counter: dict,
    rule_neurons_counter: dict,
    n_docs: int,
    min_df: int = 3,
    alpha: float = 0.05,
    min_flips: int = 3
) -> pd.DataFrame:
    """Return a DataFrame with one‑tailed significance testing.
    """

    rows = []

    global_flips = sum(
        cnt for cls_dict in flip_counter.values()
        for pred_dict in cls_dict.values()
        for cnt in pred_dict.values()
    )
    global_total = sum(
        cnt for cls_tot in total_counter.values()
        for cnt in cls_tot.values()
    )
    p0 = (global_flips + 1e-6) / (global_total + 1e-6)

    # iterate over all predicates flips
    for cls, tok_dict in flip_counter.items():
        for token, pred_dict in tok_dict.items():
            df = df_counter[token]
            if df < min_df:
                continue
            idf = math.log((n_docs + 1) / (df + 1))

            for ptype, flips in pred_dict.items():
                total = total_counter[cls][(token, ptype)]
                if total < min_flips:
                    continue

                p_hat = flips / total

                # one tailed z test for significant flips
                se = math.sqrt(max(p0 * (1 - p0), 1e-9) / total)
                if se == 0:
                    continue

                z = (p_hat - p0) / se
                if z <= 0:
                    continue

                p_val = 1 - stats.norm.cdf(z)  # one tail
                if p_val >= alpha:
                    continue

                # log‑odds in bits, IDF‑weighted
                support_score = 0.5 * idf * math.log2((p_hat + 1e-6) / (p0 + 1e-6))
                cnt = rule_neurons_counter.get((cls, token, ptype), {})
                for neuron_id, n_flips in cnt.items():
                    rows.append(
                        dict(
                            cls        = cls,
                            neuron_id  = neuron_id, # id of the neuron that flipped
                            token      = token,
                            pred_type  = ptype,
                            flips      = n_flips, # neuron‑specific count
                            total      = total,
                            rate       = p_hat,
                            idf        = idf,
                            z          = z,
                            p_val      = p_val,
                            support_score = 0.5 * idf * math.log2((p_hat + 1e-6) /
                                                                  (p0 + 1e-6)),
                        )
                    )

    df_out = pd.DataFrame(rows)
    # High score => strong enrichment and rarity.
    return df_out


In [None]:
import math, string
from collections import defaultdict, Counter
from typing import Dict, List, Tuple
import gc
import torch, torch.nn.functional as F
import pandas as pd
import spacy
import numpy as np

nlp = spacy.load("en_core_web_sm", disable=["ner"])  # load once
FLUSH_INTERVAL = 1024
@torch.no_grad()
def causal_word_lexical_batched(
    model,
    dataloader,
    predicates: Dict[int, List[Tuple[int, float, float]]],
    module: torch.nn.Module,
    tokenizer,
    use_gelu: bool,
    *,
    support_thr: float = 0.03,
    min_df: int = 3,
    device: str = "cuda",
    chunk_size: int = 32,
    pos_threshold: float = 0.2,
    max_after_tokens: int = 6, # distance cap after/before subj/verb
) -> pd.DataFrame:
    keyword_counts = defaultdict(Counter)
    # polysemantic metric
    poly_counts = defaultdict(int)
    poly_map = defaultdict(set)
    rule_neuron_counter = defaultdict(Counter)
    # mapping for neurons
    keyword_map  = defaultdict(set)
    predtype_map = defaultdict(set)
    def _flush(seqs, attns, meta):
      if not seqs:
          return

      inputs = torch.stack(seqs).to(device, non_blocking=True)
      masks  = torch.stack(attns).to(device, non_blocking=True)

      acts_holder.clear()
      _ = model(input_ids=inputs, attention_mask=masks)
      cls = acts_holder.pop().cpu()
      for j, (c_hat, word, ptype, active) in enumerate(meta):
          key = (word, ptype)
          total_counter[c_hat][key] += 1

          dropped = [i for i, thr in active if cls[j, i] < thr]
          if dropped:  # causal flip happened
              flip_counter[c_hat][word][ptype] += 1
              for i in dropped:
                  poly_counts[(i, ptype)] += 1
                  keyword_counts[i][word] += 1  # count keyword flips
                  rule_neuron_counter[(c_hat, word, ptype)][i] += 1

      del inputs, masks, cls
      torch.cuda.empty_cache()
    # top k neurons per class by purity
    n_classes = model.config.num_labels
    mask_id   = tokenizer.mask_token_id or tokenizer.unk_token_id
    n_classes = model.config.num_labels

    flip_counter  = {c: defaultdict(Counter) for c in range(n_classes)}
    total_counter = {c: defaultdict(int)      for c in range(n_classes)}
    df_counter    = Counter()

    acts_holder: List[torch.Tensor] = []

    def _hook(_, __, out):
        acts_holder.append(
            F.gelu(out, approximate="none")[:, 0, :].detach().cpu() if use_gelu
            else out[:, 0, :].detach().cpu()
        )



    h = module.register_forward_hook(_hook)

    model.eval()
    model.gradient_checkpointing_disable()
    model.config.use_cache = False

    N_docs = 0
    for batch in dataloader:
        batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

        acts_holder.clear()
        logits   = model(**batch).logits
        base_cls = acts_holder.pop()
        preds    = logits.argmax(-1)
        B, T     = batch["input_ids"].shape
        attn     = batch["attention_mask"]

        dec_texts = tokenizer.batch_decode(batch["input_ids"], skip_special_tokens=True)

        masked_inputs_cpu, masked_attn_cpu, meta = [], [], []
        metas_for_df = []

        for b in range(B):
            c_hat  = int(preds[b])
            tokseq = batch["input_ids"][b].cpu()
            actvec = base_cls[b]

            active_neurons = [
                (idx, thr) for idx, thr, *_ in predicates.get(c_hat, [])
                if actvec[idx] >= thr
            ]
            if not active_neurons:
                continue

            doc = nlp(dec_texts[b])

            word_ids = tokenizer.word_ids(batch_index=b) if hasattr(tokenizer, "word_ids") else None
            tokens   = tokenizer.convert_ids_to_tokens(tokseq.tolist(), skip_special_tokens=False)

            seen_tokens_in_doc = set()
            word_start = 1
            while word_start < T:
                # go one by one
                tok_id = int(tokseq[word_start])

                if tok_id == tokenizer.unk_token_id:
                    word_start += 1
                    continue
                if tok_id in (
                    tokenizer.pad_token_id,
                    tokenizer.sep_token_id,
                    getattr(tokenizer, "eos_token_id", -1),
                ):
                    break

                if word_ids is not None:
                    wid = word_ids[word_start]
                    if wid is None:
                        word_start += 1
                        continue
                    word_end = word_start
                    while word_end + 1 < T and word_ids[word_end + 1] == wid:
                        word_end += 1
                    word_idx = wid
                else:
                    word_end = word_start
                    while word_end + 1 < T and tokens[word_end + 1].startswith("##"):
                        word_end += 1
                    word_idx = len([t for t in tokens[1:word_start] if not t.startswith("##")])

                sub_toks = tokens[word_start : word_end + 1]
                word_str = "".join(
                    (
                        t[2:] if t.startswith("##")
                        else t[1:] if t and t[0] in {"Ġ", "▁"}
                        else t
                    )
                    for t in sub_toks
                ).lower()
                if (
                    not word_str
                    or word_str == tokenizer.unk_token
                    or all(ch in string.punctuation for ch in word_str)
                ):
                    word_start = word_end + 1
                    continue

                try:
                    tok_spacy = doc[word_idx]
                except IndexError:
                    word_start = word_end + 1
                    continue

                sent        = tok_spacy.sent
                sent_len    = len(sent)
                sent_start  = sent.start
                pos_in_sent = word_idx - sent_start
                frac        = pos_in_sent / sent_len if sent_len else 0.0

                subj_pos = min(
                    (t.i for t in sent if t.dep_ in {"nsubj", "nsubjpass"}), default=None
                )
                verb_pos = min((t.i for t in sent if t.pos_ == "VERB"), default=None)
                pred_types = set()
                pred_types.add("exists")
                if frac <= pos_threshold:
                    pred_types.add("at_start")
                if frac >= 1.0 - pos_threshold:
                    pred_types.add("at_end")

                # before / after subject
                if subj_pos is not None:
                    if word_idx > subj_pos and word_idx - subj_pos <= max_after_tokens:
                        pred_types.add("after_subject")
                    if word_idx < subj_pos and subj_pos - word_idx <= max_after_tokens:
                        pred_types.add("before_subject")

                # before / after verb
                if verb_pos is not None:
                    if word_idx > verb_pos and word_idx - verb_pos <= max_after_tokens:
                        pred_types.add("after_verb")
                    if word_idx < verb_pos and verb_pos - word_idx <= max_after_tokens:
                        pred_types.add("before_verb")

                # hashtag
                if word_str.startswith("#"):
                    print("exists")
                    pred_types.add("is_hashtag")

                for ptype in pred_types:
                    # mask
                    masked_seq = tokseq.clone()
                    masked_seq[word_start : word_end + 1] = mask_id
                    masked_inputs_cpu.append(masked_seq)
                    masked_attn_cpu.append(attn[b].cpu())
                    meta.append((c_hat, word_str, ptype, active_neurons))
                    if len(masked_inputs_cpu) == FLUSH_INTERVAL:   # e.g. 16‑32
                        _flush(masked_inputs_cpu, masked_attn_cpu, meta)
                        masked_inputs_cpu.clear()
                        masked_attn_cpu.clear()
                        meta.clear()


                seen_tokens_in_doc.add(word_str)
                word_start = word_end + 1

            metas_for_df.append(seen_tokens_in_doc)

        for s in metas_for_df:
            df_counter.update(s)
        N_docs += len(metas_for_df)

        if not masked_inputs_cpu:
            torch.cuda.empty_cache()
            continue

        masked_inputs_cpu = torch.stack(masked_inputs_cpu)
        masked_attn_cpu   = torch.stack(masked_attn_cpu)

        for i in range(0, masked_inputs_cpu.size(0), chunk_size):
            slc = slice(i, i + chunk_size)
            inputs = masked_inputs_cpu[slc].to(device, non_blocking=True)
            attns  = masked_attn_cpu[slc].to(device, non_blocking=True)
            _ = model(input_ids=inputs, attention_mask=attns)
            chunk_cls    = acts_holder.pop().cpu()
            for j, (c_hat, word_str, ptype, active_neurons) in enumerate(meta[slc]):
                key = (word_str, ptype)
                total_counter[c_hat][key] += 1
                dropped = [idx for idx, thr in active_neurons
                              if chunk_cls[j, idx] < thr]

                if dropped: # at least one causal flip
                    flip_counter[c_hat][word_str][ptype] += 1
                    for idx in dropped:
                        poly_counts[(idx, ptype)] += 1
                        keyword_counts[idx][word_str] += 1
                        rule_neuron_counter[(c_hat, word_str, ptype)][idx] += 1
            del inputs, attns, chunk_cls
            torch.cuda.empty_cache()

        del logits, base_cls
        gc.collect()
        torch.cuda.empty_cache()

    h.remove()

    results_df = _compute_support_table(
        flip_counter=flip_counter,
        total_counter=total_counter,
        df_counter=df_counter,
        rule_neurons_counter=rule_neuron_counter,
        n_docs=N_docs,
        min_df=min_df,
        alpha=0.05,
        min_flips=1,
    )

    return results_df


**Positional Bucketing (Heat Map)**

In [None]:
import math
import string
from collections import defaultdict, Counter
from typing import Dict, List, Tuple

import torch
import torch.nn.functional as F
import pandas as pd


@torch.no_grad()
def causal_word_buckets_batched(
    model,
    dataloader,
    predicates: Dict[int, List[Tuple[int, float]]],
    module: torch.nn.Module,
    tokenizer,
    use_gelu: bool,
    *,
    n_buckets: int = 10,
    support_thr: float = 0.02,
    min_df: int = 1,
    device: str = "cuda",
    chunk_size: int = 24,
) -> pd.DataFrame:
    """
    Returns a DataFrame with columns:
        cls, token, bucket, flips, total, rate, idf
    """
    mask_id = (
        tokenizer.mask_token_id
        if tokenizer.mask_token_id is not None
        else tokenizer.unk_token_id
    )
    n_classes = model.config.num_labels

    flip_counter  = {c: defaultdict(Counter) for c in range(n_classes)}
    total_counter = {c: defaultdict(int)      for c in range(n_classes)}
    df_counter    = Counter()

    # hooks
    acts_holder: List[torch.Tensor] = []

    def _hook(_, __, out):
        if use_gelu:
            out = F.gelu(out, approximate="none")
        acts_holder.append(out[:, 0, :].detach().cpu())

    h = module.register_forward_hook(_hook)


    model.eval()
    model.gradient_checkpointing_disable()
    model.config.use_cache = False

    N_docs = 0
    with torch.no_grad():
        for batch in dataloader:
            batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}

            # forward on original text
            acts_holder.clear()
            logits = model(**batch).logits
            base_cls = acts_holder.pop()

            preds = logits.argmax(-1)
            B, T  = batch["input_ids"].shape
            attn  = batch["attention_mask"]
            seq_len = attn.sum(-1)

            def _bucket(pos, seq_l):
                return int(math.floor(pos * n_buckets / seq_l))

            # build masked variants
            masked_inputs_cpu, masked_attn_cpu, meta = [], [], []
            metas_for_df = []

            for b in range(B):
                c_hat  = int(preds[b])
                tokseq = batch["input_ids"][b].cpu()
                actvec = base_cls[b]
                slen   = int(seq_len[b])

                class_preds = predicates.get(c_hat, [])
                if not class_preds:
                    continue
                active_neurons = [
                    (idx, thr)
                    for idx, thr, *_ in class_preds
                    if actvec[idx] >= thr
                ]
                if not active_neurons:
                    continue

                if hasattr(tokenizer, "word_ids"):
                    word_id_seq = tokenizer.word_ids(batch_index=b)
                else:
                    word_id_seq = None
                tokens = tokenizer.convert_ids_to_tokens(
                    tokseq.tolist(), skip_special_tokens=False,

                )

                seen_tokens_in_doc = set()
                word_start = 1
                while word_start < T:
                    tok_id = int(tokseq[word_start])
                    if tok_id in (
                        tokenizer.pad_token_id,
                        tokenizer.sep_token_id,
                        getattr(tokenizer, "eos_token_id", -1),
                    ):
                        break

                    if word_id_seq is not None:
                        wid = word_id_seq[word_start]
                        if wid is None:
                            word_start += 1
                            continue
                        word_end = word_start
                        while word_end + 1 < T and word_id_seq[word_end + 1] == wid:
                            word_end += 1
                    else:
                        word_end = word_start
                        while (
                            word_end + 1 < T and tokens[word_end + 1].startswith("##")
                        ):
                            word_end += 1

                    sub_toks = tokens[word_start : word_end + 1]
                    def _merge_wordpieces(tok_list):
                        # merge words
                        pieces = []
                        for t in tok_list:
                            if t.startswith("##"):
                                pieces[-1] += t[2:] # continuation
                            elif t[0] in {"Ġ", "▁"}:
                                pieces.append(t[1:]) # setence start
                            else:
                                pieces.append(t) # regular token
                        return "".join(pieces).lower()

                    word_str = _merge_wordpieces(sub_toks)
                    if (
                        not word_str
                        or all(ch in string.punctuation for ch in word_str)
                        or word_str == tokenizer.unk_token
                    ):
                        word_start = word_end + 1
                        continue

                    masked_seq = tokseq.clone()
                    masked_seq[word_start : word_end + 1] = mask_id
                    masked_inputs_cpu.append(masked_seq)
                    masked_attn_cpu.append(attn[b].cpu())

                    bucket_idx = _bucket(word_start, slen)
                    meta.append((c_hat, word_str, bucket_idx, active_neurons))
                    seen_tokens_in_doc.add(word_str)

                    word_start = word_end + 1

                metas_for_df.append(seen_tokens_in_doc)

            for s in metas_for_df:
                for tok in s:
                    df_counter[tok] += 1
            N_docs += len(metas_for_df)

            if not masked_inputs_cpu:
                torch.cuda.empty_cache()
                continue

            masked_inputs_cpu = torch.stack(masked_inputs_cpu)
            masked_attn_cpu   = torch.stack(masked_attn_cpu)

            # evaluate masked variants in chunks
            for i in range(0, masked_inputs_cpu.size(0), chunk_size):
                slc = slice(i, i + chunk_size)
                inputs = masked_inputs_cpu[slc].to(device, non_blocking=True)
                attns  = masked_attn_cpu[slc].to(device, non_blocking=True)

                acts_holder.clear()
                _ = model(input_ids=inputs, attention_mask=attns)
                chunk_cls = acts_holder.pop().cpu()

                for j, (c_hat, word_str, bucket_idx, active_neurons) in enumerate(
                    meta[slc]
                ):
                    key = (word_str, bucket_idx)
                    total_counter[c_hat][key] += 1
                    if any(chunk_cls[j, idx] < thr for idx, thr in active_neurons):
                        flip_counter[c_hat][word_str][bucket_idx] += 1

                del inputs, attns, chunk_cls
                torch.cuda.empty_cache()

            del logits, base_cls
            torch.cuda.empty_cache()

    h.remove()

    # aggregate
    rows = []
    for c in range(n_classes):
        for word_str, buckets in flip_counter[c].items():
            df = df_counter[word_str]
            if df < min_df:
                continue
            idf = math.log((N_docs + 1) / (df + 1))
            for bucket_idx, flips in buckets.items():
                total = total_counter[c][(word_str, bucket_idx)]
                if total == 0:
                    continue
                rate = flips / total
                weighted_rate = idf * rate
                if weighted_rate < support_thr:
                    continue
                rows.append(
                    dict(
                        cls=c,
                        token=word_str,
                        bucket=bucket_idx,
                        flips=flips,
                        total=total,
                        rate=rate,
                        idf=idf,
                    )
                )

    return pd.DataFrame(rows)


In [None]:
import numpy as np
import matplotlib.pyplot as plt

def plot_flip_heatmaps(
        df: pd.DataFrame,
        per_class: bool = True,
        top_tokens: int = 40,
        cmap: str = "viridis",
        figsize: tuple = (12, 6),
):
    grp = [df] if not per_class else [g for _, g in df.groupby("cls")]

    for sub in grp:
        c = int(sub.cls.iloc[0]) if per_class else None

        tok_counts = sub.groupby("token")["total"].sum().sort_values(ascending=False)
        kept = tok_counts.head(top_tokens).index
        sub   = sub[sub.token.isin(kept)]

        heat = (
            sub.pivot(index="token", columns="offset", values="rate")
               .fillna(0.0)
               .loc[lambda x: x.mean(axis=1).sort_values(ascending=False).index]
               .sort_index(axis=1)
        )

        plt.figure(figsize=figsize)
        plt.imshow(heat.values, aspect="auto", cmap=cmap, vmin=0.0, vmax=1.0)
        plt.colorbar(label="flip rate")
        plt.xticks(np.arange(heat.shape[1]), heat.columns, rotation=90)
        plt.yticks(np.arange(heat.shape[0]), heat.index)
        title = f"class {label_mapping[c]}" if per_class else "all classes"
        plt.title(f"Flip-rate heat-map: {title}")
        plt.xlabel("token position (offset)")
        plt.ylabel("token")
        plt.tight_layout()
        plt.show()

# Baselines Eval

**Evaluate Baseline Scores**

In [None]:
# top keywords from each
anchors_tokens = [
    "sad", "depression", "sadness", "depressing", "sadly",
    "depressed", "heartbreaking", "mourn", "restless", "gloomy"
]

emolex_tokens = [
    "depression", "bad", "lost", "terrorism", "sadness",
    "awful", "anxiety", "depressed", "feeling", "offended"
]

neurologic_tokens = [
    "sad", "depression", "lost", "depressing", "sadness",
    "sadly", "mourn", "nightmare", "anxiety", "never"
]

rule_sets = {
    "Anchors": anchors_tokens,
    "EmoLex":  emolex_tokens,
    "NeuroLogic": neurologic_tokens,
}

SADNESS_ID = 3

# helpers
from sklearn.metrics import precision_recall_fscore_support, accuracy_score

def make_predict_fn(tokens):
    tset = set(t.lower() for t in tokens)
    return lambda txt: int(any(tok in txt.lower() for tok in tset))   # 1 if hits

predict_fns = {n: make_predict_fn(toks) for n, toks in rule_sets.items()}

# evaluate scores
y_true = []
y_pred = {n: [] for n in rule_sets}

for batch in test_loader:
    ids    = batch["input_ids"]
    labels = batch["labels"]

    texts = tok.batch_decode(ids, skip_special_tokens=True)
    gold  = (labels == SADNESS_ID).int().tolist()

    y_true.extend(gold)
    for name, fn in predict_fns.items():
        y_pred[name].extend(fn(t) for t in texts)

print(f"{'Model':<10}  P      R      F1     Acc")
for name, preds in y_pred.items():
    p, r, f1, _ = precision_recall_fscore_support(
        y_true, preds, average="binary", pos_label=1, zero_division=0
    )
    acc = accuracy_score(y_true, preds)
    print(f"{name:<10}  {p:.3f}  {r:.3f}  {f1:.3f}  {acc:.3f}")


**Benchmark EmoLex**

In [None]:
import re
import torch, torch.nn.functional as F, numpy as np, pandas as pd, tqdm
from datasets import load_dataset, concatenate_datasets
from transformers import AutoTokenizer, AutoModelForSequenceClassification
from torch.utils.data import DataLoader
from collections import defaultdict, Counter
from sklearn.metrics import f1_score

# setup
MODEL_ID   = "philschmid/DistilBERT-tweet-eval-emotion"
BATCH_SIZE = 64
MAX_LEN    = 128
DEVICE     = "cuda" if torch.cuda.is_available() else "cpu"

# data
ds_train = concatenate_datasets([
    load_dataset("tweet_eval", "emotion", split="train"),
    load_dataset("tweet_eval", "emotion", split="validation")
])
ds_test  = load_dataset("tweet_eval", "emotion", split="test")

tok = AutoTokenizer.from_pretrained(MODEL_ID)

def collate(batch):
    enc = tok([b["text"] for b in batch],
              padding="max_length",
              truncation=True,
              max_length=MAX_LEN,
              return_tensors="pt")
    enc["labels"] = torch.tensor([b["label"] for b in batch])
    return enc

test_loader = DataLoader(ds_test,
                         batch_size=BATCH_SIZE,
                         shuffle=False,
                         collate_fn=collate)

# label mapping
tmp_model  = AutoModelForSequenceClassification.from_pretrained(
    MODEL_ID, torch_dtype=torch.float16
)
id2label   = {int(k): v.lower() for k, v in tmp_model.config.id2label.items()}
label2id   = {v: k for k, v in id2label.items()}
del tmp_model  # free memory
n_labels   = len(id2label)

# emolex baseline
URL = ("https://archive.org/download/nrc-emotion-lexicon-v0.92/"
       "NRC-emotion-lexicon-wordlevel-alphabetized-v0.92.txt")
df = pd.read_csv(URL, sep="\t", names=["word", "emotion", "flag"])

lex  = defaultdict(list)
for w, e, f in df.itertuples(index=False):
    if f == 1 and e.lower() in label2id:
        lex[w].append(label2id[e.lower()])

def emolex_predict(texts):
    preds = []
    for t in texts:
        toks  = re.findall(r"[a-z]+", t.lower()) # keep letters only
        votes = Counter(em
                        for tok in toks
                        for em in lex.get(tok, []))
        preds.append(None if not votes else votes.most_common(1)[0][0])
    return preds

# evaluate

def evaluate(loader, predict_fn):
    correct, total, covered = 0, 0, 0
    for batch in tqdm.tqdm(loader, desc="eval", leave=False):
        texts   = tok.batch_decode(batch["input_ids"],
                                   skip_special_tokens=True)
        labels  = batch["labels"].tolist()
        outputs = predict_fn(texts)
        for g, p in zip(labels, outputs):
            total += 1
            if p is not None:
                covered += 1
                if p == g:
                    correct += 1
    accuracy = correct / covered if covered else 0.0
    coverage = covered / total
    return accuracy, coverage

In [None]:
for name, fn in [("EmoLex", emolex_predict)]:
    acc, cov = evaluate(test_loader, fn)
    print(f"{name:8s} | accuracy {acc:.3f} | coverage {cov*100:.1f}%")

**EmoLex**

In [None]:
import re
from collections import defaultdict, Counter
class_names = [n.lower() for n in ds_train.features["label"].names]
label2id    = {name: i for i, name in enumerate(class_names)}
id2label    = {i: name for name, i in label2id.items()}
# map nrc emotions to tweeteval
emo_map = {
    "anger":     label2id["anger"],
    "joy":       label2id["joy"],
    "sadness":   label2id["sadness"],
    # treat anticipation / trust as optimism
    "anticipation": label2id["optimism"],
    "trust":        label2id["optimism"],
}

lex = defaultdict(list)
for w, e, f in df.itertuples(index=False):
    if f == 1 and e.lower() in emo_map:
        lex[w].append(emo_map[e.lower()])

_token_re = re.compile(r"[a-z]+")


def emolex_predict(texts):
    out = []
    for t in texts:
        toks  = _token_re.findall(t.lower()) # strip punctuation
        votes = Counter(cls for tok in toks for cls in lex.get(tok, []))
        out.append(None if not votes else votes.most_common(1)[0][0])
    return out


In [None]:
# top-k words
import re
from collections import Counter, defaultdict

def emolex_top_rules(texts, lex, id2label, top_k=20):
    word_re    = re.compile(r"[a-z]+")
    cls_counts = defaultdict(Counter)

    for txt in texts:
        for w in word_re.findall(txt.lower()):
            for cls in lex.get(w, []):
                cls_counts[cls][w] += 1

    for cls, cnt in cls_counts.items():
        print(f"\nClass {cls} ({id2label[cls]}):")
        for word, freq in cnt.most_common(top_k):
            print(f"{freq:6d}  {word}")

# run it on the TweetEval test texts
test_texts = [ex["text"] for ex in ds_test]
emolex_top_rules(test_texts, lex, id2label, top_k=20)


# Evaluation of DNF Clauses

In [None]:
from typing import Dict, List, Tuple
import torch
from sklearn.metrics import accuracy_score

# rules = {class_id: [[(idx, thr), …], …]}

def eval_clauses_individually(
    model: torch.nn.Module,
    dataloader,
    rules: Dict[int, List[List[Tuple[int, float]]]],
    target_layer: torch.nn.Module,
    *,
    apply_act: bool = False,
    device: str = "cuda",
    top_k: int = 1,
) -> Dict[int, List[Tuple[List[Tuple[int, float]], float]]]:
    buf = []

    def _hook(_, __, out):
        buf.append(out.detach())

    h = target_layer.register_forward_hook(_hook)
    act_fn = torch.nn.functional.gelu if apply_act else (lambda x: x)
    model.eval()

    # cache all activations
    acts_all, labels_all = [], []
    with torch.no_grad():
        for batch in dataloader:
            labels = batch.pop("labels").cpu()
            batch  = {k: v.to(device) for k, v in batch.items()}

            buf.clear()
            model(**batch)
            acts = act_fn(buf.pop()).cpu()
            if acts.dim() == 3:
                acts = acts[:, 0, :]
            acts_all.append(acts)
            labels_all.append(labels)
    acts_all   = torch.cat(acts_all)
    labels_all = torch.cat(labels_all)

    results: Dict[int, List[Tuple[List[Tuple[int, float]], float]]] = {}

    for class_id, clauses in rules.items():
        class_mask = labels_all == class_id
        if not class_mask.any():
            continue  # no examples of this class in loader

        scores = []
        for clause in clauses:
            # evaluate clause on all samples
            fires = torch.ones_like(class_mask, dtype=torch.bool)
            for idx, thr in clause:
                # all has to fire
                fires &= acts_all[:, idx] > thr

            # predictions: class_id when fires else -1
            correct = (fires & class_mask).sum().item()
            total   = class_mask.sum().item()
            acc     = correct / total if total else 0.0
            scores.append((clause, acc))

        # sort by accuracy (desc) then by clause length with shorter prefered
        scores.sort(key=lambda x: (-x[1], len(x[0])))
        results[class_id] = scores[:top_k]

    h.remove()
    return results


In [None]:
from typing import Dict, List, Tuple
from collections import Counter
import torch
from sklearn.metrics import accuracy_score, f1_score
def eval_pruned_dnf_rules_pos(
    model: torch.nn.Module,
    dataloader,
    rules: Dict[int, List[List[Tuple[int, float]]]],
    target_layer: torch.nn.Module,
    *,
    apply_act: bool = False,
    device: str = "cuda",
    average: str = "macro",
):
    """
    Evaluate pruned DNF rule‑sets with only positive predicates.
    """
    buf = []

    def _hook(_, __, out):
        buf.append(out.detach())

    h = target_layer.register_forward_hook(_hook)
    act_fn = torch.nn.functional.gelu if apply_act else (lambda x: x)

    y_true, y_pred = [], []
    model.eval()

    with torch.no_grad():
        for batch in dataloader:
            labels = batch.pop("labels").cpu()
            batch  = {k: v.to(device) for k, v in batch.items()}

            buf.clear()
            model(**batch)
            acts = act_fn(buf.pop()).cpu()

            if acts.dim() == 3:
                acts = acts[:, 0, :]

            for a, gold in zip(acts, labels):
                best_score, best_margin, pred = -1.0, -1.0, -1

                for c, clauses in rules.items():
                    total = len(clauses)
                    fired = []

                    for clause in clauses:
                        if all(a[idx] > thr for idx, thr in clause):
                            fired.append(clause)

                    if not fired:
                        continue

                    hit_rate = len(fired) / total
                    margin   = sum(
                        sum(a[idx] - thr for idx, thr in clause)
                        for clause in fired
                    )

                    if hit_rate > best_score or (
                        hit_rate == best_score and margin > best_margin
                    ):
                        best_score, best_margin, pred = hit_rate, margin, c

                y_true.append(gold.item())
                y_pred.append(pred)

    h.remove()

    acc = accuracy_score(y_true, y_pred)
    f1  = f1_score(y_true, y_pred, average=average, zero_division=0)
    print(f"Accuracy: {acc:.4f}   Macro‑F1: {f1:.4f}")
    return acc, f1

In [None]:
from typing import Dict, List, Tuple
from collections import Counter
import torch
from sklearn.metrics import accuracy_score, f1_score

def eval_pruned_dnf_rules(
    model: torch.nn.Module,
    dataloader,
    rules,
    target_layer: torch.nn.Module,
    *,
    apply_act: bool = False,
    device: str = "cuda",
    average: str = "macro",
):
    """
    Evaluate pruned general DNF rule‑sets.
    """
    buf = []

    def _hook(_, __, out):
        buf.append(out.detach())

    h = target_layer.register_forward_hook(_hook)
    act_fn = torch.nn.functional.gelu if apply_act else (lambda x: x)

    y_true, y_pred = [], []
    model.eval()

    with torch.no_grad():
        for batch in dataloader:
            labels = batch.pop("labels").to("cpu")
            batch  = {k: v.to(device) for k, v in batch.items()}

            buf.clear()
            model(**batch)
            acts = act_fn(buf.pop().to("cpu"))

            if acts.dim() == 3:
                acts = acts[:, 0, :]

            for a, gold in zip(acts, labels):
                best_score, best_margin, pred = -1.0, -1.0, -1

                for c, clauses in rules.items():
                    total = len(clauses)
                    fired_clauses = []

                    for clause in clauses:
                        if all(
                            (a[idx] > thr) if sign else (a[idx] <= thr)
                            for idx, thr, sign in clause
                        ):
                            # all predicates fired so DNF rule fired
                            fired_clauses.append(clause)

                    if not fired_clauses:
                        continue

                    hit_rate = len(fired_clauses) / total
                    margin   = sum( # tie break
                        sum(
                            (a[idx] - thr) if sign else (thr - a[idx])
                            for idx, thr, sign in clause
                        )
                        for clause in fired_clauses
                    )

                    if (
                        hit_rate > best_score
                        or (hit_rate == best_score and margin > best_margin)
                    ):
                        best_score, best_margin, pred = hit_rate, margin, c

                y_true.append(gold.item())
                y_pred.append(pred)

    h.remove()

    acc = accuracy_score(y_true, y_pred)
    f1  = f1_score(y_true, y_pred, average=average, zero_division=0)

    print(f"Accuracy: {acc:.4f}   Macro‑F1: {f1:.4f}")
    return acc, f1