<a href="https://colab.research.google.com/github/andjoer/llm_poetry_generation/blob/main/colabs/llm_theo_lutz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inspired by Theo Lutz

## Combining large language models and patterns

click the 'Run all' button

<img src = 'https://github.com/andjoer/llm_poetry_generation/blob/main/graphics/colab_en.jpg?raw=true'>

German



<img src = 'https://github.com/andjoer/llm_poetry_generation/blob/main/graphics/colab.jpg?raw=true'>

In [None]:
!pip install transformers

In [None]:
from transformers import GPT2Tokenizer,GPT2LMHeadModel, pipeline
import numpy as np
import torch
import spacy
import functools
import random


In [None]:
!python -m spacy download "de_core_news_lg"

nlp = spacy.load("de_core_news_lg")

## Defining the model to use

In [None]:
gpt2_model = "Anjoe/kant-gpt2-large"

generator = pipeline('text-generation', model=gpt2_model,
                 tokenizer=gpt2_model, framework = 'pt')

Downloading:   0%|          | 0.00/1.02k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.92G [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/795 [00:01<?, ?B/s]

Downloading:   0%|          | 0.00/848k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/515k [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/2.14M [00:00<?, ?B/s]

Downloading:   0%|          | 0.00/387 [00:00<?, ?B/s]

In [None]:
def gpt2_generate(input_text,max_length= 5, num_return_sequences=30):
    tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model)
    max_length += tokenizer.encode(input_text,return_tensors='pt').size(1)
    generated = generator(input_text, max_length=max_length,return_full_text = False, num_return_sequences=num_return_sequences)
    
    return [item['generated_text'] for item in generated]


def gpt2_top_k(input_text,max_length = 10,num_return_sequences=30):
    tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model)
    model = GPT2LMHeadModel.from_pretrained(gpt2_model,pad_token_id = tokenizer.eos_token_id)
    input_ids = tokenizer.encode(input_text,return_tensors='pt')
    max_length += input_ids.size(1)
    start = input_ids.size()[1]
    output = model.generate(
        input_ids,
        do_sample = True,
        max_length = max_length,
        top_k = num_return_sequences,
        num_return_sequences = num_return_sequences,
        early_stopping = True,
        num_repeat_ngram_size = 2
    )

    return [tokenizer.decode(sample_output[start:],skip_special_tokens=True) for sample_output in output]

## Define the patterns

A sequence of "nodes" is defined.

In [None]:
patterns_1 = []

#######################################################
#0
patterns_1.append({'type':'generate',                 # use the language model to produce text
                 'any': True,                         # word can be found anywhere in the text
                 'num_samples':20,                    # number of samples that comply with the criteria
                 'pos':[['ADJ','NOUN'],['NOUN']],     # should be either sequence adjective, noun or only noun
                 'case':[['','Nom'],['Nom']],         # the Noun should be nominative
                 'add_det': [1,0],                    # add an article in front of the sequence
                 'next':1})                           # next node

#######################################################
#1
patterns_1.append({'type':'insertion',                # insert a predefined sequence 
                 'dependent':['number'],              # property of the generated text it depends on
                 'dependency': -1,                    # on which previous node it depends
                 'Plur':['sind','sind nicht'],        # use when the last node is Plural
                 'Sing':['ist','ist nicht'],          # use when teh last node is Singular
                 'next':2})                            

#######################################################
#2
patterns_1.append({'type':'generate',
                 'any': False,
                 'num_samples':8,
                 'pos':[['ADJ'],['DET','NOUN']],
                 'case':[[''],['','Nom']],
                 'words':[[],[['der','die','das'],[]]],
                 'next':3})

#######################################################
#3
patterns_1.append({'type':'insertion',
                 'text':['denn'],
                 'next':0})



In [None]:
patterns_2 = []

#######################################################
#0
patterns_2.append({'type':'insertion',
                 'text':['ich','du','er','sie','es','wir'], #'ihr' is not corrected labeled by Spacy
                 'next':1})

#######################################################
#1
patterns_2.append({'type':'insertion',
                 'dependent':['number','person'],
                 'dependency': -1,
                 'Plur 1':['dürfen', 'müssen', 'dürfen nicht', 'müssen nicht', 'können', 'können nicht'],
                 'Plur 2':['dürft', 'müsst', 'dürft nicht', 'müsst nicht', 'könnt', 'könnt nicht'],
                 'Plur 3':['dürfen', 'müssen', 'dürfen nicht', 'müssen nicht', 'können', 'können nicht'], 
                 'Sing 1':['darf','muss','darf nicht','muss nicht','kann','kann nicht'],
                 'Sing 2':['darfst','musst','darfst nicht','musst nicht','kannst','kannst nicht'],
                 'Sing 3':['darf','muss','darf nicht','muss nicht','kann','kann nicht'],  
                 'next':2})

