In [79]:
import os
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import DebertaV2Tokenizer, AutoModel, RobertaTokenizer
from sklearn.preprocessing import MultiLabelBinarizer
from tqdm import tqdm
from sklearn.metrics import f1_score, classification_report
import numpy as np
import shap
from captum.attr import IntegratedGradients
from transformers import AutoTokenizer
import torch.nn.functional as F
import hf_xet

# Single-Task Learning

Below is the current pipeline for single-task learning. The code chunk bellow allows to switch between the two tasks: "narrative_classification" and "entity_framing".

We can also specify the training and test domains.

In [81]:
# ==========================
# CONTROL PANEL
# ==========================

# choose a task for the pipeline below: "narrative_classification" or "entity_framing"
TASK = "narrative_classification"

# select domains for training and testing: "UA"; "CC"; "UA", "CC";
TRAIN_DOMAIN = ["UA","CC"]
TEST_DOMAIN = ["UA", "CC"] # The test data comes from a separate dataset. 
# The test data is always the same regardless of the domain we choose to train on. This is for consistency.

"""
Note that all articles are now in English, but if we wanted to control for e.g. certain cultural variations of a specific language,
we could exclude articles that were originally written in that language.

Not to use the functionality, 'ALL' should be selected.

"""
# select languages for training and testing: "ALL";"EN";"HI";"BG";"RU";"PT"
TRAIN_LANGUAGES = ["ALL"] 
TEST_LANGUAGES = ["ALL"]

# debug mode -- reduced samples
DEBUG_MODE = False

# change the training hyperparameters here
MODEL_NAME = "roberta-base" # OR "deberta-v3-base"
MAX_LEN = 512
BATCH_SIZE = 8
EPOCHS = 3
LEARNING_RATE = 2e-5
MODEL_PATH = f"{TASK}_STL_{'-'.join(TRAIN_DOMAIN)}_to_{'-'.join(TEST_DOMAIN)}.pt" # -- to save the model later


In [83]:
# ==========================
# LOAD AND MERGE DATA
# ==========================

articles = pd.read_csv("train-all-articles.csv")
s1 = pd.read_csv("train-S1-labels.csv")
s2 = pd.read_csv("train-S2-labels.csv")

test_s1_articles = pd.read_csv("test-S1-articles.csv")
test_s1_labels = pd.read_csv("test-S1-labels.csv")
test_s2_articles = pd.read_csv("test-S2-articles.csv")
test_s2_labels = pd.read_csv("test-S2-labels.csv")

# ==========================
# STANDARDISE TEST SET COLUMNS
# ==========================

if TASK == "entity_framing":
    test_s1_labels.rename(columns={"Translated_Entity": "Entity"}, inplace=True)
elif TASK == "narrative_classification":
    test_s2_labels.columns = ["Filename", "Narrative", "Subnarrative"]

# ==========================
# FILTER + SPLIT TRAIN/VAL
# ==========================

# filter domains/languages for train/val
filtered_articles = articles[articles["Domain"].isin(TRAIN_DOMAIN)]
if "ALL" not in TRAIN_LANGUAGES:
    filtered_articles = filtered_articles[filtered_articles["Language"].isin(TRAIN_LANGUAGES)]

# 80/20 train/val split
filtered_articles = filtered_articles.sample(frac=1, random_state=42).reset_index(drop=True)
split_idx = int(0.8 * len(filtered_articles))
train_articles = filtered_articles.iloc[:split_idx].copy()
val_articles = filtered_articles.iloc[split_idx:].copy()

# debug subsampling if needed -- off by default
if DEBUG_MODE:
    train_articles = train_articles.sample(100)
    val_articles = val_articles.sample(100)
    test_s1_articles = test_s1_articles.sample(100)
    test_s2_articles = test_s2_articles.sample(100)


In [84]:
# ==========================
# TASK-SPECIFIC MERGE + PROCESSING
# ==========================

