In [59]:
import torch
from torch import nn
from datasets import load_dataset
from transformers import BertTokenizerFast
import functools

In [56]:
class Head(nn.Module):
    """ one head of self-attention """

    def __init__(self, head_size):
        super().__init__()
        self.key = nn.Linear(n_embd, head_size, bias=False)
        self.query = nn.Linear(n_embd, head_size, bias=False)
        self.value = nn.Linear(n_embd, head_size, bias=False)
        self.register_buffer('tril', torch.tril(torch.ones(block_size, block_size)))

        self.dropout = nn.Dropout(dropout)

    def forward(self, x):
        B,T,C = x.shape
        k = self.key(x)   # (B,T,C)
        q = self.query(x) # (B,T,C)
        # compute attention scores ("affinities")
        wei = q @ k.transpose(-2,-1) * C**-0.5 # (B, T, C) @ (B, C, T) -> (B, T, T)
        wei = wei.masked_fill(self.tril[:T, :T] == 0, float('-inf')) # (B, T, T)
        wei = F.softmax(wei, dim=-1) # (B, T, T)
        wei = self.dropout(wei)
        # perform the weighted aggregation of the values
        v = self.value(x) # (B,T,C)
        out = wei @ v # (B, T, T) @ (B, T, C) -> (B, T, C)
        return out

def tokenize_function(tokenizer, examples):
    result = tokenizer(examples["text"])
    if tokenizer.is_fast:
        # word_ids maps each token to the index of the word in the source sentence that it came frome
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    
    result["num_tokens"] = [len(tokens) for tokens in result["input_ids"]]
    return result


def tokenize_and_chunk(examples):
    result = tokenizer(
        examples["text"],
        truncation = True,
        max_length=23,
        padding=True,
        return_overflowing_tokens=True,
    )
    if tokenizer.is_fast:
        # word_ids maps each token to the index of the word in the source sentence that it came frome
        result["word_ids"] = [result.word_ids(i) for i in range(len(result["input_ids"]))]
    
    # Extract mapping between new and old indices
    sample_map = result.pop("overflow_to_sample_mapping")
    for key, values in examples.items():
        result[key] = [values[i] for i in sample_map]
        
        