#######################################################
#2
patterns_2.append({'type':'generate',
                 'any': False,
                 'num_samples':20,
                 'pos':[['VERB'],['AUX','VERB']],
                 'verbform':[['Inf'],['','Inf']],
                 'next':3})

#######################################################
#3
patterns_2.append({'type':'insertion',
                 'text':['denn'],
                 'next':0})

In [None]:
patterns_3 = []

#######################################################
#0
patterns_3.append({'type':'generate',
                 'any': True,
                 'num_samples':20,
                 'pos':[['ADJ','NOUN'],['NOUN']],
                 'case':[['','Nom'],['Nom']],
                 'add_det': [1,0],
                 'rand_next': [1,4]})               # list of next nodes to choose from randomly

#######################################################
#1
patterns_3.append({'type':'insertion',
                 'dependent':['number'],
                 'dependency': -1,
                 'Plur':['sind','sind nicht'],
                 'Sing':['ist','ist nicht'],
                 'rand_next': [2,5]})

#######################################################
#2
patterns_3.append({'type':'generate',
                 'any': False,
                 'num_samples':8,
                 'pos':[['ADJ'],['DET','NOUN']],
                 'case':[[''],['','Nom']],
                 'words':[[],[['der','die','das'],[]]],
                 'next':3})

#######################################################
#3
patterns_3.append({'type':'insertion',
                 'text':['denn'],
                 'next':0})

#######################################################
#4
patterns_3.append({'type':'generate',
                 'any': False,
                 'num_samples': 100,
                 'pos':[['VERB'],['AUX','VERB']],
                 'verbform':[['Inf'],['','Inf']],
                 'next':3,
                 'if_failed':1})                       # where to continue if no option is found

#######################################################
#5
patterns_3.append({'type':'generate',
                 'any': False,
                 'num_samples':90,
                 'pos':[['ADJ'],['AUX']],
                 'next':3,
                 'if_failed':2})

In [None]:
def check_pos(pos_tags, idx,doc):
    pos = [item.pos_ for item in doc]
    return pos == pos_tags[idx]

def check_word(words, idx,doc):
    comp_words = words[idx]
    checked = False
    for i, word in enumerate(comp_words):
        if word == []:
            checked = True
        elif word == doc[i].text:
            checked = True
            
    return checked

def check_case(cases, idx,doc):
    comp_cases = cases[idx]
    checked = False
    for i, case in enumerate(comp_cases):
        if case == '':
            checked = True
        elif doc[i].morph.to_dict()['Case'] == case:
            checked = True
            
    return checked

def check_verbform(verbforms, idx,doc):
    forms = verbforms[idx]
    checked = False
    for i, form in enumerate(forms):
        if form == '':
            checked = True
        elif 'VerbForm' in doc[i].morph.to_dict().keys():
            if doc[i].morph.to_dict()['VerbForm'] == form:
            
                checked = True
            
    return checked

def check_all(conditions,max_idx,doc):
    
    for idx in range(max_idx):
        checked = True
        for condition in conditions:
            checked = checked and condition(idx,doc)
        if checked: 
            return True
        
    return False
        
    
def store_words(doc):
    text = ' '.join([item.text for item in doc])
    dct = {'text':' '.join([item.text for item in doc]),
           'pos':[],'morph':[],'dep':[]}
    
    for word in doc:
        dct['pos'].append(word.pos_)
        dct['dep'].append(word.dep_)
        dct['morph'].append(word.morph.to_dict())
        
    return dct
        

In [None]:
def get_pos_idx(pos_tags,n_gram):
    idx_lst = []
    for i, pos in enumerate(pos_tags):

        if n_gram['pos'] == pos:
            idx_lst.append(i)
            
    return idx_lst

            
def get_criteria_idx(criteria, n_gram):
    poss_idx = np.asarray(criteria[0](n_gram))
    if len(criteria) > 1: 
        for criterium in criteria[1:]: 
            poss_idx = np.settdif1d(poss_idx,criterium(n_gram))
        
    return poss_idx[0]
        

In [None]:
def get_any(generated,check,lengths,min_return = 5):
    found_words = []
    for sent in generated:
            doc = nlp(sent)
            for length in lengths:            
                for j in range(len(doc)-length):
                    
                    n_gram = doc[j:j+length]

                    if check(n_gram):
                        found_words.append(store_words(n_gram))
                        
                        if len(found_words) > min_return: 
                            return found_words
                        
    return found_words
    
