<a href="https://colab.research.google.com/github/andjoer/llm_poetry_generation/blob/main/colabs/LLM_alliterations.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Creating alliteration patterns with large language models and beam search

In [None]:
!pip install transformers

In [1]:
from transformers import AutoTokenizer, AutoModelForCausalLM
import numpy as np
import torch
from transformers import BeamSearchScorer, LogitsProcessorList, MaxLengthCriteria, StoppingCriteriaList
import string
import random
import pandas as pd
import pickle
import re

In [2]:
model_name = "Anjoe/german-poetry-gpt2-large"

tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)



In [3]:
def create_word_start_mask(tokenizer):    
    word_start_mask =  np.zeros(len(tokenizer))
    for i in range(len(tokenizer)):
        if tokenizer.decode(i)[0] == ' ':
            word_start_mask[i] = 1
    return word_start_mask
        
def perplexity(text):
    device = model.device
    encodings = tokenizer(text, return_tensors="pt")
    import torch
    from tqdm import tqdm

    max_length = model.config.n_positions
    stride = 512

    nlls = []
    for i in range(0, encodings.input_ids.size(1), stride):
        begin_loc = max(i + stride - max_length, 0)
        end_loc = min(i + stride, encodings.input_ids.size(1))
        trg_len = end_loc - i  # may be different from stride on last loop
        input_ids = encodings.input_ids[:, begin_loc:end_loc].to(device)
        target_ids = input_ids.clone()
        target_ids[:, :-trg_len] = -100

        with torch.no_grad():
            outputs = model(input_ids, labels=target_ids)
            neg_log_likelihood = outputs[0] * trg_len

        nlls.append(neg_log_likelihood)

    return torch.exp(torch.stack(nlls).sum() / end_loc).cpu().detach().numpy()

In [4]:
################################################################################
# Original: https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_logits_process.html     
# Modified so that it works more on a word level. 
# Example "das Denkende" and "das Denken" are the same n-gram. 
################################################################################

def _get_word_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx]
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):

            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]

    return generated_ngrams


def _get_generated_word_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):

    cur_len = len(prev_input_ids)
    start_idx = cur_len + 1 - ngram_size

    ngram_idx = tuple(prev_input_ids[start_idx:cur_len])

    return banned_ngrams.get(ngram_idx, [])


def _calc_banned_word_ngram_tokens(ngram_size: int, prev_input_ids: torch.Tensor, 
                                   num_hypos: int, cur_len: int, word_start_mask: np.array):
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    prev_input_ids = [[item for item in prev_inputs.tolist() if word_start_mask[item]==1] 
                      for prev_inputs in prev_input_ids] # MODIFICATION
    if cur_len + 1 < ngram_size:
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_word_ngrams(ngram_size, prev_input_ids, num_hypos)


    banned_tokens = [
        _get_generated_word_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]

    return banned_tokens


In [5]:
################################################################################
# Original: https://huggingface.co/transformers/v4.6.0/_modules/transformers/generation_logits_process.html     
################################################################################

def _get_ngrams(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int):
    generated_ngrams = [{} for _ in range(num_hypos)]
    for idx in range(num_hypos):
        gen_tokens = prev_input_ids[idx].tolist()
        generated_ngram = generated_ngrams[idx]
        for ngram in zip(*[gen_tokens[i:] for i in range(ngram_size)]):
            prev_ngram_tuple = tuple(ngram[:-1])
            generated_ngram[prev_ngram_tuple] = generated_ngram.get(prev_ngram_tuple, []) + [ngram[-1]]
    return generated_ngrams


def _get_generated_ngrams(banned_ngrams, prev_input_ids, ngram_size, cur_len):
    # Before decoding the next token, prevent decoding of ngrams that have already appeared
    start_idx = cur_len + 1 - ngram_size
    ngram_idx = tuple(prev_input_ids[start_idx:cur_len].tolist())
    return banned_ngrams.get(ngram_idx, [])


def _calc_banned_ngram_tokens(ngram_size: int, prev_input_ids: torch.Tensor, num_hypos: int, cur_len: int):
    """Copied from fairseq for no_repeat_ngram in beam_search"""
    if cur_len + 1 < ngram_size:
        # return no banned tokens if we haven't generated no_repeat_ngram_size tokens yet
        return [[] for _ in range(num_hypos)]

    generated_ngrams = _get_ngrams(ngram_size, prev_input_ids, num_hypos)

    banned_tokens = [
        _get_generated_ngrams(generated_ngrams[hypo_idx], prev_input_ids[hypo_idx], ngram_size, cur_len)
        for hypo_idx in range(num_hypos)
    ]
    return banned_tokens

