In [25]:
import torch 
import torch.nn as nn 
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader

import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import DistilBertModel

import pytorch_lightning as pl
from pytorch_lightning import seed_everything, loggers as pl_loggers
from pytorch_lightning.callbacks import ModelCheckpoint
from transformers import  DistilBertModel,DistilBertTokenizer

from pytorch_metric_learning import miners, losses

import pandas as pd
import random

from scripts.dataset import *
from scripts.utils import *
from arguments import jointBert_argument

In [26]:
# model parameter
config = {

'mc' : {
    'model_name' : 'distilbert-base-multilingual-cased',
    'tokenizer_name' : 'distilbert-base-multilingual-cased',
    'joint_loss_coef' : 0.5,
    'id_1': 0.29868357362720055,
    'id_2':0.2226859356474008,
    'sd':0.3180000141987541,
    'Ihs': 77,
    'freeze_decoder' : True
},

# training parameters
'tc' : {
    'lr' : 0.00003,
    'epoch' : 40,
    'batch_size' : 15,
    'weight_decay' : 0.003,
    'shuffle_data' : True,
    'num_worker' : 8
},

# data params

'dc' : {
    'train_dir' : '/content/drive/MyDrive/research/Infinite/data/multiATIS/split/train/WWTLE_Augmented/test_EN.tsv',
    'val_dir' : '/content/drive/MyDrive/research/Infinite/data/multiATIS/split/valid/clean/val.tsv',
    'intent_num' : 18,
    'slots_num' : 159,
    'max_len' : 56
},

# misc
'misc' : {
    'fix_seed' : False,
    'gpus' : -1,
    'log_dir' : './',
    'precision' : 16,
}
}

In [27]:
class contraNLUDataset(Dataset):
    def __init__(self, file_dir):

        self.data = pd.read_csv(file_dir, sep="\t")

    def __getitem__(self, index):

        text = str(self.data.TEXT[index])
        text = text.replace(".", "")
        text = text.replace("'", "")
        text = " ".join(text.split())

        return {
            "text": text,
        }

    def __len__(self):
        return len(self.data)


class contra_Dataset_pl(pl.LightningDataModule):
    def __init__(
        self, train_dir, val_dir, batch_size, num_worker
    ):

        super().__init__()
        self.train_dir = train_dir
        self.val_dir = val_dir
        self.batch_size = batch_size
        self.num_worker = num_worker

    def setup(self, stage: [str] = None):
        self.train = contraNLUDataset(self.train_dir)

        self.val = contraNLUDataset(self.val_dir)


    def train_dataloader(self):
        return DataLoader(
            self.train, batch_size=self.batch_size, num_workers=self.num_worker
        )

    def val_dataloader(self):
        return DataLoader(
            self.val, batch_size=self.batch_size, num_workers=self.num_worker
        )

In [28]:
ds = contraNLUDataset('./data/multiATIS/split/train/clean/train.tsv')

In [44]:
dl = DataLoader(ds,batch_size=14,shuffle=True)

In [45]:
encoder = DistilBertModel.from_pretrained(
            'distilbert-base-multilingual-cased', return_dict=True, output_hidden_states=True
        )

In [48]:
#tokenizer = DistilBertTokenizerFast.from_pretrained('distilbert-base-multilingual-cased')
for b in dl:
    token_ids , mask = [],[]
    for i in range(14):
        inputs = tokenizer.encode_plus(b['text'][i],None,add_special_tokens=True,return_token_type_ids=False,
            truncation=True,max_length=56,padding="max_length")
        token_ids.append(inputs["input_ids"])
        mask.append(inputs["attention_mask"])
    
    token_ids = torch.tensor(token_ids, dtype=torch.long)
    mask = torch.tensor(mask, dtype=torch.long)
    hidden = encoder(token_ids,mask)
    
    print(token_ids,mask,hidden[0])
    break

tensor([[   101,  13416,  10105,  10446,  78881,  55650,  10135,  11951,  20714,
          80341,  23005,  10238,  11132,  10188,  44555,  10500,  10114,  10134,
          30809,  11183,  65200,  10106,  10134,  30809,  11183,  10948,  15821,
          11166,  10392,  10111, 103209,  52160,    102,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0],
        [   101,  11897,  10911,  10105,  55650,  65200,  12166,  10192,  10263,
          10106,  20873,  65258,  10246,  10188,  10140,  12563,    102,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
              0,      0,      0,      0,      0,      0,      0,      0,      0,
  