def get_first(generated,check,lengths, min_return = 5):
    found_words = []
    for sent in generated:
            doc = nlp(sent.strip())
            for length in lengths:            
                n_gram = doc[:length]
                if check(n_gram):
                    found_words.append(store_words(n_gram))
                        
                    if len(found_words) > min_return: 
                        return found_words
                    
    return found_words



def process_generative_pattern(pattern, prompt):
    anywhere = False
    if 'any' in pattern.keys():
        if pattern['any'] == True:
            anywhere = True
            
    if 'num_samples' in pattern.keys():
        num_samples = pattern['num_samples']
    else:
        num_samples = 5
        
    num_gpt_samples = 30
    
    if num_samples > num_gpt_samples:
        num_gpt_samples = num_samples
        
    if anywhere: 
        generated = gpt2_generate(prompt,num_return_sequences=num_gpt_samples)
    else: 
        generated = gpt2_top_k(prompt,num_return_sequences=num_gpt_samples)

    conditions = []
    criteria = []
    lengths = []
    if 'pos' in pattern.keys():
        max_idx = len(pattern['pos'])
        conditions.append(functools.partial(check_pos,pattern['pos']))
        criteria.append(functools.partial(get_pos_idx,pattern['pos']))
        lengths += [len(item) for item in pattern['pos']]
        
    if 'words' in pattern.keys():
        conditions.append(functools.partial(check_word,pattern['words']))
        
    if 'verbform' in pattern.keys():
        conditions.append(functools.partial(check_verbform,pattern['verbform']))
        
   
    lengths = list(set(lengths))
    
    check = functools.partial(check_all,conditions,max_idx)
    get_idx = functools.partial(get_criteria_idx,criteria)
    
    
        
    if anywhere:
        found_words = get_any(generated,check,lengths, min_return = num_samples)
        
    else:
        found_words = get_first(generated,check,lengths, min_return = num_samples)
        
    if not found_words:
        return '', 0
    
    word = random.choice(found_words)
    return word, get_idx(word)
    
        
                        

In [None]:
article_dict = {'Fem':'die','Masc':'der','Neut':'das'}

def generate_patterns(prompt,patterns,loops,stop_pattern):
    next_pattern = 0
    found_lst = []
    count = 0
    while count < loops:
        
        pattern = patterns[next_pattern]
        if pattern['type'] == 'generate':
            found_word, index = process_generative_pattern(pattern,prompt)
            
            if not found_word:
                success = False
                if not 'if_failed' in pattern.keys():
                    print(pattern)
                    return prompt
            else:
                success = True

            if 'add_det' in pattern.keys() and success:
                dependency = pattern['add_det'][index]
                if pattern['add_det'][dependency] > - 1:

                    gender = found_word['morph'][dependency]['Gender']
                    number = found_word['morph'][dependency]['Number']

                    if number == 'Plur':
                        article = 'die'

                    else: 
                        article = article_dict[gender]

                    found_word['text'] = article + ' ' + found_word['text']
        else:
            success = True
            if 'dependent' in pattern.keys():
                query = ''
                dependency = pattern['dependency']
                if 'number' in pattern['dependent']:                
                    query += found_lst[dependency]['morph'][-1]['Number']
                if 'person' in pattern['dependent']:
                    query += ' ' + str(found_lst[dependency]['morph'][-1]['Person'])

                word = random.choice(pattern[query])

            else: 
                word = random.choice(pattern['text'])

            word_size = len(nlp(word))
            doc = nlp(prompt + ' ' + word)[-word_size:]
            found_word = store_words(doc)


        if success: 
            if found_lst:                                 # not the first loop
                prompt += ' ' + found_word['text'].strip()
            else:
                prompt = found_word['text'].strip()
            found_lst.append(found_word)       

            if 'rand_next' in pattern.keys():
                next_pattern = random.choice(pattern['rand_next'])
            else: 
                next_pattern = pattern['next']
        else: 
            next_pattern = pattern['if_failed']
            
        if next_pattern == stop_pattern: 
            count +=1
        
    return prompt
    

In [None]:
prompt = '<|endoftext|>.'

In [None]:
print(generate_patterns(prompt, patterns_3,2,stop_pattern = 3))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.
Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


der Verstand ist die wesentlicheEinheit denn das Verhältnis ist die Materie


## Create multiple lines

In [None]:
for k in range(100):
    loop = random.choice([2,3])
    loop = 2
    print(generate_patterns(prompt,patterns_2,loop,stop_pattern = 3))