if TASK == "narrative_classification":
    # Merge articles with S2 labels
    df_train = pd.merge(train_articles, s2, on="Filename")
    df_val   = pd.merge(val_articles, s2, on="Filename")
    df_test  = pd.merge(test_s2_articles, test_s2_labels, on="Filename")

    TEXT_COL = "Translated_Text"
    LABEL_COL = "Narrative"

    for df in [df_train, df_val, df_test]:
        df.dropna(subset=[TEXT_COL, LABEL_COL], inplace=True)
        df[LABEL_COL] = df[LABEL_COL].apply(
            lambda x: [s.strip() for s in str(x).split(";") if s.strip().lower() != "nan"]
        )

    # Create shared label space from all available narrative data (zero-shot setup)
    full_set = pd.concat([df_train, df_val, df_test])
    mlb = MultiLabelBinarizer()
    mlb.fit(full_set[LABEL_COL])

    y_train = mlb.transform(df_train[LABEL_COL])
    y_val   = mlb.transform(df_val[LABEL_COL])
    y_test  = mlb.transform(df_test[LABEL_COL])
    num_classes = len(mlb.classes_)

elif TASK == "entity_framing":
    # Merge entity labels with articles
    df_train = pd.merge(s1, train_articles, on="Filename")
    df_val   = pd.merge(s1, val_articles, on="Filename")
    df_test  = pd.merge(test_s1_labels, test_s1_articles, on="Filename")

    TEXT_COL = "Translated_Text"
    LABEL_COL = "Label"

    def insert_entity_marker(text, start, end):
        try:
            start, end = int(start), int(end)
            return text[:start] + "[ENTITY]" + text[start:end] + "[/ENTITY]" + text[end:]
        except:
            return text

    for df in [df_train, df_val, df_test]:
        df.dropna(subset=[TEXT_COL, "Entity", LABEL_COL, "Start", "End"], inplace=True)
        df["Start"] = df["Start"].astype(int)
        df["End"] = df["End"].astype(int)
        df["Input_Text"] = df.apply(lambda row: insert_entity_marker(row[TEXT_COL], row["Start"], row["End"]), axis=1)
        df[LABEL_COL] = df[LABEL_COL].apply(lambda x: [s.strip() for s in str(x).split(",") if s.strip().lower() != "nan"])

    # For entity framing, create a separate label binarizer
    mlb = MultiLabelBinarizer()
    mlb.fit(df_train[LABEL_COL] + df_val[LABEL_COL] + df_test[LABEL_COL])  # full entity label space

    y_train = mlb.transform(df_train[LABEL_COL])
    y_val   = mlb.transform(df_val[LABEL_COL])
    y_test  = mlb.transform(df_test[LABEL_COL])
    num_classes = len(mlb.classes_)

else:
    raise ValueError("Unknown TASK specified.")


In [86]:
# ==========================
# TOKENISATION and DATASET CLASS
# ==========================

tokenizer = RobertaTokenizer.from_pretrained(MODEL_NAME)
#tokenizer = DebertaV2Tokenizer.from_pretrained(MODEL_NAME)

class MultiLabelDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        encoding = self.tokenizer(
            self.texts[idx],
            truncation=True,
            padding="max_length",
            max_length=self.max_len,
            return_tensors="pt"
        )

        return {
            "input_ids": encoding["input_ids"].squeeze(),
            "attention_mask": encoding["attention_mask"].squeeze(),
            "labels": torch.tensor(self.labels[idx], dtype=torch.float)
        }

train_dataset = MultiLabelDataset(df_train[TEXT_COL].tolist(), y_train, tokenizer, MAX_LEN)
val_dataset   = MultiLabelDataset(df_val[TEXT_COL].tolist(), y_val, tokenizer, MAX_LEN)
test_dataset  = MultiLabelDataset(df_test[TEXT_COL].tolist(), y_test, tokenizer, MAX_LEN)

