In [None]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from functools import partial
from transformers import AutoTokenizer

In [None]:
with open("../data/BG_Noise_Phrase.txt") as f:
    content = f.readlines()
phrase = [x.strip() for x in content]

# contrastive noise augmentation samples


def mergelistsBG(ns, s, prob):

    noisySample = deepcopy(ns)
    sample = deepcopy(s)

    bernaulliSample = [0] * int((1000) * prob) + [1] * int(1000 * (1 - prob))
    random.shuffle(bernaulliSample)

    final = []

    while len(noisySample) > 0 and len(sample) > 0:

        if random.sample(bernaulliSample, 1)[0] == 0:
            final.append(noisySample.pop(0))
        else:
            final.append(sample.pop(0))

    if len(noisySample) == 0:
        final = final + sample
    else:
        final = final + noisySample

    return s,final

def mergelistsMC(text_packed, prob):

    text = deepcopy(text_packed)

    bernaulliSample = [0] * int((1000) * prob) + [1] * int(1000 * (1 - prob))
    random.shuffle(bernaulliSample)

    orig,aug  = [text_packed[0]],[text_packed[0]]
    for idx,tokens in enumerate(text_packed[1:]):
        
        if random.sample(bernaulliSample, 1)[0] == 0:
            orig.append([tokens[0],'2000'])
        else:
            orig.append(tokens)
            aug.append(tokens)

    return orig,aug

def contrastiveSampleGenerator(sample, noise_type):

    samplePacked = [[token, str(idx)] for idx, token in enumerate(sample.split())]

    noisyTEXT = random.sample(phrase, 3)
    noisyTEXT = (noisyTEXT[0] + noisyTEXT[1] + noisyTEXT[2]).split()
    noisyTOKENS = random.sample(noisyTEXT, random.sample([5, 6, 7,8,9,10], 1)[0])
    noisyPacked = [[token, '2000'] for idx, token in enumerate(noisyTOKENS)]

    if noise_type == 'MC':
        noise_param = random.sample([0.20,0.40,0.60],1)[0]
        orig, aug = mergelistsMC(samplePacked, prob=noise_param)
        augText, augSlots = zip(*aug)
        origText, origSlots = zip(*orig)

        return ' '.join(list(origText)), ' '.join(list(augText)), ' '.join(list(origSlots)), ' '.join(list(augSlots))

    elif noise_type == 'BG':
        noise_param = random.sample([0.20,0.40,0.60],1)[0]
        orig, aug  = mergelistsBG(noisyPacked,samplePacked,  prob=noise_param)
        augText, augSlots = zip(*aug)
        origText, origSlots = zip(*orig)
        return ' '.join(list(origText)), ' '.join(list(augText)), ' '.join(list(origSlots)), ' '.join(list(augSlots))

def contrastivePairs(text, noise_type):

    textP1, textP2, slotsID1, slotsID2, sentID1, sentID2 = [], [], [], [], [], []

    for idx, sample in enumerate(text):

        origText,augText, origSlots, augSlots = contrastiveSampleGenerator(sample,noise_type)
          
        textP1.append(origText)
        slotsID1.append(origSlots)
        sentID1.append(idx)

        textP2.append(augText)
        slotsID2.append(augSlots)
        sentID2.append(idx)

    return textP1, textP2, slotsID1, slotsID2, sentID1, sentID2

In [None]:
def collate_CT_AT(batch, tokenizer, noise_type):

    text,intent_id,slot_id = [],[],[]
    
    for datapoint in batch:
        text.append(datapoint['text'])
        intent_id.append(datapoint['intent_id'])
        slot_id.append(datapoint['slots_id'])

    # processing batch for hierarchial contrastive learning

    # generating contrastive pairs
    textP1, textP2, tokenID1, tokenID2, sentID1, sentID2 = contrastivePairs(
        text,noise_type
    )

    # tokenization and packing for pair 1
    token_ids1, mask1, processed_tokenID1 = batch_tokenizer(textP1, tokenID1, tokenizer)
    token_ids1, mask1, sentID1, packed_tokenID1 = list2Tensor(
        [token_ids1, mask1, sentID1, processed_tokenID1]
    )

    # tokenization and packing for pair 2
    token_ids2, mask2, processed_tokenID2 = batch_tokenizer(textP2, tokenID2, tokenizer)
    token_ids2, mask2, sentID2, packed_tokenID2 = list2Tensor(
        [token_ids2, mask2, sentID2, processed_tokenID2]
    )

    CP1 = {
        "token_ids": token_ids1,
        "mask": mask1,
        "sent_id": sentID1,
        "token_id": packed_tokenID1,
    }

    CP2 = {
        "token_ids": token_ids2,
        "mask": mask2,
        "sent_id": sentID2,
        "token_id": packed_tokenID2,
    }

    # processing batch for adversarial examples
    adv_text,adv_intent = textP2, intent_id
    adv_slots = []
    for slots in token

    return {"supBatch": supBatch, "HCLBatch": [CP1, CP2]}

