<a href="https://colab.research.google.com/github/JuanJoseMV/neuraltextgen/blob/main/Bert_V2.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [29]:
!pip3 install pytorch_pretrained_bert



In [30]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [31]:
import numpy as np
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

In [32]:
# Load pre-trained model (weights)
model_version = 'bert-base-uncased'
model = BertForMaskedLM.from_pretrained(model_version)
model.eval()
cuda = torch.cuda.is_available()
if cuda:
    model = model.cuda(0)

# Load pre-trained model tokenizer (vocabulary)
tokenizer = BertTokenizer.from_pretrained(model_version, do_lower_case=model_version.endswith("uncased"))

def tokenize_batch(batch):
    return [tokenizer.convert_tokens_to_ids(sent) for sent in batch]

def untokenize_batch(batch):
    return [tokenizer.convert_ids_to_tokens(sent) for sent in batch]

def detokenize(sent):
    """ Roughly detokenizes (mainly undoes wordpiece) """
    new_sent = []
    for i, tok in enumerate(sent):
        if tok.startswith("##"):
            new_sent[len(new_sent) - 1] = new_sent[len(new_sent) - 1] + tok[2:]
        else:
            new_sent.append(tok)
    return new_sent

CLS = '[CLS]'
SEP = '[SEP]'
MASK = '[MASK]'
mask_id = tokenizer.convert_tokens_to_ids([MASK])[0]
sep_id = tokenizer.convert_tokens_to_ids([SEP])[0]
cls_id = tokenizer.convert_tokens_to_ids([CLS])[0]

100%|██████████| 407873900/407873900 [00:11<00:00, 34978241.28B/s]
100%|██████████| 231508/231508 [00:00<00:00, 1146403.15B/s]


# Generations

In [33]:
def generate_step(out, gen_idx, temperature=None, top_k=0, sample=False, return_list=True):
    """ Generate a word from from out[gen_idx]
    
    args:
        - out (torch.Tensor): tensor of logits of size batch_size x seq_len x vocab_size
        - gen_idx (int): location for which to generate for
        - top_k (int): if >0, only sample from the top k most probable words
        - sample (Bool): if True, sample from full distribution. Overridden by top_k 
    """
    logits = out[:, gen_idx]
    if temperature is not None:
        logits = logits / temperature
    if top_k > 0:
        kth_vals, kth_idx = logits.topk(top_k, dim=-1)
        dist = torch.distributions.categorical.Categorical(logits=kth_vals)
        idx = kth_idx.gather(dim=1, index=dist.sample().unsqueeze(-1)).squeeze(-1)
    elif sample:
        dist = torch.distributions.categorical.Categorical(logits=logits)
        idx = dist.sample().squeeze(-1)
    else:
        idx = torch.argmax(logits, dim=-1)
    return idx.tolist() if return_list else idx

In [34]:
# Generation modes as functions
import math
import time

def get_init_text(seed_text, max_len, batch_size = 1, rand_init=False):
    """ Get initial sentence by padding seed_text with either masks or random words to max_len """
    batch = [seed_text + [MASK] * max_len + [SEP] for _ in range(batch_size)]
    #if rand_init:
    #    for ii in range(max_len):
    #        init_idx[seed_len+ii] = np.random.randint(0, len(tokenizer.vocab))
    
    return tokenize_batch(batch)

def parallel_sequential_generation(seed_text, max_len=15, top_k=0, temperature=None, max_iter=300, burnin=200,
                                   cuda=False, print_every=10, verbose=True):
    """ Generate for one random position at a timestep
    
    args:
        - burnin: during burn-in period, sample from full distribution; afterwards take argmax
    """
    seed_len = len(seed_text)
    batch = get_init_text(seed_text, max_len, batch_size)
    
    for ii in range(max_iter):
        kk = np.random.randint(0, max_len)
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = mask_id
        inp = torch.tensor(batch).cuda() if cuda else torch.tensor(batch)
        out = model(inp)
        topk = top_k if (ii >= burnin) else 0
        idxs = generate_step(out, gen_idx=seed_len+kk, top_k=topk, temperature=temperature, sample=(ii < burnin))
        for jj in range(batch_size):
            batch[jj][seed_len+kk] = idxs[jj]
            
        if verbose and np.mod(ii+1, print_every) == 0:
            for_print = tokenizer.convert_ids_to_tokens(batch[0])
            for_print = for_print[:seed_len+kk+1] + ['(*)'] + for_print[seed_len+kk+1:]
            print("iter", ii+1, " ".join(for_print))
            
    return untokenize_batch(batch)

