In [1]:
!pip install datasets transformers

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting datasets
  Downloading datasets-2.2.2-py3-none-any.whl (346 kB)
[K     |████████████████████████████████| 346 kB 16.1 MB/s 
[?25hCollecting transformers
  Downloading transformers-4.19.4-py3-none-any.whl (4.2 MB)
[K     |████████████████████████████████| 4.2 MB 47.8 MB/s 
Collecting xxhash
  Downloading xxhash-3.0.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (212 kB)
[K     |████████████████████████████████| 212 kB 55.3 MB/s 
Collecting fsspec[http]>=2021.05.0
  Downloading fsspec-2022.5.0-py3-none-any.whl (140 kB)
[K     |████████████████████████████████| 140 kB 59.0 MB/s 
[?25hCollecting responses<0.19
  Downloading responses-0.18.0-py3-none-any.whl (38 kB)
Collecting dill<0.3.5
  Downloading dill-0.3.4-py2.py3-none-any.whl (86 kB)
[K     |████████████████████████████████| 86 kB 5.3 MB/s 
Collecting aiohttp
  Downloading aiohttp-3.8.1-cp37-cp37m-manylinu

In [2]:
from datasets import load_dataset
from transformers import BertTokenizerFast
import torch

import pandas as pd
from tqdm import tqdm

# Create datasets

In [3]:
squadv2 = load_dataset("squad_v2")

Downloading builder script:   0%|          | 0.00/1.87k [00:00<?, ?B/s]

Downloading metadata:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading and preparing dataset squad_v2/squad_v2 (download: 44.34 MiB, generated: 122.41 MiB, post-processed: Unknown size, total: 166.75 MiB) to /root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d...


Downloading data files:   0%|          | 0/2 [00:00<?, ?it/s]

Downloading data:   0%|          | 0.00/9.55M [00:00<?, ?B/s]

Downloading data:   0%|          | 0.00/801k [00:00<?, ?B/s]

Extracting data files:   0%|          | 0/2 [00:00<?, ?it/s]

Generating train split:   0%|          | 0/130319 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/11873 [00:00<?, ? examples/s]

Dataset squad_v2 downloaded and prepared to /root/.cache/huggingface/datasets/squad_v2/squad_v2/2.0.0/09187c73c1b837c95d9a249cd97c2c3f1cebada06efe667b4427714b27639b1d. Subsequent calls will reuse this data.


  0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
squadv2 # train and validation sets (we will treat validation set as test set)

DatasetDict({
    train: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 130319
    })
    validation: Dataset({
        features: ['id', 'title', 'context', 'question', 'answers'],
        num_rows: 11873
    })
})

In [5]:
from transformers.utils.dummy_pt_objects import LayoutLMv2ForQuestionAnswering

def create_dataset(squad_data, split=None):
    print("FIRST PASS")
    contexts = set()
    for row in tqdm(squad_data):
        contexts.add(row["context"])
    
    if split:
        contexts = tuple(contexts)
        n_valid = int(split*len(contexts))
        splits = [contexts[:n_valid], contexts[n_valid:]]
    else:
        splits = [tuple(contexts)]    

    full_data = {s: {
        'question': [],
        'context': [],
        'orig_answer': [],
        'answer_begin': [],
        'answer_end': [],
    } for s in splits}

    print("SECOND PASS")
    for row in tqdm(squad_data):
        # Let's ignore all impossible answers for now

        answers_start, answers_text = row['answers']["answer_start"], row["answers"]["text"]
        answers_full = list(set(list(zip(answers_start, answers_text))))
        for start_idx, answer_text in answers_full:
            text = row['context']
            end_idx = start_idx + len(answer_text)

            for key, data in full_data.items():
                if text in key:
                    data['question'].append(row['question'])
                    data['context'].append(text)
                    data['orig_answer'].append(answer_text)
                    data['answer_begin'].append(start_idx)

                    # Deal with the problem of 1 or 2 more characters 
                    if text[start_idx:end_idx] == answer_text:
                        data['answer_end'].append(end_idx)
                    # When the real answer is more by one character
                    # elif text[start_idx-1:end_idx-1] == answer:
                    #     data['answer_begin'] = start_idx - 1
                    #     data['answer_end'] = end_idx - 1  
                    # # When the real answer is more by two characters  
                    # elif text[start_idx-2:end_idx-2] == answer:
                    #     data['answer_begin'] = start_idx - 2
                    #     data['answer_end'] = end_idx - 2
                    else:
                        raise RuntimeError("There are only 1 or 2 character shifts in the dataset so this error should never happen")
    
    if len(splits) == 1:
        return full_data[splits[0]]
    return full_data[splits[1]], full_data[splits[0]]
# pd.DataFrame(data)

In [6]:
train, valid = create_dataset(squadv2["train"], 0.1)
test = create_dataset(squadv2["validation"])

FIRST PASS


100%|██████████| 130319/130319 [00:20<00:00, 6400.06it/s]


SECOND PASS


100%|██████████| 130319/130319 [00:52<00:00, 2478.09it/s]


FIRST PASS


100%|██████████| 11873/11873 [00:01<00:00, 7516.01it/s]


SECOND PASS


100%|██████████| 11873/11873 [00:01<00:00, 6284.26it/s]


In [19]:
train_df = pd.DataFrame(train)
train_df.to_csv('train.csv')
train_df

Unnamed: 0,question,context,orig_answer,answer_begin,answer_end
0,When did Beyonce start becoming popular?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,in the late 1990s,269,286
1,What areas did Beyonce compete in when she was...,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,singing and dancing,207,226
2,When did Beyonce leave Destiny's Child and bec...,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,2003,526,530
3,In what city and state did Beyonce grow up?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,"Houston, Texas",166,180
4,In which decade did Beyonce become famous?,Beyoncé Giselle Knowles-Carter (/biːˈjɒnseɪ/ b...,late 1990s,276,286
...,...,...,...,...,...
78130,In what US state did Kathmandu first establish...,"Kathmandu Metropolitan City (KMC), in order to...",Oregon,229,235
78131,What was Yangon previously known as?,"Kathmandu Metropolitan City (KMC), in order to...",Rangoon,414,421
78132,With what Belorussian city does Kathmandu have...,"Kathmandu Metropolitan City (KMC), in order to...",Minsk,476,481
78133,In what year did Kathmandu create its initial ...,"Kathmandu Metropolitan City (KMC), in order to...",1975,199,203


In [20]:
valid_df = pd.DataFrame(valid)
valid_df.to_csv('valid.csv')
valid_df

Unnamed: 0,question,context,orig_answer,answer_begin,answer_end
0,In which year was reports about Beyonce perfor...,"In 2011, documents obtained by WikiLeaks revea...",2011,3,7
1,Who did Beyonce donate the money to earned fro...,"In 2011, documents obtained by WikiLeaks revea...",Clinton Bush Haiti Fund,367,390
2,Beyonce became the first female artist to perf...,"In 2011, documents obtained by WikiLeaks revea...",the 2011 Glastonbury Festival,486,515
3,Which organization did Beyonce's spokespeople ...,"In 2011, documents obtained by WikiLeaks revea...",The Huffington Post,313,332
4,Beyonce was listed in 2011 as the highest paid...,"In 2011, documents obtained by WikiLeaks revea...",minute,596,602
...,...,...,...,...,...
8681,In what century was Bhrikuti said to live?,Legendary Princess Bhrikuti (7th-century) and ...,7th,29,32
8682,When did Araniko die?,Legendary Princess Bhrikuti (7th-century) and ...,1306,69,73
8683,What religion did Araniko help to evangelize?,Legendary Princess Bhrikuti (7th-century) and ...,Buddhism,157,165
8684,How many Newar Buddhist monasteries are presen...,Legendary Princess Bhrikuti (7th-century) and ...,108,201,204


In [21]:
test_df = pd.DataFrame(test) # Notice: multiple correct answers in a single context
test_df.to_csv('test.csv')
test_df

Unnamed: 0,question,context,orig_answer,answer_begin,answer_end
0,In what country is Normandy located?,The Normans (Norman: Nourmands; French: Norman...,France,159,165
1,When were the Normans in Normandy?,The Normans (Norman: Nourmands; French: Norman...,in the 10th and 11th centuries,87,117
2,When were the Normans in Normandy?,The Normans (Norman: Nourmands; French: Norman...,10th and 11th centuries,94,117
3,From which countries did the Norse originate?,The Normans (Norman: Nourmands; French: Norman...,"Denmark, Iceland and Norway",256,283
4,Who was the Norse leader?,The Normans (Norman: Nourmands; French: Norman...,Rollo,308,313
...,...,...,...,...,...
10383,What is a very seldom used unit of mass in the...,"The pound-force has a metric counterpart, less...",slug,274,278
10384,What is a very seldom used unit of mass in the...,"The pound-force has a metric counterpart, less...",the metric slug,263,278
10385,What is a very seldom used unit of mass in the...,"The pound-force has a metric counterpart, less...",metric slug,267,278
10386,What seldom used term of a unit of force equal...,"The pound-force has a metric counterpart, less...",kip,712,715


# Load datasets, prepare tokenizers and dataloaders

In [13]:
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
# tokenize
train_tokenizer = tokenizer(train['context'], train['question'],
                  truncation=True, padding='max_length',
                  max_length=512, return_tensors='pt')
valid_tokenizer = tokenizer(valid['context'], valid['question'],
                  truncation=True, padding='max_length',
                  max_length=512, return_tensors='pt')
test_tokenizer = tokenizer(test['context'], test['question'],
                  truncation=True, padding='max_length',
                  max_length=512, return_tensors='pt')

Downloading:   0%|          | 0.00/28.0 [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/226k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/455k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/570 [00:00<?, ?B/s]

In [14]:
def add_token_positions(encodings, data):
  start_positions = []
  end_positions = []

  count = 0
  for i in range(len(data['context'])):
    start_positions.append(encodings.char_to_token(i, data['answer_begin'][i]))
    end_positions.append(encodings.char_to_token(i, data['answer_end'][i]))

    # if start position is None, the answer passage has been truncated
    if start_positions[-1] is None:
      start_positions[-1] = tokenizer.model_max_length
      
    # if end position is None, the 'char_to_token' function points to the space after the correct token, so add - 1
    if end_positions[-1] is None:
      end_positions[-1] = encodings.char_to_token(i, data['answer_end'][i] - 1)
      # if end position is still None the answer passage has been truncated
      if end_positions[-1] is None:
        count += 1
        end_positions[-1] = tokenizer.model_max_length

  # Update the data in dictionary
  encodings.update({'start_positions': start_positions, 'end_positions': end_positions})

In [15]:
add_token_positions(train_tokenizer, train)
add_token_positions(valid_tokenizer, valid)
add_token_positions(test_tokenizer, test)

In [18]:
train_tokenizer.input_ids.shape

torch.Size([78135, 512])

In [None]:
class SquadDataset(torch.utils.data.Dataset):
    def __init__(self, encodings):
        self.encodings = encodings

    def __getitem__(self, idx):
        return {key: torch.tensor(val[idx]) for key, val in self.encodings.items()}

    def __len__(self):
        return len(self.encodings.input_ids)

In [None]:
BATCH_SIZE=8

train_dataset = SquadDataset(train_tokenizer)
val_dataset = SquadDataset(valid_tokenizer)
test_dataset = SquadDataset(test_tokenizer)

train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True)
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=BATCH_SIZE, shuffle=True)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=BATCH_SIZE, shuffle=False)

In [None]:
for example in train_dataset:
    break

print(example)

{'input_ids': tensor([  101, 20773, 21025, 19358, 22815,  1011,  5708,  1006,  1013, 12170,
        23432, 29715,  3501, 29678, 12325, 29685,  1013, 10506,  1011, 10930,
         2078,  1011,  2360,  1007,  1006,  2141,  2244,  1018,  1010,  3261,
         1007,  2003,  2019,  2137,  3220,  1010,  6009,  1010,  2501,  3135,
         1998,  3883,  1012,  2141,  1998,  2992,  1999,  5395,  1010,  3146,
         1010,  2016,  2864,  1999,  2536,  4823,  1998,  5613,  6479,  2004,
         1037,  2775,  1010,  1998,  3123,  2000,  4476,  1999,  1996,  2397,
         4134,  2004,  2599,  3220,  1997,  1054,  1004,  1038,  2611,  1011,
         2177, 10461,  1005,  1055,  2775,  1012,  3266,  2011,  2014,  2269,
         1010, 25436, 22815,  1010,  1996,  2177,  2150,  2028,  1997,  1996,
         2088,  1005,  1055,  2190,  1011,  4855,  2611,  2967,  1997,  2035,
         2051,  1012,  2037, 14221,  2387,  1996,  2713,  1997, 20773,  1005,
         1055,  2834,  2201,  1010, 20754,  1999, 

  