In [None]:
class dataset(Dataset):
    def __init__(self, file_dir):

        self.data = pd.read_csv(file_dir, sep="\t")

    def __getitem__(self, index):
        
        # text
        text = str(self.data.TEXT[index])
        
        # intent
        intent_label = self.data.INTENT[index]
        intent_id = self.data.INTENT_ID[index]
        
        # slots
        slot_label = self.data.SLOTS[index]
        slot_id = self.data.SLOTS_ID[index]

        return {

            "text": text,

            "intent_id": intent_id,
            "intent_label": intent_label,

            "slots_id": slot_id,
            "slots_label": slot_label,
        }

    def __len__(self):
        return len(self.data)

In [None]:
class dataloader(pl.LightningDataModule):
    
    def __init__( self,train_dir,batch_size,num_workers):

        super().__init__()
        self.batch_size = batch_size
        self.num_worker = num_workers
        self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased',cache_dir = '/efs-storage/tokenizer/')
        self.mode = args.mode
        self.args = args

    def setup(self, stage: [str] = None):

        self.train = dataset(self.train_dir)

        self.val = dataset(self.val_dir)

        if self.mode == 'BASELINE':
            self.train_collate = partial(collate_sup,tokenizer = self.tokenizer)
            self.val_collate = partial(collate_sup, tokenizer = self.tokenizer)
        
        elif self.mode == 'AT':
            self.train_collate = partial(collate_AT,tokenizer = self.tokenizer, noise_type = self.args.noise_type)
            self.val_collate = partial(collate_sup,tokenizer = self.tokenizer)
        
        elif self.mode == 'CT':
            self.train_collate = partial(collate_CT,tokenizer = self.tokenizer, noise_type = self.args.noise_type)
            self.val_collate = partial(collate_sup,tokenizer = self.tokenizer)

    def train_dataloader(self):
        return DataLoader(
            self.train, batch_size=self.batch_size, shuffle=True, collate_fn=self.train_collate, num_workers=self.num_worker
        )

    def val_dataloader(self):
        return DataLoader(
            self.val, batch_size=self.batch_size, collate_fn=self.val_collate, num_workers=self.num_worker
        )

