<a href="https://colab.research.google.com/github/andjoer/llm_poetry_generation/blob/main/colabs/LLM_alliterations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Creating alliteration patterns with large language models and beam search

In [None]:
!pip install transformers

In [None]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import torch
from transformers import BeamSearchScorer, LogitsProcessorList, MaxLengthCriteria, StoppingCriteriaList
import string
import random

In [None]:
model_name = "Anjoe/german-poetry-gpt2-large"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)



In [None]:
def create_word_start_mask(tokenizer):    
    word_start_mask =  np.zeros(len(tokenizer))
    for i in range(len(tokenizer)):
        if tokenizer.decode(i)[0] == ' ':
            word_start_mask[i] = 1
    return word_start_mask
        
def perplexity(text):
    device = model.device
    encodings = tokenizer(text, return_tensors="pt")
    import torch
    from tqdm import tqdm

    max_length = model.config.n_positions
    stride = 512

    nlls = []
    for i in range(0, encodings.input_ids.size(1), stride):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.input_ids.size(1))
        trg_len = end_loc - i  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs[0] * trg_len

        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / end_loc).cpu().detach().numpy()

In [None]:
################################################################################
# Original: https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_logits_process.html     
# Modified so that it works more on a word level. 
# Example "das Denkende" and "das Denken" are the same n-gram. 
################################################################################

def _get_word_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx]
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):

            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    return generated_ngrams


def _get_generated_word_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):

    cur_len = len(prev_input_ids)
    start_idx = cur_len + 1 - ngram_size

    ngram_idx = tuple(prev_input_ids[start_idx:cur_len])

    return banned_ngrams.get(ngram_idx, [])


def _calc_banned_word_ngram_tokens(ngram_size: int, prev_input_ids: torch.Tensor, 
                                   num_hypos: int, cur_len: int, word_start_mask: np.array):
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    prev_input_ids = [[item for item in prev_inputs.tolist() if word_start_mask[item]==1] 
                      for prev_inputs in prev_input_ids] # MODIFICATION
    if cur_len + 1 < ngram_size:
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_word_ngrams(ngram_size, prev_input_ids, num_hypos)


    banned_tokens = [
        _get_generated_word_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]

    return banned_tokens


In [None]:
################################################################################
# Original: https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_logits_process.html     
################################################################################

def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams


def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    # Before decoding the next token, prevent decoding of ngrams that have already appeared
    start_idx = cur_len + 1 - ngram_size
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    return banned_ngrams.get(ngram_idx, [])


def _calc_banned_ngram_tokens(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int):
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

    banned_tokens = [
        _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens

In [None]:
def find_word_beginning(token_ids):    
    for j in range(1,len(token_ids)):
                possible_beginning = tokenizer.decode(token_ids[-j])
                if possible_beginning[0] == ' ' and possible_beginning.strip():
                    return possible_beginning.strip(),j
                    
    else:
        return tokenizer.decode(token_ids[0]).strip(),len(token_ids)

In [None]:
from transformers import LogitsProcessor
import numpy as np

class alit_logits(LogitsProcessor):
    def __init__(self, tokenizer,block_token_dict, 
              word_beginnings = False,
              ngram_size_words = 2,
              ngram_size_tokens = 4,
              max_word_len = 4,
              len_rand = True):
        self.tokenizer = tokenizer
        self.block_token_dict = block_token_dict
        self.word_beginnings = word_beginnings
        self.ngram_size_words = ngram_size_words
        self.ngram_size_tokens = ngram_size_tokens
        self.word_start_mask = create_word_start_mask(tokenizer)
        self.max_word_len = max_word_len
        self.len_rand = len_rand
    
    def __call__(self, input_ids, scores):
        
        banned_tokens = []
    
        for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):    

            banned = []
            
            if len(beam_input_ids) > 0:
                if self.word_beginnings:
                    last_word, word_len = find_word_beginning(beam_input_ids)
                    last_letter = last_word[0]
                else:
                    last_letter = self.tokenizer.decode(beam_input_ids[-1])[0]
                    _, word_len =  find_word_beginning(beam_input_ids)

            else: 
                last_letter = list(block_token_dict)[0]

            last_letter = last_letter.lower()

            if last_letter in list(self.block_token_dict):
                banned = self.block_token_dict[last_letter][0]
            else: 
                if self.word_beginnings: 
                    banned = (list(self.block_token_dict.values())[0][1])
                else: 
                    banned = (list(self.block_token_dict.values())[0][0])
                    
            if self.len_rand:              
                max_word_len = random.randint(1,self.max_word_len)
                
            else: 
                max_word_len = self.max_word_len
                
            if max_word_len: 
                if word_len >= max_word_len: 
                    banned += list(np.where(self.word_start_mask == 0)[0])  
                    
            banned_tokens.append(banned)
            

        num_batch_hypotheses = scores.shape[0]
        cur_len = input_ids.shape[-1]
        banned_word_tokens = _calc_banned_word_ngram_tokens(self.ngram_size_words,
                                                            input_ids,
                                                            num_batch_hypotheses,
                                                            cur_len,
                                                            self.word_start_mask)
        
        banned_token_tokens = _calc_banned_ngram_tokens(self.ngram_size_tokens, 
                                                        input_ids, 
                                                        num_batch_hypotheses, 
                                                        cur_len)


        
        
        for i in range (len(banned_tokens)): 
            banned_tokens[i] += banned_word_tokens[i] + banned_token_tokens[i]

        for i, banned_token in enumerate(banned_tokens):
            scores[i, banned_token] = -float("inf")

        return scores

