In [1]:
# Install dependency and download codebase
%pip install torch transformers datasets
#!mkdir checkpoints

Collecting datasets
  Downloading datasets-3.3.2-py3-none-any.whl.metadata (19 kB)
Collecting nvidia-cuda-nvrtc-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-runtime-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cuda-cupti-cu12==12.4.127 (from torch)
  Downloading nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cudnn-cu12==9.1.0.70 (from torch)
  Downloading nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl.metadata (1.6 kB)
Collecting nvidia-cublas-cu12==12.4.5.8 (from torch)
  Downloading nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting nvidia-cufft-cu12==11.2.1.3 (from torch)
  Downloading nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl.metadata (1.5 kB)
Collecting 

In [10]:
from datasets import load_dataset
from torch.utils.data import DataLoader
import torch
from transformers import AutoTokenizer, AutoModelForMultipleChoice, PreTrainedTokenizerFast
from dataclasses import dataclass

In [6]:
@dataclass
class OpenBookQAExample:
    question_stem: str
    choices: list  # list of possible answers (strings)
    correct_idx: int  # integer in [0..3]

    @staticmethod
    def from_dict(data: dict):
        label_to_idx = {'A': 0, 'B': 1, 'C': 2, 'D': 3}
        question_stem = data['question_stem']
        answerKey = data['answerKey']
        correct_idx = label_to_idx[answerKey]
        choices = [ch for ch in data['choices']['text']]

        return OpenBookQAExample(
            question_stem=question_stem,
            choices=choices,
            correct_idx=correct_idx
        )

In [11]:
from torch.utils.data import Dataset

class OpenBookQADataset(torch.utils.data.Dataset):
    tokenizer: PreTrainedTokenizerFast = None

    def __init__(self, tokenizer, raw_data_list):
        OpenBookQADataset.tokenizer = tokenizer
        self.sample_list = [OpenBookQAExample.from_dict(d) for d in raw_data_list]

    def __len__(self):
        return len(self.sample_list)

    def __getitem__(self, idx):
        return self.sample_list[idx]

    @staticmethod
    def collate_fn(batch_samples):
        stems = [ex.question_stem for ex in batch_samples]
        list_of_choices = [ex.choices for ex in batch_samples]
        labels = [ex.correct_idx for ex in batch_samples]

        # Flatten out question+choice pairs
        flattened_inputs = []
        for stem, choices in zip(stems, list_of_choices):
            for c in choices:
                flattened_inputs.append(stem + " " + c)

        # Tokenize
        tokenizer = OpenBookQADataset.tokenizer
        tokenized = tokenizer(
            flattened_inputs,
            padding=True,
            truncation=True,
            max_length=128,
            return_tensors="pt"
        )

        # Reshape
        batch_size = len(batch_samples)
        num_choices = len(list_of_choices[0])  # typically 4
        for k in tokenized:
            tokenized[k] = tokenized[k].view(batch_size, num_choices, -1)

        tokenized["labels"] = torch.LongTensor(labels)
        return tokenized


In [12]:
def initialize_openbookqa_datasets(tokenizer):
    raw_data = load_dataset("openbookqa", "main")
    split_datasets = {}
    for split_name in raw_data.keys():
        split_data = list(raw_data[split_name])
        split_datasets[split_name] = OpenBookQADataset(tokenizer, split_data)
    return split_datasets

In [13]:
@torch.no_grad()
def evaluate(model, dataloader, split="Val"):
    model.eval()
    all_preds, all_labels = [], []
    for batch in dataloader:
        input_ids = batch["input_ids"].cuda()
        attn_mask = batch["attention_mask"].cuda()
        labels = batch["labels"].cuda()  # correct choice indices

        outputs = model(input_ids=input_ids, attention_mask=attn_mask)
        logits = outputs.logits  # shape [batch_size, num_choices]
        preds = torch.argmax(logits, dim=1).cpu()
        all_preds.extend(preds.tolist())
        all_labels.extend(labels.cpu().tolist())

    accuracy = (torch.tensor(all_preds) == torch.tensor(all_labels)).float().mean().item()
    print(f"{split} Accuracy: {accuracy:.4f}")
    return accuracy


In [17]:
torch.manual_seed(64)

def baseline_no_finetune():
    model_name = "roberta-base"

    #  Load tokenizer and model
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForMultipleChoice.from_pretrained(model_name)

    # Move model to GPU if available
    model.cuda()  # or skip if you have no GPU
    model.eval()  # set in eval mode

    # Prepare the OpenBookQA dataset
    datasets = initialize_openbookqa_datasets(tokenizer)

    # Create test dataloader
    test_loader = DataLoader(
        datasets["test"],
        batch_size=4,
        shuffle=False,
        collate_fn=OpenBookQADataset.collate_fn
    )

    # Evaluate with evaluation function
    test_acc = evaluate(model, test_loader, split="Test")
    print("Zero-shot BERT baseline test accuracy:", test_acc)

if __name__ == "__main__":
    baseline_no_finetune()

tokenizer_config.json:   0%|          | 0.00/25.0 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/481 [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/899k [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/456k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/1.36M [00:00<?, ?B/s]

model.safetensors:   0%|          | 0.00/499M [00:00<?, ?B/s]

Some weights of RobertaForMultipleChoice were not initialized from the model checkpoint at roberta-base and are newly initialized: ['classifier.bias', 'classifier.weight', 'roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


Test Accuracy: 0.2680
Zero-shot BERT baseline test accuracy: 0.2680000066757202