def generate(n_samples, seed_text="[CLS]", batch_size=10, max_len=25, 
             sample=True, top_k=100, temperature=1.0, burnin=200, max_iter=500,
             cuda=False, print_every=1):
    # main generation function to call
    sentences = []
    n_batches = math.ceil(n_samples / batch_size)
    start_time = time.time()
    for batch_n in range(n_batches):

        batch = parallel_sequential_generation(seed_text, max_len=max_len, top_k=top_k,
                                               temperature=temperature, burnin=burnin, max_iter=max_iter, 
                                               cuda=cuda, verbose=False)
    

        if (batch_n + 1) % print_every == 0:
            print("Finished batch %d in %.3fs" % (batch_n + 1, time.time() - start_time))
            start_time = time.time()
        
        sentences += batch
    return sentences

In [35]:
# Utility functions

def printer(sent, should_detokenize=True):
    if should_detokenize:
        sent = detokenize(sent)[1:-1]
    print(" ".join(sent))
    
def read_sents(in_file, should_detokenize=False):
    sents = [sent.strip().split() for sent in open(in_file).readlines()]
    if should_detokenize:
        sents = [detokenize(sent) for sent in sents]
    return sents

def write_sents(out_file, sents, should_detokenize=False):
    with open(out_file, "w") as out_fh:
        for sent in sents:
            sent = detokenize(sent[1:-1]) if should_detokenize else sent
            ###print("%s\n" % " ".join(sent))
            out_fh.write("%s\n" % " ".join(sent))

In [36]:
n_samples = 100 #1000
batch_size = 5 #50
max_len = 40
top_k = 100
temperature = 0.7

leed_out_len = 5 # max_len
burnin = 250
sample = True
max_iter = 500

# Choose the prefix context
seed_text = "[CLS]".split()

for temp in [1.0]:
    bert_sents = generate(n_samples, seed_text=seed_text, batch_size=batch_size, max_len=max_len,
                          sample=sample, top_k=top_k, temperature=temp, burnin=burnin, max_iter=max_iter,
                          cuda=True)
    out_file = "/content/%s-len%d-burnin%d-topk%d-temp%.3f.txt" % (model_version, max_len, burnin, top_k, temp)
    write_sents(out_file, bert_sents, should_detokenize=True)

Finished batch 1 in 9.629s
Finished batch 2 in 9.392s
Finished batch 3 in 9.487s
Finished batch 4 in 9.558s
Finished batch 5 in 9.661s
Finished batch 6 in 9.736s
Finished batch 7 in 9.856s
Finished batch 8 in 9.943s
Finished batch 9 in 10.003s
Finished batch 10 in 10.098s
Finished batch 11 in 10.105s
Finished batch 12 in 10.105s
Finished batch 13 in 10.110s
Finished batch 14 in 10.094s
Finished batch 15 in 10.109s
Finished batch 16 in 10.144s
Finished batch 17 in 10.177s
Finished batch 18 in 10.211s
Finished batch 19 in 10.263s
Finished batch 20 in 10.304s


In [37]:
in_file = "/content/bert-base-uncased-len40-burnin250-topk100-temp1.000.txt"
bert_sents = read_sents(in_file, should_detokenize=False)

In [38]:
for i in range(25):
    printer(bert_sents[i], should_detokenize=True)

later , he appeared on live albums . they were live from cleveland , where he " went wild with feeling " before the performance , and dead kennedys , where " taxi star " performed briefly
... the man from earlier , remembering his conversation with doctor marlin , the so called doc who always plugged into my brain , would stop me from thinking about the doctor who worked in space
the years he had heard that several members of the council had shown up at evernight academy before and they were pushing them back , too , and it took as much time as anyone possibly could
is however very keen to go to dimarzio to retrieve a piece of the dark archive , which is the result of explosion . meanwhile , simon , emily and carol are currently on the move
trees had expired . the shadows who lay dead on the hill had grown so black that there was no joy , no suffering or grief or fear , no recognition of the shadows who were long gone
change in restaurants , eateries , and urban cuisines . westphal sys

# Evaluation

In [39]:
from nltk.translate import bleu_score as bleu

## Quality Measures

How similar are the generated sentences to the original training data (Toronto Book Corpus and Wikipedia dumps). We follow Yu et al., (2017) and compute the BLEU between the generations and the test sets of both corpora by treating the test set as the references for each generation. The tests sets are large; we subsample 5000 examples from each.

In [40]:
def prepare_data(data_file, replacements={}, uncased=True):
    data = [d.strip().split() for d in open(data_file, 'r').readlines()]
    if uncased:
        data = [[t.lower() for t in sent] for sent in data]
        
    for k, v in replacements.items():
        data = [[t if t != k else v for t in sent] for sent in data]
 
    return data

def prepare_wiki(data_file, uncased=True):
    replacements = {"@@unknown@@": "[UNK]"}
    return prepare_data(data_file, replacements=replacements, uncased=uncased)

def prepare_tbc(data_file):        
    replacements = {"``": "\"", "\'\'": "\""}
    return prepare_data(data_file, replacements=replacements)

