In [1]:
import argparse
import sys
import os
import json
import pprint
import numpy as np
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
import re
import torch
from collections import Counter
from transformers import AutoModel, AutoTokenizer, pipeline
from tqdm import tqdm
from sklearn.metrics import f1_score

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
BASE_DIR = "/Users/katemarg/PycharmProjects/temporal_robustness_evaluation"
DATA_DIR = os.path.join(BASE_DIR, 'data')
CKPT_DIR = os.path.join(BASE_DIR, 'pretrained_models')
RES_DIR = os.path.join(BASE_DIR, 'results')
LOG_DIR = os.path.join(BASE_DIR, 'logs')
CACHE_DIR = os.path.join(BASE_DIR, 'cached')

TEMPLAMA_ORIG_DIR = os.path.join("data", "templama", "test.json")
TEMPLAMA_REPRODUCED_DIR = os.path.join("templama_docker", "reproduce1", "templama", "test.jsonl")
TEMPLAMA_NEW_DIR = os.path.join(DATA_DIR, "dynamic-templama",
                                "dataset_from_2019-1-1_to_2022-12-31_per_quarter", 
                                "test.jsonl")
lm = "cardiffnlp/twitter-roberta-base-mar2022"
# dataset_filepath=CACHE_DIR+'/{}_dynamic-templama_multiple_masks.pt'.format(lm)
dataset_filepath=CACHE_DIR+'/cardiffnlp-twitter-roberta-base-mar2022_dynamic-templama_multiple_masks.pt'

## Load dataset

In [3]:
data_dict_multi_token = torch.load(dataset_filepath)

In [4]:
# Load pre-trained model tokenizer (vocabulary)
tokenizer = AutoTokenizer.from_pretrained(lm, use_fast=False, add_prefix_space=True)

CLS = tokenizer.cls_token
PAD = tokenizer.pad_token
SEP = tokenizer.sep_token
MASK = tokenizer.mask_token

mask_id = tokenizer.mask_token_id
sep_id = tokenizer.sep_token_id
cls_id = tokenizer.cls_token_id
pad_id = tokenizer.pad_token_id

special_ids = [mask_id, sep_id, cls_id, pad_id]
def tokenizer_return_id(text, filter_special_tokens=False):
    """
    Text to token ids for a string.
    """
    output = tokenizer(text)
    if filter_special_tokens:
        token_ids = [i for i in output['input_ids'] if i not in tokenizer.all_special_ids]
    else:
        token_ids = [i for i in output['input_ids'] ]
    return token_ids

def tokenize_batch(batch):
    """
    Text to token ids for a list of strings.
    """
#     return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]
    return [tokenizer_return_id(sent) for sent in batch]

def untokenize_id(ids):
    """
    Token ids to strings.
    """
#     return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]
    return [tokenizer.decode(_id) for _id in ids]


def untokenize_batch(batch, filter_special_tokens=False):
    """
    Token ids to strings for a list of ids.
    """
#     return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]
#     print(label_id)
    if filter_special_tokens:
        _batch = []
        for sent in batch:
            _batch.append([x for x in sent if x not in special_ids])
        batch = _batch
#         return [tokenizer.decode(label_id) for label_id in batch if label_id not in special_ids]
#     else:
    return [untokenize_id(sent) for sent in batch]

def detokenize(sent):
    """ Roughly detokenizes (mainly undoes wordpiece) """
    # not sure what this does.... their code
    new_sent = []
    for i, tok in enumerate(sent):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent

