## Imports

In [11]:
from datasets import load_dataset
from sklearn.model_selection import train_test_split
import pandas as pd
import ast
from transformers import AutoTokenizer
from torch.utils.data import TensorDataset, DataLoader
import torch
import numpy as np
from transformers import AutoModel
import torch.nn as nn
from sklearn.metrics import f1_score, roc_auc_score
from torch.nn import BCEWithLogitsLoss
from torch.optim import AdamW
import matplotlib.pyplot as plt
from tqdm import tqdm

## Proccess data
1. **SemEval-2021 Task 5: Toxic Spans Detection**, to detect the offensive part of the message
2. **Jigsaw**, to detect toxic messages

In [15]:
# Helper functions for the SemEval-2021 Task 5: Toxic Spans Detection dataset
def extract_text_and_positions(dataset_split):
    X = []
    y = []  # each entry = set of toxic char indices
    for sample in dataset_split:
        text = sample["text_of_post"]
        X.append(text)
        try:
            toxic_positions = ast.literal_eval(sample["position"])
        except:
            toxic_positions = []
        y.append(set(toxic_positions))  # store as set for O(1) lookup
    return X, y

def encode_with_labels(texts, toxic_positions_list, tokenizer, max_length):
    input_ids, attention_masks, labels = [], [], []
    for text, toxic_positions in zip(texts, toxic_positions_list):
        encoding = tokenizer(
            text,
            truncation=True,
            padding="max_length",
            max_length=max_length,
            return_offsets_mapping=True,
        )
        # create token-level labels
        token_labels = []
        for start, end in encoding["offset_mapping"]:
            if start == end:  # special tokens
                token_labels.append(-100)
            else:
                toxic = any(pos in toxic_positions for pos in range(start, end))
                token_labels.append(1 if toxic else 0)

        input_ids.append(encoding["input_ids"])
        attention_masks.append(encoding["attention_mask"])
        labels.append(token_labels)

    return (
        torch.tensor(input_ids, dtype=torch.long),
        torch.tensor(attention_masks, dtype=torch.long),
        torch.tensor(labels, dtype=torch.long),
    )

In [14]:
# SemEval-2021 Task 5: Toxic Spans Detection
dataset = load_dataset("heegyu/toxic-spans")
train = dataset["train"]
test = dataset["test"]
X_train_span, y_train_span = extract_text_and_positions(train)
X_test_span, y_test_span = extract_text_and_positions(test)
# Split to train and val
print(len(X_train_span))
print(len(X_test_span))
X_train_span, X_val_span, y_train_span, y_val_span = train_test_split(
    X_train_span, y_train_span, test_size=0.15, random_state=2025
)

# Jigsaw
# -Train / Val-
subcategories = ["severe_toxic", "obscene", "threat", "insult", "identity_hate"]
train_data = pd.read_csv("jigsaw-toxic-comment-data/train.csv")
# If any subcategory is 1, set toxic to 1
train_data["toxic"] = train_data[["toxic"] + subcategories].max(axis=1)
X_train_toxic = train_data["comment_text"]
y_train_toxic = train_data["toxic"]
# Split to train and val
X_train_toxic, X_val_toxic, y_train_toxic, y_val_toxic = train_test_split(
    X_train_toxic, y_train_toxic, test_size=0.15, stratify=y_train_toxic, random_state=2025
)

# -Test-
test_text = pd.read_csv("jigsaw-toxic-comment-data/test.csv")
test_labels = pd.read_csv("jigsaw-toxic-comment-data/test_labels.csv")
# Keep only rows where toxic is not -1
mask = test_labels["toxic"] != -1
test_text = test_text[mask].reset_index(drop=True)
test_labels = test_labels[mask].reset_index(drop=True)
test_labels["toxic"] = test_labels[["toxic"] + subcategories].max(axis=1)
X_test_toxic = test_text["comment_text"]
y_test_toxic = test_labels["toxic"]

10006
1000


In [6]:
train[0]

