In [0]:
!pip install transformers
!pip install wget

Collecting transformers
[?25l  Downloading https://files.pythonhosted.org/packages/50/10/aeefced99c8a59d828a92cc11d213e2743212d3641c87c82d61b035a7d5c/transformers-2.3.0-py3-none-any.whl (447kB)
[K     |▊                               | 10kB 31.1MB/s eta 0:00:01[K     |█▌                              | 20kB 1.9MB/s eta 0:00:01[K     |██▏                             | 30kB 2.7MB/s eta 0:00:01[K     |███                             | 40kB 1.8MB/s eta 0:00:01[K     |███▋                            | 51kB 2.2MB/s eta 0:00:01[K     |████▍                           | 61kB 2.7MB/s eta 0:00:01[K     |█████▏                          | 71kB 3.1MB/s eta 0:00:01[K     |█████▉                          | 81kB 3.5MB/s eta 0:00:01[K     |██████▋                         | 92kB 3.9MB/s eta 0:00:01[K     |███████▎                        | 102kB 3.0MB/s eta 0:00:01[K     |████████                        | 112kB 3.0MB/s eta 0:00:01[K     |████████▉                       | 122kB 3.0M

In [0]:
import os
import wget
import torch
import numpy as np
import pandas as pd
from transformers import BertForQuestionAnswering, BertTokenizer
from torch.utils.data import Dataset, DataLoader

In [0]:
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
model = BertModel.from_pretrained('bert-base-uncased')

**Fine Tune the BERT Model with SQuAD**

In [0]:
def SQuAD_organize(SQuAD):

	"""
	takes the SQuAD json file and organizes the important parts into a list of lists
	impossible questions are "answered" with an empty string
	"""

	SQuAD_list_of_lists = []

	for i, subject in enumerate(SQuAD['data']):													#limited for testing purposes
		for paragraph in SQuAD['data'][i]['paragraphs'][0:2]:									#limited for testing purposes
			#find the context (paragraph)
			the_context = paragraph['context']
			for j, questions in enumerate(SQuAD['data'][i]['paragraphs']):	#limited for testing purposes
				for question in questions['qas'][0:2]:															#limited for testing purposes
					#find the question
					the_question = question['question']
					if question['answers'] != []:
						#find the answer (label)
						the_answer = question['answers'][0]['text']
					else:
						#or an empty string for impossible questions
						the_answer = ''
					row = [the_context.lower(), the_question.lower(), the_answer.lower()]
					SQuAD_list_of_lists.append(row)

	SQuAD_df = pd.DataFrame(SQuAD_list_of_lists, columns=['context', 'question', 'answer'])

	return SQuAD_df

In [0]:
#download the dataset from the github repository of the webinar
url_train = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/train-v2.0.json'
url_dev = 'https://rajpurkar.github.io/SQuAD-explorer/dataset/dev-v2.0.json'
if not os.path.exists('./train-v2.0.json'):
  wget.download(url_train, './train-v2.0.json')
if not os.path.exists('./dev-v2.0.json'):
  wget.download(url_dev, './dev-v2.0.json')

with open('train-v2.0.json', 'r') as json_train:
	SQuAD_train = json.load(json_train, encoding='utf-8')
with open('dev-v2.0.json', 'r') as json_val:
	SQuAD_val = json.load(json_val, encoding='utf-8')
 
SQuAD_train_df = SQuAD_organize(SQuAD_train)
SQuAD_val_df = SQuAD_organize(SQuAD_val)
SQuAD_train_df.head()

In [0]:
#time for the fine-tuning


**Make Predictions with the Fine-Tuned Model**

In [0]:
question = 'How many picks does it take to get to the Tootsie roll centre of a Tootsie pop?'
answer = 'Nobody knows...'
input_ids = tokenizer.encode(question, answer)
print(input_ids)
sequence_ids = [0 if i <= input_ids.index(102) else 1 for i in range(len(input_ids))]
print(sequence_ids)
start_scores, end_scores = model(torch.tensor([input_ids]), token_type_ids=torch.tensor([sequence_ids]))
all_tokens = tokenizer.convert_ids_to_tokens(input_ids)
print(all_tokens)
print(' '.join(all_tokens[torch.argmax(start_scores) : torch.argmax(end_scores)+1]))

[101, 2129, 2116, 11214, 2515, 2009, 2202, 2000, 2131, 2000, 1996, 2205, 3215, 2666, 4897, 2803, 1997, 1037, 2205, 3215, 2666, 3769, 1029, 102, 6343, 4282, 1012, 1012, 1012, 102]
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1]
['[CLS]', 'how', 'many', 'picks', 'does', 'it', 'take', 'to', 'get', 'to', 'the', 'too', '##ts', '##ie', 'roll', 'centre', 'of', 'a', 'too', '##ts', '##ie', 'pop', '?', '[SEP]', 'nobody', 'knows', '.', '.', '.', '[SEP]']
nobody knows
