# Beamsearch with rythmic constrains

In [None]:
!pip install transformers

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import torch
from transformers import BeamSearchScorer, LogitsProcessorList, MaxLengthCriteria, StoppingCriteriaList
import string
import random
import pandas as pd
import pickle
import re

import pyphen
hyp_dic = pyphen.Pyphen(lang='de_DE')


In [2]:
model_name = "Anjoe/german-poetry-gpt2-large"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)



In [3]:
def create_word_start_mask(tokenizer):    
    word_start_mask =  np.zeros(len(tokenizer))
    for i in range(len(tokenizer)):
        if tokenizer.decode(i)[0] == ' ':
            word_start_mask[i] = 1
    return word_start_mask
        
def perplexity(text):
    device = model.device
    encodings = tokenizer(text, return_tensors="pt")
    import torch
    from tqdm import tqdm

    max_length = model.config.n_positions
    stride = 512

    nlls = []
    for i in range(0, encodings.input_ids.size(1), stride):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.input_ids.size(1))
        trg_len = end_loc - i  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs[0] * trg_len

        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / end_loc).cpu().detach().numpy()

In [4]:
################################################################################
# Original: https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_logits_process.html     
# Modified so that it works more on a word level. 
# Example "das Denkende" and "das Denken" are the same n-gram. 
################################################################################

def _get_word_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx]
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):

            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    return generated_ngrams


def _get_generated_word_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):

    cur_len = len(prev_input_ids)
    start_idx = cur_len + 1 - ngram_size

    ngram_idx = tuple(prev_input_ids[start_idx:cur_len])

    return banned_ngrams.get(ngram_idx, [])


def _calc_banned_word_ngram_tokens(ngram_size: int, prev_input_ids: torch.Tensor, 
                                   num_hypos: int, cur_len: int, word_start_mask: np.array):
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    prev_input_ids = [[item for item in prev_inputs.tolist() if word_start_mask[item]==1] 
                      for prev_inputs in prev_input_ids] # MODIFICATION
    if cur_len + 1 < ngram_size:
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_word_ngrams(ngram_size, prev_input_ids, num_hypos)


    banned_tokens = [
        _get_generated_word_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]

    return banned_tokens


In [5]:
################################################################################
# Original: https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_logits_process.html     
################################################################################

def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams


def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    # Before decoding the next token, prevent decoding of ngrams that have already appeared
    start_idx = cur_len + 1 - ngram_size
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    return banned_ngrams.get(ngram_idx, [])


def _calc_banned_ngram_tokens(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int):
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

    banned_tokens = [
        _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens

In [6]:
def find_word_beginning(token_ids):    
    for j in range(1,len(token_ids)):
                possible_beginning = tokenizer.decode(token_ids[-j])
                if possible_beginning[0] == ' ' and possible_beginning.strip():
                    return possible_beginning.strip(),j
                    
    else:
        return False,0
    
 

## Beam search

- tokens which could not result in a correct meter (according to a dictionary) will be blocked
- all tokens that would result in too many repetitions will be blocked

In [12]:
from transformers import LogitsProcessor
import numpy as np

class verse_logits(LogitsProcessor):
    def __init__(self, tokenizer,
              first_stress = 1,  
              offset = 1,
              ngram_size_words = 2,
              ngram_size_tokens = 4,
              max_word_len = 4,
              len_rand = True,
              randomize = True,
              random = 4,
              len_metrum = 2,
              len_verse = 10):
        
        self.tokenizer = tokenizer
        self.ngram_size_words = ngram_size_words
        self.ngram_size_tokens = ngram_size_tokens
        self.word_start_mask = create_word_start_mask(tokenizer)
        self.max_word_len = max_word_len
        self.len_rand = len_rand
        self.randomize = randomize
        self.random = random
        self.first_stress = first_stress
        self.len_metrum = len_metrum
        self.delim = tokenizer.encode('.')+ tokenizer.encode(',')
        self.new_line = tokenizer.encode('\n')[0]
        self.len_verse = len_verse
        self.first_stress = first_stress
        self.start = offset
        
                                
        with open(r'notstressed', 'rb') as f:
            self.lst_0 = pickle.load(f)
    
        with open(r'stressed', 'rb') as f:
            self.lst_1 = pickle.load(f)
    
        with open(r'notstressed_start', 'rb') as f:
            self.lst_0_s = pickle.load(f)
    
        with open(r'stressed_start', 'rb') as f:
            self.lst_1_s = pickle.load(f)
            
        with open(r'stressed_start', 'rb') as f:
            self.lst_1_s = pickle.load(f)
        self.rythm_df = pd.read_csv('word_rythm.csv')

        
    def __call__(self, input_ids, scores):

        banned_tokens = []
    
        for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):    

            banned = []
            last_word, word_len = find_word_beginning(beam_input_ids)
            text = self.tokenizer.decode(beam_input_ids[self.start:])  # start behind end of text token

            all_words = re.sub('[\W_]+', ' ', text).split()
            words = all_words[:-1] #stop before last word that is not final
            num_syll = 0
            if words:
                for word in words:
                    num_syll += len(hyp_dic.inserted(word, ' ').split())

            if num_syll % self.len_metrum == 0:
                next_stress = self.first_stress
            else:
                next_stress = (1-self.first_stress)**2
                
            
            if next_stress == 1:
                target_dict = self.lst_1
            else:
                target_dict = self.lst_0
            
            last_stress = 5
            last_stress_val = []
            first_stress = 5
            try:
                last_stress = self.rythm_df.loc[(self.rythm_df['word'] == all_words[-1])]['end'].values[0]
                first_stress = self.rythm_df.loc[(self.rythm_df['word'] == all_words[-1])]['start'].values[0]
            except:
                pass

            for i in range(len(tokenizer)):
                
                probability = 0
                poss_word = str(list(beam_input_ids[self.start:].cpu().detach().numpy())[-word_len:] + [i])
             
                probability = target_dict[poss_word]
                
                if last_stress == 0 and first_stress == next_stress:
                    probability = self.lst_1_s[i]
                 
                elif last_stress == 1 and first_stress == next_stress:
                    probability = self.lst_0_s[i]
                    
               
                if i == self.new_line and num_syll >= self.len_verse -5 and first_stress == next_stress:
                    probability = 1
                    
                if probability == 0 or i < 33:
                    banned.append(i)

            banned_tokens.append(banned)
            

        num_batch_hypotheses = scores.shape[0]
        cur_len = input_ids.shape[-1]
        banned_word_tokens = _calc_banned_word_ngram_tokens(self.ngram_size_words,
                                                            input_ids,
                                                            num_batch_hypotheses,
                                                            cur_len,
                                                            self.word_start_mask)
        
        banned_token_tokens = _calc_banned_ngram_tokens(self.ngram_size_tokens, 
                                                        input_ids, 
                                                        num_batch_hypotheses, 
                                                        cur_len)


        
        
        
        for i in range (len(banned_tokens)): 
            banned_tokens[i] += banned_word_tokens[i] + banned_token_tokens[i]

        for i, banned_token in enumerate(banned_tokens):
            scores[i, banned_token] = -float("inf")
            
        ##############################################################
        # randomize
        if self.randomize:
     
            scores *= torch.randn(scores.shape)/self.random+1

        ###############################################################

        return scores

