# Distilling Justice

## Imports

In [None]:
!pip install captum



In [None]:
import torch
import numpy as np
import pandas as pd
from sklearn.metrics import accuracy_score, f1_score, precision_recall_fscore_support
from transformers import (
    BertForSequenceClassification, BertTokenizer, BertConfig,
    DistilBertForSequenceClassification, DistilBertTokenizer,
    Trainer, TrainingArguments
)
from captum.attr import IntegratedGradients
from tqdm import tqdm
from sklearn.metrics.pairwise import cosine_similarity

## Load Data

In [None]:
df = pd.read_csv("justice.csv")
df = df.dropna(subset=["facts", "first_party_winner"])
df["label"] = df["first_party_winner"].astype(int)

from sklearn.model_selection import train_test_split
train_texts, val_texts, train_labels, val_labels = train_test_split(
    df["facts"].tolist(),
    df["label"].tolist(),
    test_size=0.2,
    stratify=df["label"],
    random_state=42
)

In [None]:
df

Unnamed: 0.1,Unnamed: 0,ID,name,href,docket,term,first_party,second_party,facts,facts_len,majority_vote,minority_vote,first_party_winner,decision_type,disposition,issue_area,label
0,0,50606,Roe v. Wade,https://api.oyez.org/cases/1971/70-18,70-18,1971,Jane Roe,Henry Wade,"<p>In 1970, Jane Roe (a fictional name used in...",501,7,2,True,majority opinion,reversed,,1
1,1,50613,Stanley v. Illinois,https://api.oyez.org/cases/1971/70-5014,70-5014,1971,"Peter Stanley, Sr.",Illinois,<p>Joan Stanley had three children with Peter ...,757,5,2,True,majority opinion,reversed/remanded,Civil Rights,1
2,2,50623,Giglio v. United States,https://api.oyez.org/cases/1971/70-29,70-29,1971,John Giglio,United States,<p>John Giglio was convicted of passing forged...,495,7,0,True,majority opinion,reversed/remanded,Due Process,1
3,3,50632,Reed v. Reed,https://api.oyez.org/cases/1971/70-4,70-4,1971,Sally Reed,Cecil Reed,"<p>The Idaho Probate Code specified that ""male...",378,7,0,True,majority opinion,reversed/remanded,Civil Rights,1
4,4,50643,Miller v. California,https://api.oyez.org/cases/1971/70-73,70-73,1971,Marvin Miller,California,"<p>Miller, after conducting a mass mailing cam...",305,5,4,True,majority opinion,vacated/remanded,First Amendment,1
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
3298,3298,63324,United States v. Palomar-Santiago,https://api.oyez.org/cases/2020/20-437,20-437,2020,United States,Refugio Palomar-Santiago,"<p>Refugio Palomar-Santiago, a Mexican nationa...",2054,9,0,True,majority opinion,reversed/remanded,Criminal Procedure,1
3299,3299,63323,Terry v. United States,https://api.oyez.org/cases/2020/20-5904,20-5904,2020,Tarahrick Terry,United States,<p>Tarahrick Terry pleaded guilty to one count...,1027,9,0,False,majority opinion,affirmed,Criminal Procedure,0
3300,3300,63331,United States v. Cooley,https://api.oyez.org/cases/2020/19-1414,19-1414,2020,United States,Joshua James Cooley,<p>Joshua James Cooley was parked in his picku...,1309,9,0,True,majority opinion,vacated/remanded,Civil Rights,1
3301,3301,63332,Florida v. Georgia,https://api.oyez.org/cases/2020/142-orig,142-orig,2020,Florida,Georgia,<p>This is an ongoing case of original jurisdi...,297,9,0,False,majority opinion,none,,0


In [None]:
# Dataset helper
class JusticeDataset(torch.utils.data.Dataset):
    def __init__(self, encodings, labels, model_type=None):
        self.encodings = encodings
        self.labels = labels
        self.model_type = model_type
    def __len__(self):
        return len(self.labels)
    def __getitem__(self, idx):
        item = {k: torch.tensor(v[idx]) for k, v in self.encodings.items()}
        item["labels"] = torch.tensor(self.labels[idx])
         # DistilBERT and MiniLM don't use token_type_ids
        if self.model_type in ["Step1_DistilBERT", "Step4_MiniLM", "Step3_TinyBERT"]:
            item.pop("token_type_ids", None)
        return item

## Metrics

In [None]:
def compute_metrics(pred):
    labels = pred.label_ids
    preds = pred.predictions.argmax(-1)
    acc = accuracy_score(labels, preds)
    precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average=None)
    macro_f1 = f1_score(labels, preds, average="macro")
    return {
        "accuracy": acc,
        "precision_class0": precision[0],
        "recall_class0": recall[0],
        "f1_class0": f1[0],
        "precision_class1": precision[1],
        "recall_class1": recall[1],
        "f1_class1": f1[1],
        "macro_f1": macro_f1
    }

## Training Arguments

In [None]:
!pip install 'accelerate>=0.26.0'

training_args = TrainingArguments(
    output_dir="./results",
    eval_strategy="epoch",
    save_strategy="no",
    learning_rate=2e-5,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
    num_train_epochs=3,
    weight_decay=0.01,
    logging_steps=50
)

Explanation Function

