In [4]:
import os
import csv
import torch
import pickle
import numpy as np
import pandas as pd
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from tqdm.notebook import tqdm
from torch.utils.data import Dataset, DataLoader
from sklearn.model_selection import train_test_split
from sklearn.feature_extraction.text import CountVectorizer
from torchmetrics import AUROC, F1Score
from transformers import BertTokenizer, BertModel, BertConfig

In [5]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

accumulation_steps = 4
def worst_group_accuracy(prediction, y):
    """
        Compute the worst group accuracy, with the groups being defined by ['male', 'female', 'LGBTQ',
        'christian', 'muslim', 'other_religions', 'black', 'white'] for positive and negative toxicity.
        arguments:
            prediction [pandas.DataFrame]: dataframe with 2 columns (index and pred)
            y [pandas.DataFrame]: dataframe containing the metadata
        returns:
            wga [float]: worst group accuracy
    """
    y.loc[prediction.index, 'pred'] = prediction.pred

    categories = ['male', 'female', 'LGBTQ', 'christian', 'muslim', 'other_religions', 'black', 'white']
    accuracies = []
    for category in categories:
        for label in [0, 1]:
            group = y.loc[y[category] == label]
            group_accuracy = (group['y'] == (group['pred'] > 0.5)).mean()
            accuracies.append(group_accuracy)
    wga = np.min(accuracies)
    return wga

# Use mixed precision training

from torch.cuda.amp import GradScaler, autocast
import time
scaler = GradScaler()

def train_BERT_model(model, dataloader, criterion, optimizer, checkpoint_folder=None):
    model.train()
    total_loss = 0
    start_time = time.time()
    checkpoint_interval = 900  # 15 minutes in seconds

    for batch in tqdm(dataloader, leave=False):
        input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
        input_ids, attention_mask, labels = input_ids.to(device), attention_mask.to(device), labels.to(device).float()

        optimizer.zero_grad()

        with autocast():
            outputs = model(input_ids, attention_mask=attention_mask)
            loss = criterion(outputs.logits, labels)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        total_loss += loss.item()

        # Save checkpoint every 15 minutes
        if checkpoint_folder is not None and time.time() - start_time >= checkpoint_interval:
            checkpoint_path = os.path.join(checkpoint_folder, f'checkpoint_{time.time()}.pt')
            torch.save(model.state_dict(), checkpoint_path)
            start_time = time.time()

    average_loss = total_loss / len(dataloader)
    return average_loss

def evaluate_BERT_model(model, dataloader, criterion):
    model.eval()
    predictions = []
    all_labels = []
    all_indices = []

    for batch in tqdm(dataloader, leave=False):
        input_ids, attention_mask, labels = batch['input_ids'], batch['attention_mask'], batch['labels']
        input_ids = input_ids.to(device)
        attention_mask = attention_mask.to(device)
        labels = labels.to(device).float()

        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
            pred_logits = outputs.logits
            pred = torch.sigmoid(pred_logits)  # Convert logits to probabilities for binary classification

        # Collect all predictions, labels, and indices
        predictions.extend(pred.detach().cpu().tolist())
        all_labels.extend(labels.detach().cpu().tolist())
        all_indices.extend(batch['index'].tolist())  # Make sure 'index' is part of your batch

    # Create DataFrame for metric calculation
    pred_df = pd.DataFrame({'index': all_indices, 'pred': predictions})
    label_df = pd.DataFrame(all_labels, index=all_indices)  # Assuming labels are structured correctly
    dataset_loss = np.mean(losses)
    dataset_metric = worst_group_accuracy(pred_df, label_df)
    return dataset_loss, dataset_metric



In [6]:
class BertDataset(Dataset):
    def __init__(self, data_dir, mode, max_length=512):
        super(BertDataset, self).__init__()
        assert mode in ['train', 'val', 'test']
        self.data = pd.read_csv(os.path.join(data_dir, f'{mode}_x.csv'), index_col=0)

        if mode != 'test':
            self.labels = pd.read_csv(os.path.join(data_dir, f'{mode}_y.csv'))

        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.max_length = max_length

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        text = str(self.data.iloc[idx, 0])
        inputs = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_length,
            padding='max_length',
            return_attention_mask=True,
            truncation=True
        )
        input_ids = inputs['input_ids']
        attention_mask = inputs['attention_mask']

        if hasattr(self, 'labels'):
            label = self.labels.iloc[idx].tolist()  # Assuming labels are in separate columns per class
            label = torch.tensor(label, dtype=torch.long)
            return {
                'input_ids': torch.tensor(input_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long),
                'labels': label
            }
        else:
            return {
                'input_ids': torch.tensor(input_ids, dtype=torch.long),
                'attention_mask': torch.tensor(attention_mask, dtype=torch.long)
            }

data_dir = 'kaggle_data'
# Example usage
train_dataset = BertDataset(data_dir, 'train')
val_dataset = BertDataset(data_dir, 'val')

train_dataloader = DataLoader(train_dataset, batch_size=32, shuffle=True)
val_dataloader = DataLoader(val_dataset, batch_size=32, shuffle=False)

In [7]:
from transformers import BertForSequenceClassification
from transformers import BertTokenizer
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model_bert = BertForSequenceClassification.from_pretrained('bert-base-uncased', num_labels=17)

Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [8]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model_bert.to(device)
optimizer = optim.AdamW(model_bert.parameters(), lr=0.0001, weight_decay=0.1)
criterion = nn.BCEWithLogitsLoss()

train_loss, train_metric = train_BERT_model(model_bert, train_dataloader, criterion, optimizer)
rnn_val_loss, rnn_val_metric = evaluate_BERT_model(model_bert, val_dataloader, criterion)


print(f'bert classifier validation loss {rnn_val_loss:.4f} WGA {rnn_val_metric:.4f}')

  0%|          | 0/10762 [00:00<?, ?it/s]



KeyboardInterrupt: 