In [1]:
import pandas as pd
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import DistilBertTokenizer, DistilBertForSequenceClassification, AdamW, get_linear_schedule_with_warmup
from sklearn.model_selection import train_test_split
from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

# Load dataset
df = pd.read_csv('/home/k64769/OnionOrNot.csv')

# Clean data
df = df.dropna(subset=['text', 'label']).reset_index(drop=True)

# Define constants
MAX_LEN = 128
BATCH_SIZE = 32
EPOCHS = 10  # Increased training epochs for better convergence
LEARNING_RATE = 2e-5
WEIGHT_DECAY = 0.01

# Dataset class
class NewsDataset(Dataset):
    def __init__(self, texts, labels, tokenizer, max_len):
        self.texts = texts
        self.labels = labels
        self.tokenizer = tokenizer
        self.max_len = max_len

    def __len__(self):
        return len(self.texts)

    def __getitem__(self, idx):
        text = str(self.texts[idx])
        label = self.labels[idx]

        encoding = self.tokenizer.encode_plus(
            text,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            truncation=True,
            return_attention_mask=True,
            return_tensors='pt'
        )

        return {
            'input_ids': encoding['input_ids'].flatten(),
            'attention_mask': encoding['attention_mask'].flatten(),
            'label': torch.tensor(label, dtype=torch.long)
        }

# Prepare data
train_text, val_text, train_labels, val_labels = train_test_split(
    df['text'].tolist(), df['label'].tolist(), test_size=0.2, stratify=df['label'], random_state=42
)

# Initialize tokenizer
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')

# Create datasets
train_dataset = NewsDataset(train_text, train_labels, tokenizer, MAX_LEN)
val_dataset = NewsDataset(val_text, val_labels, tokenizer, MAX_LEN)

# Data loaders
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=BATCH_SIZE)

# Set device
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Initialize model
model = DistilBertForSequenceClassification.from_pretrained(
    'distilbert-base-uncased', num_labels=2
).to(device)

# Optimizer with weight decay
no_decay = ['bias', 'LayerNorm.bias', 'LayerNorm.weight']
optimizer_grouped_parameters = [
    {'params': [p for n, p in model.named_parameters() if not any(nd in n for nd in no_decay)],
     'weight_decay': WEIGHT_DECAY},
    {'params': [p for n, p in model.named_parameters() if any(nd in n for nd in no_decay)],
     'weight_decay': 0.0}
]
optimizer = AdamW(optimizer_grouped_parameters, lr=LEARNING_RATE)

# Scheduler
total_steps = len(train_loader) * EPOCHS
scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=0,
    num_training_steps=total_steps
)

# Training loop
best_accuracy = 0.0
for epoch in range(EPOCHS):
    model.train()
    total_loss = 0.0

    for batch in train_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        optimizer.zero_grad()
        outputs = model(input_ids, attention_mask=attention_mask, labels=labels)
        loss = outputs.loss
        total_loss += loss.item()

        loss.backward()
        optimizer.step()
        scheduler.step()

    # Evaluation
    model.eval()
    true_labels, predictions = [], []

    with torch.no_grad():
        for batch in val_loader:
            input_ids = batch['input_ids'].to(device)
            attention_mask = batch['attention_mask'].to(device)
            labels = batch['label'].to(device)

            outputs = model(input_ids, attention_mask=attention_mask)
            logits = outputs.logits
            preds = torch.argmax(logits, dim=1)

            true_labels.extend(labels.cpu().numpy())
            predictions.extend(preds.cpu().numpy())

    # Compute metrics
    accuracy = accuracy_score(true_labels, predictions)
    f1 = f1_score(true_labels, predictions)
    precision = precision_score(true_labels, predictions)
    recall = recall_score(true_labels, predictions)

    print(f"Epoch {epoch + 1}/{EPOCHS} | Loss: {total_loss / len(train_loader):.4f}")
    print(f"Validation: Acc={accuracy:.4f} | F1={f1:.4f} | Prec={precision:.4f} | Rec={recall:.4f}")

    # Save best model
    if accuracy > best_accuracy:
        best_accuracy = accuracy
        torch.save(model.state_dict(), 'best_model.pth')

# Load best model
model.load_state_dict(torch.load('best_model.pth'))
model.eval()

# Final evaluation
with torch.no_grad():
    for batch in val_loader:
        input_ids = batch['input_ids'].to(device)
        attention_mask = batch['attention_mask'].to(device)
        labels = batch['label'].to(device)

        outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1)

        true_labels.extend(labels.cpu().numpy())
        predictions.extend(preds.cpu().numpy())

# Print final metrics
print("Final Metrics:")
print(f"Accuracy: {accuracy_score(true_labels, predictions):.4f}")
print(f"F1 Score: {f1_score(true_labels, predictions):.4f}")
print(f"Precision: {precision_score(true_labels, predictions):.4f}")
print(f"Recall: {recall_score(true_labels, predictions):.4f}")

Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Epoch 1/10 | Loss: 0.3085
Validation: Acc=0.9152 | F1=0.8827 | Prec=0.9168 | Rec=0.8511
Epoch 2/10 | Loss: 0.1510
Validation: Acc=0.9206 | F1=0.8897 | Prec=0.9287 | Rec=0.8539
Epoch 3/10 | Loss: 0.0717
Validation: Acc=0.9181 | F1=0.8898 | Prec=0.8981 | Rec=0.8817
Epoch 4/10 | Loss: 0.0366
Validation: Acc=0.9196 | F1=0.8924 | Prec=0.8959 | Rec=0.8889
Epoch 5/10 | Loss: 0.0230
Validation: Acc=0.9158 | F1=0.8838 | Prec=0.9165 | Rec=0.8533
Epoch 6/10 | Loss: 0.0158
Validation: Acc=0.9154 | F1=0.8849 | Prec=0.9038 | Rec=0.8667
Epoch 7/10 | Loss: 0.0130
Validation: Acc=0.9156 | F1=0.8862 | Prec=0.8965 | Rec=0.8761
Epoch 8/10 | Loss: 0.0079
Validation: Acc=0.9175 | F1=0.8868 | Prec=0.9134 | Rec=0.8617
Epoch 9/10 | Loss: 0.0047
Validation: Acc=0.9165 | F1=0.8852 | Prec=0.9132 | Rec=0.8589
Epoch 10/10 | Loss: 0.0046
Validation: Acc=0.9192 | F1=0.8905 | Prec=0.9053 | Rec=0.8761


  model.load_state_dict(torch.load('best_model.pth'))


Final Metrics:
Accuracy: 0.9199
F1 Score: 0.8901
Precision: 0.9167
Recall: 0.8650
