In [1]:
!pip --quiet install transformers sentence_transformers

In [2]:
import argparse
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from transformers import AutoTokenizer, AutoModelWithLMHead
import numpy as np
import torch
import requests
import json
import pandas as pd
from sentence_transformers import SentenceTransformer, util

In [3]:
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [4]:
# input_path = "/content/drive/MyDrive/val_samples_em.csv"

In [5]:
# aqua_df = pd.read_csv(input_path)
# questions = aqua_df["question"].tolist()
# print(aqua_df.shape)
# aqua_df.head()

In [6]:
def hf_masked_encode(
        tokenizer,
        sentence: str,
        *addl_sentences,
        noise_prob=0.0,
        random_token_prob=0.0,
        leave_unmasked_prob=0.0):

    if random_token_prob > 0.0:
        weights = np.ones(len(tokenizer.vocab))
        weights[tokenizer.all_special_ids] = 0
        for k, v in tokenizer.vocab.items():
            if '[unused' in k:
                weights[v] = 0
        weights = weights / weights.sum()

    tokens = np.asarray(tokenizer.encode(sentence, *addl_sentences, add_special_tokens=True))

    if noise_prob == 0.0:
        return tokens

    sz = len(tokens)
    mask = np.full(sz, False)
    num_mask = int(noise_prob * sz + np.random.rand())

    mask_choice_p = np.ones(sz)
    for i in range(sz):
        if tokens[i] in [tokenizer.sep_token_id, tokenizer.cls_token_id, tokenizer.pad_token_id]:
            mask_choice_p[i] = 0
    mask_choice_p = mask_choice_p / mask_choice_p.sum()

    mask[np.random.choice(sz, num_mask, replace=False, p=mask_choice_p)] = True

    # decide unmasking and random replacement
    rand_or_unmask_prob = random_token_prob + leave_unmasked_prob
    if rand_or_unmask_prob > 0.0:
        rand_or_unmask = mask & (np.random.rand(sz) < rand_or_unmask_prob)
        if random_token_prob == 0.0:
            unmask = rand_or_unmask
            rand_mask = None
        elif leave_unmasked_prob == 0.0:
            unmask = None
            rand_mask = rand_or_unmask
        else:
            unmask_prob = leave_unmasked_prob / rand_or_unmask_prob
            decision = np.random.rand(sz) < unmask_prob
            unmask = rand_or_unmask & decision
            rand_mask = rand_or_unmask & (~decision)
    else:
        unmask = rand_mask = None

    if unmask is not None:
        mask = mask ^ unmask

    tokens[mask] = tokenizer.mask_token_id
    if rand_mask is not None:
        num_rand = rand_mask.sum()
        if num_rand > 0:
            tokens[rand_mask] = np.random.choice(
                len(tokenizer.vocab),
                num_rand,
                p=weights,
            )

    mask_targets = np.full(len(mask), tokenizer.pad_token_id)
    mask_targets[mask] = tokens[mask == 1]

    return torch.tensor(tokens).long(), torch.tensor(mask_targets).long()

def hf_reconstruction_prob_tok(masked_tokens, target_tokens, tokenizer, model, softmax_mask, reconstruct=False, topk=1):
    single = False

    # expand batch size 1
    if masked_tokens.dim() == 1:
        single = True
        masked_tokens = masked_tokens.unsqueeze(0)
        target_tokens = target_tokens.unsqueeze(0)

    masked_fill = torch.ones_like(masked_tokens)

    masked_index = (target_tokens != tokenizer.pad_token_id).nonzero(as_tuple=True)
    masked_orig_index = target_tokens[masked_index]

    # edge case of no masked tokens
    if len(masked_orig_index) == 0:
        if reconstruct:
            return masked_tokens, masked_fill
        else:
            return 1.0

    masked_orig_enum = [list(range(len(masked_orig_index))), masked_orig_index]

    outputs = model(
        masked_tokens.long().to(device=next(model.parameters()).device),
        labels=target_tokens
    )

    features = outputs[1]

    logits = features[masked_index].detach().clone()
    for l in logits:
        l[softmax_mask] = float('-inf')
    probs = logits.softmax(dim=-1)


    if (reconstruct):

        # sample from topk
        if topk != -1:
            values, indices = probs.topk(k=topk, dim=-1)
            kprobs = values.softmax(dim=-1)
            if (len(masked_index) > 1):
                samples = torch.cat([idx[torch.multinomial(kprob, 1)] for kprob, idx in zip(kprobs, indices)])
            else:
                samples = indices[torch.multinomial(kprobs, 1)]

        # unrestricted sampling
        else:
            if (len(masked_index) > 1):
                samples = torch.cat([torch.multinomial(prob, 1) for prob in probs])
            else:
                samples = torch.multinomial(probs, 1)

        # set samples
        masked_tokens[masked_index] = samples
        masked_fill[masked_index] = samples

        if single:
            return masked_tokens[0], masked_fill[0]
        else:
            return masked_tokens, masked_fill

    return torch.sum(torch.log(probs[masked_orig_enum])).item()