def group_texts(examples):
    # Concatenate all texts
    concatenated_examples = {k: sum(examples[k], []) for k in examples.keys()}
    # Compute length of concatenated texts
    total_length = len(concatenated_examples[list(examples.keys())[0]])
    # We drop the last chunk if it's smaller than chunk_size
    total_length = (total_length // chunk_size) * chunk_size
    # Split by chunks of max_len
    result = {
        k: [t[i : i + chunk_size] for i in range(0, total_length, chunk_size)]
        for k, t in concatenated_examples.items()
    }
    # Create a new labels column
    result["labels"] = result["input_ids"].copy()
    return result



In [121]:
dataset = load_dataset("bookcorpus", split="train").select(range(100000,200000))
dataset = dataset.map(lambda samples: {"text_length": [len(text) for text in samples["text"]]}, batched=True)
dataset = dataset.sort("text_length", reverse=True)

#tokenizer = BertTokenizerFast.from_pretrained("bert-base-uncased")
#tokenized_dataset = dataset.map(functools.partial(tokenize_function, tokenizer), batched=True)


In [122]:
def tokenize_and_chunk(examples):
    result = tokenizer(
        examples["text"],
        truncation = True,
        max_length=32,
        padding=True,
        return_overflowing_tokens=True,
    )
    # Extract mapping between new and old indices
    sample_map = result.pop("overflow_to_sample_mapping")
    for key, values in examples.items():
        result[key] = [values[i] for i in sample_map]
    return result
    
    
result = tokenize_and_chunk(dataset[0:2])

In [123]:
dataset[0:2]["text"]

["they were also generous with their wealth , but that was all done secretly , making sure that the people who had taken care of them as kids were now well cared for , including the cop who had released dominic for stealing some bread one day , or the teacher who had smacked the back of zayn 's head when he was n't concentrating on his math , or the neighborhood lady who would sit on her front stoop knitting hats for each of them so they were a little warmer on those freezing cold , winter days .",
 "third , that girl is desperate to talk about her business with someone , and since she 's coming to me that tells me that either you 're not her fiance and she does n't want to talk to you about anything , or you are her fiance but you have no business sense , which is probably the case because you have n't put a ring on her finger , or you 're a complete fuck-up who does n't care about the thing that 's most important to his future wife ."]

In [125]:
result[0:]

{'input_ids': [[101, 2027, 2020, 2036, 12382, 2007, 2037, 7177, 1010, 2021, 2008, 2001, 2035, 2589, 10082, 1010, 2437, 2469, 2008, 1996, 2111, 2040, 2018, 2579, 2729, 1997, 2068, 2004, 4268, 2020, 2085, 102], [101, 2092, 8725, 2005, 1010, 2164, 1996, 8872, 2040, 2018, 2207, 11282, 2005, 11065, 2070, 7852, 2028, 2154, 1010, 2030, 1996, 3836, 2040, 2018, 19203, 1996, 2067, 1997, 23564, 6038, 1005, 102], [101, 1055, 2132, 2043, 2002, 2001, 1050, 1005, 1056, 16966, 2006, 2010, 8785, 1010, 2030, 1996, 5101, 3203, 2040, 2052, 4133, 2006, 2014, 2392, 2358, 18589, 26098, 16717, 2005, 2169, 1997, 102], [101, 2068, 2061, 2027, 2020, 1037, 2210, 16676, 2006, 2216, 12809, 3147, 1010, 3467, 2420, 1012, 102, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [101, 2353, 1010, 2008, 2611, 2003, 7143, 2000, 2831, 2055, 2014, 2449, 2007, 2619, 1010, 1998, 2144, 2016, 1005, 1055, 2746, 2000, 2033, 2008, 4136, 2033, 2008, 2593, 2017, 1005, 2128, 102], [101, 2025, 2014, 19154, 1998, 2016, 2515, 1050, 1005, 105

In [100]:
len(result["input_ids"][1])

23

In [101]:
tokenizer.decode(result["input_ids"][1])

"[CLS] another man ' s child? [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]"

In [57]:
chunk_size = 128 # BERT is trained with 128 long sequences for the first 90% steps, and then with 512
result = tokenize_function(tokenizer, dataset[0:2])

In [58]:
result

{'input_ids': [[101, 2788, 1010, 2002, 2052, 2022, 13311, 2105, 1996, 2542, 2282, 1010, 2652, 2007, 2010, 10899, 1012, 102], [101, 2021, 2074, 2028, 2298, 2012, 1037, 7163, 2239, 2741, 2032, 8134, 4937, 22436, 2594, 1012, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], 'word_ids': [[None, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, None], [None, 0, 1, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 10, 10, 11, None]], 'num_tokens': [18, 17]}

In [55]:
dataset

Dataset({
    features: ['text'],
    num_rows: 74004228
})

In [52]:
sum(result["input_ids"], [])

[101,
 2788,
 1010,
 2002,
 2052,
 2022,
 13311,
 2105,
 1996,
 2542,
 2282,
 1010,
 2652,
 2007,
 2010,
 10899,
 1012,
 102,
 101,
 2021,
 2074,
 2028,
 2298,
 2012,
 1037,
 7163,
 2239,
 2741,
 2032,
 8134,
 4937,
 22436,
 2594,
 1012,
 102]

In [53]:
result["input_ids"]

[[101,
  2788,
  1010,
  2002,
  2052,
  2022,
  13311,
  2105,
  1996,
  2542,
  2282,
  1010,
  2652,
  2007,
  2010,
  10899,
  1012,
  102],
 [101,
  2021,
  2074,
  2028,
  2298,
  2012,
  1037,
  7163,
  2239,
  2741,
  2032,
  8134,
  4937,
  22436,
  2594,
  1012,
  102]]

In [13]:
result = tokenizer(dataset[0:2]["text"])

In [14]:
result["input_ids"]

{'input_ids': [[101, 2788, 1010, 2002, 2052, 2022, 13311, 2105, 1996, 2542, 2282, 1010, 2652, 2007, 2010, 10899, 1012, 102], [101, 2021, 2074, 2028, 2298, 2012, 1037, 7163, 2239, 2741, 2032, 8134, 4937, 22436, 2594, 1012, 102]], 'token_type_ids': [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]], 'attention_mask': [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1]]}

In [36]:
result.word_ids(1)

[None, 0, 1, 2, 3, 4, 5, 6, 6, 7, 8, 9, 10, 10, 10, 11, None]

In [35]:
tokenizer.decode(result["input_ids"][1])

'[CLS] but just one look at a minion sent him practically catatonic. [SEP]'

In [33]:
tokenizer.convert_ids_to_tokens(result["input_ids"][1])

['[CLS]',
 'but',
 'just',
 'one',
 'look',
 'at',
 'a',
 'mini',
 '##on',
 'sent',
 'him',
 'practically',
 'cat',
 '##aton',
 '##ic',
 '.',
 '[SEP]']

In [34]:
tokenizer.tokenize(dataset[1]["text"])

['but',
 'just',
 'one',
 'look',
 'at',
 'a',
 'mini',
 '##on',
 'sent',
 'him',
 'practically',
 'cat',
 '##aton',
 '##ic',
 '.']

In [None]:
tokenizer.