In [1]:
from datasets import set_caching_enabled
set_caching_enabled(False)

import pprint
pp = pprint.PrettyPrinter(depth=6, compact=True)
print = pp.pprint

In [2]:
from datasets import Dataset, DatasetDict

In [3]:
dataset = Dataset.from_csv('../scripts/data/ROCStories_winter2017 - ROCStories_winter2017.csv')
dataset

Using custom data configuration default-c9b127b229452d3c


Reusing dataset csv (/Users/lennartkeller/.cache/huggingface/datasets/csv/default-c9b127b229452d3c/0.0.0)


Dataset({
    features: ['storyid', 'storytitle', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5'],
    num_rows: 52665
})

In [4]:
len(dataset)

52665

In [5]:
print(dataset[0])

{'sentence1': 'David noticed he had put on a lot of weight recently.',
 'sentence2': 'He examined his habits to try and figure out the reason.',
 'sentence3': "He realized he'd been eating too much fast food lately.",
 'sentence4': 'He stopped going to burger places and started a vegetarian '
              'diet.',
 'sentence5': 'After a few weeks, he started to feel much better.',
 'storyid': '8bbe6d11-1e2e-413c-bf81-eaea05f4f1bd',
 'storytitle': 'David Drops the Weight'}


In [6]:
from random import shuffle
from random import seed as set_seed

def make_shuffle_func(sep_token):
    def shuffle_stories(entries, seed=42):
        set_seed(seed)
        entries_as_dicts = [
            dict(zip(entries, values))
            for values in zip(*entries.values())
        ]
        converted_entries = []
        for entry in entries_as_dicts:
            sents = [
                entry[key]
                for key in sorted(
                    [key for key in entry.keys() if key.startswith('sentence')
                    ], key=lambda x: int(x[-1])
                )
            ]
            sent_idx = list(range(len(sents)))
            sents_with_idx = list(zip(sents, sent_idx))
            shuffle(sents_with_idx)
            text = f'{sep_token} ' + f' {sep_token} '.join(
                [s[0]for s in sents_with_idx]
            ) 
            so_targets = [s[1] for s in sents_with_idx]
            shuffled_entry = {'text': text, 'so_targets': so_targets}
            converted_entries.append(shuffled_entry)
        new_entry = {
            key: [entry[key] for entry in converted_entries]
            for key in converted_entries[0]
        }
        return new_entry
    return shuffle_stories

In [7]:
map_func = make_shuffle_func('[CLS]')

In [8]:
dataset = dataset.map(map_func, batched=True)

  0%|          | 0/53 [00:00<?, ?ba/s]

In [9]:
print(dataset[0])

{'sentence1': 'David noticed he had put on a lot of weight recently.',
 'sentence2': 'He examined his habits to try and figure out the reason.',
 'sentence3': "He realized he'd been eating too much fast food lately.",
 'sentence4': 'He stopped going to burger places and started a vegetarian '
              'diet.',
 'sentence5': 'After a few weeks, he started to feel much better.',
 'so_targets': [3, 1, 2, 4, 0],
 'storyid': '8bbe6d11-1e2e-413c-bf81-eaea05f4f1bd',
 'storytitle': 'David Drops the Weight',
 'text': '[CLS] He stopped going to burger places and started a vegetarian '
         'diet. [CLS] He examined his habits to try and figure out the reason. '
         "[CLS] He realized he'd been eating too much fast food lately. [CLS] "
         'After a few weeks, he started to feel much better. [CLS] David '
         'noticed he had put on a lot of weight recently.'}


In [10]:
train_test = dataset.train_test_split(test_size=0.2, seed=42)

test_validation = train_test['test'].train_test_split(test_size=0.3, seed=42)

dataset = DatasetDict({
    'train': train_test['train'],
    'test': test_validation['train'],
    'val': test_validation['test']})
dataset

DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'so_targets', 'storyid', 'storytitle', 'text'],
        num_rows: 42132
    })
    test: Dataset({
        features: ['sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'so_targets', 'storyid', 'storytitle', 'text'],
        num_rows: 7373
    })
    val: Dataset({
        features: ['sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'so_targets', 'storyid', 'storytitle', 'text'],
        num_rows: 3160
    })
})

In [11]:
dataset.save_to_disk('rocstories')

In [12]:
! rm -r rocstories

In [13]:
from datasets import load_from_disk

dataset = load_from_disk('../scripts/data/rocstories')
print(dataset['train'].features)

{'sentence1': Value(dtype='string', id=None),
 'sentence2': Value(dtype='string', id=None),
 'sentence3': Value(dtype='string', id=None),
 'sentence4': Value(dtype='string', id=None),
 'sentence5': Value(dtype='string', id=None),
 'so_targets': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'storyid': Value(dtype='string', id=None),
 'storytitle': Value(dtype='string', id=None),
 'text': Value(dtype='string', id=None)}


In [14]:
print(dataset['train'][0])

{'sentence1': 'Bob took his daughter canoeing on the river.',
 'sentence2': 'The entire trip was about three miles.',
 'sentence3': 'The current was very fast through the shoals.',
 'sentence4': 'They got hung up on a rock near the end.',
 'sentence5': 'They had a good time and pledged to go again.',
 'so_targets': [3, 4, 2, 0, 1],
 'storyid': '2696704e-60a8-44de-99ea-bed2dada2a68',
 'storytitle': 'Canoeing',
 'text': '[CLS] They got hung up on a rock near the end. [CLS] They had a good '
         'time and pledged to go again. [CLS] The current was very fast '
         'through the shoals. [CLS] Bob took his daughter canoeing on the '
         'river. [CLS] The entire trip was about three miles.'}