def fill_batch(batch, min_len, max_len,
               tokenizer,
               sents,
               l,
               lines,
               labels,
               next_sent,
               num_gen,
               num_tries,
               gen_index):

    # load sentences into batch until full
    while(len(sents) < batch):

        # search for the next valid sentence
        while True:
            if next_sent >= len(lines[0]):
                break

            next_sents = [s_list[next_sent][0] for s_list in lines]
            next_len = len(tokenizer.encode(*next_sents))

            # skip input if too short or long
            if next_len > min_len and next_len < max_len:
                break
            next_sent += 1

        # add it to our lists
        if next_sent < len(lines[0]):
            next_sent_lists = [s_list[next_sent] for s_list in lines]
            sents.append(list(zip(*next_sent_lists)))
            l.append(labels[next_sent])

            num_gen.append(0)
            num_tries.append(0)
            gen_index.append(0)
            next_sent += 1
        else:
            break

    return sents, l, next_sent, num_gen, num_tries, gen_index

In [7]:
class SSMBA:
    def __init__(self, model_path="bert-base-uncased", seed=1212):
        torch.manual_seed(seed)
        np.random.seed(seed)
        self.r_model = AutoModelWithLMHead.from_pretrained(model_path)
        self.tokenizer = AutoTokenizer.from_pretrained(model_path)
        self.r_model.eval()
        if torch.cuda.is_available():
            self.r_model.cuda()
        self.sem_model = SentenceTransformer("paraphrase-mpnet-base-v2")


    def augment(self, questions, num_samples=8, noise_prob=0.25, topk=10, shard=0, num_shards=1, batch=8, min_len=4, max_len=512, max_tries=10,
                     random_token_prob=0.1, leave_unmasked_prob=0.1, threshold=0.8):

        # remove unused vocab and special ids from sampling
        questions = questions[::-1]
        softmax_mask = np.full(len(self.tokenizer.vocab), False)
        softmax_mask[self.tokenizer.all_special_ids] = True
        for k, v in self.tokenizer.vocab.items():
            if '[unused' in k:
                softmax_mask[v] = True

        # load the inputs and labels
        lines = [tuple(s.strip().split('\t')) for s in questions]
        num_lines = len(lines)
        lines = [[[s] for s in s_list] for s_list in list(zip(*lines))]

        labels = [0] * num_lines
        output_labels = False

        # shard the input and labels
        if num_shards > 0:
            shard_start = (int(num_lines/num_shards) + 1) * shard
            shard_end = (int(num_lines/num_shards) + 1) * (shard + 1)
            lines = [s_list[shard_start:shard_end] for s_list in lines]
            labels = labels[shard_start:shard_end]

        s_rec_file = []

        # sentences and labels to process
        sents = []
        l = []

        # number sentences generated
        num_gen = []

        # sentence index to noise from
        gen_index = []

        # number of tries generating a new sentence
        num_tries = []

        # next sentence index to draw from
        next_sent = 0

        sents, l, next_sent, num_gen, num_tries, gen_index = \
                fill_batch(batch, min_len, max_len,
                        self.tokenizer,
                        sents,
                        l,
                        lines,
                        labels,
                        next_sent,
                        num_gen,
                        num_tries,
                        gen_index)

        # main augmentation loop
        while (sents != []):

            # remove any sentences that are done generating and dump to file
            for i in range(len(num_gen))[::-1]:
                if num_gen[i] == num_samples or num_tries[i] > max_tries:

                    # get sent info
                    gen_sents = sents.pop(i)
                    num_gen.pop(i)
                    gen_index.pop(i)
                    label = l.pop(i)

                    current_sent = []
                    original_sent = gen_sents[0]
                    original_embedding = self.sem_model.encode(original_sent, convert_to_tensor=True)

                    for sg in gen_sents[1:]:
                        # s_rec_file.write('\t'.join([repr(val)[1:-1] for val in sg]) + '\n')
                        current_sent.append(sg[0])

                    processed_embeddings = self.sem_model.encode(current_sent, convert_to_tensor=True)
                    cosine_scores = util.pytorch_cos_sim(original_embedding, processed_embeddings)
                    scores = list(cosine_scores[0])
                    current_labels = list(map(lambda x: 1 if x > threshold else 0, scores))
                    s_rec_file.append(list(zip(current_sent, current_labels)))

            # fill batch
            sents, l, next_sent, num_gen, num_tries, gen_index = \
                    fill_batch(batch, min_len, max_len,
                            self.tokenizer,
                            sents,
                            l,
                            lines,
                            labels,
                            next_sent,
                            num_gen,
                            num_tries,
                            gen_index)

            # break if done dumping
            if len(sents) == 0:
                break

            # build batch
            toks = []
            masks = []

            for i in range(len(gen_index)):
                s = sents[i][gen_index[i]]
                tok, mask = hf_masked_encode(
                        self.tokenizer,
                        *s,
                        noise_prob=noise_prob,
                        random_token_prob=random_token_prob,
                        leave_unmasked_prob=leave_unmasked_prob,
                )
                toks.append(tok)
                masks.append(mask)

            # pad up to max len input
            max_len = max([len(tok) for tok in toks])
            pad_tok = self.tokenizer.pad_token_id

            toks = [F.pad(tok, (0, max_len - len(tok)), 'constant', pad_tok) for tok in toks]
            masks = [F.pad(mask, (0, max_len - len(mask)), 'constant', pad_tok) for mask in masks]
            toks = torch.stack(toks)
            masks = torch.stack(masks)

            # load to GPU if available
            if torch.cuda.is_available():
                toks = toks.cuda()
                masks = masks.cuda()

            # predict reconstruction
            rec, rec_masks = hf_reconstruction_prob_tok(toks, masks, self.tokenizer, self.r_model, softmax_mask, reconstruct=True, topk=topk)

            # decode reconstructions and append to lists
            for i in range(len(rec)):
                rec_work = rec[i].cpu().tolist()
                s_rec = [s.strip() for s in self.tokenizer.decode([val for val in rec_work if val != self.tokenizer.pad_token_id][1:-1]).split(self.tokenizer.sep_token)]
                s_rec = tuple(s_rec)

                # check if identical reconstruction or empty
                if s_rec not in sents[i] and '' not in s_rec:
                    sents[i].append(s_rec)
                    num_gen[i] += 1
                    num_tries[i] = 0
                    gen_index[i] = 0

                # otherwise try next sentence
                else:
                    num_tries[i] += 1
                    gen_index[i] += 1
                    if gen_index[i] == len(sents[i]):
                        gen_index[i] = 0

            # clean up tensors
            del toks
            del masks

        return s_rec_file
        