In [5]:
def split_dataset(data):
    """
    Split temporal dataset Dt to D_unchanged, D_new and D_updated compared to D_(t-1) for all t.
    Specifically:
    - D_unchanged: data where text_t = text_(t-1) & label_t = label_(t-1)
    - D_updated: data where text_t = text_(t-1) & label_t != label_(t-1)
    - D_new: data where text_t not in D_(t-1)
    - D_deleted: data that exist in D_(t-1) but not in D_t

    Args:
        data: a dictionary with keys the time (year/quarter/month) and values dictionaries
        data = {
                '2019-Q1':
                    {
                    'text': [list of text],
                    'labels': [list of labels],
                    'labels_ids': [list of label token ids -- for a given model/tokenizer],
                    'relations' [list of Wikidata relations]
                    },
                '2019-Q2': {...}
                }

    Returns:
        D_unchanged, D_new, D_updated, D_deleted
    """
    unchanged_t, new_t, updated_t, deleted_t = {}, {}, {}, {}
    quarters = list(data.keys())
    t_0 = quarters[0]  # t=t0
    t_1 = quarters[0]  # t-1

    for t in quarters[1:]:
        if t in ['2022-Q3', '2022-Q4']:
            continue # skip last two quarters of 2022
        data_t = data[t]      # D_t
        data_t_1 = data[t_1]  # D_(t-1)

        unchanged_t[t] = {key: [] for key in data_t.keys()}
        new_t[t] = {key: [] for key in data_t.keys()}
        updated_t[t] = {key: [] for key in data_t.keys()}
        deleted_t[t] = {key: [] for key in data_t.keys()}

        for i in range(0, len(data_t['text'])):
            text_t = data_t['text'][i]
            labels_ids_t = data_t['labels_ids'][i][0]
            if text_t in data_t_1['text']:
                t_1_index = data_t_1['text'].index(text_t)
                labels_inds_t_1 = data_t_1['labels_ids'][t_1_index][0]
                if labels_ids_t == labels_inds_t_1:
                    # text_t = text_t-1 & label_t = label_t-1
                    # add to D_unchanged
                    for key in data_t.keys():
                        unchanged_t[t][key].append(data_t[key][i])
                else:
                    # text_t = text_(t-1) & label_t != label_(t-1)
                    # add to D_updated
                    for key in ['text', 'relation']:
                        updated_t[t][key].append(data_t[key][i])
            else:
                # text_t not in D_(t-1) texts
                # add to D_new
                for key in data_t.keys():
                    new_t[t][key].append(data_t[key][i])
        for j in range(0, len(data_t_1['text'])):
            text_t_1 = data_t_1['text'][j]
            if text_t_1 not in data_t['text']:
                # text_(t+1) not in D_t
                # add to D_deleted
                for key in data_t_1.keys():
                    deleted_t[t][key].append(data_t_1[key][j])
        t_1 = t

        assert len(data_t['text']) == len(unchanged_t[t]['text']) + len(updated_t[t]['text']) + len(new_t[t]['text'])
        print(
            't={}: From total {} samples in D_t, {} are unchanged, {} are updated, {} are deleted and {} are new, compared to D_(t-1).'.format(
                t,
                len(data_t['text']),
                len(unchanged_t[t]['text']),
                len(updated_t[t]['text']),
                len(deleted_t[t]['text']),
                len(new_t[t]['text'])))
    return unchanged_t, new_t, updated_t, deleted_t, data[t_0]

## Split dataset

In [7]:
# Split dataset
unchanged_t, new_t, updated_t, deleted_t, orig = split_dataset(data_dict_multi_token)

t=2019-Q2: From total 7962 samples in D_t, 7728 are unchanged, 36 are updated, 206 are deleted and 198 are new, compared to D_(t-1).
t=2019-Q3: From total 8011 samples in D_t, 7633 are unchanged, 65 are updated, 264 are deleted and 313 are new, compared to D_(t-1).
t=2019-Q4: From total 7989 samples in D_t, 7742 are unchanged, 51 are updated, 218 are deleted and 196 are new, compared to D_(t-1).
t=2020-Q1: From total 7872 samples in D_t, 7359 are unchanged, 107 are updated, 524 are deleted and 406 are new, compared to D_(t-1).
t=2020-Q2: From total 7848 samples in D_t, 7699 are unchanged, 30 are updated, 143 are deleted and 119 are new, compared to D_(t-1).
t=2020-Q3: From total 7900 samples in D_t, 7666 are unchanged, 40 are updated, 142 are deleted and 194 are new, compared to D_(t-1).
t=2020-Q4: From total 7892 samples in D_t, 7682 are unchanged, 26 are updated, 192 are deleted and 184 are new, compared to D_(t-1).
t=2021-Q1: From total 7819 samples in D_t, 7381 are unchanged, 69 ar