In [None]:
def check_beginnings(letter,decoded,keep_tokens = ['.'],word_beginnings = True):
    
    if word_beginnings:
        word_beginning = decoded[0] == ' '
    else:
        word_beginning = True
        
    decoded = decoded.strip()
    if decoded != '':
        start_letter = decoded[0].lower() != letter
    else: 
        start_letter = True
    not_alpha =  not decoded.isalpha()
    
    not_keep = decoded not in keep_tokens
    if (word_beginning and start_letter or not_alpha) and not_keep :
        return True
        
    else: 
        return False

In [None]:
def create_block_token_dict(tokenizer,letters,keep_tokens,word_beginnings):
    block_token_dict = {}

    for i, letter in enumerate(letters):
        block_tokens = []
        for j in range(len(tokenizer)):
            decoded = tokenizer.decode(j)
            if check_beginnings(letter,decoded,keep_tokens = keep_tokens,word_beginnings = word_beginnings):
                block_tokens.append(j)
        block_token_dict[letters[i-1]] = [block_tokens]

    if word_beginnings:
        for i, letter in enumerate(letters):
            block_tokens = []
            for j in range(len(tokenizer)):
                decoded = tokenizer.decode(j)
                if check_beginnings(letter,decoded,keep_tokens = keep_tokens,word_beginnings = False):
                    block_tokens.append(j)
            block_token_dict[letters[i-1]].append(block_tokens)
            
    return block_token_dict
    

In [None]:
def create_aliterations(prompt, letters, tokenizer,
                       keep_tokens = [],
                       word_beginnings = False,
                       max_length = 24,
                       num_beams = 15,
                       num_return_beams = 14,
                       ngram_size_words = 2,
                       ngram_size_tokens = 4,
                       max_word_len = 4,
                       len_rand = False
                       ):

    if num_beams < num_return_beams:
        print('warning: setting number of return beams equal to number of beams')
        num_return_beams = num_beams
    block_token_dict = create_block_token_dict(tokenizer,letters,keep_tokens,word_beginnings)
    prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
    prompt_tokenized = prompt_tokenized['input_ids']

    beam_scorer = BeamSearchScorer(
        batch_size = prompt_tokenized.shape[0],
        num_beams = num_beams,
        num_beam_hyps_to_keep = num_return_beams,
        device=model.device
    )

    
    logits_processor = LogitsProcessorList([alit_logits(tokenizer,
                                                        block_token_dict,
                                                        word_beginnings = word_beginnings,
                                                        ngram_size_words = ngram_size_words,
                                                        ngram_size_tokens = ngram_size_tokens,
                                                        max_word_len = max_word_len,
                                                        len_rand = len_rand)])

    generated = model.beam_search(
        torch.cat([prompt_tokenized] * num_beams),
        beam_scorer,
        logits_processor = logits_processor,
        stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
    )
    return generated



In [None]:
prompt = '<|endoftext|>'