In [8]:
ssmba = SSMBA()

# ssmba.augment(questions[:16], num_samples=8, threshold=0.8)

Some weights of the model checkpoint at bert-base-uncased were not used when initializing BertForMaskedLM: ['cls.seq_relationship.weight', 'cls.seq_relationship.bias']
- This IS expected if you are initializing BertForMaskedLM from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForMaskedLM from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).


In [9]:
qqd = pd.read_csv("drive/My Drive/Literature Review/Dataset versions/QQD-cleaner-less.csv")

In [10]:
qqd

Unnamed: 0.1,Unnamed: 0,id,qid1,qid2,question1,question2,is_duplicate,len1,len2
0,0,0,1,2,what is step by step guide to invest in share ...,what is step by step guide to invest in share ...,0,66,57
1,1,1,3,4,what is story of kohinoor koh i noor diamond,what would happen if Indian government stole k...,0,51,88
2,2,2,5,6,how can i increase speed of my internet connec...,how can internet speed be increased by hacking...,0,73,59
3,3,3,7,8,why am i mentally very lonely how can i solve it,find remainder when math 23 24 math is divided...,0,50,65
4,4,4,9,10,one dissolve in water quickly sugar salt metha...,fish would survive in salt water,0,76,39
...,...,...,...,...,...,...,...,...,...
404213,404285,404285,433578,379845,how many keywords are there in racket programm...,how many keywords are there in perl programmin...,0,85,79
404214,404286,404286,18840,155606,do you believe there is life after death,is it true that there is life after death,1,41,42
404215,404287,404287,537928,537929,what is one coin,what s this coin,0,17,17
404216,404288,404288,537930,537931,what is approx annual cost of living while stu...,i am having little hairfall problem but i want...,0,94,127


