In [109]:
import pandas as pd
import numpy as np
import sklearn
import torch
import torch.nn as nn
import ast
import tqdm

from transformers import AutoConfig, AutoTokenizer, BertForQuestionAnswering
from torch.utils.data import DataLoader, Dataset



    
config = AutoConfig.from_pretrained("deepset/bert-base-cased-squad2")
model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2", use_fast=True)
    

In [110]:
class SashimiDataset(Dataset):
    def __init__(self, data):
        super(SashimiDataset, self).__init__()
        self.text = data.Text
        self.label = data.Label
        
    def __len__(self):
        return self.text.shape[0]*4
    
    def __getitem__(self, idx):
        text = self.text[idx // len(QUESTIONS)]
        qs_idx = idx % len(QUESTIONS)
        label = self.label[idx // len(QUESTIONS)][qs_idx]
        question = QUESTIONS[qs_idx]
        
        inputs = tokenizer(question, text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        label = torch.LongTensor(label)
        
        plus_idx = (inputs['input_ids'] == 102).nonzero()[0,1]
        if label.tolist() != [0,0]:
            label += plus_idx 
        
        return inputs, label
    

QUESTIONS = [
    'What activity?',
    'Which day?',
    'What time on that day?',
    'Which place to go?'
]

data = pd.read_csv('./data/updated3500s.csv', sep='\t')
#data['Text'] = [ast.literal_eval(data.Text[i])[0] for i in range(data.shape[0])]
data['Label'] = [ast.literal_eval(data.Label[i]) for i in range(data.shape[0])]
    
dataset = SashimiDataset(data)
dataloader = DataLoader(dataset, batch_size=4)

data.head()

Unnamed: 0,Text,Label
0,Hey. Are you free on Tuesday? Yes. Do you want...,"[[15, 18], [7, 7], [0, 0], [29, 31]]"
1,"Hey. Are you free on Friday? Sorry, I have pla...","[[0, 0], [0, 0], [0, 0], [0, 0]]"
2,Sup! Sup. Would you want to go rafting on Satu...,"[[11, 14], [16, 16], [29, 30], [41, 43]]"
3,"Yo, what's up? Nothing much. Let's watch a mov...","[[14, 16], [18, 18], [42, 43], [65, 67]]"
4,"Yo, what's up? Nothing much. Do you want to gr...","[[15, 16], [18, 18], [26, 27], [35, 36]]"


In [113]:
for idx, (inputs, label) in tqdm.notebook.tqdm(enumerate(dataloader)):
    for k in inputs.keys():
        inputs[k] = inputs[k].squeeze(1)
    outputs = model(**inputs)
    start, end = torch.argmax(outputs[0], -1), torch.argmax(outputs[1], -1)
    print(f"True:    {label.tolist()}\nPredict: {torch.stack((start, end)).T.tolist()}")
    for lb in label:
        print(lb, tokenizer.decode(inputs['input_ids'][0, lb[0]:lb[1]+1].tolist()), '\n')
    if idx == 1:
        break

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

True:    [[19, 22], [11, 11], [0, 0], [35, 37]]
Predict: [[20, 22], [11, 11], [14, 14], [22, 24]]
tensor([19, 22]) go rafting 

tensor([11, 11]) Tuesday 

tensor([0, 0]) [CLS] 

tensor([35, 37]) park? OK 

True:    [[0, 0], [0, 0], [0, 0], [0, 0]]
Predict: [[0, 0], [23, 23], [26, 26], [25, 25]]
tensor([0, 0]) [CLS] 

tensor([0, 0]) [CLS] 

tensor([0, 0]) [CLS] 

tensor([0, 0]) [CLS] 




In [111]:
data.Text[0]

'Hey. Are you free on Tuesday? Yes. Do you want to go rafting? Sure! Where should we meet? How about the theme park? OK! See you there!'

In [1]:
import pandas as pd
import numpy as np
import sklearn
import torch
import torch.nn as nn
import ast
import tqdm

from transformers import AutoConfig, AutoTokenizer, BertForQuestionAnswering
from torch.utils.data import DataLoader, Dataset


class SashimiDataset(Dataset):
    def __init__(self, data):
        super(SashimiDataset, self).__init__()
        self.text = data.Text
        self.label = data.Label
        
    def __len__(self):
        return self.text.shape[0]*4
    
    def __getitem__(self, idx):
        text = self.text[idx // len(QUESTIONS)]
        qs_idx = idx % len(QUESTIONS)
        label = self.label[idx // len(QUESTIONS)][qs_idx]
        question = QUESTIONS[qs_idx]
        
        inputs = tokenizer(question, text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        label = torch.LongTensor(label)
        
        print(label)
        
        plus_idx = (inputs['input_ids'] == 102).nonzero()[0,1]
        label += plus_idx + 1
        
        print(plus_idx, label)
        
        return inputs, label
        
        
def get_data(path, sep=',', index_col=None):
    data = pd.read_csv(path, sep=sep, index_col=index_col)
    data['Text'] = [ast.literal_eval(data.Text[i])[0] for i in range(data.shape[0])]
    data['Label'] = [ast.literal_eval(data.Label[i]) for i in range(data.shape[0])]
    return data

if __name__ == '__main__':
    
    config = AutoConfig.from_pretrained("deepset/bert-base-cased-squad2")
    model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")
    tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2", use_fast=True)
    
    QUESTIONS = [
        'What activity?',
        'What date?',
        'What time?',
        'Where to go?'
    ]

    data = get_data('./data/training_data.csv', sep='\t')
    dataset = SashimiDataset(data)
    dataloader = DataLoader(dataset, batch_size=1)
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), 2e-5)
    
    model.train()
    for epoch in range(3):
        total_loss = 0
        
        for inputs, label in tqdm.tqdm(dataloader):
            # Reform the inputs
            for k in inputs.keys():
                inputs[k] = inputs[k].squeeze(1)

            optimizer.zero_grad()
            outputs = model(**inputs)
            start, end = outputs[0], outputs[1]
            label = label

            loss = loss_fn(start, label[:,0]) + loss_fn(end, label[:,1])
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            break
            
        break
            
        print(f'Train loss {total_loss}')
        
        


  0%|                                                                                         | 0/8000 [00:00<?, ?it/s]

tensor([10, 12])
tensor(4) tensor([15, 17])


  0%|                                                                                         | 0/8000 [00:01<?, ?it/s]


In [7]:
start, end = torch.argmax(outputs[0]), torch.argmax(outputs[1])
print(f"{start}\t{end}\t{tokenizer.decode(inputs['input_ids'][0,start:end+1].tolist())}")

15	18	a play on Friday


In [20]:
data['Text'][1000]

"Sup! Sup. Would you want to go to a party tomorrow? Sure. Where should we meet? I don't know. How about the Japanese place? Great. When? You free 4 am? Sorry, I can't. How about 7 am? Ok. See you then!"