# Fine-tuning RoBERTa - efficient approach

In [None]:
import enlighten
import glob
import matplotlib.pyplot as plt
import numpy as np
import os
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F

from sklearn.metrics import accuracy_score
from torch.utils.data import DataLoader, TensorDataset
from tqdm import tqdm
from transformers import get_scheduler, RobertaModel, RobertaTokenizer

In [None]:
if torch.cuda.is_available(): device = torch.device("cuda")
elif torch.backends.mps.is_available(): device = torch.device("mps")
else: device = torch.device("cpu")

In [None]:
os.makedirs("data/embeddings/", exist_ok=True)
os.makedirs("output/preds/", exist_ok=True)

# Preprocessing
To improve efficiency, perform a forward pass through the base model, and store the resulting embeddings.

In [None]:
class DataPreprocessor:
    def __init__(self, tokenizer, base_model, label2id, batch_size):
        self.tokenizer = tokenizer
        self.base_model = base_model.to(device)
        self.base_model.eval()
        self.label2id = label2id
        self.batch_size = batch_size

    def extract_embeddings(self, texts):
        """Parallelized embedding extraction in batches with progress bar."""
        dataloader = DataLoader(texts, batch_size=self.batch_size, shuffle=False)
        all_embeddings = []

        with torch.no_grad():
            for batch in tqdm(dataloader, desc="Extracting Embeddings", leave=True):
                tokens = self.tokenizer(
                    batch,
                    padding="longest",
                    truncation=True,
                    return_token_type_ids=False,
                    return_tensors="pt"
                ).to(device)

                outputs = self.base_model(**tokens)
                embeddings = outputs.last_hidden_state[:, 0, :]  # CLS token representation
                all_embeddings.append(embeddings.cpu())          # Move to CPU to save memory

        return torch.cat(all_embeddings, dim=0)  # Concatenate all batches

    def prepare(self, data):
        """Prepare dataloader with precomputed embeddings."""
        texts = data["text"].tolist()
        embeddings = self.extract_embeddings(texts)

        # Convert labels
        numeric_labels = [self.label2id[label] for label in data["sentiment"]]
        labels = torch.tensor(numeric_labels, dtype=torch.long)

        # Extract review IDs
        ids = torch.tensor(data["review_id"].tolist(), dtype=torch.long)

        # Create dataset and dataloader
        dataset = TensorDataset(embeddings, labels, ids)
        dataloader = DataLoader(dataset, batch_size=self.batch_size, shuffle=True)

        return dataloader

In [None]:
tokenizer = RobertaTokenizer.from_pretrained("roberta-large")
base_model = RobertaModel.from_pretrained("roberta-large")

id2label = {0: "negative", 1: "positive"}
label2id = {"negative": 0, "positive": 1}

batch_size = 250

preprocessor = DataPreprocessor(tokenizer, base_model, label2id, batch_size)

In [None]:
train_savefile = "data/embeddings/train.pt"

if os.path.exists(train_savefile):
    print(f"Loading embeddings for train samples")
    train_tensors = torch.load(train_savefile)
    train_dataset = TensorDataset(*train_tensors)
    train_dataloader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
else:
    train = pd.read_csv("data/processed/train.csv")
    train_dataloader = preprocessor.prepare(train)
    train_dataset = train_dataloader.dataset
    torch.save(train_dataset.tensors, train_savefile)

In [None]:
test_savefile = "data/embeddings/test.pt"

if os.path.exists(test_savefile):
    print(f"Loading embeddings for validation samples")
    test_tensors = torch.load(test_savefile)
    test_dataset = TensorDataset(*test_tensors)
    test_dataloader = DataLoader(test_dataset, batch_size=250, shuffle=True)
else:
    test = pd.read_csv("data/processed/test.csv")
    test_dataloader = preprocessor.prepare(test)
    test_dataset = test_dataloader.dataset
    torch.save(test_dataset.tensors, test_savefile)

# Classification head

In [None]:
class RoBERTaClassifier(nn.Module): 
    def __init__(self, embedding_dim=1024, mod=1): 
        super(RoBERTaClassifier, self).__init__()
      ### Parameters
        self.mod = mod
        self.hidden_size = 1024 if mod == 1 else 2048
        self.inter_size = 1024 if mod == 1 else 512
      ### Layers
      ### Must be activated in __init__ for the trainable parameters count to be exact
        self.in_proj = nn.Linear(embedding_dim, self.hidden_size)           # Input layer
        self.dropout = nn.Dropout(0.1)                                      # Dropout layer
        self.silu = nn.SiLU()                                               # Activation
        if mod >= 4:
            self.layer_norm = nn.LayerNorm(self.hidden_size, eps=1e-5)      # Normalization
        if mod >= 2:
            self.inter_proj = nn.Linear(self.hidden_size, self.inter_size)  # Intermediate dense layer
        self.out_proj = nn.Linear(self.inter_size, 2)                       # Output layer

    def forward(self, embeddings):
        x = self.in_proj(embeddings)
        x = self.dropout(x) if self.mod >= 5 else x
        x = self.layer_norm(x) if self.mod >= 4 else x
        x = self.silu(x) if self.mod >= 3 else x
        x = x + self.in_proj(embeddings) if self.mod >= 6 else x
        x = self.inter_proj(x) if self.mod >= 2 else x
        x = self.dropout(x) if self.mod >= 5 else x
        x = self.out_proj(x)
        return x

