In [None]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from transformers import DistilBertTokenizer, DistilBertModel
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, classification_report
from tqdm import tqdm
import pandas as pd
import os

# Load datasets
train_x = pd.read_csv('kaggle_data/train_x.csv')
train_y = pd.read_csv('kaggle_data/train_y.csv')

def predict_and_save(test_path, model_path, output_path):
    test_x = pd.read_csv(test_path)
    test_texts = test_x['string'].tolist()

    # Prepare test dataset
    test_dataset = ToxicCommentDataset(test_texts, [0]*len(test_texts), tokenizer, max_len)  # Dummy labels
    test_loader = DataLoader(test_dataset, batch_size=16, shuffle=False)

    # Load the trained model
    model.load_state_dict(torch.load(model_path))
    model.eval()
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Predict
    test_preds = []
    with torch.no_grad():
        for batch in tqdm(test_loader, desc="Predicting", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)

            with torch.cuda.amp.autocast():
                outputs = model(input_ids, attention_mask)
            test_preds.extend(torch.sigmoid(outputs).cpu().numpy())

    test_preds = [1 if p > 0.5 else 0 for p in test_preds]

    # Save predictions
    test_x['prediction'] = test_preds
    test_x[['index', 'prediction']].to_csv(output_path, index=False)
    print(f"Predictions saved to {output_path}")

# Define the Focal Loss class
class FocalLoss(nn.Module):
    def __init__(self, alpha=1, gamma=2, reduction='mean'):
        super(FocalLoss, self).__init__()
        self.alpha = alpha
        self.gamma = gamma
        self.reduction = reduction

    def forward(self, inputs, targets):
        BCE_loss = nn.BCEWithLogitsLoss(reduction='none')(inputs, targets)
        pt = torch.exp(-BCE_loss)  # Prevents nans when probability is 0
        F_loss = self.alpha * (1-pt)**self.gamma * BCE_loss

        if self.reduction == 'mean':
            return F_loss.mean()
        elif self.reduction == 'sum':
            return F_loss.sum()
        else:
            return F_loss

# Define a custom Dataset
class ToxicCommentDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = self.texts[idx]
        label = self.labels[idx]

        encoding = self.tokenizer(
            text,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].squeeze(0),
            'attention_mask': encoding['attention_mask'].squeeze(0),
            'labels': torch.tensor(label, dtype=torch.float)
        }

# Define the DistilBERT-based model
class ToxicCommentClassifier(nn.Module):
    def __init__(self, n_classes):
        super(ToxicCommentClassifier, self).__init__()
        self.distilbert = DistilBertModel.from_pretrained('distilbert-base-uncased')
        self.dropout = nn.Dropout(0.3)
        self.out = nn.Linear(self.distilbert.config.hidden_size, n_classes)

    def forward(self, input_ids, attention_mask):
        output = self.distilbert(input_ids=input_ids, attention_mask=attention_mask)
        hidden_state = output.last_hidden_state[:, 0]  # Take the [CLS] token's representation
        dropped_out = self.dropout(hidden_state)
        return self.out(dropped_out)

# Load data
texts = train_x['string'].tolist()
labels = train_y['y'].values

# Train-test split
train_texts, val_texts, train_labels, val_labels = train_test_split(
    texts, labels, test_size=0.2, random_state=42
)

# Tokenizer and dataset
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
max_len = 128

train_dataset = ToxicCommentDataset(train_texts, train_labels, tokenizer, max_len)
val_dataset = ToxicCommentDataset(val_texts, val_labels, tokenizer, max_len)

train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=16, shuffle=False)

# Model, loss, and optimizer
model = ToxicCommentClassifier(n_classes=1)
loss_fn = FocalLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
scaler = torch.amp.GradScaler('cuda')  # For mixed precision training

# Define the worst group accuracy function
def worst_group_accuracy(targets, preds, groups):
    group_accuracies = {}

    for group in groups.columns:
        group_mask = groups[group] == 1
        group_targets = targets[group_mask]
        group_preds = preds[group_mask]

        if len(group_targets) > 0:
            accuracy = accuracy_score(group_targets, group_preds)
            group_accuracies[group] = accuracy

    worst_group = min(group_accuracies, key=group_accuracies.get)
    return worst_group, group_accuracies[worst_group]

# Training loop
def train_model(model, train_loader, val_loader, loss_fn, optimizer, epochs=3, model_path="model/best_model.pth", patience=3):
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    best_val_loss = float('inf')
    no_improve_epochs = 0

    for epoch in range(epochs):
        model.train()
        total_loss = 0

        print(f"Epoch {epoch+1}/{epochs}")
        for batch in tqdm(train_loader, desc="Training", leave=False):
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['labels'].to(device)

            optimizer.zero_grad()

            with torch.amp.autocast('cuda'):
                outputs = model(input_ids, attention_mask)
                loss = loss_fn(outputs.squeeze(-1), labels)

            scaler.scale(loss).backward()
            scaler.step(optimizer)
            scaler.update()

            total_loss += loss.item()

        print(f"Loss: {total_loss / len(train_loader):.4f}")

        # Validation
        model.eval()
        val_preds = []
        val_targets = []

        with torch.no_grad():
            for batch in tqdm(val_loader, desc="Validating", leave=False):
                input_ids = batch['input_ids'].to(device)
                attention_mask = batch['attention_mask'].to(device)
                labels = batch['labels'].to(device)

                with torch.amp.autocast('cuda'):
                    outputs = model(input_ids, attention_mask)
                val_preds.extend(torch.sigmoid(outputs).cpu().numpy())
                val_targets.extend(labels.cpu().numpy())

        val_preds = [1 if p > 0.5 else 0 for p in val_preds]
        val_loss = loss_fn(torch.tensor(val_preds, dtype=torch.float), torch.tensor(val_targets, dtype=torch.float)).item()

        if val_loss < best_val_loss:
            best_val_loss = val_loss
            no_improve_epochs = 0
            torch.save(model.state_dict(), model_path)
            print(f"Model saved to {model_path}")
        else:
            no_improve_epochs += 1
            if no_improve_epochs >= patience:
                print("Early stopping triggered.")
                break

        print("Validation Accuracy:", accuracy_score(val_targets, val_preds))
        print(classification_report(val_targets, val_preds))

        # Calculate worst group accuracy
        group_labels = train_y.iloc[:, :-1]  # Adjust according to your dataset structure
        worst_group, worst_accuracy = worst_group_accuracy(
            torch.tensor(val_targets),
            torch.tensor(val_preds),
            group_labels
        )
        print(f"Worst group: {worst_group}, Accuracy: {worst_accuracy}")

train_model(model, train_loader, val_loader, loss_fn, optimizer, epochs=3, model_path="model/best_model.pth", patience=3)

# Predict and save predictions
predict_and_save(test_path='kaggle_data/test_x.csv', model_path="model/best_model.pth", output_path="predictions.csv")


Epoch 1/3


Training:   2%|▏         | 311/13452 [02:56<15:24, 14.21it/s]   