In [1]:
from datasets import load_dataset

swag = load_dataset("swag", "regular")

Found cached dataset swag (C:/Users/arifa/.cache/huggingface/datasets/swag/regular/0.0.0/9640de08cdba6a1469ed3834fcab4b8ad8e38caf5d1ba5e7436d8b1fd067ad4c)


  0%|          | 0/3 [00:00<?, ?it/s]

In [2]:
swag["train"][0]

{'video-id': 'anetv_jkn6uvmqwh4',
 'fold-ind': '3416',
 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',
 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',
 'sent2': 'A drum line',
 'gold-source': 'gold',
 'ending0': 'passes by walking down the street playing their instruments.',
 'ending1': 'has heard approaching them.',
 'ending2': "arrives and they're outside dancing and asleep.",
 'ending3': 'turns the lead singer watches the performance.',
 'label': 0}

In [3]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")

In [4]:
ending_names = ["ending0", "ending1", "ending2", "ending3"]


def preprocess_function(examples):
    first_sentences = [[context] * 4 for context in examples["sent1"]]
    question_headers = examples["sent2"]
    second_sentences = [
        [f"{header} {examples[end][i]}" for end in ending_names] for i, header in enumerate(question_headers)
    ]

    first_sentences = sum(first_sentences, [])
    second_sentences = sum(second_sentences, [])

    tokenized_examples = tokenizer(first_sentences, second_sentences, truncation=True)
    return {k: [v[i : i + 4] for i in range(0, len(v), 4)] for k, v in tokenized_examples.items()}

In [5]:
tokenized_swag = swag.map(preprocess_function, batched=True)

Map:   0%|          | 0/73546 [00:00<?, ? examples/s]

Map:   0%|          | 0/20006 [00:00<?, ? examples/s]

Map:   0%|          | 0/20005 [00:00<?, ? examples/s]

In [6]:
temp = preprocess_function(swag["train"][:1])
temp