train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader   = DataLoader(val_dataset, batch_size=BATCH_SIZE)
test_loader  = DataLoader(test_dataset, batch_size=BATCH_SIZE)




In [66]:
# Optional oversampling

# Get training texts based on task type
if TASK == "narrative_classification":
    train_texts_raw = df_train["Translated_Text"].tolist()
elif TASK == "entity_framing":
    train_texts_raw = df_train["Input_Text"].tolist()
else:
    raise ValueError("TASK must be either 'narrative_classification' or 'entity_framing'.")

# 1. Compute label frequency
label_counts = np.sum(y_train, axis=0)

# 2. Score each sample by rarity of its labels
sample_weights = (y_train * (1 / (label_counts + 1e-6))).sum(axis=1)
sample_weights = sample_weights / sample_weights.sum()

# 3. Resample indices with replacement
indices = np.arange(len(y_train))
resampled_indices = np.random.choice(indices, size=len(indices), replace=True, p=sample_weights)

# 4. Create oversampled dataset
train_texts_resampled = [train_texts_raw[i] for i in resampled_indices]
y_train_resampled = y_train[resampled_indices]

# 5. Rebuild Dataset & Dataloader
train_dataset = MultiLabelDataset(train_texts_resampled, y_train_resampled, tokenizer, MAX_LEN)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)


In [68]:
# ==========================
# MODEL CLASS
# ==========================

class TransformerClassifier(nn.Module):
    def __init__(self, model_name, num_classes):
        super().__init__()
        self.encoder = AutoModel.from_pretrained(model_name)
        self.dropout = nn.Dropout(0.3)
        self.classifier = nn.Linear(self.encoder.config.hidden_size, num_classes)

    def forward(self, input_ids, attention_mask):
        outputs = self.encoder(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0]  # CLS token
        pooled_output = self.dropout(pooled_output)
        return self.classifier(pooled_output)


In [70]:
# ==========================
# TRAINING UTILS
# ==========================

def predict_proba(model, loader, device):
    model.eval()
    probs = []
    with torch.no_grad():
        for batch in loader:
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            outputs = model(input_ids, attention_mask)
            probs.extend(torch.sigmoid(outputs).cpu().numpy())
    return np.array(probs)

def evaluate_threshold_sweep(y_true, y_pred, thresholds=np.arange(0.1, 0.9, 0.05)):
    best_thresh = 0.5
    best_f1 = 0
    results = []

    for thresh in thresholds:
        y_pred_bin = (y_pred > thresh).astype(int)
        macro = f1_score(y_true, y_pred_bin, average='macro', zero_division=0)
        micro = f1_score(y_true, y_pred_bin, average='micro', zero_division=0)
        exact = (y_pred_bin == y_true).all(axis=1).mean()

        results.append((thresh, macro, micro, exact))
        if macro > best_f1:
            best_f1 = macro
            best_thresh = thresh

    print("Threshold sweep results:")
    for t, macro, micro, exact in results:
        print(f"Thresh {t:.2f} | Macro F1: {macro:.3f} | Micro F1: {micro:.3f} | Exact Match: {exact:.3f}")

    print(f"\n Best threshold = {best_thresh:.2f} with Macro F1 = {best_f1:.3f}")
    return best_thresh


In [73]:
# ==========================
# TRAINING LOOP
# ==========================

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = TransformerClassifier(MODEL_NAME, num_classes).to(device)
optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
criterion = nn.BCEWithLogitsLoss()
best_macro_f1 = 0.0 

