In [1]:
import pandas as pd
import numpy as np
import sklearn
import torch
import torch.nn as nn
import ast
import tqdm

from transformers import AutoConfig, AutoTokenizer, BertForQuestionAnswering
from torch.utils.data import DataLoader, Dataset


class SashimiDataset(Dataset):
    def __init__(self, data):
        super(SashimiDataset, self).__init__()
        self.text = data.Text
        self.label = data.Label
        
    def __len__(self):
        return self.text.shape[0]*4
    
    def __getitem__(self, idx):
        text = self.text[idx // len(QUESTIONS)]
        qs_idx = idx % len(QUESTIONS)
        label = self.label[idx // len(QUESTIONS)][qs_idx]
        question = QUESTIONS[qs_idx]
        
        inputs = tokenizer(question, text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        label = torch.LongTensor(label)
        
        plus_idx = (inputs['input_ids'] == 102).nonzero()[0,1]
        label += plus_idx + 1
        
        return inputs, label
    
    
config = AutoConfig.from_pretrained("deepset/bert-base-cased-squad2")
model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2", use_fast=True)
    

In [2]:
QUESTIONS = [
    'What activity?',
    'What date?',
    'What time?',
    'Where to go?'
]

data = pd.read_csv('./data/training_data.csv', sep='\t')
data['Text'] = [ast.literal_eval(data.Text[i])[0] for i in range(data.shape[0])]
data['Label'] = [ast.literal_eval(data.Label[i]) for i in range(data.shape[0])]
    
dataset = SashimiDataset(data)
dataloader = DataLoader(dataset, batch_size=3, shuffle=True)

data.head()

Unnamed: 0,Text,Label
0,Sup! Sup. Let's see a play on Friday? Ok. Wher...,"[[10, 12], [13, 14], [51, 52], [31, 33]]"
1,Hey. Would you want to go somewhere on Wednesd...,"[[7, 8], [9, 10], [50, 51], [28, 29]]"
2,"Yo, what's up? Nothing much. Let's go to a par...","[[14, 17], [18, 19], [61, 62], [38, 40]]"
3,"Yo, what's up? Nothing much. Would you want to...","[[15, 16], [17, 18], [59, 60], [33, 35]]"
4,"Hey, how are you doing? I’m good, and you? I’m...","[[26, 28], [29, 30], [70, 71], [49, 51]]"


In [295]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), 2e-5)

for epoch in range(3):
    total_loss = 0
    
    for inputs, label in tqdm.notebook.tqdm(dataloader):
        # Reform the inputs
        for k in inputs.keys():
            inputs[k] = inputs[k].squeeze(1)

        optimizer.zero_grad()
        outputs = model(**inputs)
        start, end = outputs[0], outputs[1]

        loss = loss_fn(start, label[:,0]) + loss_fn(end, label[:,1])
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
    print(f'Train loss {total_loss}')

HBox(children=(FloatProgress(value=0.0, max=2667.0), HTML(value='')))




KeyboardInterrupt: 

In [4]:
TEXT = "Sup! Sup. Let's see a play on Friday? Ok. Where should we meet? I don't know. How about the theme park? Sure! When? 9 pm? Sorry, I can't. How about 11 pm? Sure. See you then!"

for qs in QUESTIONS:
    inputs = tokenizer(qs, TEXT, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
        start, end = torch.argmax(outputs[0]), torch.argmax(outputs[1])
        print(f"{qs}\t{start}\t{end}\t{tokenizer.decode(inputs['input_ids'][0,start:end+1].tolist())}")
    

What activity?	11	16	Let's see a play
What date?	18	18	Friday
What time?	43	44	9 pm
Where to go?	0	38	[CLS] Where to go? [SEP] Sup! Sup. Let's see a play on Friday? Ok. Where should we meet? I don't know. How about the theme park


In [None]:
model.load_state_dict()

In [3]:
data.Text[0]

"Sup! Sup. Let's see a play on Friday? Ok. Where should we meet? I don't know. How about the theme park? Sure! When? 9 pm? Sorry, I can't. How about 11 pm? Sure. See you then!"