<a href="https://colab.research.google.com/github/ChriSuh10/Poetix18/blob/master/BERT.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
!pip install pytorch-pretrained-bert
!pip install ftfy

Collecting pytorch-pretrained-bert
[?25l  Downloading https://files.pythonhosted.org/packages/5d/3c/d5fa084dd3a82ffc645aba78c417e6072ff48552e3301b1fa3bd711e03d4/pytorch_pretrained_bert-0.6.1-py3-none-any.whl (114kB)
[K    100% |████████████████████████████████| 122kB 3.7MB/s 
Installing collected packages: pytorch-pretrained-bert
Successfully installed pytorch-pretrained-bert-0.6.1
Collecting ftfy
[?25l  Downloading https://files.pythonhosted.org/packages/8f/86/df789c5834f15ae1ca53a8d4c1fc4788676c2e32112f6a786f2625d9c6e6/ftfy-5.5.1-py3-none-any.whl (43kB)
[K    100% |████████████████████████████████| 51kB 2.0MB/s 
Installing collected packages: ftfy
Successfully installed ftfy-5.5.1


In [2]:
from google.colab import drive
drive.mount('/content/drive')

Go to this URL in a browser: https://accounts.google.com/o/oauth2/auth?client_id=947318989803-6bn6qk8qdgf4n4g3pfee6491hc0brc4i.apps.googleusercontent.com&redirect_uri=urn%3Aietf%3Awg%3Aoauth%3A2.0%3Aoob&scope=email%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdocs.test%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fdrive.photos.readonly%20https%3A%2F%2Fwww.googleapis.com%2Fauth%2Fpeopleapi.readonly&response_type=code

Enter your authorization code:
··········
Mounted at /content/drive


In [3]:
%cd drive/My\ Drive/Colab\ Notebooks/Poetix/Hugging\ Face

/content/drive/My Drive/Colab Notebooks/Poetix/Hugging Face


In [4]:
import torch
from pytorch_pretrained_bert import BertTokenizer, BertModel, BertForMaskedLM

import pickle
import random
import itertools

Better speed can be achieved with apex installed from https://www.github.com/nvidia/apex.


In [5]:
model = BertForMaskedLM.from_pretrained('bert-base-uncased')
model.to('cuda')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')

100%|██████████| 407873900/407873900 [00:08<00:00, 48436558.43B/s]
100%|██████████| 231508/231508 [00:00<00:00, 5984241.11B/s]


In [6]:
def raw_text_to_tokens(text, masks):
  tokenized_text = ['[CLS]']
  for i, w in enumerate(tokenizer.tokenize(text)):
    if i in masks:
      w = '[MASK]'
    elif w == '.':
      w = '[SEP]'
    tokenized_text.append(w)
  
  return tokenized_text

def token_segment_tensors(tokens, to_cuda=True):
  indexed_tokens = tokenizer.convert_tokens_to_ids(tokens)
  segment_ids = [0 for i in range(len(tokens))]
  
  tokens_tensor = torch.tensor([indexed_tokens])
  segment_tensor = torch.tensor([segment_ids])
  
  if to_cuda:
    tokens_tensor = tokens_tensor.to('cuda')
    segment_tensor = segment_tensor.to('cuda')
  
  return tokens_tensor, segment_tensor
  
def feed_bert(text, masks, no_punct=True):
  tokenized_text = raw_text_to_tokens(text, masks)
  
  tokens_tensor, segment_tensor = token_segment_tensors(tokenized_text)
  
  model.eval()
  
  with torch.no_grad():
    predictions = model(tokens_tensor, segment_tensor)
  
  pred_str = text.split()
  if no_punct:
    punct_indices = [1000, 1012, 1010, 1013, 1005, 1011, 1035, 1007, 1006]
    top_indices = torch.argsort(predictions, dim=2, descending=True)
    for i in masks:
      for j in range(len(punct_indices) + 1):
        potential_prediction = top_indices[0, i + 1, j].item()
        if potential_prediction not in punct_indices:
          predicted_token = tokenizer.convert_ids_to_tokens([potential_prediction])[0]
          pred_str.pop(i)
          pred_str.insert(i, predicted_token)
          break
  else:
    for i in masks:
      # Because we add [CLS] to the beginning of the sentence
      predicted_index = torch.argmax(predictions[0, i + 1]).item()
      predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
      pred_str.pop(i)
      pred_str.insert(i, predicted_token)
  return pred_str, tokenized_text

text = 'There once was a woman from Niger.'
masks = [3, 5]

pred, masked = feed_bert(text, masks)

print(text)
print(' '.join(masked))
print(' '.join(pred))

There once was a woman from Niger.
[CLS] there once was [MASK] woman [MASK] niger [SEP]
There once was a woman in Niger.


In [7]:
def mask_in_range(tokens, masks):
  masked_tokens = []
  for i in range(len(tokens)):
    if i in masks:
      masked_tokens.append('[MASK]')
    else:
      masked_tokens.append(tokens[i])
  return masked_tokens
    

### For sequence (x1, x2, ..., xn),
### returns p(x1) * p(x2 | x1) * ... * p(xn | x1, ..., xn-1)
def score_sentence_bert(text):
  tokenized_text = raw_text_to_tokens(text, [])
#   print(tokenized_text)
  # compute log probabilities across all words for each token
  log_prob = torch.nn.LogSoftmax(dim=2)
  prob_by_spot = []
  probs = 0
  # since sentence begins with '[CLS]'
  for i in range(1, len(tokenized_text)):
    masked_tokens = tokenized_text[:i]
    masked_tokens.append('[MASK]')
    tokens_tensor, segment_tensor = token_segment_tensors(masked_tokens)
    
    model.eval()
    with torch.no_grad():
      log_probs = log_prob(model(tokens_tensor))
      this_log_prob = log_probs[0, i, tokens_tensor[0, i]]
      probs += this_log_prob.item()
      prob_by_spot.append(this_log_prob.item())

#   print(prob_by_spot)
  return probs / len(tokenized_text)
    
    

human_text = 'Who wanted to become a ballerina'
generated_text = 'been one perturbations of his liar'

print(score_sentence_bert(human_text))
print(score_sentence_bert(generated_text))

-11.797262072563171
-12.830491924285889


In [0]:
# Create data structures
with open('templates.p', 'rb') as f:
  dataset, second_line_, third_line_, last_two_lines = pickle.load(f)

with open('likelihood_collocations.p', 'rb') as f:
  coll_dict = pickle.load(f)
  
# map second word to first word
rev_coll_dict = {}
for k, v in coll_dict.items():
  for w, score in v:
    if w not in rev_coll_dict:
      rev_coll_dict[w] = []
    rev_coll_dict[w].append((k, score))
for l in rev_coll_dict.values():
  l.sort(key=lambda x: x[1], reverse=True)
  
with open('postag_dict_all.p', 'rb') as f:
  postag_dict = pickle.load(f)

pos_to_words = postag_dict[1]
words_to_pos = postag_dict[2]

# map words to meter
with open('cmudict-0.7b.txt') as f:
    lines = [line.rstrip("\n").split() for line in f if (";;;" not in line)]
    dict_meters = {}
    for i in range(len(lines)):
        line = lines[i]
        newLine = [line[0].lower()]
        if("(" in newLine[0] and ")" in newLine[0]):
            newLine[0] = newLine[0][:-3]
        chars = ""
        for word in line[1:]:
            for ch in word:
                if(ch in "012"):
                    if(ch == "2"):
                        chars+="1"
                    else:
                        chars+=ch
        newLine+=[chars]
        lines[i] = newLine
        if(newLine[0] not in dict_meters): #THIS IF STATEMENT ALLOWS FOR MULTIPLE PRONUNCIATIONS OF A WORD
            dict_meters[newLine[0]]=[chars]
        else:
            if(chars not in dict_meters[newLine[0]]):
                dict_meters[newLine[0]]+=[chars]
    dict_meters[','] = ['']
    dict_meters['.'] = ['']
    
# map pos to possible syllables
pos_syllables = {}
for k, v in pos_to_words.items():
    pos_syllables[k] = set()
    for w in v:
        try:
            pos_syllables[k].add(len(dict_meters[w][0]))
        except:
            continue
pos_syllables[','].add(0)
pos_syllables['.'].add(0)

In [0]:
first_line = { 
    'NN':[
      (['DT', 'NN', 'WP', 'VBD', 'DT', 'NN'], ['a', 'tutor', 'who', 'tooted', 'a', 'flute']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'DT', 'NN'], ['there', 'was', 'an', 'old', 'man', 'with', 'a', 'beard']),
      (['DT', 'JJ', 'NN', 'NN', 'IN', 'NN'], ['a', 'young', 'gourmet', 'dining', 'at', 'crewe']),
      (['EX', 'RB', 'VBD', 'DT', 'NN', 'IN', 'DT', 'NN'], ['there', 'once', 'was', 'a', 'fly', 'on', 'the', 'wall'])
    ],
    
    'NNP': [
      (['DT', 'JJ', 'JJ', 'NN', 'IN', 'NNP'], ['a', 'silly', 'young', 'man', 'from', 'clyde']),
      (['DT', 'JJ', 'JJ', 'NN', 'VBN', 'NNP'], ['a', 'nifty', 'young', 'flapper', 'named', 'jane']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NNP'], ['there', 'was', 'an', 'old', 'man', 'of', 'nantucket']),
      (['DT', 'JJ', 'NN', 'IN', 'NN', 'NNP'], ['an', 'elderly', 'bride', 'of', 'port', 'jervis']),
      (['DT', 'NN', 'NN', 'VBN', 'NNP'], ['a', 'crossword', 'compiler', 'named', 'moss']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'NNP'], ['there', 'was', 'a', 'young', 'fellow', 'called', 'binn']),
      (['DT', 'JJ', 'VBG', 'NN', 'IN', 'NNP'], ['an', 'odd', 'looking', 'guy', 'from', 'beruit']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'NNP'], ['there', 'was', 'a', 'young', 'lady', 'named', 'rose'])
    ],
              
    ',': [
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'SO', 'JJ'], ['there', 'was', 'a', 'young', 'man', 'so', 'benighted']),
      (['DT', 'NN', 'IN', 'NN', ',', 'JJ', 'NN', ','], ['a', 'maiden', 'at', 'college', ',', 'miss', 'breeze', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'was', 'a', 'young', 'lady', 'of', 'cork', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'NN', ','], ['there', 'was', 'a', 'young', 'woman', 'named', 'kite', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'was', 'a', 'dear', 'lady', 'of', 'eden', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'JJ', ','], ['there', 'was', 'an', 'old', 'fellow', 'named', 'green', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'NN', ','], ['there', 'was', 'a', 'young', 'lady', 'named', 'hannah', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'DT', 'NN', ','], ['there', 'was', 'an', 'old', 'man', 'in', 'a', 'hearse', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'was', 'a', 'young', 'lady', 'of', 'kent', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'was', 'a', 'young', 'lady', 'of', 'lynn', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'MD', ','], ['there', 'was', 'a', 'young', 'lady', 'named', 'may', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'was', 'a', 'young', 'lady', 'of', 'munich', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'was', 'a', 'young', 'lady', 'from', 'niger', ',']),
      (['EX', 'RB', 'VBD', 'DT', 'NN', 'VBN', 'NN', ','], ['there', 'once', 'was', 'a', 'guy', 'named', 'othello', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'NNS', ','], ['there', 'was', 'a', 'young', 'lady', 'named', 'perkins', ',']),
      (['RB', 'VBZ', 'DT', 'JJ', 'NN', 'VBN', 'NN', ','], ['here', 'lies', 'a', 'young', 'salesman', 'named', 'phipps', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'NN', ','], ['there', 'was', 'a', 'young', 'fellow', 'named', 'weir', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'NN', ','], ['there', 'was', 'a', 'young', 'person', 'called', 'smarty', ',']),
      (['EX', 'RB', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'once', 'was', 'an', 'old', 'man', 'of', 'esser', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'was', 'a', 'young', 'lady', 'from', 'hyde', ',']),
      (['EX', 'RB', 'VBD', 'DT', 'NN', 'VBN', 'NN', ','], ['there', 'once', 'was', 'a', 'girl', 'named', 'irene', ',']),
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'IN', 'NN', ','], ['there', 'was', 'an', 'old', 'man', 'from', 'milan', ','])
    ],
              
    'VBD': [    
      (['EX', 'VBD', 'DT', 'JJ', 'NN', 'WP', 'VBD'], ['there', 'was', 'a', 'young', 'dentist', 'who', 'thrilled'])
    ]
}