In [None]:
def explain_prediction(model, tokenizer, text, label=None):
    model.eval()
    inputs = tokenizer(
        text,
        return_tensors="pt",
        truncation=True,
        max_length=128,
        padding="max_length"
    )
    if label is None:
        with torch.no_grad():
            outputs = model(**inputs)
            label = outputs.logits.argmax(-1).item()

    def forward_func(embeddings, attention_mask=None):
        outputs = model(inputs_embeds=embeddings, attention_mask=attention_mask)
        return torch.softmax(outputs.logits, dim=1)[:, label]

    embeddings = model.get_input_embeddings()(inputs["input_ids"])
    ig = IntegratedGradients(forward_func)
    attributions, _ = ig.attribute(
        embeddings,
        additional_forward_args=(inputs["attention_mask"],),
        return_convergence_delta=True
    )

    scores = attributions.sum(dim=-1).squeeze(0).detach().numpy()
    return scores


## Model Configurations (Teacher + 5 Steps)

In [None]:
model_configs = {
    "Teacher_BERT": {
        "tokenizer": "bert-base-uncased",
        "model": lambda: BertForSequenceClassification.from_pretrained("bert-base-uncased", num_labels=2)
    },
    "Step1_DistilBERT": {
        "tokenizer": "distilbert-base-uncased",
        "model": lambda: DistilBertForSequenceClassification.from_pretrained("distilbert-base-uncased", num_labels=2)
    },
    "Step2_Custom6x512": {
        "tokenizer": "bert-base-uncased",
        "model": lambda: BertForSequenceClassification(
            BertConfig(
                vocab_size=30522,
                hidden_size=512,
                num_hidden_layers=6,
                num_attention_heads=8,
                intermediate_size=2048,
                num_labels=2
            )
        )
    },
    "Step3_TinyBERT": {
        "tokenizer": "bert-base-uncased",
        "model": lambda: BertForSequenceClassification.from_pretrained("huawei-noah/TinyBERT_General_4L_512D", num_labels=2)
    },
    "Step4_MiniLM": {
        "tokenizer": "bert-base-uncased",
        "model": lambda: BertForSequenceClassification.from_pretrained("microsoft/MiniLM-L12-H384-uncased", num_labels=2)
    },
    "Step5_Custom2x384": {
        "tokenizer": "bert-base-uncased",
        "model": lambda: BertForSequenceClassification(
            BertConfig(
                vocab_size=30522,
                hidden_size=384,
                num_hidden_layers=2,
                num_attention_heads=6,
                intermediate_size=1024,
                num_labels=2
            )
        )
    }
}


## Loop through Models

In [None]:
results = {}
teacher_vectors = {}

for name, cfg in model_configs.items():
    print(f"\n🔹 Training {name}...")

    # Load tokenizer
    tokenizer = BertTokenizer.from_pretrained(cfg["tokenizer"]) if "bert" in cfg["tokenizer"] else DistilBertTokenizer.from_pretrained(cfg["tokenizer"])

    # Tokenize
    train_enc = tokenizer(train_texts, truncation=True, padding=True, max_length=128)
    val_enc = tokenizer(val_texts, truncation=True, padding=True, max_length=128)
    train_ds = JusticeDataset(train_enc, train_labels, model_type=name)
    val_ds = JusticeDataset(val_enc, val_labels, model_type=name)

    # Load model
    model = cfg["model"]()

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        tokenizer=tokenizer,
        compute_metrics=compute_metrics
    )

    trainer.train()
    metrics = trainer.evaluate()
    results[name] = metrics

    # Explanations: only store teacher’s attributions once
    if name == "Teacher_BERT":
        teacher_vectors = [explain_prediction(model, tokenizer, text) for text in val_texts[:50]]
    else:
        student_vectors = [explain_prediction(model, tokenizer, text) for text in val_texts[:50]]
        cos_sims = []
        for t_vec, text in zip(teacher_vectors, val_texts[:50]):
            s_vec = explain_prediction(model, tokenizer, text)
            sim = cosine_similarity([t_vec], [s_vec])[0][0]
            cos_sims.append(sim)
        results[name]["cosine_similarity"] = np.mean(cos_sims)


🔹 Training Teacher_BERT...


Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss


Epoch,Training Loss,Validation Loss,Accuracy,Precision Class0,Recall Class0,F1 Class0,Precision Class1,Recall Class1,F1 Class1,Macro F1
1,0.6827,0.646403,0.650456,0.0,0.0,0.0,0.650456,1.0,0.788214,0.394107
2,0.6471,0.642228,0.650456,0.0,0.0,0.0,0.650456,1.0,0.788214,0.394107
3,0.5777,0.66564,0.629179,0.402778,0.126087,0.192053,0.656997,0.899533,0.759369,0.475711


  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))
  _warn_prf(average, modifier, f"{metric.capitalize()} is", len(result))



🔹 Training Step1_DistilBERT...


tokenizer_config.json:   0%|          | 0.00/48.0 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

config.json:   0%|          | 0.00/483 [00:00<?, ?B/s]

The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizer'.


model.safetensors:   0%|          | 0.00/268M [00:00<?, ?B/s]

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
  trainer = Trainer(


Epoch,Training Loss,Validation Loss


## Results Summary

In [None]:
results_df = pd.DataFrame(results).T
print(results_df)