In [10]:
def create_verse(prompt, tokenizer,
                       offset = 1,
                       max_length = 24,
                       num_beams = 15,
                       num_return_beams = 14,
                       ngram_size_words = 2,
                       ngram_size_tokens = 4,
                       max_word_len = 4,
                       len_rand = False,
                       randomize = False,
                       random = 4 
                       ):

    if num_beams < num_return_beams:
        print('warning: setting number of return beams equal to number of beams')
        num_return_beams = num_beams
    prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
    prompt_tokenized = prompt_tokenized['input_ids']

    beam_scorer = BeamSearchScorer(
        batch_size = prompt_tokenized.shape[0],
        num_beams = num_beams,
        num_beam_hyps_to_keep = num_return_beams,
        device=model.device
    )

    
    logits_processor = LogitsProcessorList([verse_logits(tokenizer,
                                                        offset = offset,
                                                        ngram_size_words = ngram_size_words,
                                                        ngram_size_tokens = ngram_size_tokens,
                                                        max_word_len = max_word_len,
                                                        len_rand = len_rand,
                                                        randomize = randomize,
                                                        random = random)])

    generated = model.beam_search(
        torch.cat([prompt_tokenized] * num_beams),
        beam_scorer,
        logits_processor = logits_processor,
        stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
    )
    return generated



In [14]:
prompt = '''Nur durch das Morgentor des Schönen
'''

for i in range(4):
    offset = len(tokenizer.encode(prompt))
    len_0 = len(prompt.split('\n'))
    generated = create_verse(prompt,tokenizer,
                                    offset = offset,
                                    max_length = offset + 12,   # number of tokens after which the beam search stops
                                    num_beams = 14,         # number of beams the algorithm will try
                                    num_return_beams = 5,  # total number of beams that will be kept after each step
                                    ngram_size_words = 2,   # maximum number a word n-gram may be repeated
                                    ngram_size_tokens = 4,  # maximum number of token n-gram may be repeaded
                                    max_word_len = 4,       # maximum number of tokens a word may contain
                                    len_rand = False,       # make the maximum number of tokens for a word random 
                                    randomize = True,     # additional sampling of the output
                                    random = 6)            # random value for the output (more is less random)


    for index, output_tokenized in enumerate(generated):
        output = tokenizer.decode(output_tokenized,skip_special_tokens = True)
        print(f'beam {index}: {output}')
    out = tokenizer.decode(generated[0],skip_special_tokens = True).split('\n')[len_0-1:len_0][0]
 
    prompt += out + '\n'

beam 0: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick hinaus 


beam 1: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick herab 


beam 2: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick hinaus 
 halbe
beam 3: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick herab 
 und
beam 4: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick herab 
 Lim
beam 0: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick hinaus 
Heiter auferstanden ist dereinst 


beam 1: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick hinaus 
Heiter auferstanden sind im dunkeln Todtenreich
beam 2: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick hinaus 
Heiter auferstanden sind im dunkeln Todtenb
beam 3: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick hinaus 
Heiter auferstanden sind im dunkeln Todtenkranz
beam 4: Nur durch das Morgentor des Schönen
Winkt ein feuervoller Blick hinaus 
Heiter