In [0]:
def valid_permutation_sylls(num_sylls, template, last_word_sylls):
        """
        Finds and returns the first integer partition of num_sylls with a total
        number of integers equal to the length of template - 1 for which each
        assignment of syllables to pos is valid.

        Parameters
        ----------
        num_sylls : int
            The total number of syllables to be distributed across the words in
            the line.
        template : list
            A list of str containing the pos for each word in the line.
        last_word_sylls : int
            The number of syllables of the last word in the line.

        Returns
        -------
        list
            A list of ints corresponding to a valid assignment of syllables to
            each word in the line.
        """
        def get_all_partition_size_n(n, partition_size):
            """
            Returns all integer partitions of an int with a partition_size number
            of ints.
            """
            def get_all_partitions(n, I=1):
                yield (n,)
                for i in range(I, n//2 + 1):
                    for p in get_all_partitions(n-i, i):
                        yield (i,) + p
            return [p for p in get_all_partitions(n) if len(p) == partition_size]
        def valid_syll(sylls, template):
            """
            Checks if a template and syllable mapping are compatible.
            """
            if 'POS' in template:
              print(template)
            for i in range(len(template) - 1):
                # Add in zeros to account for punctuation
                if template[i] == ',' or template[i] == '.':
                    sylls.insert(i, 0)
                if sylls[i] not in pos_syllables[template[i]]:
                    return False
            return True
        syllables_left = num_sylls - last_word_sylls
        # Punctuation takes up no syllables, so subtract to get number of partitions
        num_zero_sylls = sum(1 if pos == '.' or pos == ',' else 0 for pos in template)
        num_words_left = len(template) - num_zero_sylls - 1

        for partition in get_all_partition_size_n(syllables_left, num_words_left):
            # Goes through all permutations by index, not numbers,
            # inefficient implementation
            permutations = list(itertools.permutations(partition))
            random.shuffle(permutations)
            for perm in permutations:
                perm = list(perm)
                # Last word is fixed
                perm.append(last_word_sylls)
                if valid_syll(perm, template):
                    return perm

In [0]:
def add_colls_to_line(i, sequence, template, template_sylls):
  curr_word = sequence[i]
  if curr_word == '[MASK]':
    return
  colls = rev_coll_dict[curr_word]
  for j in reversed(range(max(0, i - 3), i)):
    for coll_word, score in colls:
      if sequence[j] == '[MASK]' and coll_word not in sequence and template[j] in words_to_pos[coll_word] and coll_word in dict_meters and template_sylls[j] == len(dict_meters[coll_word][0]):
        sequence[j] = coll_word

def fill_from_back(sequence, template, template_sylls):
  for i in reversed(range(len(sequence))):
      add_colls_to_line(i, sequence, template, template_sylls)
      
def feed_bert_templates(text, masks, template, template_syll, no_punct=True):
  tokenized_text = raw_text_to_tokens(text, masks)
  
  tokens_tensor, segment_tensor = token_segment_tensors(tokenized_text)
  
  model.eval()
  
  with torch.no_grad():
    predictions = model(tokens_tensor, segment_tensor)
  
  pred_str = text.split()
  if no_punct:
    punct_indices = [1000, 1012, 1010, 1013, 1005, 1011, 1035, 1007, 1006]
    top_indices = torch.argsort(predictions, dim=2, descending=True)
    for i in masks:
      for j in range(len(top_indices[0, i + 1])):
        potential_prediction = top_indices[0, i + 1, j].item()
        if potential_prediction not in punct_indices:
          predicted_token = tokenizer.convert_ids_to_tokens([potential_prediction])[0]
          if not (predicted_token in words_to_pos and predicted_token in dict_meters):
            continue
          if template[i] in words_to_pos[predicted_token] and len(dict_meters[predicted_token][0]) == template_syll[i] and predicted_token not in pred_str:
            pred_str.pop(i)
            pred_str.insert(i, predicted_token)
            break
  else:
    for i in masks:
      # Because we add [CLS] to the beginning of the sentence
      predicted_index = torch.argmax(predictions[0, i + 1]).item()
      predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
      pred_str.pop(i)
      pred_str.insert(i, predicted_token)
  return pred_str, tokenized_text

def try_all_low_freq_pos(w1, template, template_sylls):
  # iterate over all permutions of low pos words
  return

def get_template(pos, line_number, num_templates=1):
  if line_number == 1:
    templates = random.sample(first_line[pos], k=num_templates)
  if line_number == 2:
    templates = random.sample(second_line_[pos], k=num_templates)
  if line_number == 3:
    templates = random.sample(third_line_[pos], k=num_templates)
  if line_number == 4:
    templates = random.sample(last_two_lines[pos], k=num_templates)
  return templates

def gen_line(w1, line=2, template=None, num_sylls=10):
  pos = words_to_pos[w1][0]
  if template is None:
    template, original = get_template(pos, line)[0]
  
  last_word_sylls = len(dict_meters[w1][0])
  template_sylls = valid_permutation_sylls(num_sylls, template, last_word_sylls)
  if template_sylls is None:
    print(num_sylls, w1)
  
  seq = ['[MASK]' for i in range(len(template_sylls) - 1)]
  seq.append(w1)
  fill_from_back(seq, template, template_sylls)
  
  masks = [i for i in range(len(seq)) if seq[i] == '[MASK]']
  as_str = ' '.join(seq)
  print(as_str)
  pred_seq, masked_tokens = feed_bert_templates(as_str, masks, template, template_sylls)
  
  for iteration in range(10):
    if '[MASK]' in pred_seq:
      masks = [i for i in range(len(seq)) if seq[i] == '[MASK]']
      as_str = ' '.join(pred_seq)
      pred_seq, masked_tokens = feed_bert_templates(as_str, masks, template, template_sylls)
    else:
      break
  score = score_sentence_bert(' '.join(pred_seq))
  
  return pred_seq, score, template

In [14]:
seq = ['[MASK]', '[MASK]', '[MASK]', '[MASK]', 'amount']
template = ['JJ', 'VB', 'JJ', 'VB', 'NN']
template_sylls = [1, 1, 1, 1, 2]

add_colls_to_line(4, seq, template, template_sylls)
print(seq)

['[MASK]', '[MASK]', 'fair', '[MASK]', 'amount']


In [15]:
gen_line('amount')

[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] amount


(['whose', 'team', 'continued', 'then', 'more', '[MASK]', 'amount'],
 -12.245478749275208,
 ['WP$', 'NN', 'VBD', 'RB', 'JJR', 'THAN', 'NN'])

In [16]:
s = 0
for k, v in tokenizer.wordpiece_tokenizer.vocab.items():
  if k in words_to_pos:
    s += 1
print(s)
print(s / len(words_to_pos))

11439
0.44566953675926285


In [0]:
words_to_pos['discomfort'].append('NN')
words_to_pos['discomfort'].remove('NNP')
words_to_pos['discomfort'].remove('VB')
pos_to_words['NNP'].remove('discomfort')
pos_to_words['VB'].remove('discomfort')

In [61]:
def gen_poem(five_words):
  templates = []
#   first_pos = words_to_pos[five_words[0]][0]
  # Assume for now
  first_pos = 'NNP'
  templates.append(get_template(first_pos, 1)[0][0])
  second_pos = words_to_pos[five_words[1]][0]
  templates.append(get_template(second_pos, 2)[0][0])
  third_pos = words_to_pos[five_words[2]][0]
  templates.append(get_template(third_pos, 3)[0][0])
  fourth_pos = words_to_pos[five_words[3]][0]
  fifth_pos = words_to_pos[five_words[4]][0]
  last_template, original, idx = get_template(fourth_pos + "-" + fifth_pos, 4)[0]
  fourth_template = last_template[:idx + 1]
  fifth_template = last_template[idx + 1:]
  templates.extend((fourth_template, fifth_template))
  
#   print(templates)
  
  lines = []
  syllables = [len(t) + len(dict_meters[five_words[i]][0]) + 1 for i, t in enumerate(templates)]
  for i, word in enumerate(five_words):
    lines.append(gen_line(five_words[i], template=templates[i], num_sylls=syllables[i]))
    
  return lines

def print_poem(five_words):         
  out = gen_poem(five_words)
  print("*********************")
  for l, s, t in out:
    print('{:60} line score: {:2.3f}'.format(' '.join(l), s))
    print(t)

print_poem(('greece', 'fleece', 'pull', 'wool', 'piece'))

[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] greece
[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] fleece
[MASK] stabbed [MASK] [MASK] pull
[MASK] [MASK] [MASK] [MASK] wool
lined [MASK] [MASK] block piece
*********************
there was the first oracle called greece                     line score: -11.824
['EX', 'VBD', 'DT', 'JJ', 'NN', 'VBN', 'NNP']
who was alcohol one of her fleece                            line score: -12.555
['WHO', 'VBD', 'IN', 'CD', 'IN', 'PRP$', 'NNS']
he stabbed behaviour to pull                                 line score: -11.884
['PRP', 'VBD', 'CD', 'TO', 'NN']
that included to her wool                                    line score: -11.644
['WDT', 'VBD', 'TO', 'PRP$', 'NN']
lined down every block piece                                 line score: -11.438
['VBD', 'RB', 'DT', 'NN', 'NN']


In [66]:
print_poem(('york', 'fork', 'eat', 'meat', 'pork'))

[MASK] [MASK] [MASK] [MASK] york
[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] fork
[MASK] [MASK] [MASK] [MASK] eat
[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] meat
[MASK] [MASK] dried [MASK] [MASK] beef [MASK] [MASK] pork
*********************
the west east recorded york                                  line score: -11.785
['DT', 'NN', 'NN', 'VBN', 'NNP']
who was i then on another fork                               line score: -12.303
['WHO', 'VBD', 'PRP', 'RB', 'IN', 'DT', 'NN']
i included one to eat                                        line score: -11.691
['PRP', 'VBD', 'CD', 'TO', 'NN']
do he be [MASK] represented the meat                         line score: -12.324
['VB', 'PRP', 'VB', ',', 'VBD', 'DT', 'NN']
[MASK] it dried on that beef in the pork                     line score: -12.482
['CC', 'PRP', 'VBD', 'IN', 'DT', 'NN', 'IN', 'DT', 'NN']


In [72]:
print_poem(('cain', 'pain', 'comfort', 'discomfort', 'vein'))

[MASK] ominous fault [MASK] stain cain
[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] [MASK] pain
[MASK] loved [MASK] [MASK] comfort
[MASK] [MASK] [MASK] discomfort
[MASK] [MASK] [MASK] [MASK] [MASK] [MASK] vein
*********************
the ominous fault of stain cain                              line score: -12.056
['DT', 'JJ', 'NN', 'IN', 'NN', 'NNP']
who was as i included in the pain                            line score: -12.542
['WHO', 'VBD', 'AS', 'PRP', 'VBD', 'IN', 'DT', 'NN']
he loved alcohol the comfort                                 line score: -11.829
['PRP', 'VBD', 'IN', 'DT', 'NN']
as another new discomfort                                    line score: -10.887
['AS', 'DT', 'JJ', 'NN']
in another church of wife to vein                            line score: -12.340
['IN', 'DT', 'NN', 'IN', 'NN', 'TO', 'NN']


In [0]:
poetic_lines = ["There was an Old Man with a beard Who said It is just as I feared",
                "Two Owls and a Hen Four Larks and a Wren",
                "Have all built their nests in my beard",
               "There was a Young Lady of Dorking Who bought a large bonnet for walking",
                "But its colour and size So bedazzled her eyes",
                "That she very soon went back to Dorking"]

generated_lines = ["who at years deserted last night on sky",
                   "told me for what no birth if i heard the blue earth",
                   "the extraordinary way you lie",
                  "of the whole thing everything alternate who fascinated me now as a mate",
                   "i continued , with a dream undertone eyes and the team",
                   "cant sing the fascination of her state"]

poetic_scores = []
for line in poetic_lines:
  poetic_scores.append(score_sentence_bert(line))
  print(line)
  
generated_scores = []
for line in generated_lines:
  generated_scores.append(score_sentence_bert(line))
  print(line)

['[CLS]', 'there', 'was', 'an', 'old', 'man', 'with', 'a', 'beard', 'who', 'said', 'it', 'is', 'just', 'as', 'i', 'feared']
[-15.141115188598633, -12.286706924438477, -12.675008773803711, -13.918147087097168, -13.837422370910645, -14.138650894165039, -14.68960952758789, -14.954139709472656, -14.437015533447266, -14.79336166381836, -13.459794998168945, -13.66992473602295, -13.885150909423828, -14.508549690246582, -13.66031551361084, -14.599510192871094]
There was an Old Man with a beard Who said It is just as I feared
['[CLS]', 'two', 'owls', 'and', 'a', 'hen', 'four', 'lark', '##s', 'and', 'a', 'wren']
[-15.141115188598633, -14.01447868347168, -14.642720222473145, -14.090720176696777, -13.978071212768555, -13.346451759338379, -14.218205451965332, -15.601720809936523, -13.77589225769043, -14.648971557617188, -13.974457740783691]
Two Owls and a Hen Four Larks and a Wren
['[CLS]', 'have', 'all', 'built', 'their', 'nests', 'in', 'my', 'beard']
[-15.141115188598633, -13.587356567382812, -13

In [0]:
sum(generated_scores) / len(generated_scores)

-12.576219562985424

In [0]:
sum(poetic_scores) / len(poetic_scores)

-12.84572671754764

In [0]:
def possiblePartsSpeechPaths():
    #SO dict["verb"] == set(adverb, noun, ...) BUT NOT set(adjective, determiner, etc)
    pos_list = ["CC","CD","DT","EX","FW","IN","JJ","JJR","JJS", "LS","MD","NN","NNS","NNP","NNPS", \
                "PDT","POS","PRP","PRP$","RB","RBR","RBS","RP","TO","UH","VB","VBD","VBG","VBN","VBP", \
                "VBZ","WDT","WP","WP$","WRB",',']
    dictTags = {}
    for tag in pos_list:
        s = set([])
        if("VB" in tag):
            s = set(["CC","RB","RBR","RBS","NN","NN","NNS","NNP","NNPS","MD","PRP"])
            sing_nouns = set(["NN","NNP"])
            plur_nouns = set(["NNS","NNPS"])
            if(tag in set(["VB","VBG","VBP","VBN"])):
                s.difference(sing_nouns)
            if(tag in set(["VBG","VBZ","VBN"])):
                s.difference(plur_nouns)
            if(tag in set(["VBG","VBN"])):
                s.union(set(["VB","VBD","VBP","VBZ"]))
        else:
            s=set(pos_list)
            if("IN"==tag):
                t = set(["IN","DT","CC"]) #maybe not CC
                s.difference(t)
            if("JJ" in tag):
                t = set(["NN","NNS","NNP","NNPS"])
                s.difference(t)
            if("TO"==tag):
                t = set(["DT","CC","IN"])
                s.difference(t)
            if("CC"==tag):
                t = set(["DT","JJ","JJR","JJS"])
                s.difference(t)
            if("NN" in tag):
                t = set(["NN","NNS","NNP","NNPS","PRP","CC"]) #maybe not CC
                s.difference(t)
            if("MD"==tag):
                t = set(["DT","VB","VBD","VBG","VBN","VBP","VBZ"])
                s.difference(t)
            if("PRP"==tag):
                t = set(["CC","JJ","JJR","JJS","NN","NNS","NNP","NNPS","DT"])
                s.difference(t)
            if("PRP$"==tag):
                t = set(["CC","DT","VB","VBD","VBG","VBN","VBP","VBZ","PRP"])
                s.difference(t)
            adv = set(["RB","RBR","RBS"])
            if(tag not in adv):
                s.remove(tag)
        dictTags[tag] = s
    return dictTags

dictPossiblePOS = possiblePartsSpeechPaths()  

In [0]:
def partsOfSpeechFilter(sequence, index, word):
    if index == len(sequence) - 1 or sequence[index + 1] == '[MASK]':
      return False
    post_word = sequence[index + 1]
    try:
        tag1 = set(words_to_pos[post_word])
    except KeyError:
        return True
    try:
        tag2 = set(words_to_pos[word])
    except KeyError:
        return True
    #if(tag1==tag2 and tag1 not in okay_tags):
    #    return True
    
    for post_pos in tag1:
      if len(tag2.intersection(set(dictPossiblePOS[post_pos]))) > 0:
        return False
    return True

def get_a_partition_of_n(n, last_word_syll, min_length):
    """
    Returns all integer partitions of an int with a partition_size number
    of ints.
    """
    def get_all_partitions(n, I=1):
        yield (n,)
        for i in range(I, n//2 + 1):
            for p in get_all_partitions(n-i, i):
                yield (i,) + p
    all_possible = [p for p in get_all_partitions(n) if p[-1] == last_word_syll and len(p) >= min_length]
    random.shuffle(all_possible)
    return all_possible[0]
          
                  
def add_colls_to_line(i, sequence, template_sylls):
  curr_word = sequence[i]
  if curr_word == '[MASK]':
    return
  colls = rev_coll_dict[curr_word]
  for j in reversed(range(max(0, i - 3), i)):
    for coll_word, score in colls:
      if sequence[j] == '[MASK]' and coll_word not in sequence and coll_word in dict_meters and template_sylls[j] == len(dict_meters[coll_word][0]):
        sequence[j] = coll_word

def fill_from_back(sequence, template_sylls):
  for i in reversed(range(len(sequence))):
      add_colls_to_line(i, sequence, template_sylls)
      
def feed_bert(text, masks, template_syll, no_punct=True):
  tokenized_text = raw_text_to_tokens(text, masks)
  
  tokens_tensor, segment_tensor = token_segment_tensors(tokenized_text)
  
  model.eval()
  
  with torch.no_grad():
    predictions = model(tokens_tensor, segment_tensor)
  
  pred_str = text.split()
  if no_punct:
    punct_indices = [1000, 1012, 1010, 1013, 1005, 1011, 1035, 1007, 1006]
    top_indices = torch.argsort(predictions, dim=2, descending=True)
    for i in reversed(masks):
      print(i)
      for j in range(len(top_indices[0, i + 1])):
        potential_prediction = top_indices[0, i + 1, j].item()
        if potential_prediction not in punct_indices:
          predicted_token = tokenizer.convert_ids_to_tokens([potential_prediction])[0]
          if predicted_token not in dict_meters:
            continue
          if predicted_token not in pred_str and len(dict_meters[predicted_token][0]) == template_syll[i] and not partsOfSpeechFilter(pred_str, i, predicted_token):
            pred_str.pop(i)
            pred_str.insert(i, predicted_token)
            break
  else:
    for i in masks:
      # Because we add [CLS] to the beginning of the sentence
      predicted_index = torch.argmax(predictions[0, i + 1]).item()
      predicted_token = tokenizer.convert_ids_to_tokens([predicted_index])[0]
      pred_str.pop(i)
      pred_str.insert(i, predicted_token)
  return pred_str, tokenized_text

def gen_line(w1, num_sylls=10):
  last_word_sylls = len(dict_meters[w1][0])
  template_sylls = get_a_partition_of_n(num_sylls, last_word_sylls, 4)
  print(template_sylls)
  
  seq = ['[MASK]' for i in range(len(template_sylls) - 1)]
  seq.append(w1)
#   fill_from_back(seq, template_sylls)
  
  masks = [i for i in range(len(seq)) if seq[i] == '[MASK]']
  as_str = ' '.join(seq)
  print(as_str)
  pred_seq, masked_tokens = feed_bert(as_str, masks, template_sylls)
  score = score_sentence_bert(' '.join(pred_seq))
  
  return pred_seq, score

In [0]:
gen_line('amount', num_sylls=7)

(1, 2, 2, 2)
[MASK] [MASK] [MASK] amount
2
1
0
['[CLS]', 'is', 'again', 'also', 'amount']
[-15.141115188598633, -14.059334754943848, -12.64249324798584, -13.567127227783203]


(['is', 'again', 'also', 'amount'], -11.082014083862305)

# GPT Experiments

In [0]:
from pytorch_pretrained_bert import GPT2Tokenizer, GPT2LMHeadModel

tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
model.eval()

100%|██████████| 1042301/1042301 [00:00<00:00, 12407131.58B/s]
100%|██████████| 456318/456318 [00:00<00:00, 8036751.99B/s]
100%|██████████| 548118077/548118077 [00:11<00:00, 46582573.99B/s]
100%|██████████| 176/176 [00:00<00:00, 31333.99B/s]


GPT2LMHeadModel(
  (transformer): GPT2Model(
    (wte): Embedding(50257, 768)
    (wpe): Embedding(1024, 768)
    (h): ModuleList(
      (0): Block(
        (ln_1): BertLayerNorm()
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
        )
        (ln_2): BertLayerNorm()
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
        )
      )
      (1): Block(
        (ln_1): BertLayerNorm()
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
        )
        (ln_2): BertLayerNorm()
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
        )
      )
      (2): Block(
        (ln_1): BertLayerNorm()
        (attn): Attention(
          (c_attn): Conv1D()
          (c_proj): Conv1D()
        )
        (ln_2): BertLayerNorm()
        (mlp): MLP(
          (c_fc): Conv1D()
          (c_proj): Conv1D()
        )
      )
      (3): Block(
        (ln_1): BertLayerNorm()
      

In [0]:
def feed_gpt(text):
  indexed_tokens = tokenizer.encode(text)
  tokens_tensor = torch.tensor([indexed_tokens])
  
  tokens_tensor = tokens_tensor.to('cuda')
  model.to('cuda')
  
  with torch.no_grad():
    predictions, past = model(tokens_tensor)
    
  predicted_index = torch.argmax(predictions[0, -1, :]).item()
  predicted_token = tokenizer.decode([predicted_index])
  
  return predicted_token

text = 'Who was Jim Henson? Why was he'
feed_gpt(text)

' so'