In [85]:
import pandas as pd
import numpy as np
import sklearn
import torch
import torch.nn as nn
import ast
import tqdm

from transformers import AutoConfig, AutoTokenizer, BertForQuestionAnswering
from torch.utils.data import DataLoader, Dataset


class SashimiDataset(Dataset):
    def __init__(self, data):
        super(SashimiDataset, self).__init__()
        self.text = data.Text
        self.label = data.Label
        
    def __len__(self):
        return self.text.shape[0]*4
    
    def __getitem__(self, idx):
        text = self.text[idx // len(QUESTIONS)]
        qs_idx = idx % len(QUESTIONS)
        label = self.label[idx // len(QUESTIONS)][qs_idx]
        question = QUESTIONS[qs_idx]
        
        inputs = tokenizer(question, text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        label = torch.LongTensor(label)
        
        plus_idx = (inputs['input_ids'] == 102).nonzero()[0,1]
        label += plus_idx 
        
        return inputs, label
    
    
config = AutoConfig.from_pretrained("deepset/bert-base-cased-squad2")
model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")
tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2", use_fast=True)
    

In [86]:
QUESTIONS = [
    'What activity?',
    'What date?',
    'What time?',
    'Where to go?'
]

data = pd.read_csv('./data/updated3500s.csv', sep='\t')
#data['Text'] = [ast.literal_eval(data.Text[i])[0] for i in range(data.shape[0])]
data['Label'] = [ast.literal_eval(data.Label[i]) for i in range(data.shape[0])]
    
dataset = SashimiDataset(data)
dataloader = DataLoader(dataset, batch_size=4)

data.head()

Unnamed: 0,Text,Label
0,Hey. Are you free on Tuesday? Yes. Do you want...,"[[15, 18], [7, 7], [0, 0], [29, 31]]"
1,"Hey. Are you free on Friday? Sorry, I have pla...","[[0, 0], [0, 0], [0, 0], [0, 0]]"
2,Sup! Sup. Would you want to go rafting on Satu...,"[[11, 14], [16, 16], [29, 30], [41, 43]]"
3,"Yo, what's up? Nothing much. Let's watch a mov...","[[14, 16], [18, 18], [42, 43], [65, 67]]"
4,"Yo, what's up? Nothing much. Do you want to gr...","[[15, 16], [18, 18], [26, 27], [35, 36]]"


In [59]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.AdamW(model.parameters(), 2e-5)

for epoch in range(3):
    total_loss = 0
    
    for inputs, label in tqdm.notebook.tqdm(dataloader):
        # Reform the inputs
        for k in inputs.keys():
            inputs[k] = inputs[k].squeeze(1)

        optimizer.zero_grad()
        outputs = model(**inputs)
        predict = torch.stack(outputs, -1)

        loss = loss_fn(predict, label)
        loss.backward()
        optimizer.step()
        
        total_loss += loss.item()
        
        break
        
    print(f'Train loss {total_loss}')

HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train loss 0.008059692569077015


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train loss 0.1382034569978714


HBox(children=(FloatProgress(value=0.0, max=2000.0), HTML(value='')))


Train loss 0.004682702012360096


In [58]:
loss_fn(torch.stack([start, end], -1), label)

tensor(0.0105, grad_fn=<NllLoss2DBackward>)

In [53]:
torch.stack([start, start], -1).shape

torch.Size([4, 2])

In [21]:
TEXT = "Sup! Sup. Let's see a play on Friday? Ok. Where should we meet? I don't know. How about the theme park? Sure! When? 9 pm? Sorry, I can't. How about 11 pm? Sure. See you then!"

for qs in QUESTIONS:
    inputs = tokenizer(qs, data.Text[1000], max_length=128, padding='max_length', truncation=True, return_tensors='pt')
    with torch.no_grad():
        outputs = model(**inputs)
        start, end = torch.argmax(outputs[0]), torch.argmax(outputs[1])
        print(f"{qs}\t{start}\t{end}\t{tokenizer.decode(inputs['input_ids'][0,start:end+1].tolist())}\t{tokenizer.decode(inputs['input_ids'][0,start:end+1].tolist())}")
    

What activity?	0	18	[CLS] What activity? [SEP] Sup! Sup. Would you want to go to a party
What date?	0	0	[CLS]
What time?	58	59	7 am
Where to go?	0	0	[CLS]


In [88]:
for idx, (inputs, label) in tqdm.notebook.tqdm(enumerate(dataloader)):
    for k in inputs.keys():
        inputs[k] = inputs[k].squeeze(1)
    outputs = model(**inputs)
    start, end = torch.argmax(outputs[0], -1), torch.argmax(outputs[1], -1)
    print(f"True:    {label.tolist()}\nPredict: {torch.stack((start, end)).T.tolist()}")
    for lb in label:
        print(lb, tokenizer.decode(inputs['input_ids'][0, lb[0]:lb[1]+1].tolist()), '\n')
    if idx == 2:
        break

HBox(children=(FloatProgress(value=1.0, bar_style='info', max=1.0), HTML(value='')))

True:    [[19, 22], [11, 11], [4, 4], [34, 36]]
Predict: [[20, 22], [0, 0], [11, 11], [21, 36]]
tensor([19, 22]) go rafting 

tensor([11, 11]) Tuesday 

tensor([4, 4]) [SEP] 

tensor([34, 36]) theme park? 

True:    [[4, 4], [4, 4], [4, 4], [5, 5]]
Predict: [[0, 0], [23, 23], [23, 23], [24, 24]]
tensor([4, 4]) [SEP] 

tensor([4, 4]) [SEP] 

tensor([4, 4]) [SEP] 

tensor([5, 5]) Hey 

True:    [[15, 18], [20, 20], [33, 34], [46, 48]]
Predict: [[16, 18], [20, 20], [20, 20], [21, 21]]
tensor([15, 18]) go rafting 

tensor([20, 20]) Saturday 

tensor([33, 34]) 7 pm 

tensor([46, 48]) ##nyi? 




In [83]:
data

Unnamed: 0,Text,Label
0,Hey. Are you free on Tuesday? Yes. Do you want...,"[[15, 18], [7, 7], [0, 0], [29, 31]]"
1,"Hey. Are you free on Friday? Sorry, I have pla...","[[0, 0], [0, 0], [0, 0], [0, 0]]"
2,Sup! Sup. Would you want to go rafting on Satu...,"[[11, 14], [16, 16], [29, 30], [41, 43]]"
3,"Yo, what's up? Nothing much. Let's watch a mov...","[[14, 16], [18, 18], [42, 43], [65, 67]]"
4,"Yo, what's up? Nothing much. Do you want to gr...","[[15, 16], [18, 18], [26, 27], [35, 36]]"
...,...,...
3495,Hey! Let's grab coffee on Tuesday? OK! Where s...,"[[6, 7], [9, 9], [47, 48], [20, 22]]"
3496,"Yo, what's up? Nothing much. Do you want to gr...","[[15, 16], [18, 18], [29, 30], [41, 43]]"
3497,Sup! Sup. Would you want to bake today? Ok. Wh...,"[[11, 12], [13, 13], [36, 37], [54, 55]]"
3498,Sup! Sup. Would you want to see a concert on W...,"[[11, 13], [15, 15], [41, 42], [26, 28]]"


In [1]:
import pandas as pd
import numpy as np
import sklearn
import torch
import torch.nn as nn
import ast
import tqdm

from transformers import AutoConfig, AutoTokenizer, BertForQuestionAnswering
from torch.utils.data import DataLoader, Dataset


class SashimiDataset(Dataset):
    def __init__(self, data):
        super(SashimiDataset, self).__init__()
        self.text = data.Text
        self.label = data.Label
        
    def __len__(self):
        return self.text.shape[0]*4
    
    def __getitem__(self, idx):
        text = self.text[idx // len(QUESTIONS)]
        qs_idx = idx % len(QUESTIONS)
        label = self.label[idx // len(QUESTIONS)][qs_idx]
        question = QUESTIONS[qs_idx]
        
        inputs = tokenizer(question, text, padding='max_length', truncation=True, max_length=128, return_tensors='pt')
        label = torch.LongTensor(label)
        
        print(label)
        
        plus_idx = (inputs['input_ids'] == 102).nonzero()[0,1]
        label += plus_idx + 1
        
        print(plus_idx, label)
        
        return inputs, label
        
        
def get_data(path, sep=',', index_col=None):
    data = pd.read_csv(path, sep=sep, index_col=index_col)
    data['Text'] = [ast.literal_eval(data.Text[i])[0] for i in range(data.shape[0])]
    data['Label'] = [ast.literal_eval(data.Label[i]) for i in range(data.shape[0])]
    return data

if __name__ == '__main__':
    
    config = AutoConfig.from_pretrained("deepset/bert-base-cased-squad2")
    model = BertForQuestionAnswering.from_pretrained("deepset/bert-base-cased-squad2")
    tokenizer = AutoTokenizer.from_pretrained("deepset/bert-base-cased-squad2", use_fast=True)
    
    QUESTIONS = [
        'What activity?',
        'What date?',
        'What time?',
        'Where to go?'
    ]

    data = get_data('./data/training_data.csv', sep='\t')
    dataset = SashimiDataset(data)
    dataloader = DataLoader(dataset, batch_size=1)
    
    loss_fn = nn.CrossEntropyLoss()
    optimizer = torch.optim.AdamW(model.parameters(), 2e-5)
    
    model.train()
    for epoch in range(3):
        total_loss = 0
        
        for inputs, label in tqdm.tqdm(dataloader):
            # Reform the inputs
            for k in inputs.keys():
                inputs[k] = inputs[k].squeeze(1)

            optimizer.zero_grad()
            outputs = model(**inputs)
            start, end = outputs[0], outputs[1]
            label = label

            loss = loss_fn(start, label[:,0]) + loss_fn(end, label[:,1])
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            
            break
            
        break
            
        print(f'Train loss {total_loss}')
        
        


  0%|                                                                                         | 0/8000 [00:00<?, ?it/s]

tensor([10, 12])
tensor(4) tensor([15, 17])


  0%|                                                                                         | 0/8000 [00:01<?, ?it/s]


In [7]:
start, end = torch.argmax(outputs[0]), torch.argmax(outputs[1])
print(f"{start}\t{end}\t{tokenizer.decode(inputs['input_ids'][0,start:end+1].tolist())}")

15	18	a play on Friday


In [20]:
data['Text'][1000]

"Sup! Sup. Would you want to go to a party tomorrow? Sure. Where should we meet? I don't know. How about the Japanese place? Great. When? You free 4 am? Sorry, I can't. How about 7 am? Ok. See you then!"