In [2]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
import pytorch_lightning as pl
from transformers import DistilBertModel, DistilBertTokenizer
import pandas as pd
from scripts.utils import *
from models import jointBert
from TorchCRF import CRF
from seqeval.metrics import f1_score
import pickle

In [3]:
train_dir = './data/splits/multi-train.tsv'
max_len = 56
tokenizer_weights = 'distilbert-base-multilingual-cased'

In [4]:
with open('./notebooks/map_ids_slots.pickle', 'rb') as handle:
    slot_dictionary = pickle.load(handle)

In [5]:
def get_slot_labels(slot_labels,slot_pred,slot_dictionary):
    processed_labels = []
    for labels in slot_labels:
        processed_labels.append(labels.split())
    
    processed_pred = []
    
    for pred in slot_pred:
        slot_tokens = []
        for slot_id in pred:
            slot_tokens.append(slot_dictionary[slot_id])
        processed_pred.append(slot_tokens)  
    return processed_labels,processed_pred

In [6]:
def process_label(labels, max_len):
    slot_label,slot_mask,slot_length = [] , [],0
    
    for sLabel in labels.split():
        slot_label.append(int(sLabel))
        slot_mask.append(1)
        slot_length +=1
    slot_label += [82]*(max_len - slot_length)
    slot_mask += [0]*(max_len - slot_length)
    
    slot_label = torch.LongTensor(slot_label)
    slot_mask = torch.LongTensor(slot_mask)
    
    return slot_label, slot_mask

In [7]:
class nlu_dataset(Dataset):
    def __init__(self, file_dir, tokenizer, max_len):
        
        self.data = pd.read_csv(file_dir, sep='\t')
        self.tokenizer = DistilBertTokenizer.from_pretrained(tokenizer)
        self.max_len = max_len
    def __getitem__(self, index):
        
        text = str(self.data.utterance[index])
        text = " ".join(text.split())
        
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            max_length=self.max_len,
            padding='max_length',
            return_token_type_ids=True,
            truncation=True
        )
        
        token_ids = inputs['input_ids']
        mask = inputs['attention_mask']

        intent_target = torch.tensor(self.data.intent_ID[index], dtype=torch.long)
        slot_target,slot_mask = process_label(self.data.slots_ID[index],self.max_len)
        language = self.data.language[index]
        
        return {
            'token_ids': torch.tensor(token_ids, dtype=torch.long),
            'mask': torch.tensor(mask, dtype=torch.long),
            'intent_target': intent_target,
            'slot_target' : slot_target,
            'slot_mask' : slot_mask,
            'language' : language,
            'intent_label': self.data.intent[index],
            'slot_label' : self.data.slot_labels[index]
        } 
    
    def __len__(self):
        return len(self.data)

In [8]:
class jointBert(nn.Module):

    def __init__(self, model_name, num_intent,num_slots,joint_loss_coef):

        super(jointBert,self).__init__()
        
        # base encoder
        self.encoder = DistilBertModel.from_pretrained(model_name,return_dict=True,output_hidden_states=True)

        # intent layer
        #p_intent = trial.suggest_float("dropout_intent", 0.1, 0.4)
        self.intent_dropout = nn.Dropout(0.1)
        self.intent_linear = nn.Linear(768, num_intent)
        
        # slots layer
        self.slot_classifier = nn.Linear(768, num_slots)
        #p_slots = trial.suggest_float("dropout_slots", 0.1, 0.4)
        self.dropout_slots = nn.Dropout(0.1)

        self.crf = CRF(num_slots)

        self.intent_loss = nn.CrossEntropyLoss()
        self.joint_loss_coef =  joint_loss_coef
    

    
    def forward(self, input_ids, attention_mask, intent_target, slot_target,slot_mask):

        encoded_output = self.encoder(input_ids, attention_mask)

        #intent data flow
        intent_hidden = encoded_output[0][:,0]
        intent_hidden = self.intent_dropout(intent_hidden)
        intent_logits = self.intent_linear(intent_hidden)
        # accumulating intent classification loss 
        intent_loss = self.intent_loss(intent_logits, intent_target)
        
        intent_pred = torch.argmax(nn.Softmax(dim=1)(intent_logits),axis=1)
        
        # slots data flow 
        slots_hidden = encoded_output[0]
        #print(slots_hidden.size())
        slots_logits = self.slot_classifier(self.dropout_slots(slots_hidden))
        #print(slots_logits.size())
        # accumulating slot prediction loss
        #print(slot_target)
        #print(slots_logits)
        #print(attention_mask.byte())
        #slot_loss = self.crf(slots_logits, slot_target, mask=slot_mask.byte())
        slot_loss = -1 * self.joint_loss_coef * self.crf(slots_logits, slot_target, mask=slot_mask.byte())
        slot_loss = torch.mean(slot_loss)
        
        joint_loss = (slot_loss + intent_loss)/2.0

        slot_pred = self.crf.viterbi_decode(slots_logits, slot_mask.byte())
        #print(slot_pred)
        return joint_loss,slot_pred,intent_pred

In [9]:
trainDS =  nlu_dataset(train_dir,tokenizer_weights,max_len)

In [10]:
trainDL = DataLoader(trainDS,batch_size=2,shuffle=True)

In [11]:
model = jointBert(model_name='distilbert-base-multilingual-cased',num_intent=17,num_slots=159,joint_loss_coef=1.0)

In [12]:
def accuracy(pred,target):
    return torch.sum(pred==target)/float(len(target))

In [13]:
for idx,batch in enumerate(trainDL):
    
    token_ids = batch['token_ids']
    mask = batch['mask']
    intent_target = batch['intent_target']
    slot_target = batch['slot_target']
    slot_label = batch['slot_label']
    slot_mask = batch['slot_mask']
    joint_loss , slot_pred, intent_pred = (model(token_ids,mask,intent_target,slot_target,slot_mask))
    #print(slot_pred,slot_label,slot_mask)
    slot_target,slot_pred = get_slot_labels(slot_label,slot_pred,slot_dictionary)
    print(slot_target,slot_pred)
    print(f1_score(slot_target,slot_pred))
    print(accuracy(intent_pred,intent_target))
    print(intent_pred,intent_target)
    break

[['O', 'O', 'O', 'O', 'O', 'B-fromloc.city_name', 'I-fromloc.city_name', 'O', 'B-toloc.city_name', 'O', 'O', 'B-depart_date.day_name'], ['O', 'O', 'O', 'O', 'O', 'B-fromloc.city_name', 'O', 'B-toloc.city_name', 'I-toloc.city_name']] [['B-arrive_time.start_time', 'I-arrive_date.day_name', 'I-airport_code', 'B-arrive_date.day_number', 'B-flight_days', 'I-depart_time.start_time', 'B-arrive_time.start_time', 'B-today_relative', 'I-return_date.day_number', 'B-today_relative', 'B-day_number', 'B-arrive_date.day_number'], ['I-depart_date.day_number', 'B-toloc.airport_code', 'B-depart_time.start_time', 'B-class_type', 'B-flight_stop', 'I-day_number', 'B-or', 'I-airline_code', 'I-airline_name']]
0.0
tensor(0.)