In [6]:
def find_word_beginning(token_ids):    
    for j in range(1,len(token_ids)):
                possible_beginning = tokenizer.decode(token_ids[-j])
                if possible_beginning[0] == ' ' and possible_beginning.strip():
                    return possible_beginning.strip(),j
                    
    else:
        return False,0
    
 

In [7]:
from transformers import LogitsProcessor
import numpy as np

class alit_logits(LogitsProcessor):
    def __init__(self, tokenizer,
              first_stress = 0,  
              word_beginnings = False,
              ngram_size_words = 2,
              ngram_size_tokens = 4,
              max_word_len = 4,
              len_rand = True,
              nucleus_sampling = False,
              top_p = 0.1,
              len_metrum = 2,
              len_verse = 10):
        
        self.tokenizer = tokenizer
        self.word_beginnings = word_beginnings
        self.ngram_size_words = ngram_size_words
        self.ngram_size_tokens = ngram_size_tokens
        self.word_start_mask = create_word_start_mask(tokenizer)
        self.max_word_len = max_word_len
        self.len_rand = len_rand
        self.nucleus_sampling = nucleus_sampling
        self.top_p = top_p
        self.first_stress = first_stress
        self.len_metrum = len_metrum
        self.delim = tokenizer.encode('.')+ tokenizer.encode(',')
        self.new_line = tokenizer.encode('\n')
        self.len_verse = len_verse
        self.first_stress = first_stress
        
        with open(r'notstressed', 'rb') as f:
            self.lst_0 = pickle.load(f)
    
        with open(r'stressed', 'rb') as f:
            self.lst_1 = pickle.load(f)
    
        with open(r'notstressed_start', 'rb') as f:
            self.lst_0_s = pickle.load(f)
    
        with open(r'stressed_start', 'rb') as f:
            self.lst_1_s = pickle.load(f)
            
        with open(r'stressed_start', 'rb') as f:
            self.lst_1_s = pickle.load(f)
        self.rythm_df = pd.read_csv('word_rythm.csv')
        
    def __call__(self, input_ids, scores):

        banned_tokens = []
    
        for beam_index, (beam_input_ids, beam_scores) in enumerate(zip(input_ids, scores)):    

            banned = []
            last_word, word_len = find_word_beginning(beam_input_ids)
            text = self.tokenizer.decode(beam_input_ids[1:])  # start behind end of text token
            
            words = re.sub('[\W_]+', ' ', text).split()[:-1] #stop before last word that is not final
            num_syll = 0
            if words:
                for word in words:
                    num_syll += self.rythm_df.loc[(self.rythm_df['word'] == word)]['num_syll'].values[0]

            if num_syll % self.len_metrum == 0:
                next_stress = (1-self.first_stress)**2
            else:
                next_stress = self.first_stress
                
            if next_stress == 1:
                target_dict = self.lst_1
                target_dict_s = self.lst_1_s
            else:
                target_dict = self.lst_0
                target_dict_s = self.lst_0_s
                
            if num_syll > self.len_verse - 4:
                keep = self.delim + self.new_line
            else: 
                keep = self.delim
            for i in range(len(tokenizer)):
                
                poss_word = str(list(beam_input_ids.cpu().detach().numpy())[-word_len:] + [i])
                
                probability = max(target_dict[poss_word],
                        target_dict_s[i])
                
                if i in keep and num_syll > 0:
                    probability = 1
                    
                if probability == 0:
                    banned.append(i)
                    
            banned_tokens.append(banned)
            

        num_batch_hypotheses = scores.shape[0]
        cur_len = input_ids.shape[-1]
        banned_word_tokens = _calc_banned_word_ngram_tokens(self.ngram_size_words,
                                                            input_ids,
                                                            num_batch_hypotheses,
                                                            cur_len,
                                                            self.word_start_mask)
        
        banned_token_tokens = _calc_banned_ngram_tokens(self.ngram_size_tokens, 
                                                        input_ids, 
                                                        num_batch_hypotheses, 
                                                        cur_len)


        
        
        
        for i in range (len(banned_tokens)): 
            banned_tokens[i] += banned_word_tokens[i] + banned_token_tokens[i]

        for i, banned_token in enumerate(banned_tokens):
            scores[i, banned_token] = -float("inf")
            
        ##############################################################
        # top p
        if self.nucleus_sampling:
     
            m = torch.nn.Softmax(dim=1)
            scores_sm = m(scores).cpu().detach().numpy()
            for i in range(num_batch_hypotheses):

                scores_sm_sorted = -np.sort(-scores_sm[i,:])

                prob = 0
                j = 0 

                while prob < self.top_p:
                    prob += scores_sm_sorted[i]
                    j += 1

                j = max(j,1)
                candidate_idx = np.argsort(-scores_sm)[0][:j]

                score_idx = random.choice(candidate_idx)

                scores[i,:score_idx] = -float("inf")
                scores[i,score_idx+1:] = -float("inf")

        ###############################################################

        return scores