In [15]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-cased', return_dict=True)
tokenized_text = tokenizer("Jimmy went down the road.")
print(tokenized_text)

{'attention_mask': [1, 1, 1, 1, 1, 1, 1, 1],
 'input_ids': [101, 4479, 1355, 1205, 1103, 1812, 119, 102],
 'token_type_ids': [0, 0, 0, 0, 0, 0, 0, 0]}


In [16]:
def make_tokenization_func(tokenizer, text_column, *args, **kwargs):
    def tokenization(entry):
        return tokenizer(entry[text_column], *args, **kwargs)
    return tokenization

tokenization = make_tokenization_func(
    tokenizer=tokenizer,
    text_column="text",
    padding="max_length",
    truncation=True,
    add_special_tokens=False,
    return_tensors='np'
)

dataset = dataset.map(tokenization, batched=True)
print(dataset['train'][0].keys())

  0%|          | 0/43 [00:00<?, ?ba/s]

  0%|          | 0/8 [00:00<?, ?ba/s]

  0%|          | 0/4 [00:00<?, ?ba/s]

dict_keys(['attention_mask', 'input_ids', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'so_targets', 'storyid', 'storytitle', 'text', 'token_type_ids'])


In [17]:
from torch.utils.data import DataLoader

def identity(batch):
    return batch

data_loader = DataLoader(dataset['train'], batch_size=2, collate_fn=identity)
batch = next(iter(data_loader))
print(len(batch))
print(type(batch))
print(batch[0].keys())

2
<class 'list'>
dict_keys(['attention_mask', 'input_ids', 'sentence1', 'sentence2', 'sentence3', 'sentence4', 'sentence5', 'so_targets', 'storyid', 'storytitle', 'text', 'token_type_ids'])


In [18]:
dataset = dataset.remove_columns(
    ["text", "storyid", "storytitle"] + [f"sentence{i}" for i in range(1, 6)]
)
dataset.set_format("torch")
print(dataset["train"].features)

{'attention_mask': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None),
 'input_ids': Sequence(feature=Value(dtype='int32', id=None), length=-1, id=None),
 'so_targets': Sequence(feature=Value(dtype='int64', id=None), length=-1, id=None),
 'token_type_ids': Sequence(feature=Value(dtype='int8', id=None), length=-1, id=None)}


In [19]:
from transformers import default_data_collator
from torch.nn.utils.rnn import pad_sequence

def so_data_collator(batch_entries, label_key='so_targets'):
    """
    Custom dataloader to apply padding to the labels.
    """
    label_dicts = []

    # We split the labels from the rest to process them independently
    for entry in batch_entries:
        label_dict = {}
        for key in list(entry.keys()):
            if label_key in key:
                label_dict[key] = entry.pop(key)
        label_dicts.append(label_dict)

    # Everything except our labels can easily be handled by the "default collator"
    batch = default_data_collator(batch_entries)

    # We need to pad the labels 'manually'
    for label in label_dicts[0]:
        labels = pad_sequence(
            [label_dict[label] for label_dict in label_dicts],
            batch_first=True,
            padding_value=-100,
        )

        batch[label] = labels
    return batch

In [20]:
data_loader = DataLoader(dataset['train'], batch_size=2, collate_fn=so_data_collator)
batch = next(iter(data_loader))
print(batch)

{'attention_mask': tensor([[1, 1, 1,  ..., 0, 0, 0],
        [1, 1, 1,  ..., 0, 0, 0]]),
 'input_ids': tensor([[ 101, 1220, 1400,  ...,    0,    0,    0],
        [ 101, 1109, 3336,  ...,    0,    0,    0]]),
 'so_targets': tensor([[3, 4, 2, 0, 1],
        [1, 0, 2, 3, 4]]),
 'token_type_ids': tensor([[0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])}


In [None]:
import torch
from torch import nn

def sentence_ordering_loss(batch_logits, batch_targets, batch_input_ids, target_token_id) -> torch.Tensor:
    # Since we have varying number of labels per instance, we need to compute the loss manually for each one.
    loss_fn = nn.MSELoss(reduction="sum")
    batch_loss = torch.tensor(0.0, dtype=torch.float64, requires_grad=True)
    for labels, logits, input_ids in zip(
        batch_labels, batch_logits, batch_input_ids
    ):
        # Firstly, we need to convert the sentence indices to regression targets.
        # Also we need to remove the padding entries (-100)
        true_labels = labels[labels != -100].reshape(-1)
        targets = true_labels.float()

        # Secondly, we need to get the logits from each target token in the input sequence
        target_logits = logits[input_ids == target_token_id].reshape(-1)

        # Sometimes, we will have less target_logits than targets due to trunction of the input.
        # In this case, we just consider as many targets as we have logits
        if target_logits.size(0) < targets.size(0):
            targets = targets[: target_logits.size(0)]

        # Finally we compute the loss for the current instance and add it to the batch loss
        batch_loss = batch_loss + loss_fn(targets, target_logits)

    # The final loss is obtained by averaging over the number of instances per batch
    loss = batch_loss / batch_logits.size(0)

    return loss