# Training

In [None]:
loss_fn = nn.CrossEntropyLoss()

In [None]:
def model_train(batch, classifier, optimizer, metrics):
    # Unpack the batch and move tensors to the device
    embeddings, b_labels, b_ids = [t.to(device) for t in batch]
    # Reset gradients before backpropagation
    classifier.zero_grad()
    # Perform a forward pass to calculate outputs using embeddings as input
    logits = classifier(embeddings)
    # Store results for later analysis
    all_logits.append(logits.detach().cpu())
    all_labels.append(b_labels.detach().cpu())
    all_ids.append(b_ids.detach().cpu())
    # Calculate loss
    loss = loss_fn(logits, b_labels)
    metrics['batch_train_losses'].append(loss.item())
    # Calculate accuracy
    preds = torch.argmax(logits, dim=1)
    accuracy = (preds == b_labels).sum().item() / b_labels.size(0)
    metrics['batch_train_accuracy'].append(accuracy)
    # Backpropagate the loss
    loss.backward()
    # Update model parameters
    optimizer.step()

    return loss

In [None]:
def model_eval(batch, classifier, optimizer, metrics):
    # Unpack the batch and move tensors to the device
    embeddings, b_labels, b_ids = [t.to(device) for t in batch]
    # Forward pass using embeddings as input
    logits = classifier(embeddings)
    # Store results for later analysis
    all_logits.append(logits.detach().cpu())
    all_labels.append(b_labels.detach().cpu())
    all_ids.append(b_ids.detach().cpu())
    # Calculate loss
    loss = loss_fn(logits, b_labels)
    metrics['batch_test_losses'].append(loss.item())
    # Calculate accuracy
    preds = torch.argmax(logits, dim=1)
    accuracy = (preds == b_labels).sum().item() / b_labels.size(0)
    metrics['batch_test_accuracy'].append(accuracy)
    
    return loss

## Performance metrics

In [None]:
def get_epoch_metrics(metrics_dict, all_logits, dataloader_length, phase):
    batch_loss_key = f'batch_{phase}_losses'
    batch_acc_key = f'batch_{phase}_accuracy'
    
    # Loss and accuracy
    avg_loss = np.mean(metrics_dict[batch_loss_key][-dataloader_length:])
    metrics_dict[f'epoch_{phase}_loss'] = float(avg_loss)
    avg_accuracy = np.mean(metrics_dict[batch_acc_key][-dataloader_length:])
    metrics_dict[f'epoch_{phase}_accuracy'] = float(avg_accuracy)
    
    # Classification error
    all_logits = torch.cat(all_logits, dim=0)
    probs = F.softmax(all_logits, dim=1).detach()
    prob_class_0 = probs[:, 0]
    prob_class_1 = probs[:, 1]
    classif_error = (1 - torch.max(prob_class_0, prob_class_1)).mean().item()
    metrics_dict[f'{phase}_classif_error'] = float(classif_error)
    
    return probs

## Training parameters

In [None]:
num_epochs = 300
num_training_steps = num_epochs * len(train_dataloader)
print(f"Number of training steps per model: {num_training_steps}")

In [None]:
def get_optimizer_and_scheduler(classifier):
    # Optimizer
    optimizer = torch.optim.AdamW(
        classifier.parameters(),
        lr = 1e-3,
        weight_decay = 0.01,
        eps = 1e-8)

    # Scheduler
    lr_scheduler = get_scheduler(
        "cosine",
        optimizer = optimizer,
        num_warmup_steps = 0.1 * num_training_steps,
        num_training_steps = num_training_steps)

    return optimizer, lr_scheduler

## Training loop

In [None]:
metrics = {mod_type: {} for mod_type in range(1, 7)}

manager = enlighten.get_manager()
model_progress = manager.counter(total=6, desc="Models  ->", unit="model", color="forestgreen")