In [8]:
def check_beginnings(letter,decoded,keep_tokens = ['.'],word_beginnings = True):
    
    if word_beginnings:
        word_beginning = decoded[0] == ' '
    else:
        word_beginning = True
        
    decoded = decoded.strip()
    if decoded != '':
        start_letter = decoded[0].lower() != letter
    else: 
        start_letter = True
    not_alpha =  not decoded.isalpha()
    
    not_keep = decoded not in keep_tokens
    if (word_beginning and start_letter or not_alpha) and not_keep :
        return True
        
    else: 
        return False

In [9]:
def create_block_token_dict(tokenizer,letters,keep_tokens,word_beginnings):
    block_token_dict = {}

    for i, letter in enumerate(letters):
        block_tokens = []
        for j in range(len(tokenizer)):
            decoded = tokenizer.decode(j)
            if check_beginnings(letter,decoded,keep_tokens = keep_tokens,word_beginnings = word_beginnings):
                block_tokens.append(j)
        block_token_dict[letters[i-1]] = [block_tokens]

    if word_beginnings:
        for i, letter in enumerate(letters):
            block_tokens = []
            for j in range(len(tokenizer)):
                decoded = tokenizer.decode(j)
                if check_beginnings(letter,decoded,keep_tokens = keep_tokens,word_beginnings = False):
                    block_tokens.append(j)
            block_token_dict[letters[i-1]].append(block_tokens)
            
    return block_token_dict
    

In [10]:
def create_aliterations(prompt, letters, tokenizer,
                       word_beginnings = False,
                       max_length = 24,
                       num_beams = 15,
                       num_return_beams = 14,
                       ngram_size_words = 2,
                       ngram_size_tokens = 4,
                       max_word_len = 4,
                       len_rand = False,
                       nucleus_sampling = False,
                       top_p = 0.1 
                       ):

    if num_beams < num_return_beams:
        print('warning: setting number of return beams equal to number of beams')
        num_return_beams = num_beams
    prompt_tokenized = tokenizer(prompt, return_tensors='pt' )
    prompt_tokenized = prompt_tokenized['input_ids']

    beam_scorer = BeamSearchScorer(
        batch_size = prompt_tokenized.shape[0],
        num_beams = num_beams,
        num_beam_hyps_to_keep = num_return_beams,
        device=model.device
    )

    
    logits_processor = LogitsProcessorList([alit_logits(tokenizer,
                                                        word_beginnings = word_beginnings,
                                                        ngram_size_words = ngram_size_words,
                                                        ngram_size_tokens = ngram_size_tokens,
                                                        max_word_len = max_word_len,
                                                        len_rand = len_rand,
                                                        nucleus_sampling = nucleus_sampling,
                                                        top_p = top_p)])

    generated = model.beam_search(
        torch.cat([prompt_tokenized] * num_beams),
        beam_scorer,
        logits_processor = logits_processor,
        stopping_criteria=StoppingCriteriaList([MaxLengthCriteria(max_length=max_length)])
    )
    return generated



In [11]:
prompt = '<|endoftext|>'

letters = ['d','a']
generated = create_aliterations(prompt,letters,tokenizer,
                                word_beginnings = True, # only the words should begin with the same character,
                                                        # if False all tokens will begin with the same character
                                max_length = 15,        # number of tokens after which the beam search stops
                                num_beams = 10,         # number of beams the algorithm will try
                                num_return_beams = 5,  # total number of beams that will be kept after each step
                                ngram_size_words = 2,   # maximum number a word n-gram may be repeated
                                ngram_size_tokens = 4,  # maximum number of token n-gram may be repeaded
                                max_word_len = 4,       # maximum number of tokens a word may contain
                                len_rand = False,       # make the maximum number of tokens for a word random 
                                nucleus_sampling = False,# do additional top p sampling
                                top_p = 0.1)            # value vor top p sampling
                                                        # at each step

for index, output_tokenized in enumerate(generated):
  output = tokenizer.decode(output_tokenized,skip_special_tokens = True)
  print(f'beam {index}: {output}')

FileNotFoundError: [Errno 2] No such file or directory: 'notstressed'

In [105]:
rythm_df = pd.read_csv('word_rythm.csv')

words = 'und schön ists zu scherzen und scherzen mit dir und mit den andern'.split()

for word in words: 
    print(rythm_df.loc[(rythm_df['word'] == word)]['num_syll'].values[0])

1
1
1
1
2
1
2
1
1
1
1
1
2


In [162]:
print(rythm_df.loc[(rythm_df['word'] == 'sah')]['rythm'].values[0])

[0.5]


In [138]:
tokenizer.encode('\n')

[199]

In [None]:
[14,12]