{'probability': '{(86, 92): 0.6666666666666666, (8, 13): 0.6666666666666666}',
 'position': '[8, 9, 10, 11, 12, 86, 87, 88, 89, 90, 91]',
 'text': "{'stupid': 0.6666666666666666, 'clown': 0.6666666666666666}",
 'type': "{'insult': 1.0}",
 'support': 3,
 'text_of_post': 'Another clown in favour of more tax in this country. Blows my mind people can be this stupid.',
 'position_probability': '{86: 0.6666666666666666, 87: 0.6666666666666666, 88: 0.6666666666666666, 89: 0.6666666666666666, 90: 0.6666666666666666, 91: 0.6666666666666666, 8: 0.6666666666666666, 9: 0.6666666666666666, 10: 0.6666666666666666, 11: 0.6666666666666666, 12: 0.6666666666666666}',
 'toxic': 1}

In [17]:
X_train_span[0]

'your an idiot life for 10k smh'

In [18]:
# Load tokenizer for DeBERTa-v3-base (moved before usage)
tokenizer = AutoTokenizer.from_pretrained("microsoft/deberta-v3-base")

avg_tokens = 90

# Toxic Spans Detection
train_input_ids, train_attention_masks, train_labels = encode_with_labels(
    X_train_span, y_train_span, tokenizer, avg_tokens
)
val_input_ids, val_attention_masks, val_labels = encode_with_labels(
    X_val_span, y_val_span, tokenizer, avg_tokens
)
test_input_ids, test_attention_masks, test_labels = encode_with_labels(
    X_test_span, y_test_span, tokenizer, avg_tokens
)

train_span_dataset = TensorDataset(train_input_ids, train_attention_masks, train_labels)
val_span_dataset = TensorDataset(val_input_ids, val_attention_masks, val_labels)
test_span_dataset = TensorDataset(test_input_ids, test_attention_masks, test_labels)

train_span_loader = DataLoader(train_span_dataset, batch_size=8, shuffle=True)
val_span_loader = DataLoader(val_span_dataset, batch_size=8, shuffle=False)
test_span_loader = DataLoader(test_span_dataset, batch_size=8, shuffle=False)

# Jigsaw
# Need to change them since they are pandas series objects
X_train_toxic = np.array(X_train_toxic, dtype=str)
X_val_toxic = np.array(X_val_toxic, dtype=str)
X_test_toxic = np.array(X_test_toxic, dtype=str)

y_train_toxic = np.array(y_train_toxic, dtype=np.float32)
y_val_toxic = np.array(y_val_toxic, dtype=np.float32)
y_test_toxic = np.array(y_test_toxic, dtype=np.float32)

# Tokenization
train_encodings = tokenizer(
    X_train_toxic.tolist(),  # Convert to list for better compatibility
    truncation=True,
    padding="max_length",
    max_length=avg_tokens,
    return_tensors="pt",
)

val_encodings = tokenizer(
    X_val_toxic.tolist(),
    truncation=True,
    padding="max_length",
    max_length=avg_tokens,
    return_tensors="pt",
)

test_encodings = tokenizer(
    X_test_toxic.tolist(),
    truncation=True,
    padding="max_length",
    max_length=avg_tokens,
    return_tensors="pt",
)

y_train_toxic_tensor = torch.tensor(y_train_toxic, dtype=torch.float32)
y_val_toxic_tensor = torch.tensor(y_val_toxic, dtype=torch.float32)
y_test_toxic_tensor = torch.tensor(y_test_toxic, dtype=torch.float32)

train_toxic_dataset = TensorDataset(
    train_encodings["input_ids"],
    train_encodings["attention_mask"],
    y_train_toxic_tensor
)

val_toxic_dataset = TensorDataset(
    val_encodings["input_ids"],
    val_encodings["attention_mask"],
    y_val_toxic_tensor
)

test_toxic_dataset = TensorDataset(
    test_encodings["input_ids"],
    test_encodings["attention_mask"],
    y_test_toxic_tensor
)

train_toxic_loader = DataLoader(train_toxic_dataset, batch_size=8, shuffle=True)
val_toxic_loader = DataLoader(val_toxic_dataset, batch_size=8, shuffle=False) 
test_toxic_loader = DataLoader(test_toxic_dataset, batch_size=8, shuffle=False)  

Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


## Model architecture and training/evaluation