def corpus_bleu(generated, references):
    """ Compute similarity between two corpora as measured by
    comparing each sentence of `generated` against all sentences in `references` 
    
    args:
        - generated (List[List[str]]): list of sentences (split into tokens)
        - references (List[List[str]]): list of sentences (split into tokens)
        
    returns:
        - bleu (float)
    """    
    return bleu.corpus_bleu([references for _ in range(len(generated))], generated)

In [41]:
!git clone https://github.com/nyu-dl/bert-gen
wiki103_file = 'bert-gen/data/wiki103.5k.txt'
tbc_file = 'bert-gen/data/tbc.5k.txt'

wiki_data = prepare_wiki(wiki103_file)
tbc_data = prepare_tbc(tbc_file)

fatal: destination path 'bert-gen' already exists and is not an empty directory.


In [51]:
def cleaner(data):
  len_mask = []
  for i in range(len(data)):
    if len(data[i]) <4:
      len_mask.append(False)
    else:
      len_mask.append(True)

  data = [b for a, b in zip(len_mask, data) if a]
  return data

wiki_data = cleaner(wiki_data)
tbc_data = cleaner(tbc_data)

In [53]:
print("BERT-TBC BLEU: %.2f" % (100 * corpus_bleu(bert_sents, tbc_data)))
print("BERT-Wiki103 BLEU: %.2f" % (100 * corpus_bleu(bert_sents, wiki_data)))
print("BERT-{TBC + Wiki103} BLEU: %.2f" % (100 * corpus_bleu(bert_sents, tbc_data[:2500] + wiki_data[:2500])))

BERT-TBC BLEU: 7.48
BERT-Wiki103 BLEU: 8.83
BERT-{TBC + Wiki103} BLEU: 9.37


## Diversity Measures

Self-BLEU: treat each sentence as a hypothesis and treat rest of corpus as reference. Lower is better.

In [54]:
from collections import Counter
from nltk.util import ngrams

def self_bleu(sents):
    return bleu.corpus_bleu([[s for (j, s) in enumerate(sents) if j != i] for i in range(len(sents))], sents)

def get_ngram_counts(sents, max_n=4):
    size2count = {}
    for i in range(1, max_n + 1):
        size2count[i] = Counter([n for sent in sents for n in ngrams(sent, i)])
    return size2count

def ref_unique_ngrams(preds, refs, max_n=4):
    # get # of *distinct* pred ngrams that don't appear in ref
    pct_unique = {}
    pred_ngrams = get_ngram_counts(preds, max_n)
    ref_ngrams = get_ngram_counts(refs, max_n)
    for i in range(1, max_n + 1):
      pred_ngram_counts = set(pred_ngrams[i].keys())
      total = sum(pred_ngrams[i].values())
      ref_ngram_counts = set(ref_ngrams[i].keys())
      pct_unique[i] = len(pred_ngram_counts.difference(ref_ngram_counts)) / total
    return pct_unique
        
def self_unique_ngrams(preds, max_n=4):
    # get # of pred ngrams with count 1
    pct_unique = {}
    pred_ngrams = get_ngram_counts(preds, max_n)
    for i in range(1, max_n + 1):
        n_unique = len([k for k, v in pred_ngrams[i].items() if v == 1])
        total = sum(pred_ngrams[i].values())
        pct_unique[i] = n_unique / total
    return pct_unique

In [55]:
print("BERT self-BLEU: %.2f" % (100 * self_bleu(bert_sents)))

BERT self-BLEU: 15.15


Corpus/Sentence contains 0 counts of 4-gram overlaps.
BLEU scores might be undesirable; use SmoothingFunction().


In [56]:
max_n = 4

pct_uniques = ref_unique_ngrams(bert_sents, wiki_data, max_n)
for i in range(1, max_n + 1):
    print("BERT unique %d-grams relative to Wiki: %.2f" % (i, 100 * pct_uniques[i]))
pct_uniques = ref_unique_ngrams(bert_sents, tbc_data, max_n)
for i in range(1, max_n + 1):
    print("BERT unique %d-grams relative to TBC: %.2f" % (i, 100 * pct_uniques[i]))
pct_uniques = self_unique_ngrams(bert_sents, max_n)
for i in range(1, max_n + 1):
    print("BERT unique %d-grams relative to self: %.2f" % (i, 100 * pct_uniques[i]))

BERT unique 1-grams relative to Wiki: 10.68
BERT unique 2-grams relative to Wiki: 64.27
BERT unique 3-grams relative to Wiki: 93.85
BERT unique 4-grams relative to Wiki: 99.19
BERT unique 1-grams relative to TBC: 16.87
BERT unique 2-grams relative to TBC: 68.44
BERT unique 3-grams relative to TBC: 94.42
BERT unique 4-grams relative to TBC: 99.30
BERT unique 1-grams relative to self: 29.62
BERT unique 2-grams relative to self: 81.44
BERT unique 3-grams relative to self: 97.43
BERT unique 4-grams relative to self: 99.39
