In [None]:
from datasets import load_from_disk, load_dataset, DatasetDict
from transformers import TrainingArguments, Trainer
from transformers import DistilBertTokenizerFast, BertTokenizerFast
import contextGenerator
import numpy as np
import utils
import torch
import re
import evaluate


tokenizer = BertTokenizerFast.from_pretrained('distilbert-base-cased-distilled-squad')
model = DistilBertTokenizerFast.from_pretrained('distilbert-base-cased-distilled-squad')
contextGen = contextGenerator.LuceneRetrieval()
try:
    ds = load_from_disk('../res/data/qanta')
except:
    ds = load_dataset("community-datasets/qanta", "mode=first,char_skip=25")


The tokenizer class you load from this checkpoint is not the same type as the class this function is called from. It may result in unexpected tokenization. 
The tokenizer class you load from this checkpoint is 'DistilBertTokenizer'. 
The class this function is called from is 'BertTokenizerFast'.


# Preprocessing the data 
Given how BERT is a extractive model it will attempt to highlight its prediction in the provided context. In other words our task is to fine tune the model to predict the start and end positions of the answer in the context.  
#### 1. Retreive context
For each question we will need a relevent document where the answer may exist. 

In [8]:
ds = ds.map(lambda x: {'context':  contextGen(x['full_question'], 1)[0]})

Map: 100%|██████████| 96221/96221 [4:18:24<00:00,  6.21 examples/s]     
Map: 100%|██████████| 16706/16706 [14:50<00:00, 18.76 examples/s]
Map: 100%|██████████| 1055/1055 [01:13<00:00, 14.31 examples/s]
Map: 100%|██████████| 1161/1161 [01:21<00:00, 14.30 examples/s]
Map: 100%|██████████| 2151/2151 [02:44<00:00, 13.05 examples/s]
Map: 100%|██████████| 1953/1953 [02:21<00:00, 13.76 examples/s]
Map: 100%|██████████| 1145/1145 [01:07<00:00, 17.06 examples/s]


#### 2. Find the start and end postions
The contexts and questions are just strings to so we need to find the positions for the answers in the context. 

In [9]:
ds = ds.map(lambda x: {'char_pos':  utils.term_char_index(x['answer'], x['context']['contents'])})

Map: 100%|██████████| 96221/96221 [00:18<00:00, 5223.65 examples/s]
Map: 100%|██████████| 16706/16706 [00:03<00:00, 5426.48 examples/s]
Map: 100%|██████████| 1055/1055 [00:00<00:00, 5698.17 examples/s]
Map: 100%|██████████| 1161/1161 [00:00<00:00, 3495.97 examples/s]
Map: 100%|██████████| 2151/2151 [00:00<00:00, 5880.18 examples/s]
Map: 100%|██████████| 1953/1953 [00:00<00:00, 5915.87 examples/s]
Map: 100%|██████████| 1145/1145 [00:00<00:00, 3488.00 examples/s]


### 2. Tokenize context/question pair and find the token positions
Ensure the context comes first in the pair to align the character index with the token index. BERT limits the combined token count of context and question to 512. Since the context is capped at 400 words, this won’t cause issues, but we’ll use padding and truncation for consistency and edge cases.

In [68]:
unpack = lambda x, y, z: {"start_positions": x, "end_positions": y, "encodings": z}

def tokenize_row(row: dict, tokenizer):
    try: 
        encoding =  tokenizer(
            text = row['context']['contents'], 
            text_pair = row['full_question'], 
            padding = 'max_length', 
            truncation = 'only_first', 
            max_length = 512, 
            return_tensors = 'pt', 
            padding_side = 'right',
            return_length = True
            )
    except:
        cleaned = utils.clean_text(row['full_question'])
        encoding =  tokenizer(
            text = row['context']['contents'], 
            text_pair = cleaned, 
            padding = 'max_length', 
            truncation = 'only_first', 
            max_length = 512, 
            return_tensors = 'pt', 
            padding_side = 'right',
            return_length = True
            )
    start_pos = []
    end_pos = []
    # Convert the dictionary to a BatchEncoding object
    for (x, y) in row['char_pos']:
        start_pos.append(encoding.char_to_token(x))
        end_pos.append(encoding.char_to_token(y - 1))
    return start_pos, end_pos, encoding



ds = ds.map(lambda x: unpack(*tokenize_row(x, tokenizer)))

Map: 100%|██████████| 96221/96221 [02:51<00:00, 561.98 examples/s]
Map: 100%|██████████| 16706/16706 [00:29<00:00, 559.08 examples/s]
Map: 100%|██████████| 1055/1055 [00:01<00:00, 540.50 examples/s]
Map: 100%|██████████| 1161/1161 [00:02<00:00, 472.54 examples/s]
Map: 100%|██████████| 2151/2151 [00:03<00:00, 538.36 examples/s]
Map: 100%|██████████| 1953/1953 [00:03<00:00, 542.76 examples/s]
Map: 100%|██████████| 1145/1145 [00:01<00:00, 578.59 examples/s]


### Verify consistent length

In [None]:
# tokenizer.decode(ds[0]['encodings']["input_ids"][0][ds[0]['start_positions'][0]:ds[0]['end_positions'][0]+1], skip_special_tokens=True)
splits = ['guesstrain', 'guessdev', 'guesstrain']
count = []

for y in splits:
    for x in ds[y]:
        if x['encodings']['length'][0] != 512:
            count.append((x, y))
len(count)

0

In [87]:
guessTrain = DatasetDict({
    'train': ds['guesstrain'],
    'val': ds['guessdev'],
    'test': ds['guesstest'],

})
guessTrain = guessTrain.remove_columns(['qanta_id', 'proto_id', 'qdb_id', 'dataset', 'text', 'char_idx', 'sentence_idx', 'tokenizations', 'fold']) 

guessTrain.save_to_disk('../res/data/guessTrain')

Saving the dataset (4/4 shards): 100%|██████████| 96221/96221 [00:00<00:00, 109545.66 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 1055/1055 [00:00<00:00, 110052.50 examples/s]
Saving the dataset (1/1 shards): 100%|██████████| 2151/2151 [00:00<00:00, 119300.06 examples/s]


# Training

In [None]:
training_args = TrainingArguments(
    output_dir="your-model",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    num_train_epochs=2,
    weight_decay=0.01,
    eval_strategy="epoch",
    save_strategy="epoch",
    load_best_model_at_end=True,
    push_to_hub=True,
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=guessTrain["train"],
    eval_dataset=guessTrain["test"],
    processing_class=tokenizer,
    data_collator=data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()

NameError: name 'dataset' is not defined