In [None]:
# Model needs to contain 2 different heads (return values) one for the classification problem
# and one for the span
class ToxicityModel(nn.Module):
    def __init__(self, model_name="microsoft/deberta-v3-base"):
        super().__init__()
        self.backbone = AutoModel.from_pretrained(model_name)
        hidden = self.backbone.config.hidden_size
        self.dropout = nn.Dropout(0.1)

        # Heads
        self.seq_head = nn.Linear(hidden, 1)   # [batch, 1]
        self.tok_head = nn.Linear(hidden, 2)   # [batch, seq_len, 2] -> CE over classes

    def forward(self, input_ids, attention_mask, token_type_ids=None):
        out = self.backbone(
            input_ids=input_ids,
            attention_mask=attention_mask,
            token_type_ids=token_type_ids,
            return_dict=True
        )
        last_hidden = self.dropout(out.last_hidden_state)   # [B, T, H]
        cls_pooled  = self.dropout(last_hidden[:, 0])       # CLS pooling for DeBERTa-v3

        seq_logits  = self.seq_head(cls_pooled)             # [B, 1]
        tok_logits  = self.tok_head(last_hidden)      
              # [B, T, 2]
        return seq_logits, tok_logits


In [None]:
# Helper functions
def freeze_backbone(model, freeze=True):
    for p in model.backbone.parameters():
        p.requires_grad = not freeze

@torch.no_grad()
def bin_acc_from_logits(logits, labels, threshold):
    probs = torch.sigmoid(logits)
    preds = (probs >= threshold).long()
    labs  = labels.view_as(preds).long()
    return (preds == labs).float().mean().item()

In [None]:
# First we train on the jigsaw dataset for toxicity detection

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model  = ToxicityModel().to(device)

# Unfreeze backbone for Stage 1
freeze_backbone(model, freeze=False)

clf_criterion = BCEWithLogitsLoss()
optimizer = AdamW(model.parameters(), lr=2e-5)

num_epochs_stage1 = 4
tr_losses, va_losses = [], []
tr_accs, va_accs     = [], []

for epoch in range(1, num_epochs_stage1+1):
    # ---- train ----
    model.train()
    total_loss = 0.0
    total_acc  = 0.0
    total_n    = 0

    for batch in train_toxic_loader:
        input_ids, attention_mask, labels = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device).float().unsqueeze(1)

        optimizer.zero_grad()
        seq_logits, _ = model(input_ids=input_ids, attention_mask=attention_mask)

        loss = clf_criterion(seq_logits, labels)
        loss.backward()
        optimizer.step()

        bs = labels.size(0)
        total_loss += loss.item() * bs
        total_acc  += bin_acc_from_logits(seq_logits.detach(), labels) * bs
        total_n    += bs

    tr_losses.append(total_loss / total_n)
    tr_accs.append(total_acc / total_n)

    print("Train epoch done")

    # ---- validate ----
    model.eval()
    val_loss = 0.0
    val_acc  = 0.0
    val_n    = 0
    with torch.no_grad():
        for batch in val_toxic_loader:
            input_ids, attention_mask, labels = batch
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            labels = labels.to(device).float().unsqueeze(1)

            seq_logits, _ = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = clf_criterion(seq_logits, labels)

            bs = labels.size(0)
            val_loss += loss.item() * bs
            val_acc  += bin_acc_from_logits(seq_logits, labels) * bs
            val_n    += bs

    va_losses.append(val_loss / val_n)
    va_accs.append(val_acc / val_n)

    print(f"[Epoch {epoch}] TrainLoss {tr_losses[-1]:.4f} | TrainAcc {tr_accs[-1]:.4f} | "
          f"ValLoss {va_losses[-1]:.4f} | ValAcc {va_accs[-1]:.4f}")
    
    torch.save(model.state_dict(), f"toxicity_model_epoch{epoch}.pth")

# ---- Run grid search on validation set for final threshold ----
model.eval()
all_preds, all_labels = [], []
with torch.no_grad():
    for batch in tqdm(val_toxic_loader, desc="Collecting predictions for threshold search"):
        input_ids, attention_mask, labels = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.cpu().numpy()

        seq_logits, _ = model(input_ids=input_ids, attention_mask=attention_mask)
        probs = torch.sigmoid(seq_logits).cpu().numpy()

        all_preds.extend(probs)
        all_labels.extend(labels)

all_preds = np.array(all_preds).flatten()
all_labels = np.array(all_labels).flatten()