for epoch in range(EPOCHS):
    model.train()
    total_loss = 0
    for batch in tqdm(train_loader, desc=f"Epoch {epoch+1}"):
        input_ids = batch["input_ids"].to(device)
        attention_mask = batch["attention_mask"].to(device)
        labels = batch["labels"].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    print(f"\nEpoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

    # validation
    val_probs = predict_proba(model, val_loader, device)
    threshold = evaluate_threshold_sweep(y_val, val_probs)
    y_val_pred = (val_probs > threshold).astype(int)
    macro_f1 = f1_score(y_val, y_val_pred, average="macro", zero_division=0)
    print(f"Validation Macro F1 (Epoch {epoch+1}): {macro_f1:.4f}")

    if macro_f1 > best_macro_f1:
        best_macro_f1 = macro_f1
        torch.save(model.state_dict(), MODEL_PATH)
        print(f"Saved best model (Epoch {epoch+1}) to {MODEL_PATH}")


Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Epoch 1: 100%|██████████| 192/192 [01:20<00:00,  2.40it/s]

Epoch 1: Loss = 0.3231
Threshold sweep results:
Thresh 0.10 | Macro F1: 0.296 | Micro F1: 0.304 | Exact Match: 0.000
Thresh 0.15 | Macro F1: 0.320 | Micro F1: 0.349 | Exact Match: 0.010
Thresh 0.20 | Macro F1: 0.349 | Micro F1: 0.407 | Exact Match: 0.045
Thresh 0.25 | Macro F1: 0.342 | Micro F1: 0.437 | Exact Match: 0.068
Thresh 0.30 | Macro F1: 0.348 | Micro F1: 0.451 | Exact Match: 0.084
Thresh 0.35 | Macro F1: 0.351 | Micro F1: 0.469 | Exact Match: 0.108
Thresh 0.40 | Macro F1: 0.340 | Micro F1: 0.478 | Exact Match: 0.113
Thresh 0.45 | Macro F1: 0.306 | Micro F1: 0.479 | Exact Match: 0.113
Thresh 0.50 | Macro F1: 0.272 | Micro F1: 0.479 |

In [None]:
torch.cuda.empty_cache()

In [75]:
# ==========================
# FIXED THRESHOLD EVALUATION
# ==========================

def evaluate(loader, df_source, mlb, label="TEST", threshold=0.25): 

    """
    Evaluates a multi-label classification model using a fixed probability threshold.

    Args:
        loader (DataLoader): A PyTorch DataLoader yielding batches of tokenised input data
        df_source (pd.DataFrame): Source dataframe containing metadata for each example, including domain info.
        mlb (MultiLabelBinarizer): The fitted multi-label binarizer used for encoding and decoding labels.
        label (str, optional): Label for the dataset (e.g., 'TEST', 'VALIDATION'). Used for logging. Defaults to "TEST".
        threshold (float, optional): Probability threshold to convert predicted probabilities into binary labels. Defaults to 0.25.

    Returns:
        dict: A dictionary containing overall macro F1, micro F1, exact match score, 
              the threshold used, and the list of labels used after filtering.
              Also prints per-domain breakdowns of these metrics.

    Notes:
        - Filters out labels that are completely unseen in both predictions and ground truths 
          to avoid skewed metric calculations.
        - Performs evaluation on the entire dataset as well as broken down by domain.
    """
    model.eval()
    y_true, y_pred, domains = [], [], []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(loader, desc=f"Evaluating {label}")):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].cpu().numpy()

            outputs = model(input_ids, attention_mask)
            probs = torch.sigmoid(outputs).cpu().numpy()

            y_pred.extend(probs)
            y_true.extend(labels)

            start = i * loader.batch_size
            end = start + len(labels)
            domains.extend(df_source["Domain"].iloc[start:end].tolist())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    domains = np.array(domains)

    y_pred_bin = (y_pred > threshold).astype(int)

    # filter columns where y_true or y_pred has no samples (i.e., unseen label)
    mask = (y_true.sum(axis=0) + y_pred_bin.sum(axis=0)) > 0
    y_true = y_true[:, mask]
    y_pred_bin = y_pred_bin[:, mask]
    filtered_labels = np.array(mlb.classes_)[mask]

    macro = f1_score(y_true, y_pred_bin, average="macro", zero_division=0)
    micro = f1_score(y_true, y_pred_bin, average="micro", zero_division=0)
    exact = (y_pred_bin == y_true).all(axis=1).mean()

    print(f"\n {label} (Fixed Threshold={threshold:.2f}):")
    print(f"Macro F1: {macro:.3f}")
    print(f"Micro F1: {micro:.3f}")
    print(f"Exact Match: {exact:.3f}")

    print("\n----------------------------")
    print("Per-Domain Breakdown")
    print("----------------------------")
    for domain in np.unique(domains):
        idx = np.where(domains == domain)[0]
        y_true_d = y_true[idx]
        y_pred_d = y_pred_bin[idx]

        macro_d = f1_score(y_true_d, y_pred_d, average="macro", zero_division=0)
        micro_d = f1_score(y_true_d, y_pred_d, average="micro", zero_division=0)
        exact_d = (y_pred_d == y_true_d).all(axis=1).mean()

        print(f"\n Domain: {domain}")
        print(f"Macro F1: {macro_d:.3f}")
        print(f"Micro F1: {micro_d:.3f}")
        print(f"Exact Match: {exact_d:.3f}")

    return {
        "macro": macro,
        "micro": micro,
        "exact": exact,
        "threshold": threshold,
        "labels_used": filtered_labels.tolist()
    }