{'input_ids': [[[101,
    2372,
    1997,
    1996,
    14385,
    3328,
    2091,
    1996,
    2395,
    3173,
    2235,
    7109,
    8782,
    5693,
    1012,
    102,
    1037,
    6943,
    2240,
    5235,
    2011,
    3788,
    2091,
    1996,
    2395,
    2652,
    2037,
    5693,
    1012,
    102],
   [101,
    2372,
    1997,
    1996,
    14385,
    3328,
    2091,
    1996,
    2395,
    3173,
    2235,
    7109,
    8782,
    5693,
    1012,
    102,
    1037,
    6943,
    2240,
    2038,
    2657,
    8455,
    2068,
    1012,
    102],
   [101,
    2372,
    1997,
    1996,
    14385,
    3328,
    2091,
    1996,
    2395,
    3173,
    2235,
    7109,
    8782,
    5693,
    1012,
    102,
    1037,
    6943,
    2240,
    8480,
    1998,
    2027,
    1005,
    2128,
    2648,
    5613,
    1998,
    6680,
    1012,
    102],
   [101,
    2372,
    1997,
    1996,
    14385,
    3328,
    2091,
    1996,
    2395,
    3173,
    2235,
    7109,
    8782,
    5693,


In [7]:
swag["train"][:1]

{'video-id': ['anetv_jkn6uvmqwh4'],
 'fold-ind': ['3416'],
 'startphrase': ['Members of the procession walk down the street holding small horn brass instruments. A drum line'],
 'sent1': ['Members of the procession walk down the street holding small horn brass instruments.'],
 'sent2': ['A drum line'],
 'gold-source': ['gold'],
 'ending0': ['passes by walking down the street playing their instruments.'],
 'ending1': ['has heard approaching them.'],
 'ending2': ["arrives and they're outside dancing and asleep."],
 'ending3': ['turns the lead singer watches the performance.'],
 'label': [0]}

In [8]:
for chunk in temp['input_ids'][0]:
    print(tokenizer.decode(chunk))

[CLS] members of the procession walk down the street holding small horn brass instruments. [SEP] a drum line passes by walking down the street playing their instruments. [SEP]
[CLS] members of the procession walk down the street holding small horn brass instruments. [SEP] a drum line has heard approaching them. [SEP]
[CLS] members of the procession walk down the street holding small horn brass instruments. [SEP] a drum line arrives and they're outside dancing and asleep. [SEP]
[CLS] members of the procession walk down the street holding small horn brass instruments. [SEP] a drum line turns the lead singer watches the performance. [SEP]


In [9]:
tokenized_swag["train"][0]

{'video-id': 'anetv_jkn6uvmqwh4',
 'fold-ind': '3416',
 'startphrase': 'Members of the procession walk down the street holding small horn brass instruments. A drum line',
 'sent1': 'Members of the procession walk down the street holding small horn brass instruments.',
 'sent2': 'A drum line',
 'gold-source': 'gold',
 'ending0': 'passes by walking down the street playing their instruments.',
 'ending1': 'has heard approaching them.',
 'ending2': "arrives and they're outside dancing and asleep.",
 'ending3': 'turns the lead singer watches the performance.',
 'label': 0,
 'input_ids': [[101,
   2372,
   1997,
   1996,
   14385,
   3328,
   2091,
   1996,
   2395,
   3173,
   2235,
   7109,
   8782,
   5693,
   1012,
   102,
   1037,
   6943,
   2240,
   5235,
   2011,
   3788,
   2091,
   1996,
   2395,
   2652,
   2037,
   5693,
   1012,
   102],
  [101,
   2372,
   1997,
   1996,
   14385,
   3328,
   2091,
   1996,
   2395,
   3173,
   2235,
   7109,
   8782,
   5693,
   1012,
   102,


In [10]:
from dataclasses import dataclass
from transformers.tokenization_utils_base import PreTrainedTokenizerBase, PaddingStrategy
from typing import Optional, Union
import torch


@dataclass
class DataCollatorForMultipleChoice:
    """
    Data collator that will dynamically pad the inputs for multiple choice received.
    """

    tokenizer: PreTrainedTokenizerBase
    padding: Union[bool, str, PaddingStrategy] = True
    max_length: Optional[int] = None
    pad_to_multiple_of: Optional[int] = None

    def __call__(self, features):
        label_name = "label" if "label" in features[0].keys() else "labels"
        labels = [feature.pop(label_name) for feature in features]
        batch_size = len(features)
        num_choices = len(features[0]["input_ids"])
        flattened_features = [
            [{k: v[i] for k, v in feature.items()} for i in range(num_choices)] for feature in features
        ]
        flattened_features = sum(flattened_features, [])

        batch = self.tokenizer.pad(
            flattened_features,
            padding=self.padding,
            max_length=self.max_length,
            pad_to_multiple_of=self.pad_to_multiple_of,
            return_tensors="pt",
        )

        batch = {k: v.view(batch_size, num_choices, -1) for k, v in batch.items()}
        batch["labels"] = torch.tensor(labels, dtype=torch.int64)
        return batch

In [11]:
data_collator=DataCollatorForMultipleChoice(tokenizer=tokenizer)

In [12]:
tokenized_swag

DatasetDict({
    train: Dataset({
        features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 73546
    })
    validation: Dataset({
        features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 20006
    })
    test: Dataset({
        features: ['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3', 'label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 20005
    })
})

In [21]:
# tokenized_swag = tokenized_swag.remove_columns(['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3'])
tokenized_swag

DatasetDict({
    train: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 73546
    })
    validation: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 20006
    })
    test: Dataset({
        features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
        num_rows: 20005
    })
})

In [23]:
tokenized_swag.set_format("torch")

In [13]:
temp_data = tokenized_swag["train"].remove_columns(['video-id', 'fold-ind', 'startphrase', 'sent1', 'sent2', 'gold-source', 'ending0', 'ending1', 'ending2', 'ending3'])
temp_data

Dataset({
    features: ['label', 'input_ids', 'token_type_ids', 'attention_mask'],
    num_rows: 73546
})

In [14]:
samples = [temp_data[i] for i in range(1)]
data_collator(samples)

{'input_ids': tensor([[[  101,  2372,  1997,  1996, 14385,  3328,  2091,  1996,  2395,  3173,
            2235,  7109,  8782,  5693,  1012,   102,  1037,  6943,  2240,  5235,
            2011,  3788,  2091,  1996,  2395,  2652,  2037,  5693,  1012,   102],
          [  101,  2372,  1997,  1996, 14385,  3328,  2091,  1996,  2395,  3173,
            2235,  7109,  8782,  5693,  1012,   102,  1037,  6943,  2240,  2038,
            2657,  8455,  2068,  1012,   102,     0,     0,     0,     0,     0],
          [  101,  2372,  1997,  1996, 14385,  3328,  2091,  1996,  2395,  3173,
            2235,  7109,  8782,  5693,  1012,   102,  1037,  6943,  2240,  8480,
            1998,  2027,  1005,  2128,  2648,  5613,  1998,  6680,  1012,   102],
          [  101,  2372,  1997,  1996, 14385,  3328,  2091,  1996,  2395,  3173,
            2235,  7109,  8782,  5693,  1012,   102,  1037,  6943,  2240,  4332,
            1996,  2599,  3220, 12197,  1996,  2836,  1012,   102,     0,     0]]]),
 'token_

In [15]:
from transformers import AutoModelForMultipleChoice, TrainingArguments, Trainer

model = AutoModelForMultipleChoice.from_pretrained("bert-base-uncased")

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMultipleChoice: ['cls.predictions.transform.dense.weight', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.seq_relationship.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.bias', 'cls.predictions.bias']
- This IS expected if you are initializing BertForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForMultipleChoice were not initialized from the model checkpoint at bert-base-uncased and are newly

In [24]:
from torch.utils.data import DataLoader
batch_size = 32

train_dataloader = DataLoader(
    tokenized_swag["train"], shuffle=True, batch_size=batch_size, collate_fn=data_collator
)
eval_dataloader = DataLoader(
    tokenized_swag["test"], batch_size=batch_size, collate_fn=data_collator
)

In [25]:
for batch in train_dataloader:
    print({k: v.shape for k, v in batch.items()})
    break

{'input_ids': torch.Size([32, 4, 51]), 'token_type_ids': torch.Size([32, 4, 51]), 'attention_mask': torch.Size([32, 4, 51]), 'labels': torch.Size([32])}


In [26]:
import torch
with torch.no_grad():
    outputs = model(**batch)
    print(outputs.loss, outputs.logits.shape)

tensor(1.4142) torch.Size([32, 4])
