In [None]:
!pip install torchmetrics --quiet

In [None]:
import os
import csv
import torch
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm
from transformers import DistilBertModel, DistilBertTokenizer
from torch.utils.data import Dataset, DataLoader
from sklearn.feature_extraction.text import CountVectorizer

In [None]:
data_dir = '/content/kaggle_data'
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
torch.manual_seed(0)

<torch._C.Generator at 0x7f91404ba6b0>

In [None]:
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
bert_model = DistilBertModel.from_pretrained('distilbert-base-uncased').to(device)

In [None]:
class ImprovedDataset(Dataset):
    def __init__(self, data_dir, mode, max_length=128):
        super().__init__()
        self.mode = mode
        self.max_length = max_length
        self.data = pd.read_csv(os.path.join(data_dir, f'{mode}_x.csv'), index_col=0)
        self.data['string'] = self.data['string'].fillna("")

        if self.mode != 'test':
            self.label = pd.read_csv(os.path.join(data_dir, f'{mode}_y.csv'))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = self.data.iloc[idx, 0]
        inputs = tokenizer(
            text,
            max_length=self.max_length,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        if self.mode == 'test':
            return inputs['input_ids'].squeeze(), inputs['attention_mask'].squeeze(), idx
        else:
            y = torch.tensor(self.label.iloc[idx, -2]).float()
            groups = torch.tensor(self.label.iloc[idx, :8]).float()
            return inputs['input_ids'].squeeze(), inputs['attention_mask'].squeeze(), y, groups, idx

In [None]:
class GradientReversal(nn.Module):
    def forward(self, x):
        return x
    def backward(self, grad_output):
        return -grad_output

class BERTFairClassifier(nn.Module):
    def __init__(self, num_groups=8, hidden_dim=128):
        super().__init__()
        self.bert = bert_model
        self.grl = GradientReversal()

        # Main toxicity classifier
        self.toxicity_head = nn.Sequential(
            nn.Linear(768, hidden_dim),
            nn.ReLU(),
            nn.Linear(hidden_dim, 1)
        )

        # Adversarial branch
        self.adversary = nn.Sequential(
            nn.Linear(768, 64),
            nn.ReLU(),
            nn.Linear(64, num_groups)
        )

    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = outputs.last_hidden_state[:, 0, :]

        # Toxicity prediction
        toxicity = torch.sigmoid(self.toxicity_head(pooled_output)).squeeze()

        # Adversarial prediction
        rev_output = self.grl(pooled_output)
        group_pred = torch.sigmoid(self.adversary(rev_output))

        return toxicity, group_pred

In [None]:
class GroupDRO:
    def __init__(self, num_groups=8, eta=0.1):
        self.eta = eta
        self.group_weights = torch.ones(num_groups, device=device)/num_groups
        self.group_losses = torch.zeros(num_groups, device=device)

    def update(self, group_losses):
        self.group_weights *= torch.exp(self.eta * group_losses)
        self.group_weights /= self.group_weights.sum()

    def get_weighted_loss(self, losses, group_membership):
        group_losses = (group_membership.T @ losses) / (group_membership.sum(dim=0) + 1e-8)
        self.update(group_losses.detach())
        return torch.sum(self.group_weights * group_losses)

In [None]:
def collate_fn(batch):
    if len(batch[0]) == 3:  # Test batch
        input_ids, attention_masks, indices = zip(*batch)
        return (
            torch.stack(input_ids).to(device),
            torch.stack(attention_masks).to(device),
            torch.tensor(indices).to(device)
        )
    else:  # Train/val batch
        input_ids, attention_masks, ys, groups, indices = zip(*batch)
        return (
            torch.stack(input_ids).to(device),
            torch.stack(attention_masks).to(device),
            torch.tensor(ys).float().to(device),
            torch.stack(groups).float().to(device),
            torch.tensor(indices).to(device)
        )

In [None]:
def train_epoch(model, optimizer, dro, train_loader):
    model.train()
    total_loss = 0
    dro = GroupDRO(eta=0.2)

    for batch in tqdm(train_loader, desc="Training"):
        input_ids, attention_mask, y, groups, _ = batch
        optimizer.zero_grad()

        toxicity_pred, group_pred = model(input_ids, attention_mask)

        # Compute losses
        toxicity_loss = F.binary_cross_entropy(toxicity_pred, y)
        adversary_loss = F.binary_cross_entropy(group_pred, groups)

        # DRO component
        with torch.no_grad():
            per_sample_loss = F.binary_cross_entropy(toxicity_pred, y, reduction='none')
        dro_loss = dro.get_weighted_loss(per_sample_loss, groups)

        # Combined loss
        loss = dro_loss + toxicity_loss - 0.5 * adversary_loss

        loss.backward()
        optimizer.step()
        total_loss += loss.item()

    return total_loss / len(train_loader)

In [None]:
def evaluate(model, dataloader):
    model.eval()
    predictions = []
    indices = []

    with torch.no_grad():
        for batch in tqdm(dataloader, desc="Evaluating"):
            if len(batch) == 3:  # Test set
                input_ids, attention_mask, idx = batch
                toxicity_pred, _ = model(input_ids, attention_mask)
                predictions.extend(toxicity_pred.cpu().numpy())
                indices.extend(idx.cpu().numpy())
            else:  # Train/val set
                input_ids, attention_mask, y, groups, idx = batch
                toxicity_pred, _ = model(input_ids, attention_mask)
                predictions.extend(toxicity_pred.cpu().numpy())
                indices.extend(idx.cpu().numpy())

    return pd.DataFrame({'index': indices, 'pred': predictions})

In [None]:
def worst_group_accuracy(prediction, y):
    y.loc[prediction.index, 'pred'] = prediction.pred
    categories = ['male', 'female', 'LGBTQ', 'christian', 'muslim', 'other_religions', 'black', 'white']
    accuracies = []
    for category in categories:
        for label in [0, 1]:
            group = y.loc[y[category] == label]
            group_accuracy = (group['y'] == (group['pred'] > 0.4)).mean()
            accuracies.append(group_accuracy)
    return np.min(accuracies)

In [None]:
train_dataset = ImprovedDataset(data_dir, 'train')
val_dataset = ImprovedDataset(data_dir, 'val')

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [None]:
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False, collate_fn=collate_fn)