def evaluate_and_compare_fixed_thresh(val_loader, df_val, test_loader, df_test, mlb, threshold=0.25):
    print("\n=========================")
    print("Validation (Fixed Threshold)")
    print("=========================")
    val_results = evaluate(val_loader, df_val.reset_index(drop=True), mlb, label="VALIDATION", threshold=threshold)

    print("\n=========================")
    print("Test (Fixed Threshold)")
    print("=========================")
    test_results = evaluate(test_loader, df_test.reset_index(drop=True), mlb, label="TEST", threshold=threshold)

    print("\n=========================")
    print("OOD Generalization (Fixed Threshold)")
    print("=========================")
    macro_drop = val_results["macro"] - test_results["macro"]
    print(f"Δ Macro F1 (val - test): {macro_drop:.3f}")

    return {
        "val": val_results,
        "test": test_results,
        "ood_gap_macro": macro_drop
    }



In [77]:
results = evaluate_and_compare_fixed_thresh(
    val_loader, df_val,
    test_loader, df_test,
    mlb
)


Validation (Fixed Threshold)
Evaluating VALIDATION: 100%|██████████| 48/48 [00:07<00:00,  6.82it/s]

 VALIDATION (Fixed Threshold=0.25):
Macro F1: 0.450
Micro F1: 0.557
Exact Match: 0.115

----------------------------
Per-Domain Breakdown
----------------------------

 Domain: CC
Macro F1: 0.264
Micro F1: 0.584
Exact Match: 0.239

 Domain: UA
Macro F1: 0.235
Micro F1: 0.546
Exact Match: 0.063

Test (Fixed Threshold)
Evaluating TEST: 100%|██████████| 23/23 [00:03<00:00,  6.84it/s]

 TEST (Fixed Threshold=0.25):
Macro F1: 0.396
Micro F1: 0.526
Exact Match: 0.129

----------------------------
Per-Domain Breakdown
----------------------------

 Domain: CC
Macro F1: 0.226
Micro F1: 0.573
Exact Match: 0.260

 Domain: UA
Macro F1: 0.197
Micro F1: 0.499
Exact Match: 0.038

OOD Generalization (Fixed Threshold)
Δ Macro F1 (val - test): 0.054


In [88]:
## Alternative Training Loop: Ensemble


# ==========================
# ENSEMBLE TRAINING LOOP
# ==========================

