from https://huggingface.co/docs/transformers/tasks/multiple_choice

In [1]:
from datasets import load_dataset

In [16]:
swag = load_dataset("swag", split="train[:1000]")

In [17]:
swag[0]

{'video-id': 'anetv_jkn6uvmqwh4',
 'fold-ind': '3416',
 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',
 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',
 'sent2': 'A drum line',
 'gold-source': 'gold',
 'ending0': 'passes by walking down the street playing their instruments.',
 'ending1': 'has heard approaching them.',
 'ending2': "arrives and they're outside dancing and asleep.",
 'ending3': 'turns the lead singer watches the performance.',
 'label': 0}

In [18]:
from transformers import AutoTokenizer

In [19]:
tokenizer = AutoTokenizer.from_pretrained("distilbert-base-uncased")

In [20]:
ending_names = [
    f"ending{x}"
    for x in range(4)
]

def preprocess(batch):

    # sent1 sample:
    # sent1 = ['Members of the procession walk down the street holding small horn brass instruments.',
    # 'A drum line passes by walking down the street playing their instruments.',
    # 'A group of members in green uniforms walks waving flags.',
    # 'A drum line passes by walking down the street playing their instruments.']
    
    # sent2 = ['A drum line',
    # 'Members of the procession',
    # 'Members of the procession',
    # 'Members of the procession']
            
    # repeat the context as many as the choices, in this case 4.
     # first_sentences =   ['Members of the procession walk down the street holding small horn brass instruments.',
     # 'Members of the procession walk down the street holding small horn brass instruments.',
     # 'Members of the procession walk down the street holding small horn brass instruments.',
     # 'Members of the procession walk down the street holding small horn brass instruments.',
     # 'A drum line passes by walking down the street playing their instruments.',
     # 'A drum line passes by walking down the street playing their instruments.',
     # 'A drum line passes by walking down the street playing their instruments.',
     # 'A drum line passes by walking down the street playing their instruments.',
     # 'A group of members in green uniforms walks waving flags.',
     # 'A group of members in green uniforms walks waving flags.',
     # 'A group of members in green uniforms walks waving flags.',
     # 'A group of members in green uniforms walks waving flags.',
     # 'A drum line passes by walking down the street playing their instruments.',
     # 'A drum line passes by walking down the street playing their instruments.',
     # 'A drum line passes by walking down the street playing their instruments.',
     # 'A drum line passes by walking down the street playing their instruments.']
    first_sentences = [
        [context] * 4
        for context in batch["sent1"]
    ]
    
    question_headers = batch["sent2"]
    # add sent2 to each choice. e.g: sent2='A durm line'
    # second_sentences = ['A drum line passes by walking down the street playing their instruments.',
    #  'A drum line has heard approaching them.',
    #  "A drum line arrives and they're outside dancing and asleep.",
    #  'A drum line turns the lead singer watches the performance.',
    #  'Members of the procession are playing ping pong and celebrating one left each in quick.',
    #  'Members of the procession wait slowly towards the cadets.',
    #  'Members of the procession continues to play as well along the crowd along with the band being interviewed.',
    #  'Members of the procession continue to play marching, interspersed.',
    #  'Members of the procession pay the other coaches to cheer as people this chatter dips in lawn sheets.',
    #  'Members of the procession walk down the street holding small horn brass instruments.',
    #  'Members of the procession is seen in the background.',
    #  'Members of the procession are talking a couple of people playing a game of tug of war.',
    #  'Members of the procession are playing ping pong and celebrating one left each in quick.',
    #  'Members of the procession wait slowly towards the cadets.',
    #  'Members of the procession makes a square call and ends by jumping down into snowy streets where fans begin to take their positions.',
    #  'Members of the procession play and go back and forth hitting the drums while the audience claps for them.']
    second_sentences = [
        [f"{header} {batch[end][i]}" for end in ending_names]
        for i, header in enumerate(question_headers)
    ]
    
    # flattened the list [[s1, s2, s3, s4]] -> [s1, s2, s3, s4]
    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    # this will concatenate 2 sentences, from first + second, into 1 tokenized sentence
    # sample: 
    #   sent1: ['Members of the procession walk down the street holding small horn brass instruments.']
    #   sent2: ['A drum line passes by walking down the street playing their instruments.']
    #   tokenized1: [101, 2372, 1997, 1996, 14385, 3328, 2091, 1996, 2395, 3173, 2235, 7109, 8782, 5693, 1012, 102] 
    #   tokenized2: [101, 1037, 6943, 2240, 5235, 2011, 3788, 2091, 1996, 2395, 2652, 2037, 5693, 1012, 102]
    #   tokenized: [101, 2372, 1997, 1996, 14385, 3328, 2091, 1996, 2395, 3173, 2235, 7109, 8782, 5693, 1012, 102, 
    #                    1037, 6943, 2240, 5235, 2011, 3788, 2091, 1996, 2395, 2652, 2037, 5693, 1012, 102]
    #   tokenized: Dict[str, array] = {"input_ids": [], "attention_mask": []}        
    tokenized = tokenizer(
        text=first_sentences, 
        text_pair=second_sentences, 
        truncation=True)
    
    res = {
        k: [
            v[i:i + 4]
            for i in range(0, len(v), 4)
           ]
        for k, v in tokenized.items()
    }
    
    return res

In [21]:
tokenized_swag = swag.map(preprocess, batched=True, batch_size=4)

In [22]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

In [23]:
import evaluate

accuracy = evaluate.load("accuracy")

In [24]:
import numpy as np


def compute_metrics(eval_pred):
    predictions, labels = eval_pred
    predictions = np.argmax(predictions, axis=1)
    return accuracy.compute(predictions=predictions, references=labels)

In [25]:
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

model = AutoModelForMultipleChoice.from_pretrained("bert-base-uncased")

Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [29]:
training_args = TrainingArguments(
    output_dir="my_awesome_swag_model",
    evaluation_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    learning_rate=5e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=3,
    weight_decay=0.01,
    push_to_hub=False,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_swag,
    eval_dataset=tokenized_swag,
    tokenizer=tokenizer,
    data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer),
    compute_metrics=compute_metrics,
)

trainer.train()

You're using a DistilBertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


Epoch,Training Loss,Validation Loss


RuntimeError: MPS backend out of memory (MPS allocated: 6.02 GB, other allocations: 12.10 GB, max allowed: 18.13 GB). Tried to allocate 14.44 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).

In [27]:
tokenized_swag

Dataset({
    features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label', 'input_ids', 'attention_mask'],
    num_rows: 1000
})