for mod_type in range(1, 7):
    ### Create model
    classifier = RoBERTaClassifier(embedding_dim=1024, mod=mod_type).to(device)
    optimizer, lr_scheduler = get_optimizer_and_scheduler(classifier)
    
    metrics[mod_type] = {}
    metrics[mod_type]['parameters'] = sum(p.numel() for p in classifier.parameters() if p.requires_grad)

    # Loop over epochs
    epoch_progress = manager.counter(total=num_epochs, desc=f"Model {mod_type} ->", unit="epoch", color="darkgrey")
    for epoch in range(0, num_epochs):
        metrics[mod_type][epoch] = {'batch_train_losses': [],
                                    'batch_train_accuracy': [],
                                    'batch_test_losses': [],
                                    'batch_test_accuracy': []}

        ### Training
        classifier.train()
        all_logits, all_labels, all_ids = [], [], []
        for batch in train_dataloader:
            loss = model_train(batch, classifier, optimizer, metrics[mod_type][epoch])
        _ = get_epoch_metrics(metrics[mod_type][epoch], all_logits, len(train_dataloader), phase='train')
        lr_scheduler.step()

        ### Validating
        classifier.eval()
        all_logits, all_labels, all_ids = [], [], []
        for batch in test_dataloader:
            loss = model_eval(batch, classifier, optimizer, metrics[mod_type][epoch])
        probs = get_epoch_metrics(metrics[mod_type][epoch], all_logits, len(test_dataloader), phase='test')

        ### Saving predictions
        epoch_accuracy = metrics[mod_type][epoch]['epoch_test_accuracy']
        best_test_accuracy = max(
            metrics[mod_type][epoch]['epoch_test_accuracy']
            for epoch in metrics[mod_type]
            if epoch != 'parameters')
        if epoch_accuracy >= best_test_accuracy:
            probs_array = probs.cpu().numpy()
            labels_array = torch.cat(all_labels, dim=0).cpu().numpy()
            results = pd.DataFrame(probs_array, columns=[f"prob_class_{i}" for i in range(probs_array.shape[1])])
            results['true_label'] = [id2label[label] for label in labels_array]
            results['review_id'] = torch.cat(all_ids, dim=0).detach().cpu().numpy()
            results.to_csv(f"output/preds/mod_{mod_type}_epoch_{epoch}.csv", index=False)

        epoch_progress.update()
    model_progress.update()
manager.stop()

# 3. Results

In [None]:
# Display the best epoch for each model
summary = []

for mod_type in metrics:
    epoch_accuracies = [
        (epoch, data['epoch_test_accuracy'])
        for epoch, data in metrics[mod_type].items()
        if isinstance(epoch, int) and 'epoch_test_accuracy' in data]

    if epoch_accuracies:
        best_epoch, best_accuracy = max(epoch_accuracies, key=lambda x: x[1])
        best_error = metrics[mod_type][best_epoch]['test_classif_error']
        summary.append({
            'mod_type': mod_type,
            'n_parameters': metrics[mod_type]['parameters'],
            'best_epoch': best_epoch,
            'best_test_accuracy': best_accuracy,
            'test_classif_error': best_error})

summary = pd.DataFrame(summary).sort_values(by='mod_type')
summary.style.hide(axis="index").format({
    "best_test_accuracy": "{:.4f}",
    "test_classif_error": "{:.4f}"})

In [None]:
# Identify the best model`
best_model_index = summary['best_test_accuracy'].idxmax()
best_model_info = summary.loc[best_model_index]

best_mod_type = int(best_model_info['mod_type'])
best_epoch = int(best_model_info['best_epoch'])

In [None]:
# Prepare data for ploting performance curves
## Convert results from dictionnary to df
rows = []
for mod_type, subdict in metrics.items():
    for epoch_key, metric in subdict.items():
        if epoch_key == 'parameters':
            continue
        if isinstance(metric, dict):
            rows.append({
                "mod_type": mod_type,
                "epoch": epoch_key,
                "epoch_train_loss": metric.get("epoch_train_loss"),
                "epoch_test_loss": metric.get("epoch_test_loss"),
                "epoch_train_accuracy": metric.get("epoch_train_accuracy"),
                "epoch_test_accuracy": metric.get("epoch_test_accuracy"),
                "train_classif_errors": metric.get("train_classif_error"),
                "test_classif_errors": metric.get("test_classif_error")
            })
df = pd.DataFrame(rows)

## Extract values for best model
df_subset = df[df["mod_type"] == best_mod_type].sort_values(by="epoch")
epoch_train_losses = df_subset["epoch_train_loss"].tolist()
epoch_test_losses = df_subset["epoch_test_loss"].tolist()
epoch_train_accuracy = df_subset["epoch_train_accuracy"].tolist()
epoch_test_accuracy = df_subset["epoch_test_accuracy"].tolist()
train_classif_errors = df_subset["train_classif_errors"].tolist()
test_classif_errors = df_subset["test_classif_errors"].tolist()