## Load model

In [10]:
lm = "cardiffnlp/twitter-roberta-base-mar2022"
fill_mask_model = pipeline(
    'fill-mask', model=lm, framework="pt",
    tokenizer=tokenizer, top_k=100
)
model = fill_mask_model.model
model.eval()
cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda()
year='2019-Q1'
batch_size = 100
N = batch_size # number of shots
temperature = 1.0
burnin = 200

In [11]:
def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True, num_masks=0):
    """ Generate a word from out[gen_idx]

    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k
    """
    logits = out[:, gen_idx]
    if num_masks==1:  # if single-token then return the argmax
        idx = torch.argmax(logits, dim=-1)
        return idx.tolist() if return_list else idx
    else:
        if temperature is not None:
            logits = logits / temperature
        if top_k > 0:
            kth_vals, kth_idx = logits.topk(top_k, dim=-1)
            dist = torch.distributions.categorical.Categorical(logits=kth_vals)
            idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
        elif sample:
            dist = torch.distributions.categorical.Categorical(logits=logits)
            idx = dist.sample().squeeze(-1)
        else:
            idx = torch.argmax(logits, dim=-1)
        return idx.tolist() if return_list else idx

In [12]:
text_list = data_dict_multi_token[year]['text']
labels_list = data_dict_multi_token[year]['labels']
labels_ids_list = data_dict_multi_token[year]['labels_ids']
relation_list = data_dict_multi_token[year]['relation']
num_answers_list = data_dict_multi_token[year]['num_answers']

## For a single test example

In [374]:
M = 5 # max masks to try all from 1, ..., M
i = 0
text_i = text_list[i][0]
# text_i = ['Lionel M plays for <mask> <mask>.'][0]
labels_i = labels_list[i][0]
labels_ids_i = labels_ids_list[i][0] # there is an 'extra' list that is why we put [0]
relation_i = relation_list[i]
num_answers_i = num_answers_list[i]
print('Example: {}, Labels: {}, Ids: {}'.format(text_i, labels_i, labels_ids_i))

Example: Lionel Messi plays for <mask> <mask>., Labels: [' FC Barcelona'], Ids: [5429, 4612]


In [375]:
# tokenize text
tokenized_sentence_orig = tokenizer_return_id(text_i)

# mask indices to generate
mask_inds_orig = list(np.where(np.array(tokenized_sentence_orig) == mask_id)[0])

# number of masks (tokens)
gold_num_masks = len(mask_inds_orig)

In [376]:
# find first and last masks
first_mask_idx = mask_inds_orig[0]
last_mask_idx = mask_inds_orig[-1]

In [377]:
# split tokenized sentence to list of token ids before and after the masks
before_mask_ids = tokenized_sentence_orig[:first_mask_idx]
after_mask_ids = tokenized_sentence_orig[last_mask_idx+1:]

In [378]:
assert mask_id not in before_mask_ids
assert mask_id not in after_mask_ids

In [379]:
tokenized_sentence_orig

[0, 15350, 9711, 1974, 13, 50264, 50264, 4, 2]

In [380]:
# create all M combinations of masks [1,M]
all_M_mask_combos = []
for i in range (1,M+1):
    
    all_M_mask_combos.append([mask_id for _ in range(i)])

In [381]:
all_M_mask_combos

[[50264],
 [50264, 50264],
 [50264, 50264, 50264],
 [50264, 50264, 50264, 50264],
 [50264, 50264, 50264, 50264, 50264]]