In [None]:
class hierCon_model(nn.Module):

    def __init__(self, args):

        super(hierCon_model, self).__init__()

        self.encoder = DistilBertModel.from_pretrained(
            args.encoder, return_dict=True, output_hidden_states=True,
            sinusoidal_pos_embds=True, cache_dir='/efs-storage/model/'
        )
        
        self.intent_head = nn.Sequential(
                                         nn.Dropout(args.intent_dropout),
                                         nn.Linear(256,args.intent_count)
                                        )

        self.slots_head = nn.Sequential(
                                         nn.Dropout(args.slots_dropout),
                                         nn.Linear(256,args.slots_count)
                                        )

        self.token_contrast_proj = nn.Sequential(
                                                 nn.Linear(768,512),
                                                 nn.BatchNorm1d(512),
                                                 nn.ReLU(inplace=True),
                                                 nn.Linear(512,256),
                                                 nn.ReLU(inplace=True)
                                                )
        
        self.sent_contrast_proj = nn.Sequential(
                                                 nn.Linear(768,512),
                                                 nn.BatchNorm1d(512),
                                                 nn.ReLU(inplace=True),
                                                 nn.Linear(512,256),
                                                 nn.ReLU(inplace=True) 
                                                 )
        self.intent_predictor = nn.Sequential(
                                         nn.Linear(256,256),
                                        nn.BatchNorm1d(256),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(256,256)
                                        )

        self.slots_predictor = nn.Sequential(
                                         nn.Linear(256,256),
                                        nn.BatchNorm1d(256),
                                        nn.ReLU(inplace=True),
                                        nn.Linear(256,256)
                                        )

        
        self.criterion = nn.CosineSimilarity(dim=1)
        
        self.CE_loss = nn.CrossEntropyLoss()

        self.icnerCoef = args.icnerCoef
        self.hierConCoef = args.hierConCoef
        self.args = args

    def ICNER_loss(self, encoded_output, intent_target, slots_target):

        # intent prediction loss
        intent_hidden = encoded_output[0][:, 0]
        intent_logits = self.intent_head(self.sent_contrast_proj(intent_hidden))
        intent_loss = self.CE_loss(intent_logits, intent_target)
        intent_pred = torch.argmax(nn.Softmax(dim=1)(intent_logits), axis=1)

        # slots prediction loss

        shape = encoded_output[0].shape
        slots_hidden = encoded_output[0].view(shape[0]*shape[1],-1)
        slots_logits = self.slots_head(self.token_contrast_proj(slots_hidden))
        slots_pred = torch.argmax(nn.Softmax(dim=1)(slots_logits), axis=1)
        slots_pred = slots_pred.view(shape[0],-1)
        slots_loss = self.CE_loss(
            slots_logits.view(-1, self.args.slots_count), slots_target.view(-1)
        )

        joint_loss = self.icnerCoef * intent_loss + (1.0 - self.icnerCoef) * slots_loss

        return {
            "joint_loss": joint_loss,
            "ic_loss": intent_loss,
            "ner_loss": slots_loss,
            "intent_pred": intent_pred,
            "slot_pred": slots_pred,
        }

    def sentCL(self, sentz1, sentz2):

        # calculating sentence level loss
        p1, p2 = self.sent_contrast_proj(sentz1), self.sent_contrast_proj(sentz2)
        z1, z2 = self.intent_predictor(p1) , self.intent_predictor(p2) 
        p1.detach()
        p2.detach()
        
        sentCLLoss =  -(self.criterion(p2, z1).mean() + self.criterion(p1, z2).mean()) * 0.5

        return sentCLLoss
    
    def tokenCL(self, tokenEmb1,tokenEmb2,tokenID1,tokenID2):
        #torch.Size([32, 56, 768]) torch.Size([32, 56, 768]) torch.Size([1792]) torch.Size([1792])
        tokenID1 = torch.flatten(tokenID1)
        tokenID2 = torch.flatten(tokenID2)
        
        shape = tokenEmb1.shape
        tokenEmb1 = tokenEmb1.view(shape[0]*shape[1],-1)
        tokenEmb2 = tokenEmb2.view(shape[0]*shape[1],-1) #torch.Size([1792, 768]) torch.Size([1792, 768])
        
        filterTokenIdx1 = [idx for idx,val in enumerate(tokenID1.tolist()) if (val==-100 or val == 2000)!=True]
        filterTokenIdx2 = [idx for idx,val in enumerate(tokenID2.tolist()) if (val==-100 or val == 2000)!=True]

        if len(filterTokenIdx1) > 0:
            filterTokenIdx1 = torch.tensor( filterTokenIdx1,dtype=torch.long,device=torch.device('cuda'))
            tokenEmb1 = torch.index_select(tokenEmb1,0,filterTokenIdx1) 
        
        if len(filterTokenIdx2) > 0:
            filterTokenIdx2 = torch.tensor( filterTokenIdx2,dtype=torch.long,device=torch.device('cuda')) 
            tokenEmb2 = torch.index_select(tokenEmb2,0,filterTokenIdx2)
        
        # calculating sentence level loss
        p1, p2 = self.token_contrast_proj(tokenEmb1), self.token_contrast_proj(tokenEmb2)
        z1, z2 = self.intent_predictor(p1) , self.intent_predictor(p2) 
        p1.detach()
        p2.detach()
        tokenCLLoss =  -(self.criterion(p2, z1).mean() + self.criterion(p1, z2).mean()) * 0.5
            
        return tokenCLLoss

    def forward(self, batch , mode):

        if mode == "ICNER":
            encoded_output = self.encoder(batch['supBatch']['token_ids'], batch['supBatch']['mask'])
            return self.ICNER_loss(encoded_output, batch['supBatch']['intent_id'], batch['supBatch']['slots_id'])

        if mode == "hierCon":
            encoded_output_0 = self.encoder(batch['HCLBatch'][0]['token_ids'],batch['HCLBatch'][0]['mask']) 
            encoded_output_1 = self.encoder(batch['HCLBatch'][1]['token_ids'],batch['HCLBatch'][1]['mask']) 
            sentCL = self.sentCL(encoded_output_0[0][:, 0], encoded_output_1[0][:, 0])
            
            tokenIDs0 = batch['HCLBatch'][0]['token_id']
            tokenIDs1 = batch['HCLBatch'][1]['token_id']
            
            tokenCL = self.tokenCL(encoded_output_0[0], encoded_output_1[0],tokenIDs0,tokenIDs1)

            hierConLoss = self.args.hierConCoef*sentCL + (1.0-self.args.hierConCoef)*tokenCL

            return hierConLoss