In [11]:
ls = list(qqd['question1'])
ls.extend(list(qqd['question2']))

In [12]:
len(ls)

808436

In [13]:
df = pd.DataFrame(ls, columns = ['question'])

In [14]:
df = df[:1000]
questions = df['question'].tolist()

In [15]:
print(df.shape)
df.head()

(1000, 1)


Unnamed: 0,question
0,what is step by step guide to invest in share ...
1,what is story of kohinoor koh i noor diamond
2,how can i increase speed of my internet connec...
3,why am i mentally very lonely how can i solve it
4,one dissolve in water quickly sugar salt metha...


In [16]:
train_set = df.copy().loc[:, ['question']]
aug_cols = [f"aug_{i}" for i in range(1, 9)];
label_cols = [f"aug_label_{i}" for i in range(1, 9)]
train_set.loc[:, aug_cols] = ""
train_set.loc[:, label_cols] = ""
print(train_set.shape)
train_set.head()

(1000, 17)


Unnamed: 0,question,aug_1,aug_2,aug_3,aug_4,aug_5,aug_6,aug_7,aug_8,aug_label_1,aug_label_2,aug_label_3,aug_label_4,aug_label_5,aug_label_6,aug_label_7,aug_label_8
0,what is step by step guide to invest in share ...,,,,,,,,,,,,,,,,
1,what is story of kohinoor koh i noor diamond,,,,,,,,,,,,,,,,
2,how can i increase speed of my internet connec...,,,,,,,,,,,,,,,,
3,why am i mentally very lonely how can i solve it,,,,,,,,,,,,,,,,
4,one dissolve in water quickly sugar salt metha...,,,,,,,,,,,,,,,,


In [17]:
output_path = "EM_SSMBA_val.csv"

In [None]:
start = train_set[train_set["aug_1"]==""].index[0]
print(f"[INFO] Starting from {start}.")
for i in range(start, len(questions), 8):
    questions_ = questions[i: i+8]
    generated = ssmba.augment(questions_, num_samples=8, threshold=0.8)
    print(f"[INFO] {i+1}:{i+8} questions generated.")
    for idx, gen in enumerate(generated):
        augs, labels = list(zip(*gen))
        train_set.loc[i+idx, aug_cols+label_cols] = augs+labels
    train_set.to_csv(output_path, index=False)
train_set.to_csv(output_path, index=False)

[INFO] Starting from 0.
[INFO] 1:8 questions generated.
[INFO] 9:16 questions generated.
[INFO] 17:24 questions generated.
[INFO] 25:32 questions generated.
[INFO] 33:40 questions generated.
[INFO] 41:48 questions generated.
[INFO] 49:56 questions generated.
[INFO] 57:64 questions generated.
[INFO] 65:72 questions generated.
[INFO] 73:80 questions generated.
[INFO] 81:88 questions generated.
[INFO] 89:96 questions generated.
[INFO] 97:104 questions generated.
[INFO] 105:112 questions generated.
[INFO] 113:120 questions generated.
[INFO] 121:128 questions generated.
[INFO] 129:136 questions generated.
[INFO] 137:144 questions generated.
[INFO] 145:152 questions generated.
[INFO] 153:160 questions generated.
[INFO] 161:168 questions generated.
[INFO] 169:176 questions generated.
[INFO] 177:184 questions generated.
[INFO] 185:192 questions generated.
[INFO] 193:200 questions generated.
[INFO] 201:208 questions generated.
[INFO] 209:216 questions generated.
[INFO] 217:224 questions generat

In [None]:
print(train_set[train_set["aug_1"]==""].index)

In [None]:
train_set.transpose().iloc[:9].transpose()

In [None]:
# idx = 30000
# a, l = list(zip(*ssmba.augment([train_set.loc[idx, "question"]], num_samples=8, threshold=0.8)[0]))
# train_set.loc[idx, aug_cols+label_cols] = a + l
# train_set.loc[idx]

In [None]:
train_set.to_csv(output_path, index=False)