letters = ['d','a']
generated = create_aliterations(prompt,letters,tokenizer,
                                word_beginnings = True, # only the words should begin with the same character,
                                                        # if False all tokens will begin with the same character
                                max_length = 28,        # number of tokens after which the beam search stops
                                num_beams = 30,         # number of beams the algorithm will try
                                num_return_beams = 18,  # total number of beams that will be kept after each step
                                ngram_size_words = 2,   # maximum number a word n-gram may be repeated
                                ngram_size_tokens = 4,  # maximum number of token n-gram may be repeaded
                                max_word_len = 4,       # maximum number of tokens a word may contain
                                len_rand = False)       # make the maximum number of tokens for a word random 
                                                        # at each step

for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized,skip_special_tokens = True)
  print(f'beam {index}: {output}')

beam 0:  dämmernd aus der Abendröte des Abends des achten Dorfes an dem alten Dorf Anger dort an den alten Dörfern Anger Dörfer Anger
beam 1:  dämmernd aus der Abendröte des Abends des achten Dorfes an dem alten Dorf Anger de Anger da Anger Der Anger dient als
beam 2:  dämmernd aus der Abendröte des Abends des achten Dorfes an dem alten Dorf Anger de Anger da Anger Der Anger diente als
beam 3:  dämmernd aus der Abendröte des Abends des achten Dorfes an dem alten Dorf Anger dort an den alten deutschen Adel deutscher Adels
beam 4:  dämmernd aus der Abendröte des Abends des achten Dorfes an dem alten Dorf Anger dort an den alten Dörfern Anger daran Anger
beam 5:  dämmernd aus der Abendröte des Abends des achten Dorfes an dem alten Dorf Anger dort an den Anger dieses alte deutsche Anger
beam 6:  dämmernd aus der Abendröte des Abends des achten Dorfes an dem alten Dorf Anger dort an den alten Dörfern Anger dran Anger
beam 7:  dämmernd aus der Abendröte des Abends des achten Dorfes an dem al

In [None]:
%mkdir logs

In [None]:
import glob
import re

files = glob.glob("logs/*.log")
max_idx = 0
for file in files: 
    max_idx = max(int(re.findall(r'\d+', file)[0]),max_idx)

start_idx = max_idx + 1

#letter_selection = ['a','b','d','e','f','g','h','i','j','k','l','m','n','o','s','t','u']

max_num_letters = 2
letter_selection = [[]]*max_num_letters
letter_selection[0] = ['a','d','e','i','o','s','u']
if max_num_letters > 1:
    letter_selection[1] = letter_selection[0]        # option to use different sets for different positions

prompt = '<|endoftext|>'


num_beams_range = [15,30]
num_return_beams = 15
max_length_range = [12,35]


for iteration in range(5):
    save_text = model_name+'\n'
    final_output = []
    for _ in range(3):
        num_letters = random.randint(1,max_num_letters)
        letters = []
        for i in range(num_letters):
            letters.append(random.choice(letter_selection[i]))

        num_beams = random.randint(*num_beams_range)
        max_length = random.randint(*max_length_range)
 
        generated = create_aliterations(prompt,letters,tokenizer,
                                word_beginnings = True,                           
                                max_length = max_length,        
                                num_beams = num_beams,        
                                num_return_beams = num_return_beams,  
                                ngram_size_words = 2,   
                                ngram_size_tokens = 4,  
                                max_word_len = 4,      
                                len_rand = False)       

        perplexities = []
        output_tmp = []

        print('calc perp')
        for j, output_tokenized in enumerate(generated):
            output = tokenizer.decode(output_tokenized,skip_special_tokens = True).strip()
            perpl = perplexity(output+'.')
            perplexities.append(perpl)
            save_text += output +' //perplexity: ' + str(perpl) +'\n' 
            output_tmp.append(output)

        if min(perplexities) < 400:
            choice_idx = np.argmin(np.asarray(perplexities))
            final_output.append(output_tmp[choice_idx])

    if final_output:
        save_text += '\n*** final output ***\n'

        for output in final_output:
            save_text += output + '\n\n'

        save_text += '***'

        with open('logs/poem_' + str(iteration+start_idx)+'.log', 'w') as f:
                     f.write(save_text)
            