In [382]:
# add all M mask combos to list of token ids
tokenized_sentence_list = []
for mask_seq in all_M_mask_combos:
    tokenized_sentence_list.append(before_mask_ids + mask_seq + after_mask_ids) # list of M lists (variable number of masks)

In [383]:
# We do not know a priori the correct number of masks so we try all in range 1, ..., M (M=5)
# tokenized_sentence_b_np, tokenized_sentence_b, logits_m, logits_m = [], [], [], []
top_ranked_gen_token_seqs = []
top_ranked_log_probs = []
for num_mask_j in range(M):
    generated_seq_tokens = [[] for _ in range(batch_size)]  # for each trial
    logits_list = [[] for _ in range(batch_size)]
    
    print('##'*11)
    print('We try with {}  mask(s)'.format(num_mask_j+1))
    print('##'*11)
    tokenized_sentence = tokenized_sentence_list[num_mask_j]
    print(tokenized_sentence)
    # list of indeces of the mask tokens
    mask_inds = list(np.where(np.array(tokenized_sentence) == mask_id)[0])
    
    # create batch (same input, N times)
    tokenized_sentence_b = [tokenized_sentence for _ in range(batch_size)]
    
    for m in mask_inds:
        # tensor
        inp = torch.tensor(tokenized_sentence_b).cuda() if cuda else torch.tensor(tokenized_sentence_b)

        # get logits
        out = model(inp)
        logits = out.logits  # batch_size x max_len x vocab

        # get new ids
        idxs_b = generate_step(logits, gen_idx=m, top_k=10, temperature=temperature, # sample=(m < burnin)
                              )

        # replace mask with predicted token id
        tokenized_sentence_b_np = np.array(tokenized_sentence_b)                       
        tokenized_sentence_b_np[:,m] = np.array(idxs_b)
        tokenized_sentence_b = tokenized_sentence_b_np.tolist() 

        # the following code does not work *and I don't know why*
        # for jj in range(len(idxs_b)):
        #     print('before: {}, predicted token id: {}'.format(tokenized_sentence_b[jj][m],idxs_b[jj]))
        #     tokenized_sentence_b[jj][m] = idxs_b[jj]

        assert sorted(idxs_b) == sorted([sent[m] for sent in tokenized_sentence_b])

        # find logits
        logits_m = logits[:,m]  # logits for the mask in position m
        logits_b = logits_m[:,idxs_b].tolist()[0]  # logits for sampled tokens

        assert len(idxs_b) == len(logits_b)

        # add generated tokens and corresponding logits to lists
        for j in range(N):
            generated_seq_tokens[j].append(idxs_b[j])
            logits_list[j].append(logits_b[j])
    
    # calculate sum of logits for each generated sequence of tokens
    sum_logits = np.array(logits_list).sum(axis=-1) / len(mask_inds)

    # finding ranking (return indices of the parallel lists for logits and generated tokens)
    ranked_inds_of_list = sum_logits.argsort()[::-1]

    # ranked logits sum in descending order
    ranked_logits = sum_logits[ranked_inds_of_list]

    # ranked generated tokens in descending order
    ranked_generated_tokens = np.array(generated_seq_tokens)[ranked_inds_of_list]
    
    for no, p in enumerate(ranked_generated_tokens[:20]):
        print('{}: {}, {}'.format(no,"".join(untokenize_id(p)), ranked_logits[no]))
    
    top_ranked_gen_token_seqs.append(ranked_generated_tokens.tolist()[0])
    top_ranked_log_probs.append(ranked_logits.tolist()[0])

