In [1]:
#implementing dedier framework for LLM

In [2]:
!pip install "transformers>=4.39" "datasets>=2.19" "evaluate" torch accelerate
!pip install sentencepiece

Collecting evaluate
  Using cached evaluate-0.4.3-py3-none-any.whl.metadata (9.2 kB)
Collecting accelerate
  Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Using cached evaluate-0.4.3-py3-none-any.whl (84 kB)
Downloading accelerate-1.6.0-py3-none-any.whl (354 kB)
Installing collected packages: accelerate, evaluate
Successfully installed accelerate-1.6.0 evaluate-0.4.3

[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m A new release of pip is available: [0m[31;49m25.0[0m[39;49m -> [0m[32;49m25.0.1[0m
[1m[[0m[34;49mnotice[0m[1;39;49m][0m[39;49m To update, run: [0m[32;49mpip install --upgrade pip[0m
Collecting sentencepiece
  Downloading sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl.metadata (7.7 kB)
Downloading sentencepiece-0.2.0-cp39-cp39-macosx_10_9_x86_64.whl (1.2 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.2/1.2 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: sentencepiece
Successfull

In [None]:
from datasets import load_dataset, DatasetDict

ds_all = load_dataset("google/civil_comments", split="train")  # 2 M rows
# use the built‑in official train/val/test indices released by Jigsaw
splits = DatasetDict({
    "train": ds_all.filter(lambda ex: ex["split"] == "train"),
    "validation": ds_all.filter(lambda ex: ex["split"] == "val"),
    "test": ds_all.filter(lambda ex: ex["split"] == "test"),
})

In [None]:
from transformers import AutoTokenizer, AutoModelForSequenceClassification, TrainingArguments, Trainer

tok_t = AutoTokenizer.from_pretrained("bert-large-uncased")

def tokenize(ex):
    return tok_t(
        ex["text"],
        truncation=True,
        max_length=256,
        padding="max_length",
    )

splits_tok = splits.map(tokenize, batched=True).rename_column("toxicity", "labels")
teacher = AutoModelForSequenceClassification.from_pretrained(
    "bert-large-uncased", num_labels=2
)

args = TrainingArguments(
    output_dir="teacher_out",
    per_device_train_batch_size=8,
    per_device_eval_batch_size=32,
    learning_rate=2e‑5,      # authors’ best for Civil Comments  [oai_citation_attribution:0‡arXiv](https://arxiv.org/pdf/2310.18590)
    num_train_epochs=3,
    weight_decay=0.01,
    evaluation_strategy="epoch",
)

trainer_t = Trainer(
    model=teacher,
    args=args,
    train_dataset=splits_tok["train"],
    eval_dataset=splits_tok["validation"],
)
trainer_t.train()
teacher.save_pretrained("teacher_ckpt")

In [None]:
import torch
from torch import nn
from transformers import AutoConfig, AutoModel, AutoModelForSequenceClassification

class DEDIERStudent(nn.Module):
    def __init__(self, base_ckpt="distilbert-base-uncased", num_labels=2, aux_layer=1):
        super().__init__()
        cfg = AutoConfig.from_pretrained(base_ckpt, output_hidden_states=True)
        self.encoder = AutoModel.from_pretrained(base_ckpt, config=cfg)
        self.classifier = nn.Linear(cfg.hidden_size, num_labels)
        # early‑readout head (two‑layer MLP works best  [oai_citation_attribution:1‡arXiv](https://arxiv.org/pdf/2310.18590))
        self.aux_layer = aux_layer          # 1 = after first transformer block
        self.aux_head = nn.Sequential(
            nn.Linear(cfg.hidden_size, cfg.hidden_size),
            nn.GELU(),
            nn.Linear(cfg.hidden_size, num_labels),
        )
        self.temperature = 2.0              # KD temperature

    def forward(self, input_ids, attention_mask, labels=None, teacher_logits=None):
        out = self.encoder(
            input_ids, attention_mask=attention_mask, return_dict=True
        )
        logits = self.classifier(out.last_hidden_state[:, 0])  # [CLS]
        # early readout
        h_aux = out.hidden_states[self.aux_layer][:, 0]        # layer‑k [CLS]
        logits_aux = self.aux_head(h_aux)

        if labels is None:
            return logits, logits_aux     # inference

        # --- DEDIER losses ---
        ce = nn.functional.cross_entropy(logits, labels)

        # knowledge‑distillation (teacher => student)
        kd = nn.functional.kl_div(
            nn.functional.log_softmax(logits / self.temperature, dim=-1),
            nn.functional.softmax(teacher_logits / self.temperature, dim=-1),
            reduction="batchmean",
        ) * (self.temperature ** 2)

        # error flag: confidently wrong early readout?
        probs_aux = nn.functional.softmax(logits_aux, dim=-1)
        top2 = probs_aux.topk(2, dim=-1).values
        margin = (top2[:, 0] - top2[:, 1]).detach()
        wrong = (logits_aux.argmax(-1) != labels).float()
        w_i = (1 + margin ** 3) * wrong          # β = 3 from paper
        loss_dedier = (w_i * kd).mean()          # α implicit inside w_i

        return ce + 0.05 * loss_dedier, logits   # α = 0.05  [oai_citation_attribution:2‡arXiv](https://arxiv.org/pdf/2310.18590)

In [None]:
tok_s = tok_t               # same vocab

def collate(batch):
    # fetch teacher logits for KD
    with torch.no_grad():
        t_out = teacher(
            torch.tensor([item["input_ids"] for item in batch]),
            attention_mask=torch.tensor([item["attention_mask"] for item in batch]),
        ).logits
    labels = torch.tensor([item["labels"] for item in batch])
    return {
        "input_ids": torch.tensor([item["input_ids"] for item in batch]),
        "attention_mask": torch.tensor([item["attention_mask"] for item in batch]),
        "labels": labels,
        "teacher_logits": t_out,
    }

student = DEDIERStudent()

args_s = TrainingArguments(
    output_dir="student_out",
    per_device_train_batch_size=16,
    per_device_eval_batch_size=64,
    learning_rate=2e‑5,
    num_train_epochs=4,
    weight_decay=0.01,
    evaluation_strategy="epoch",
)

trainer_s = Trainer(
    model=student,
    args=args_s,
    train_dataset=splits_tok["train"],
    eval_dataset=splits_tok["validation"],
    data_collator=collate,
)
trainer_s.train()

In [None]:
import evaluate
roc = evaluate.load("roc_auc")

id_cols = [c for c in splits["test"].column_names if c.endswith("_identity")]

def predict(ds, model):
    logits = []
    for chunk in ds.map(tokenize, batched=True).iter(batch_size=128):
        with torch.no_grad():
            l, _ = model(
                torch.tensor(chunk["input_ids"]),
                attention_mask=torch.tensor(chunk["attention_mask"]),
            )
        logits.append(l.softmax(-1)[:, 1].cpu())
    return torch.cat(logits)

y_hat = predict(splits["test"], student)
y_true = splits["test"]["toxicity"] >= 0.5

# overall AUROC
print("overall:", roc.compute(prediction_scores=y_hat, references=y_true)["roc_auc"])

# subgroup AUROC
for g in id_cols:
    mask = splits["test"][g] == 1
    if mask.sum() > 0:
        print(g, roc.compute(prediction_scores=y_hat[mask], references=y_true[mask])["roc_auc"])