for run_id in range(1, 4):  # Train 3 models for the ensemble
    print(f"\n=== Training Model {run_id}/3 ===")
    
    torch.manual_seed(42 + run_id)
    np.random.seed(42 + run_id)

    model = TransformerClassifier(MODEL_NAME, num_classes).to(device)
    optimizer = torch.optim.AdamW(model.parameters(), lr=LEARNING_RATE)
    criterion = nn.BCEWithLogitsLoss()
    best_macro_f1 = 0.0 
    MODEL_PATH_RUN = MODEL_PATH.replace(".pt", f"_{run_id}.pt")

    for epoch in range(EPOCHS):
        model.train()
        total_loss = 0
        for batch in tqdm(train_loader, desc=f"[Model {run_id}] Epoch {epoch+1}"):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].to(device)

            optimizer.zero_grad()
            outputs = model(input_ids, attention_mask)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"\n[Model {run_id}] Epoch {epoch+1}: Loss = {total_loss / len(train_loader):.4f}")

        val_probs = predict_proba(model, val_loader, device)
        threshold = evaluate_threshold_sweep(y_val, val_probs)
        y_val_pred = (val_probs > threshold).astype(int)
        macro_f1 = f1_score(y_val, y_val_pred, average="macro", zero_division=0)
        print(f"[Model {run_id}] Validation Macro F1 (Epoch {epoch+1}): {macro_f1:.4f}")

        if macro_f1 > best_macro_f1:
            best_macro_f1 = macro_f1
            torch.save(model.state_dict(), MODEL_PATH_RUN)
            print(f" Saved best Model {run_id} (Epoch {epoch+1}) to {MODEL_PATH_RUN}")



=== Training Model 1/3 ===
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
[Model 1] Epoch 1: 100%|██████████| 192/192 [01:19<00:00,  2.41it/s]

[Model 1] Epoch 1: Loss = 0.2728
Threshold sweep results:
Thresh 0.10 | Macro F1: 0.239 | Micro F1: 0.394 | Exact Match: 0.000
Thresh 0.15 | Macro F1: 0.236 | Micro F1: 0.462 | Exact Match: 0.013
Thresh 0.20 | Macro F1: 0.206 | Micro F1: 0.506 | Exact Match: 0.150
Thresh 0.25 | Macro F1: 0.178 | Micro F1: 0.483 | Exact Match: 0.197
Thresh 0.30 | Macro F1: 0.125 | Micro F1: 0.387 | Exact Match: 0.157
Thresh 0.35 | Macro F1: 0.096 | Micro F1: 0.342 | Exact Match: 0.142
Thresh 0.40 | Macro F1: 0.078 | Micro F1: 0.287 | Exact Match: 0.129
Thresh 0.45 | Macro F1: 0.067 | Micro F1: 0.235 | Exact Match: 0.123
T