# Search for best threshold
thresholds = np.linspace(0.1, 0.9, 81) 
best_f1, best_thr = 0, 0.5
for thr in thresholds:
    f1 = f1_score(all_labels, (all_preds >= thr).astype(int))
    if f1 > best_f1:
        best_f1, best_thr = f1, thr

roc_auc = roc_auc_score(all_labels, all_preds)

print(f"\nBest threshold: {best_thr:.2f}")
print(f"Validation F1 at best threshold: {best_f1:.4f}")
print(f"Validation ROC-AUC: {roc_auc:.4f}")

# ---- Plot Stage 1 curves ----
plt.figure(figsize=(11,4))
plt.subplot(1,2,1); plt.plot(tr_losses, label="Train"); plt.plot(va_losses, label="Val")
plt.title("Stage 1: Loss"); plt.xlabel("Epoch"); plt.legend()
plt.subplot(1,2,2); plt.plot(tr_accs, label="Train"); plt.plot(va_accs, label="Val")
plt.title("Stage 1: Accuracy"); plt.xlabel("Epoch"); plt.legend()
plt.show()


In [None]:
# Model loading after the first part of the training proccess
model = ToxicityModel(model_name="microsoft/deberta-v3-base")
checkpoint_path = "models/toxicity_model_final.pth"
checkpoint = torch.load(checkpoint_path, map_location="cpu")  # "cuda"
model.load_state_dict(checkpoint)

In [None]:
# Then we fine tune even further using the span dataset, due to the large difference in data amount
# Start with backbone frozen (Stage 2a)
freeze_backbone(model, freeze=True)

tok_criterion = nn.CrossEntropyLoss(ignore_index=-100)

# Stage 2a: Higher LR for classification head only
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=5e-4, weight_decay=0.01)

num_epochs_stage2 = 8
unfreeze_epoch = 4  # at this epoch we unfreeze backbone (Stage 2b)

tr_losses, va_losses = [], []

for epoch in range(1, num_epochs_stage2+1):
    # ---- Unfreeze backbone at Stage 2b ----
    if epoch == unfreeze_epoch:
        freeze_backbone(model, freeze=False)
        # Stage 2b: Lower LR for full model fine-tuning
        optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5, weight_decay=0.01)

    # ---- train ----
    model.train()
    total_loss = 0.0
    total_n    = 0

    for batch in train_span_loader:
        input_ids, attention_mask, tok_labels = batch
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        tok_labels = tok_labels.to(device)   # [B, T], values in {0,1} or -100 for ignore

        optimizer.zero_grad()
        _, tok_logits = model(input_ids=input_ids, attention_mask=attention_mask)
        # tok_logits: [B, T, 2]

        loss = tok_criterion(tok_logits.view(-1, 2), tok_labels.view(-1))
        loss.backward()
        
        # Gradient clipping
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        
        optimizer.step()

        bs = input_ids.size(0)
        total_loss += loss.item() * bs
        total_n    += bs

    tr_losses.append(total_loss / total_n)
    print(f"Train epoch {epoch} done")

    # ---- validate ----
    model.eval()
    val_loss = 0.0
    val_n    = 0
    with torch.no_grad():
        for batch in val_span_loader:
            input_ids, attention_mask, tok_labels = batch
            input_ids = input_ids.to(device)
            attention_mask = attention_mask.to(device)
            tok_labels = tok_labels.to(device)

            _, tok_logits = model(input_ids=input_ids, attention_mask=attention_mask)
            loss = tok_criterion(tok_logits.view(-1, 2), tok_labels.view(-1))

            bs = input_ids.size(0)
            val_loss += loss.item() * bs
            val_n    += bs

    va_losses.append(val_loss / val_n)

    print(f"[Epoch {epoch}] TrainLoss {tr_losses[-1]:.4f} | ValLoss {va_losses[-1]:.4f}")

    # Save checkpoint
    torch.save(model.state_dict(), f"toxicity_span_epoch{epoch}.pth")

# ---- Plot Stage 2 curves ----
plt.figure(figsize=(8,4))
plt.plot(tr_losses, label="Train")
plt.plot(va_losses, label="Val")
plt.axvline(x=unfreeze_epoch-1, color='r', linestyle='--', alpha=0.7, label='Unfreeze')
plt.title("Stage 2: Span Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.legend()
plt.show()


## Model testing

In [None]:
# Finaly we test the model on our toxicity detection and span detection test data
# Toxicity

# Span