## Identify min / max values (on the right part of the graph)
epochs = np.linspace(21, 300, len(test_classif_errors))
min_epoch_loss, min_loss = min(zip(epochs, epoch_test_losses), key=lambda x: x[1])
max_epoch_acc, max_accuracy = max(zip(epochs, epoch_test_accuracy), key=lambda x: x[1])
min_epoch_err, min_error = min(zip(epochs, test_classif_errors), key=lambda x: x[1])

In [None]:
# Plot
fig = plt.figure(figsize=(12, 8))
gs = fig.add_gridspec(3, 2, width_ratios=[1, 3])

ax1 = fig.add_subplot(gs[0, 0])
ax1.plot(epoch_train_losses, label='Train', color='#97BC62FF')
ax1.plot(epoch_test_losses, label='Test', color='#2C5F2D', alpha=0.8)
ax1.set_xticks(np.linspace(0, num_epochs, 6))
ax1.set_xlabel('')
ax1.set_ylabel('')
ax1.set_title('Loss')
ax1.legend()

ax2 = fig.add_subplot(gs[0, 1])
ax2.plot(epoch_train_losses, color='#97BC62FF')
ax2.plot(epoch_test_losses, color='#2C5F2D', alpha=0.8)
ax2.axvline(min_epoch_loss, color='darkorchid', linestyle='--', alpha=0.6, label=f'≥ {min_loss:.4f}')
ax2.set_xticks(np.arange(21, num_epochs+1, (num_epochs-21)//8))
plt.xlim(21, num_epochs)
plt.ylim(0.14, 0.33)
ax2.set_xlabel('')
ax2.set_ylabel('')
ax2.set_title('Loss')
ax2.legend()

ax3 = fig.add_subplot(gs[1, 0])
ax3.plot(epoch_train_accuracy, label='Train', color='#9CC3D5FF')
ax3.plot(epoch_test_accuracy, label='Test', color='#0063B2FF')
ax3.set_xticks(np.linspace(0, num_epochs, 6))
ax3.set_xlabel('')
ax3.set_ylabel('')
ax3.set_title('Accuracy')
ax3.legend()

ax4 = fig.add_subplot(gs[1, 1])
ax4.plot(epoch_train_accuracy, color='#9CC3D5FF')
ax4.plot(epoch_test_accuracy, color='#0063B2FF')
ax4.axvline(max_epoch_acc, color='darkorchid', linestyle='--', alpha=0.6, label=f'≤ {max_accuracy:.4f}')
ax4.set_xticks(np.arange(21, num_epochs+1, (num_epochs-21)//8))
plt.xlim(21, num_epochs)
plt.ylim(0.87, 0.95)
ax4.set_xlabel('')
ax4.set_ylabel('')
ax4.set_title('Accuracy')
ax4.legend()

ax5 = fig.add_subplot(gs[2, 0])
ax5.plot(train_classif_errors, label='Train', color='#F5C7B8FF')
ax5.plot(test_classif_errors, label='Test', color='#FFA177FF')
ax5.set_xticks(np.linspace(0, num_epochs, 6))
ax5.set_xlabel('Epochs')
ax5.set_ylabel('')
ax5.set_title('Classification Error')
ax5.legend()

ax6 = fig.add_subplot(gs[2, 1])
ax6.plot(train_classif_errors, color='#F5C7B8FF')
ax6.plot(test_classif_errors, color='#FFA177FF')
ax6.axvline(min_epoch_err, color='darkorchid', linestyle='--', alpha=0.6, label=f'≥ {min_error:.4f}')
ax6.set_xticks(np.arange(21, num_epochs+1, (num_epochs-21)//8))
plt.xlim(21, num_epochs)
plt.ylim(0.058, 0.088)
ax6.set_xlabel('Epochs')
ax6.set_ylabel('')
ax6.set_title('Classification Error')
ax6.legend()

plt.tight_layout()
plt.savefig(f"output/mod_type_{mod_type}_learning_curves.png", dpi=300, bbox_inches='tight')
plt.show()

In [None]:
# Merge with original dataset
test = pd.read_csv("data/processed/test.csv")
results = pd.read_csv(f"output/preds/mod_{best_mod_type}_epoch_{best_epoch}.csv")
results = pd.merge(test, results, on = 'review_id')

In [None]:
# Check for consistency
print(f"Do the true labels returned by the model match the original sentiments?")
print(f"Yes!" if (results['sentiment'] == results['true_label']).all() else f"No :'(")

In [None]:
# Get predicted sentiments and save
results['RoBERTa_ft'] = np.where(results['prob_class_1'] >= 0.5, 'positive', 'negative')
results[['review_id', 'RoBERTa_ft']].to_csv("output/RoBERTa_ft.csv", index=False)