In [90]:
def evaluate(loader, df_source, mlb, label="TEST", threshold=0.25): 
    """
    Ensemble evaluation using 3 trained models with soft voting.
    """
    model_paths = [
        MODEL_PATH.replace(".pt", f"_{i}.pt") for i in range(1, 4)
    ]
    
    models = []
    for path in model_paths:
        m = TransformerClassifier(MODEL_NAME, num_classes).to(device)
        m.load_state_dict(torch.load(path, map_location=device))
        m.eval()
        models.append(m)

    y_true, y_pred, domains = [], [], []

    with torch.no_grad():
        for i, batch in enumerate(tqdm(loader, desc=f"Ensembling {label}")):
            input_ids = batch["input_ids"].to(device)
            attention_mask = batch["attention_mask"].to(device)
            labels = batch["labels"].cpu().numpy()

            logits_sum = torch.zeros((input_ids.size(0), num_classes)).to(device)
            for model in models:
                logits_sum += model(input_ids, attention_mask)
            probs = torch.sigmoid(logits_sum / len(models)).cpu().numpy()

            y_pred.extend(probs)
            y_true.extend(labels)

            start = i * loader.batch_size
            end = start + len(labels)
            domains.extend(df_source["Domain"].iloc[start:end].tolist())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)
    domains = np.array(domains)

    y_pred_bin = (y_pred > threshold).astype(int)

    # Filter unused labels
    mask = (y_true.sum(axis=0) + y_pred_bin.sum(axis=0)) > 0
    y_true = y_true[:, mask]
    y_pred_bin = y_pred_bin[:, mask]
    filtered_labels = np.array(mlb.classes_)[mask]

    macro = f1_score(y_true, y_pred_bin, average="macro", zero_division=0)
    micro = f1_score(y_true, y_pred_bin, average="micro", zero_division=0)
    exact = (y_pred_bin == y_true).all(axis=1).mean()

    print(f"\n {label} (Fixed Threshold={threshold:.2f}):")
    print(f"Macro F1: {macro:.3f}")
    print(f"Micro F1: {micro:.3f}")
    print(f"Exact Match: {exact:.3f}")

    print("\n----------------------------")
    print("Per-Domain Breakdown")
    print("----------------------------")
    for domain in np.unique(domains):
        idx = np.where(domains == domain)[0]
        y_true_d = y_true[idx]
        y_pred_d = y_pred_bin[idx]

        macro_d = f1_score(y_true_d, y_pred_d, average="macro", zero_division=0)
        micro_d = f1_score(y_true_d, y_pred_d, average="micro", zero_division=0)
        exact_d = (y_pred_d == y_true_d).all(axis=1).mean()

        print(f"\n Domain: {domain}")
        print(f"Macro F1: {macro_d:.3f}")
        print(f"Micro F1: {micro_d:.3f}")
        print(f"Exact Match: {exact_d:.3f}")

    return {
        "macro": macro,
        "micro": micro,
        "exact": exact,
        "threshold": threshold,
        "labels_used": filtered_labels.tolist()
    }


In [92]:
results = evaluate_and_compare_fixed_thresh(
    val_loader, df_val,
    test_loader, df_test,
    mlb
)