######################
We try with 1  mask(s)
######################
[0, 15350, 9711, 1974, 13, 50264, 4, 2]
0:  Barcelona, 14.143179893493652
1:  Barcelona, 14.143179893493652
2:  Barcelona, 14.143179893493652
3:  Barcelona, 14.143179893493652
4:  Barcelona, 14.143179893493652
5:  Barcelona, 14.143179893493652
6:  Barcelona, 14.143179893493652
7:  Barcelona, 14.143179893493652
8:  Barcelona, 14.143179893493652
9:  Barcelona, 14.143179893493652
10:  Barcelona, 14.143179893493652
11:  Barcelona, 14.143179893493652
12:  Barcelona, 14.143179893493652
13:  Barcelona, 14.143179893493652
14:  Barcelona, 14.143179893493652
15:  Barcelona, 14.143179893493652
16:  Barcelona, 14.143179893493652
17:  Barcelona, 14.143179893493652
18:  Barcelona, 14.143179893493652
19:  Barcelona, 14.143179893493652
######################
We try with 2  mask(s)
######################
[0, 15350, 9711, 1974, 13, 50264, 50264, 4, 2]
0:  Lionel Messi, 10.740882396697998
1:  Lionel Messi, 10.740882396697998
2:  Lionel 

In [396]:
top_ranked_gen_token_seqs

[[4612],
 [15350, 9711],
 [9711, 18, 950],
 [10, 950, 11, 1005],
 [5, 9711, 12, 574, 625]]

In [397]:
top_ranked_log_probs

[14.143179893493652,
 10.740882396697998,
 10.54391606648763,
 8.914562225341797,
 8.192081356048584]

## Evaluation

In [395]:
# compare with gold labels
import evaluate
# assert len(labels_ids_i) == M
f1_micro_list, f1_macro_list = [], []
rouge_list, bleu_list, bleu_uni_list, bert_score_list = [], [], [], []

rouge = evaluate.load('rouge')
bleu = evaluate.load("bleu")
bertscore = load("bertscore")

gold_ids = labels_ids_i
gold_tok = labels_i  # list of strings

for m in range(len(top_ranked_log_probs)):
    pred_ids_m = top_ranked_gen_token_seqs[m]
    pred_tok_m = ["".join(untokenize_id(top_ranked_gen_token_seqs[m]))] # list of strings
    
    # F1 score
    # F1_micro calculates metrics globally by counting the total true positives, 
    # false negatives and false positives.
    if len(gold_ids) == len(pred_ids_m):
        f1_micro_list.append(f1_score(gold_ids,pred_ids_m, average='micro')) 
        # F1_macro calculates metrics for each label, and finds their unweighted mean. 
        # This does not take label imbalance into account.
        f1_macro_list.append(f1_score(gold_ids,pred_ids_m, average='macro'))
    else:
        f1_micro_list.append(0.0)
        f1_macro_list.append(0.0)
        
    # BLEU
    bleu_list.append(bleu.compute(references=gold_tok,
          predictions=pred_tok_m)['bleu'])
    # unigrams
    bleu_uni_list.append(bleu.compute(references=gold_tok,
              predictions=pred_tok_m)['precisions'][0])

    # ROUGE
    rouge_list.append(rouge.compute(references=gold_tok,
                  predictions=pred_tok_m, use_aggregator=True,
                                    use_stemmer=True)['rouge1'].mid.fmeasure)
    # BERT_SCORE
    bert_score_list.append(bertscore.compute(references=gold_tok,
              predictions=pred_tok_m, lang="en")['f1'][0])

In [398]:
pred_strings = ["".join(untokenize_id(pred)) for pred in top_ranked_gen_token_seqs]
print('The gold label is <<{}>> with {} masks, while the predictions are:\n'.format(gold_tok[0], gold_num_masks))
for i in range(1,M+1):
    print('<<{}>> with {} mask(s)'.format(pred_strings[i-1], i))

The gold label is << FC Barcelona>> with 2 masks, while the predictions are:

<< Barcelona>> with 1 mask(s)
<< Lionel Messi>> with 2 mask(s)
<< Messi's club>> with 3 mask(s)
<< a club in Europe>> with 4 mask(s)
<< the Messi-Lad>> with 5 mask(s)


