In [7]:
from torch.utils.data import Dataset
from transformers import DistilBertTokenizerFast
import pandas as pd
import random
from torch.utils.data import DataLoader
import torch

In [8]:
class nluDataset(Dataset):

    def __init__(self, file_dir, tokenizer, max_len, device):
        
        self.data = pd.read_csv(file_dir, sep='\t')
        self.tokenizer = DistilBertTokenizerFast.from_pretrained(tokenizer)
        self.max_len = max_len
    
    def processSlotLabel(self,word_ids,slot_ids,text):
        
        # replace None and repetition with -100
        
        word_ids = [-100 if word_id ==None else word_id for word_id in word_ids]
        
        previous_word = -100
        
        for idx,wid in enumerate(word_ids):
            
            if wid == -100:
                continue
            
            if wid == previous_word:
                word_ids[idx] = -100
            
            previous_word = wid
        
        slot_ids = list(map(int, slot_ids.split(' ')))
        new_labels = [-100 if word_id ==-100 else slot_ids[word_id] for word_id in word_ids]
        
        return new_labels
        

    def __getitem__(self, index):
        
        text = str(self.data.TEXT[index])
        text = text.replace('.','')
        text = text.replace('\'','')
        text = " ".join(text.split())
        
        inputs = self.tokenizer.encode_plus(
            text,
            None,
            add_special_tokens=True,
            return_token_type_ids=False,
            truncation=True,
            max_length=self.max_len,
            padding='max_length',
            
            #is_split_into_words=True
        )
        
       # print(self.tokenizer.convert_ids_to_tokens(inputs["input_ids"]),inputs.word_ids())
        # text encoding
        token_ids = torch.tensor(inputs['input_ids'], dtype=torch.long)
        mask = torch.tensor(inputs['attention_mask'], dtype=torch.long)
        word_ids = inputs.word_ids()

        # intent
        intent_id = torch.tensor(self.data.INTENT_ID[index], dtype=torch.long)
        intent_label = self.data.INTENT[index]

        # label processing
        slot_label = self.data.SLOTS[index]
        slot_id = self.processSlotLabel(word_ids,self.data.SLOTS_ID[index],text)
    
        slot_id = torch.tensor(slot_id,dtype=torch.long)
        

        #language = self.data.language[index]
        
        return {
            'token_ids': token_ids,
            'mask': mask,
            'intent_id': intent_id,
            'slots_id' : slot_id,
            'intent_label': intent_label,
            'slots_label' : slot_label,
            'text' : text,
            'slotsID' : self.data.SLOTS_ID[index]
        } 
    
    def __len__(self):
        return len(self.data)

In [9]:
ds = nluDataset(file_dir='../data/multiATIS/split/test/WWTLE/25per/v1/test_EN.tsv', tokenizer='distilbert-base-multilingual-cased', max_len=56, device=1)

In [10]:
dl = DataLoader(ds,batch_size=4)

In [12]:
for batch in dl:
    intent_target,slots_target = batch['intent_id'], batch['slots_id']
    token_ids, attention_mask = batch['token_ids'], batch['mask']
    print(token_ids.size(),attention_mask.size())

    #out = model(token_ids,attention_mask,intent_target,slots_target)
    #intent_pred, slot_pred = out['intent_pred'], out['slot_pred']

    #print(intent_pred,intent_target)
    #print(slot_pred,slots_target)
    #print(batch['text'],batch['slots_id'],batch['slotsID'])
    #break
    a= 1 #print(batch)

torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
torch.Size([4, 56]) torch.Size([4, 56])
