In [1]:
from datasets import load_dataset
import torch
from transformers import AutoTokenizer, AutoModelForMultipleChoice, TrainingArguments, PreTrainedTokenizerFast
import torch.nn as nn
import numpy as np
from dataclasses import dataclass
from torch.utils.data import DataLoader
from torch.optim import AdamW
from torch.optim.lr_scheduler import CosineAnnealingLR
from torch.optim.swa_utils import AveragedModel, update_bn
from tqdm import tqdm

In [2]:
MODEL_NAME = "roberta-base"
BATCH_SIZE = 32
LEARNING_RATE = 1e-5
NUM_EPOCHS = 10
MAX_GRAD_NORM = 1.0
WARM_UP_STEPS = 1000  # Warm-up steps
SWA_START = 5
PATIENCE = 3
L2_REG = 1e-4  # L2 Regularization strength

In [3]:
dataset = load_dataset("openbookqa")
train_data = dataset["train"]
dev_data = dataset["validation"]
test_data = dataset["test"]

In [4]:
@dataclass
class OpenBookQAExample:
    question_stem: str
    choices: list
    correct_idx: int

    @staticmethod
    def from_dict(data: dict):
        label_to_idx = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        question_stem = data['question_stem']
        answerKey = data['answerKey']
        correct_idx = label_to_idx[answerKey]
        choices = [ch for ch in data['choices']['text']]
        return OpenBookQAExample(question_stem=question_stem, choices=choices, correct_idx=correct_idx)


In [5]:
class OpenBookQADataset(torch.utils.data.Dataset):
    tokenizer: PreTrainedTokenizerFast = None

    def __init__(self, tokenizer, raw_data_list):
        OpenBookQADataset.tokenizer = tokenizer
        self.sample_list = [OpenBookQAExample.from_dict(d) for d in raw_data_list]

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        return self.sample_list[idx]

    @staticmethod
    def collate_fn(batch_samples):
        stems = [ex.question_stem for ex in batch_samples]
        list_of_choices = [ex.choices for ex in batch_samples]
        labels = [ex.correct_idx for ex in batch_samples]

        flattened_inputs = []
        for stem, choices in zip(stems, list_of_choices):
            for c in choices:
                flattened_inputs.append(stem + " " + c)

        tokenizer = OpenBookQADataset.tokenizer
        tokenized = tokenizer(
            flattened_inputs,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )

        batch_size = len(batch_samples)
        num_choices = len(list_of_choices[0])
        for k in tokenized:
            tokenized[k] = tokenized[k].view(batch_size, num_choices, -1)

        tokenized["labels"] = torch.LongTensor(labels)
        return tokenized


In [6]:
def initialize_openbookqa_datasets(tokenizer):
    raw_data = load_dataset("openbookqa", "main")
    split_datasets = {}
    for split_name in raw_data.keys():
        split_data = list(raw_data[split_name])
        split_datasets[split_name] = OpenBookQADataset(tokenizer, split_data)
    return split_datasets

In [7]:
def compute_accuracy(preds, labels):
    return (preds == labels).float().mean()

In [8]:
def reinitialize_layers(model):
    for name, module in model.named_modules():
        if isinstance(module, (nn.Linear, nn.LayerNorm)):
            module.reset_parameters()
            print(f"Reinitialized {name}")

In [9]:
def train_one_epoch(model, dataloader, optimizer, epoch, swa_model=None, swa_start=5, scheduler=None, clip_grad=1.0):
    model.train()
    all_preds, all_labels = [], []
    progress_bar = tqdm(dataloader, desc=f"Train Epoch {epoch}", leave=True)

    for batch in progress_bar:
        optimizer.zero_grad()

        input_ids = batch["input_ids"].cuda()
        attn_mask = batch["attention_mask"].cuda()
        labels = batch["labels"].cuda()

        outputs = model(input_ids=input_ids, attention_mask=attn_mask, labels=labels)
        loss = outputs.loss
        logits = outputs.logits

        loss = loss + L2_REG * sum(p.norm(2) for p in model.parameters())
        loss.backward()
        
        torch.nn.utils.clip_grad_norm_(model.parameters(), clip_grad)

        optimizer.step()
        if scheduler:
            scheduler.step()

        if swa_model and epoch >= swa_start:
            swa_model.update_parameters(model)

        preds = torch.argmax(logits, dim=1).detach().cpu()
        labels_cpu = labels.detach().cpu()
        all_preds.extend(preds.tolist())
        all_labels.extend(labels_cpu.tolist())

        progress_bar.set_postfix({"loss": f"{loss.item():.4f}"})

    accuracy = compute_accuracy(torch.tensor(all_preds), torch.tensor(all_labels))
    print(f"Train Epoch {epoch} - Loss: {loss.item():.4f} - Accuracy: {accuracy:.4f}")
    return accuracy


In [10]:
@torch.no_grad()
def evaluate(model, dataloader, split="Val"):
    model.eval()
    all_preds, all_labels = [], []
    for batch in dataloader:
        input_ids = batch["input_ids"].cuda()
        attn_mask = batch["attention_mask"].cuda()
        labels = batch["labels"].cuda()

        outputs = model(input_ids=input_ids, attention_mask=attn_mask)
        logits = outputs.logits
        preds = torch.argmax(logits, dim=1).cpu()
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.cpu().tolist())

    accuracy = compute_accuracy(torch.tensor(all_preds), torch.tensor(all_labels))
    print(f"{split} Accuracy: {accuracy:.4f}")
    return accuracy.item()


In [11]:
def main():
    # Model Setup
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForMultipleChoice.from_pretrained(MODEL_NAME).cuda()

    # Optimizer and Scheduler
    optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)
    scheduler = CosineAnnealingLR(optimizer, T_max=NUM_EPOCHS)

    # SWA Model Initialization
    swa_model = AveragedModel(model)

    # Load Datasets
    datasets = initialize_openbookqa_datasets(tokenizer)
    train_loader = DataLoader(datasets['train'], batch_size=BATCH_SIZE, shuffle=True, collate_fn=OpenBookQADataset.collate_fn)
    val_loader = DataLoader(datasets["validation"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=OpenBookQADataset.collate_fn)
    test_loader = DataLoader(datasets["test"], batch_size=BATCH_SIZE, shuffle=False, collate_fn=OpenBookQADataset.collate_fn)

    best_val_acc = 0.0
    curr_patience = 0

    # Training Loop
    for epoch in range(1, NUM_EPOCHS+1):
        train_acc = train_one_epoch(model, train_loader, optimizer, epoch, swa_model=swa_model, swa_start=SWA_START, scheduler=scheduler, clip_grad=MAX_GRAD_NORM)
        val_acc = evaluate(model, val_loader, split="Val")

        if val_acc > best_val_acc:
            best_val_acc = val_acc
            model.save_pretrained("./checkpoints")
            curr_patience = 0
        else:
            curr_patience += 1
            if curr_patience == PATIENCE:
                break

    # Evaluate Best Model
    best_model = AutoModelForMultipleChoice.from_pretrained("./checkpoints").cuda()
    test_acc = evaluate(best_model, test_loader, split="Test")
    print("Final Test Acc:", test_acc)


In [None]:
main()