In [401]:
bleu_uni_list

[1.0, 0.0, 0.0, 0.0, 0.0]

In [403]:
pd.DataFrame(data={'masks':list(range(1,M+1)), 'predictions':pred_strings, 
 'f1_micro': f1_micro_list, 'f1_macro':f1_macro_list,
'rouge': rouge_list, 'bleu': bleu_list, 'blue_uni_precision':bleu_uni_list, 'bert_score': bert_score_list,
                  'norm_log_probs':top_ranked_log_probs
                  })

Unnamed: 0,masks,predictions,f1_micro,f1_macro,rouge,bleu,blue_uni_precision,bert_score,norm_log_probs
0,1,Barcelona,0.0,0.0,0.666667,0.0,1.0,0.955906,14.14318
1,2,Lionel Messi,0.0,0.0,0.0,0.0,0.0,0.889325,10.740882
2,3,Messi's club,0.0,0.0,0.0,0.0,0.0,0.877186,10.543916
3,4,a club in Europe,0.0,0.0,0.0,0.0,0.0,0.863004,8.914562
4,5,the Messi-Lad,0.0,0.0,0.0,0.0,0.0,0.842462,8.192081


## I think we should consider every argmax generated token sequence for all M masks as equal predictions, and keep the highest score from any. In this case the first.

----

In [311]:
n_samples = 5
batch_size = 1
max_len = 40
top_k = 100
temperature = 1.0
generation_mode = "parallel-sequential"
leed_out_len = 5 # max_len
burnin = 250
sample = True
max_iter = 500

# # Choose the prefix context
# seed_text = "[CLS]".split()
# bert_sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
#                       generation_mode=generation_mode,
#                       sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter,
#                       cuda=cuda)

def sequential_generation(seed_text, batch_size=10, max_len=15, leed_out_len=15, 
                          top_k=0, temperature=None, sample=True, cuda=False):
    """ Generate one word at a time, in L->R order """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_len):
        inp = [sent[:seed_len+ii+leed_out_len]+[sep_id] for sent in batch]
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        idxs = generate_step(out, gen_idx=seed_len+ii, top_k=top_k, temperature=temperature, sample=sample)
        for jj in range(batch_size):
            batch[jj][seed_len+ii] = idxs[jj]
        
    return untokenize_batch(batch)

In [21]:
labels_ids_i

[5429, 4612]

In [22]:
ranked_generated_tokens

