In [3]:
import torch
from torch.utils.data import Dataset
import pandas as pd
import pytorch_lightning as pl
from torch.utils.data import Dataset, DataLoader
from functools import partial
from transformers import AutoTokenizer

In [7]:
with open("../data/BG_Noise_Phrase.txt") as f:
    content = f.readlines()
phrase = [x.strip() for x in content]

# contrastive noise augmentation samples


def mergelistsBG(ns, s, prob):

    noisySample = deepcopy(ns)
    sample = deepcopy(s)

    bernaulliSample = [0] * int((1000) * prob) + [1] * int(1000 * (1 - prob))
    random.shuffle(bernaulliSample)

    final = []

    while len(noisySample) > 0 and len(sample) > 0:

        if random.sample(bernaulliSample, 1)[0] == 0:
            final.append(noisySample.pop(0))
        else:
            final.append(sample.pop(0))

    if len(noisySample) == 0:
        final = final + sample
    else:
        final = final + noisySample

    return s,final

def mergelistsMC(text_packed, prob):

    text = deepcopy(text_packed)

    bernaulliSample = [0] * int((1000) * prob) + [1] * int(1000 * (1 - prob))
    random.shuffle(bernaulliSample)

    orig,aug  = [text_packed[0]],[text_packed[0]]
    for idx,tokens in enumerate(text_packed[1:]):
        
        if random.sample(bernaulliSample, 1)[0] == 0:
            orig.append([tokens[0],'2000'])
        else:
            orig.append(tokens)
            aug.append(tokens)

    return orig,aug

def contrastiveSampleGenerator(sample, noise_type):

    samplePacked = [[token, str(idx)] for idx, token in enumerate(sample.split())]

    noisyTEXT = random.sample(phrase, 3)
    noisyTEXT = (noisyTEXT[0] + noisyTEXT[1] + noisyTEXT[2]).split()
    noisyTOKENS = random.sample(noisyTEXT, random.sample([5, 6, 7,8,9,10], 1)[0])
    noisyPacked = [[token, '2000'] for idx, token in enumerate(noisyTOKENS)]

    if noise_type == 'MC':
        noise_param = random.sample([0.20,0.40,0.60],1)[0]
        orig, aug = mergelistsMC(samplePacked, prob=noise_param)
        augText, augSlots = zip(*aug)
        origText, origSlots = zip(*orig)

        return ' '.join(list(origText)), ' '.join(list(augText)), ' '.join(list(origSlots)), ' '.join(list(augSlots))

    elif noise_type == 'BG':
        noise_param = random.sample([0.20,0.40,0.60],1)[0]
        orig, aug  = mergelistsBG(noisyPacked,samplePacked,  prob=noise_param)
        augText, augSlots = zip(*aug)
        origText, origSlots = zip(*orig)
        return ' '.join(list(origText)), ' '.join(list(augText)), ' '.join(list(origSlots)), ' '.join(list(augSlots))

def contrastivePairs(text, noise_type):

    textP1, textP2, slotsID1, slotsID2, sentID1, sentID2 = [], [], [], [], [], []

    for idx, sample in enumerate(text):

        origText,augText, origSlots, augSlots = contrastiveSampleGenerator(sample,noise_type)
          
        textP1.append(origText)
        slotsID1.append(origSlots)
        sentID1.append(idx)

        textP2.append(augText)
        slotsID2.append(augSlots)
        sentID2.append(idx)

    return textP1, textP2, slotsID1, slotsID2, sentID1, sentID2

In [8]:
def collate_CT(batch, tokenizer, noise_type):

    text,intent_id,slot_id = [],[],[]
    
    for datapoint in batch:
        text.append(datapoint['text'])
        intent_id.append(datapoint['intent_id'])
        slot_id.append(datapoint['slots_id'])

    # processing batch for supervised learning
    # tokenization and packing to torch tensor
    token_ids, mask, slots_ids = batch_tokenizer(text, slot_id, tokenizer)
    token_ids, mask, intent_id, slots_ids = list2Tensor(
        [token_ids, mask, intent_id, slots_ids]
    )

    supBatch = {
        "token_ids": token_ids,
        "mask": mask,
        "intent_id": intent_id,
        "slots_id": slots_ids,
    }

    # processing batch for hierarchial contrastive learning

    # generating contrastive pairs
    textP1, textP2, tokenID1, tokenID2, sentID1, sentID2 = contrastivePairs(
        text,noise_type
    )

    # tokenization and packing for pair 1
    token_ids1, mask1, processed_tokenID1 = batch_tokenizer(textP1, tokenID1, tokenizer)
    token_ids1, mask1, sentID1, packed_tokenID1 = list2Tensor(
        [token_ids1, mask1, sentID1, processed_tokenID1]
    )

    # tokenization and packing for pair 2
    token_ids2, mask2, processed_tokenID2 = batch_tokenizer(textP2, tokenID2, tokenizer)
    token_ids2, mask2, sentID2, packed_tokenID2 = list2Tensor(
        [token_ids2, mask2, sentID2, processed_tokenID2]
    )

    CP1 = {
        "token_ids": token_ids1,
        "mask": mask1,
        "sent_id": sentID1,
        "token_id": packed_tokenID1,
    }

    CP2 = {
        "token_ids": token_ids2,
        "mask": mask2,
        "sent_id": sentID2,
        "token_id": packed_tokenID2,
    }

    return {"supBatch": supBatch, "HCLBatch": [CP1, CP2]}

In [9]:
class dataset(Dataset):
    def __init__(self, file_dir):

        self.data = pd.read_csv(file_dir, sep="\t")

    def __getitem__(self, index):
        
        # text
        text = str(self.data.TEXT[index])
        
        # intent
        intent_label = self.data.INTENT[index]
        intent_id = self.data.INTENT_ID[index]
        
        # slots
        slot_label = self.data.SLOTS[index]
        slot_id = self.data.SLOTS_ID[index]

        return {

            "text": text,

            "intent_id": intent_id,
            "intent_label": intent_label,

            "slots_id": slot_id,
            "slots_label": slot_label,
        }

    def __len__(self):
        return len(self.data)

In [10]:
class dataloader(pl.LightningDataModule):
    
    def __init__( self,train_dir,batch_size,num_workers):

        super().__init__()
        self.batch_size = batch_size
        self.num_worker = num_workers
        self.tokenizer = AutoTokenizer.from_pretrained('distilbert-base-cased',cache_dir = '/efs-storage/tokenizer/')
        self.mode = args.mode
        self.args = args

    def setup(self, stage: [str] = None):

        self.train = dataset(self.train_dir)

        self.val = dataset(self.val_dir)

        if self.mode == 'BASELINE':
            self.train_collate = partial(collate_sup,tokenizer = self.tokenizer)
            self.val_collate = partial(collate_sup, tokenizer = self.tokenizer)
        
        elif self.mode == 'AT':
            self.train_collate = partial(collate_AT,tokenizer = self.tokenizer, noise_type = self.args.noise_type)
            self.val_collate = partial(collate_sup,tokenizer = self.tokenizer)
        
        elif self.mode == 'CT':
            self.train_collate = partial(collate_CT,tokenizer = self.tokenizer, noise_type = self.args.noise_type)
            self.val_collate = partial(collate_sup,tokenizer = self.tokenizer)

    def train_dataloader(self):
        return DataLoader(
            self.train, batch_size=self.batch_size, shuffle=True, collate_fn=self.train_collate, num_workers=self.num_worker
        )

    def val_dataloader(self):
        return DataLoader(
            self.val, batch_size=self.batch_size, collate_fn=self.val_collate, num_workers=self.num_worker
        )