Validation (Fixed Threshold)
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.weight', 'roberta.pooler.dense.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
Ensembling VALIDATION: 100%|██████████| 48/48 [00:20<00:00,  2.29it/s]

 VALIDATION (Fixed Thresh

# Post Hoc Interpretation -- Not Working -- Never mind for now

In [22]:
import shap
from captum.attr import IntegratedGradients
from transformers import AutoTokenizer
import torch.nn.functional as F



In [24]:
# ==========================
# LOAD BEST MODEL
# ==========================

model = TransformerClassifier(MODEL_NAME, num_classes)
model.load_state_dict(torch.load(MODEL_PATH))
model.to(device)
model.eval()


TransformerClassifier(
  (encoder): DebertaV2Model(
    (embeddings): DebertaV2Embeddings(
      (word_embeddings): Embedding(128100, 768, padding_idx=0)
      (LayerNorm): LayerNorm((768,), eps=1e-07, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): DebertaV2Encoder(
      (layer): ModuleList(
        (0-11): 12 x DebertaV2Layer(
          (attention): DebertaV2Attention(
            (self): DisentangledSelfAttention(
              (query_proj): Linear(in_features=768, out_features=768, bias=True)
              (key_proj): Linear(in_features=768, out_features=768, bias=True)
              (value_proj): Linear(in_features=768, out_features=768, bias=True)
              (pos_dropout): Dropout(p=0.1, inplace=False)
              (dropout): Dropout(p=0.1, inplace=False)
            )
            (output): DebertaV2SelfOutput(
              (dense): Linear(in_features=768, out_features=768, bias=True)
              (LayerNorm): LayerNorm((768,), 

In [28]:
# ==========================
# SHAP WRAPPER + EXPLAINER
# ==========================

import shap
import matplotlib.pyplot as plt
import numpy as np

class TextClassifierWrapper:
    def __init__(self, model, tokenizer, device):
        self.model = model
        self.tokenizer = tokenizer
        self.device = device

    def __call__(self, texts):
        texts = [str(t) for t in texts]
        encodings = self.tokenizer(
            texts,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=MAX_LEN
        )
        input_ids = encodings["input_ids"].to(self.device)
        attention_mask = encodings["attention_mask"].to(self.device)
        with torch.no_grad():
            logits = self.model(input_ids=input_ids, attention_mask=attention_mask)
            probs = torch.sigmoid(logits)
        return probs.cpu().numpy()

wrapped_model = TextClassifierWrapper(model, tokenizer, device)
masker = shap.maskers.Text(tokenizer)
explainer = shap.Explainer(wrapped_model, masker)


In [30]:
# ==========================
# GLOBAL SHAP SUMMARY PLOT
# ==========================

def shap_global_summary(texts, filename="shap_global_summary.png", top_n=20):
    shap_values = explainer(texts)

    all_scores = np.abs(shap_values.values)  # (samples, tokens, outputs)
    all_tokens = shap_values.data

    if all_scores.ndim == 3:
        all_scores = all_scores[:, :, 0]  # collapse output label dim

    token_contributions = {}
    for doc_tokens, doc_scores in zip(all_tokens, all_scores):
        for token, score in zip(doc_tokens, doc_scores):
            token = str(token)
            token_contributions[token] = token_contributions.get(token, []) + [abs(score)]

    token_avg_scores = {tok: np.mean(vals) for tok, vals in token_contributions.items()}
    top_items = sorted(token_avg_scores.items(), key=lambda x: x[1], reverse=True)[:top_n]
    top_tokens, top_scores = zip(*top_items)

    plt.figure(figsize=(10, 6))
    y_pos = np.arange(len(top_tokens))
    plt.barh(y_pos, top_scores, align='center')
    plt.yticks(y_pos, top_tokens)
    plt.xlabel('Mean |SHAP Value|')
    plt.title(f'Top {top_n} Most Influential Tokens Globally')
    plt.gca().invert_yaxis()
    plt.tight_layout()
    plt.savefig(filename)
    plt.close()


In [32]:
# ==========================
# RUN GLOBAL ANALYSIS
# ==========================

sample_texts = df_test[TEXT_COL].tolist()[:10]
shap_global_summary(sample_texts, filename="shap_global_summary.png", top_n=25)


PartitionExplainer explainer:  10%|█         | 1/10 [00:00<?, ?it/s]
  0%|          | 0/498 [00:00<?, ?it/s][A
 62%|██████▏   | 308/498 [00:00<00:00, 2946.00it/s][A
PartitionExplainer explainer:  30%|███       | 3/10 [00:23<00:30,  4.34s/it]
  0%|          | 0/498 [00:00<?, ?it/s][A
 24%|██▍       | 122/498 [00:00<00:00, 433.03it/s][A
 34%|███▍      | 170/498 [00:02<00:06, 54.17it/s] [A
 39%|███▉      | 194/498 [00:03<00:07, 41.11it/s][A
 41%|████▏     | 206/498 [00:04<00:07, 36.69it/s][A
 44%|████▍     | 218/498 [00:04<00:08, 32.98it/s][A
 45%|████▍     | 224/498 [00:05<00:08, 31.21it/s][A
 46%|████▌     | 230/498 [00:05<00:09, 29.49it/s][A
 47%|████▋     | 236/498 [00:05<00:09, 27.85it/s][A
 49%|████▊     | 242/498 [00:05<00:09, 26.32it/s][A
 50%|████▉     | 248/498 [00:06<00:09, 25.10it/s][A
 51%|█████     | 254/498 [00:06<00:10, 24.07it/s][A
 52%|█████▏    | 260/498 [00:06<00:10, 23.31it/s][A
 53%|█████▎    | 266/498 [00:07<00:10, 22.74it/s][A
 55%|█████▍    | 272/4

<a style='text-decoration:none;line-height:16px;display:flex;color:#5B5B62;padding:10px;justify-content:end;' href='https://deepnote.com?utm_source=created-in-deepnote-cell&projectId=49d39932-ba1f-4621-a036-ab99ade88496' target="_blank">
 </img>
Created in <span style='font-weight:600;margin-left:4px;'>Deepnote</span></a>