array([[ 4868,   534],
       [ 4868,   534],
       [ 4868,   534],
       [ 4868,   534],
       [ 4868,   534],
       [    5,  4612],
       [    5,  4612],
       [    5,  4612],
       [ 4612,  4612],
       [ 4612,  4612],
       [ 2361,   412],
       [ 2361,   412],
       [ 2361,   412],
       [ 2361,   412],
       [    5,  3622],
       [    5,   950],
       [    5,   950],
       [    5,   950],
       [ 2361,   315],
       [ 2361,   315],
       [ 2361,   315],
       [ 2361,   315],
       [ 2361,   315],
       [ 2361,   315],
       [ 2361,   315],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [15350,  9711],
       [153

In [23]:
for j in range(0,N):
    gold_ids = labels_ids_i
    pred_ids_j = ranked_generated_tokens[j]
    gold_tok = labels_i  # list of strings
    pred_tok_j = ["".join(untokenize_id(ranked_generated_tokens[j]))] # list of strings
    
    # F1 score
    # F1_micro calculates metrics globally by counting the total true positives, 
    # false negatives and false positives.
    f1_micro_list.append(f1_score(gold_ids,pred_ids_j, average='micro')) 
    # F1_macro calculates metrics for each label, and finds their unweighted mean. 
    # This does not take label imbalance into account.
    f1_macro_list.append(f1_score(gold_ids,pred_ids_j, average='macro'))
    
    # BLEU
    bleu_list.append(bleu.compute(references=gold_tok,
              predictions=pred_tok_j)['bleu'])

    # ROUGE
    rouge_list.append(rouge.compute(references=gold_tok,
                  predictions=pred_tok_j))
    
    # exact match / P@1
max_f1_micro = np.array(f1_micro_list).max()
max_f1_macro = np.array(f1_macro_list).max()
max_bleu = np.array(bleu_list).max()

In [24]:
print(max_f1_micro, max_f1_macro, max_bleu)

0.5 0.3333333333333333 0.0


In [None]:
def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out[:, gen_idx]
    if temperature is not None:
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1)
    else:
        idx = torch.argmax(logits, dim=-1)
    return idx.tolist() if return_list else idx
  
def get_init_text(seed_text, max_len, batch_size = 1, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]
    #if rand_init:
    #    for ii in range(max_len):
    #        init_idx[seed_len+ii] = np.random.randint(0, len(tokenizer.vocab))
    
    return tokenize_batch(batch)

def printer(sent, should_detokenize=True):
    if should_detokenize:
        sent = detokenize(sent)[1:-1]
    print(" ".join(sent))

In [None]:
# Generation modes as functions
import math
import time

def parallel_sequential_generation(seed_text, batch_size=10, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200,
                                   cuda=False, print_every=10, verbose=True):
    """ Generate for one random position at a timestep
    
    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_iter):
        kk = np.random.randint(0, max_len)
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = mask_id
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        topk = top_k if (ii >= burnin) else 0
        idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii+1, print_every) == 0:
            for_print = tokenizer.convert_ids_to_tokens(batch[0])
            for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
            print("iter", ii+1, " ".join(for_print))
            
    return untokenize_batch(batch)

def parallel_generation(seed_text, batch_size=10, max_len=15, top_k=0, temperature=None, max_iter=300, sample=True, 
                        cuda=False, print_every=10, verbose=True):
    """ Generate for all positions at each time step """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_iter):
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        for kk in range(max_len):
            idxs = generate_step(out, gen_idx=seed_len+kk, top_k=top_k, temperature=temperature, sample=sample)
            for jj in range(batch_size):
                batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii, print_every) == 0:
            print("iter", ii+1, " ".join(tokenizer.convert_ids_to_tokens(batch[0])))
    
    return untokenize_batch(batch)
            
def sequential_generation(seed_text, batch_size=10, max_len=15, leed_out_len=15, 
                          top_k=0, temperature=None, sample=True, cuda=False):
    """ Generate one word at a time, in L->R order """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_len):
        inp = [sent[:seed_len+ii+leed_out_len]+[sep_id] for sent in batch]
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        idxs = generate_step(out, gen_idx=seed_len+ii, top_k=top_k, temperature=temperature, sample=sample)
        for jj in range(batch_size):
            batch[jj][seed_len+ii] = idxs[jj]
        
    return untokenize_batch(batch)

def generate(n_samples, seed_text="[CLS]", batch_size=10, max_len=25, 
             generation_mode="parallel-sequential",
             sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500,
             cuda=False, print_every=1):
    # main generation function to call
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    for batch_n in range(n_batches):
        if generation_mode == "parallel-sequential":
            batch = parallel_sequential_generation(seed_text, batch_size=batch_size, max_len=max_len, top_k=top_k,
                                                   temperature=temperature, burnin=burnin, max_iter=max_iter, 
                                                   cuda=cuda, verbose=False)
        elif generation_mode == "sequential":
            batch = sequential_generation(seed_text, batch_size=batch_size, max_len=max_len, top_k=top_k, 
                                          temperature=temperature, leed_out_len=leed_out_len, sample=sample,
                                          cuda=cuda)
        elif generation_mode == "parallel":
            batch = parallel_generation(seed_text, batch_size=batch_size,
                                        max_len=max_len, top_k=top_k, temperature=temperature, 
                                        sample=sample, max_iter=max_iter, 
                                        cuda=cuda, verbose=False)
        
        if (batch_n + 1) % print_every == 0:
            print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time))
            start_time = time.time()
        
        sentences += batch
    return sentences

In [None]:
n_samples = 5
batch_size = 5
max_len = 40
top_k = 100
temperature = 1.0
generation_mode = "parallel-sequential"
leed_out_len = 5 # max_len
burnin = 250
sample = True
max_iter = 500

# Choose the prefix context
seed_text = "[CLS]".split()
bert_sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
                      generation_mode=generation_mode,
                      sample=sample, top_k=top_k, temperature=temperature, burnin=burnin, max_iter=max_iter,
                      cuda=cuda)

In [None]:
for sent in bert_sents:
  printer(sent, should_detokenize=True)

In [None]:
def tokenizer_return_id(text, filter_special_tokens=False):
    """
    Text to token ids for a string.
    """
    if type(text) is list: text = " ".join(text)
    output = tokenizer(text)
    if filter_special_tokens:
        token_ids = [i for i in output['input_ids'] if i not in tokenizer.all_special_ids]
    else:
        token_ids = [i for i in output['input_ids'] ]
    return token_ids

def tokenize_batch(batch):
    """
    Text to token ids for a list of strings.
    """
#     return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]
    return [tokenizer_return_id(sent) for sent in batch]

In [None]:
get_init_text(seed_text, max_len, batch_size)

In [None]:
for ii in range(max_iter):
    kk = np.random.randint(0, max_len)  # change this to id_first_mask to id_last_mask ??
    for jj in range(batch_size):
        batch[jj][seed_len+kk] = mask_id
    inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
    out = model(inp)
    topk = top_k if (ii >= burnin) else 0
    idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))
    for jj in range(batch_size):
        batch[jj][seed_len+kk] = idxs[jj]

    if verbose and np.mod(ii+1, print_every) == 0:
        for_print = tokenizer.convert_ids_to_tokens(batch[0])
        for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
        print("iter", ii+1, " ".join(for_print))

In [None]:
for ii in range(max_iter):
    kk = np.random.randint(0, max_len)  # change this to id_first_mask to id_last_mask ??
    for jj in range(batch_size):
        batch[jj][seed_len+kk] = mask_id
    inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
    out = model(inp)
    topk = top_k if (ii >= burnin) else 0
    idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))
    for jj in range(batch_size):
        batch[jj][seed_len+kk] = idxs[jj]

    if verbose and np.mod(ii+1, print_every) == 0:
        for_print = tokenizer.convert_ids_to_tokens(batch[0])
        for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
        print("iter", ii+1, " ".join(for_print))

Original implementation

In [None]:
def get_init_text(seed_text, max_len, batch_size = 1, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]
    #if rand_init:
    #    for ii in range(max_len):
    #        init_idx[seed_len+ii] = np.random.randint(0, len(tokenizer.vocab))
#     return batch
    return tokenize_batch(batch)

In [None]:
seed_text

In [None]:
n_samples = 5
batch_size = 5
max_len = 40
top_k = 100
temperature = 1.0
generation_mode = "parallel-sequential"
leed_out_len = 5 # max_len
burnin = 250
sample = True
max_iter = 500

# Choose the prefix context
seed_text = "I have 3 more weeks for my <mask>".split()
seed_len = len(seed_text)
batch = get_init_text(seed_text, 0, batch_size)



In [None]:
batch

In [None]:
for ii in range(max_len):
    inp = [sent[:seed_len+ii+leed_out_len]+[sep_id] for sent in batch]
    inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
    out = model(inp)
    idxs = generate_step(out, gen_idx=seed_len+ii, top_k=top_k, temperature=temperature, sample=sample)
    for jj in range(batch_size):
        batch[jj][seed_len+ii] = idxs[jj]

return untokenize_batch(batch)