In [None]:
model = BERTFairClassifier().to(device)
optimizer = optim.AdamW([
    {'params': model.bert.parameters(), 'lr': 2e-5},
    {'params': model.toxicity_head.parameters(), 'lr': 1e-3},
    {'params': model.adversary.parameters(), 'lr': 1e-3}
], weight_decay=0.01)

In [None]:
best_wga = 0
for epoch in range(3):
    print(f"\nEpoch {epoch+1}/3")
    train_loss = train_epoch(model, optimizer, GroupDRO(), train_loader)

    # Validation
    val_preds = evaluate(model, val_loader)
    wga = worst_group_accuracy(val_preds, val_dataset.label)
    print(f"Train Loss: {train_loss:.4f} | Val WGA: {wga:.4f}")

    if wga > best_wga:
        best_wga = wga
        torch.save(model.state_dict(), 'best_bert_model.pth')

In [None]:
val_preds = evaluate(model, val_loader)
wga = worst_group_accuracy(val_preds, val_dataset.label)
print(f"Train Loss: {train_loss:.4f} | Val WGA: {wga:.4f}")

Evaluating:   0%|          | 0/2824 [00:00<?, ?it/s]

  groups = torch.tensor(self.label.iloc[idx, :8]).float()


In [None]:
model.load_state_dict(torch.load('best_bert_model.pth'))
test_dataset = ImprovedDataset(data_dir, 'test')
test_loader = DataLoader(test_dataset, batch_size=16, collate_fn=collate_fn)

test_preds = evaluate(model, test_loader)
test_preds['pred'] = (test_preds['pred'] > 0.5).astype(int)
test_preds.rename(columns={'index':'ID'}).to_csv('bert_submission.csv', index=False)

Evaluating:   0%|          | 0/8362 [00:00<?, ?it/s]