<a href="https://colab.research.google.com/github/andjoer/llm_poetry_generation/blob/main/colabs/llm_theo_lutz.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Inspired by Theo Lutz

## Combining large language models and patterns

click the 'Run all' button

<img src = 'https://github.com/andjoer/llm_poetry_generation/blob/main/graphics/colab_en.jpg?raw=true'>

German



<img src = 'https://github.com/andjoer/llm_poetry_generation/blob/main/graphics/colab.jpg?raw=true'>

In [None]:
!pip install transformers

In [2]:
from transformers import GPT2Tokenizer,GPT2LMHeadModel, pipeline
import numpy as np
import torch
import spacy
import functools
import random
import re

  VERSION_SPEC = originalTextFor(_VERSION_SPEC)("specifier")
  MARKER_EXPR = originalTextFor(MARKER_EXPR())("marker")


In [3]:
#!python -m spacy download "de_core_news_lg"

nlp = spacy.load("de_core_news_lg")

## Defining the model to use

In [4]:
gpt2_model = "Anjoe/kant-gpt2-large"

generator = pipeline('text-generation', model=gpt2_model,
                 tokenizer=gpt2_model, framework = 'pt')

tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model)
model = GPT2LMHeadModel.from_pretrained(gpt2_model,pad_token_id = tokenizer.eos_token_id)


In [5]:
def gpt2_generate(input_text,max_length= 5, num_return_sequences=30):
    tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model)
    max_length += tokenizer.encode(input_text,return_tensors='pt').size(1)
    generated = generator(input_text, max_length=max_length,return_full_text = False, num_return_sequences=num_return_sequences)
    
    return [item['generated_text'] for item in generated]


def gpt2_top_k(input_text,max_length = 10,num_return_sequences=30):
    #tokenizer = GPT2Tokenizer.from_pretrained(gpt2_model)
    #model = GPT2LMHeadModel.from_pretrained(gpt2_model,pad_token_id = tokenizer.eos_token_id)
    input_ids = tokenizer.encode(input_text,return_tensors='pt')
    max_length += input_ids.size(1)
    start = input_ids.size()[1]
    output = model.generate(
        input_ids,
        do_sample = True,
        max_length = max_length,
        top_k = num_return_sequences,
        num_return_sequences = num_return_sequences,
        early_stopping = True,
        num_repeat_ngram_size = 2
    )

    return [tokenizer.decode(sample_output[start:],skip_special_tokens=True) for sample_output in output]

In [86]:
def remove_last(possible_tokens, possible_logits,tokenizer,max_word_count):
    
    first = True
    while len(re.sub(r'[^A-Za-zÄÖÜäöüß ]', ' ',tokenizer.decode([tokens[-1]for tokens in possible_tokens])).split()) > max_word_count or first and possible_tokens:  # necessary when words consisting out of many tokens get shortened
        possible_tokens[-1] = possible_tokens[-1][:-1]
        possible_logits[-1] = possible_logits[-1][:-1]

        first = False

        while not possible_tokens[-1] and possible_tokens[0]:
            possible_tokens = possible_tokens[:-1]
            possible_tokens[-1] = possible_tokens[-1][:-1]

            possible_logits = possible_logits[:-1]
            possible_logits[-1] = possible_logits[-1][:-1]

        if not possible_tokens[0]:
            break

        

    return possible_tokens, possible_logits


def gpt_sample_systematic(input,check, pattern,num_return_sequences = 1,loop_limit = 15000, top_p = None,top_k = 1000,top_k_0 = 0, temperature = 0.9,random_first = True, random_all = False, block_non_alpha = True,
                        trunkate_after = 100,repetition_penalty=1.2):

    '''
    builds a stack of possible tokens and filtered by a specific top_p values and goes through all of them
    
    '''
    input = re.sub(r'\s([,.!?;:](?:\s|$))', r'\1', input.strip())
 
    if not top_p:
        top_p = 0.5

    
  
    sm = torch.nn.Softmax(dim = 1)

    possible_tokens = []
    possible_logits = []

    possible_combinations = []
    combination_logits = []

    max_word_count = 3

    inputs = tokenizer(input,return_tensors='pt')['input_ids']

    max_token_count = 5  

    fulfill_pos = False

    possible_end = False
    pos_match_end = False
    depth_lst = []

    last_word_start = 0   # stop building endless words

    condition_length = []
    for condition in pattern.values():
        if type(condition) == list:
            condition_length.append([len(cond) for cond in condition if type(cond) == list])

    with torch.no_grad():
        for i in range(loop_limit): 
   
            if len(possible_tokens) > 0:
                

                while not possible_tokens[-1]:
                    possible_tokens = possible_tokens[:-1]
                    possible_logits = possible_logits[:-1]

                    if not possible_tokens:
                        break

                if not possible_tokens:
                    break

                try:
                    new_tokens =  torch.reshape(torch.IntTensor([tokens[-1] for tokens in possible_tokens]),(1,-1))
              
                except: 

                    print('possible tokens')
                    print(possible_tokens)
                    raise Exception
                input_tokens = torch.cat((inputs,new_tokens),1)

            else: 
                input_tokens = inputs

            depth_lst.append(len(possible_tokens))
         
            outputs = model(input_tokens)

            logits = outputs.logits[:,-1,:]/temperature

            last_token_test = 'test' + tokenizer.decode(torch.argmax(logits))                                   # check if next token contains a space
      
            gen_part = tokenizer.decode([tokens[-1]for tokens in possible_tokens])
            generated = ' '.join(re.sub(r'[^A-Za-zÄÖÜäöüß ]', ' ',gen_part ).strip().split())


            if generated and (len(last_token_test.split()) > 1 or not tokenizer.decode(torch.argmax(logits)).isalpha()):

                fulfill_requirements = True
                possible_end = True
                last_word_start = len(possible_tokens)
                generated_doc = nlp(input + ' ' + generated.strip())[-len(generated.strip().split()):]

                if len(generated.split()[-1]) < 2:
                    fulfill_requirements = False

                check_val, pattern_idx = check(generated_doc)
                if not check_val:
           
                    fulfill_requirements = False
                    possible_end = False

                if pattern_idx > -1:
                    if max([cond[pattern_idx] for cond in condition_length if len(cond) > pattern_idx]) > len(generated_doc):  # this is  workaround since this sampling method was implemented later
                        possible_end = False
           
           
            else:
                if len(possible_tokens) - last_word_start < 5:          # no endless compund tokens without space separation
                    fulfill_requirements = True
                else: 
                    fulfill_requirements = False
                possible_end = False


            if ' '.join(re.sub(r'[^A-Za-zÄÖÜäöüß ]', ' ',gen_part ).strip().split()) != ' '.join(gen_part.strip().split()): 
 
                fulfill_requirements = False
    
            if fulfill_requirements:

                logits_sorted,indices_sorted = torch.sort(logits, descending=True)
                logits_sorted = sm(logits_sorted)
                cum_sum = torch.cumsum(logits_sorted, dim=-1)
                cum_sum[:,0] = 0                  
                                              
                token_inside_top_p = cum_sum <= top_p                                   # keep at least one index


            if len(possible_tokens) >= max_token_count or not fulfill_requirements:  
            
                if len(depth_lst) > trunkate_after*2:                                     # too many repetitions with same trunk -> the trunk could be the problem
                    possible_tokens = possible_tokens[:1]
                    possible_logits = possible_logits[:1]
                    depth_lst = []

                elif len(depth_lst) > trunkate_after:
                    cut = max(int(min(depth_lst[-(trunkate_after+int(trunkate_after*0.8)):])/2),1)
                    if cut == 1:
                        depth_lst = []
                    possible_tokens = possible_tokens[:cut]
                    possible_logits = possible_logits[:cut]

                possible_tokens, possible_logits = remove_last(possible_tokens, possible_logits,tokenizer,max_word_count)
                if len(possible_tokens) == 1:
                    depth_lst = []
                                
            elif fulfill_requirements and possible_end:
                #print(tokenizer.decode([tokens[-1].item() for tokens in possible_tokens]))
                '''print('rythm in generation function')
                print(generated_verse.text)
                print(generated_verse.rythm)
                print(generated_verse.token_pos)'''

                depth_lst = []

                possible_combinations.append([tokens[-1].item() for tokens in possible_tokens])
     
                last_logits_sum = sum([logits[-1].item() for logits in possible_logits])
                combination_logits.append(last_logits_sum)

                possible_tokens, possible_logits = remove_last(possible_tokens, possible_logits,tokenizer,max_word_count)
                
                if not possible_tokens[0] or len(possible_combinations) >= num_return_sequences:
                    break

            elif fulfill_requirements:

                if len(possible_tokens) == 0 and (len(indices_sorted[token_inside_top_p])  < top_k or top_k_0 > 0):

                    indices_filtered = torch.flip(indices_sorted[0,top_k_0:top_k],dims=[-1])    # highest probability last so it gets accessed first
                    logits_filtered = torch.flip(logits_sorted[0,top_k_0:top_k],dims=[-1])           
            
                else:
                    indices_filtered = torch.flip(indices_sorted[token_inside_top_p],dims=[-1])    # highest probability last so it gets accessed first
                    logits_filtered = torch.flip(logits_sorted[token_inside_top_p],dims=[-1])           

                if random_all or (random_first and not possible_tokens):            # without randomness always the same poem would be created from the same prompt
                    all_indices_ran = torch.multinomial(logits_filtered,num_samples = len(logits_filtered))
                    logits_filtered = logits_filtered[all_indices_ran]
                    indices_filtered = indices_filtered[all_indices_ran]
                

                possible_tokens.append(list(indices_filtered))
                possible_logits.append(list(logits_filtered))
            else:
                possible_tokens, possible_logits = remove_last(possible_tokens, possible_logits,tokenizer,max_word_count)

            if not possible_tokens:
                break

            if not possible_tokens[0]:
                break


    return [nlp(tokenizer.decode(combination)) for combination in possible_combinations]

## Define the patterns

A sequence of "nodes" is defined.

In [7]:
patterns_lutz_org = []

#######################################################
#0
patterns_lutz_org.append({'type':'insertion',
                 'text':['eine','jede','keine','nicht jede','ein','jeder','kein','nicht jeder','ein','jedes','kein','nicht jedes'],  
                 'next':1})

#######################################################
#1
patterns_lutz_org.append({'type':'generate',                 # use the language model to produce text
                 'dependent':[[['gender']]],
                 'dependency': -1,
                 'any': False,                         # word can be found anywhere in the text
                 'num_samples':20,                    # number of samples that comply with the criteria
                 'pos':[['NOUN']],     # should be either sequence adjective, noun or only noun
                 'number': [['Sing']],
                 'case':[['Nom']],         # the Noun should be nominative
                 'next':2})                           # next node

            

#######################################################
#2
patterns_lutz_org.append({'type':'insertion',                # insert a predefined sequence 
                 'dependent':[''],              # property of the generated text it depends on
                 'dependency': -1,                    # on which previous node it depends
                 '':['ist'],        # use when the last node is Plural
                 'next':3})                            

#######################################################
#3
patterns_lutz_org.append({'type':'generate',
                 'any': False,
                 'num_samples':8,
                 'pos':[['ADJ']],
                 'next':4})

#######################################################
#4
patterns_lutz_org.append({'type':'insertion',
                 'text':['und','oder','so gilt','.','.','.','.','.'],
                 'next':0})



In [8]:
patterns_lutz_var = []


#######################################################
#0
patterns_lutz_var.append({'type':'generate',                 # use the language model to produce text
                 'any': True,                         # word can be found anywhere in the text
                 'num_samples':20,                    # number of samples that comply with the criteria
                 'pos':[['NOUN']],     # should be either sequence adjective, noun or only noun
                 'number': [['Sing']],
                 'case':[['Nom']],         # the Noun should be nominative
                 'next':1})                           # next node

            
#######################################################
#1
patterns_lutz_var.append({'type':'insertion',
                 'dependent':['gender'],
                 'dependency': -1,
                 'position': 'before',                      # only possible with insertion
                 'Fem':['eine','jede','keine','nicht jede'],  
                 'Masc':['ein','jeder','kein','nicht jeder'],
                 'Neut':['ein','jedes','kein','nicht jedes'],
                 'next':2})

#######################################################
#2
patterns_lutz_var.append({'type':'insertion',                # insert a predefined sequence 
                 'dependent':[''],              # property of the generated text it depends on
                 'dependency': -1,                    # on which previous node it depends
                 '':['ist'],        # use when the last node is Plural
                 'next':3})                            

#######################################################
#3
patterns_lutz_var.append({'type':'generate',
                 'any': False,
                 'num_samples':64,
                 'pos':[['ADJ'],['DET','NOUN']],
                 'case':[[''],['','Nom']],
                 'words':[[],[['ein','eine'],[]]],
                 'next':4})

#######################################################
#4
patterns_lutz_var.append({'type':'insertion',
                 'text':['und','oder','so gilt','.','.','.','.','.'],
                 'next':0})



In [9]:
patterns_1 = []

#######################################################
#0
patterns_1.append({'type':'generate',                 # use the language model to produce text
                 'any': True,                         # word can be found anywhere in the text
                 'num_samples':20,                    # number of samples that comply with the criteria
                 'pos':[['ADJ','NOUN'],['NOUN']],     # should be either sequence adjective, noun or only noun
                 'case':[['','Nom'],['Nom']],         # the Noun should be nominative
                 'add_det': [1,0],                    # add an article in front of the sequence
                 'next':1})                           # next node

#######################################################
#1
patterns_1.append({'type':'insertion',                # insert a predefined sequence 
                 'dependent':['number'],              # property of the generated text it depends on
                 'dependency': -1,                    # on which previous node it depends
                 'Plur':['sind','sind nicht'],        # use when the last node is Plural
                 'Sing':['ist','ist nicht'],          # use when teh last node is Singular
                 'next':2})                            

#######################################################
#2
patterns_1.append({'type':'generate',
                 'any': False,
                 'num_samples':8,
                 'pos':[['ADJ'],['DET','NOUN']],
                 'case':[[''],['','Nom']],
                 'words':[[],[['der','die','das'],[]]],
                 'next':3})

#######################################################
#3
patterns_1.append({'type':'insertion',
                 'text':['denn'],
                 'next':0})



In [10]:
patterns_1b = []


patterns_1b.append({'type':'insertion',
                 'text':['der','die','das'],  
                 'next':1})
#######################################################
#1
patterns_1b.append({'type':'generate',                 # use the language model to produce text
                 'dependent':[['',['gender','number']],[['gender','number']]],
                 'dependency': -1,                       
                 'num_samples':20,                    # number of samples that comply with the criteria
                 'pos':[['ADJ','NOUN'],['NOUN']],     # should be either sequence adjective, noun or only noun
                 'case':[['','Nom'],['Nom']],         # the Noun should be nominative
                 'next':2})                           # next node

#######################################################
#2
patterns_1b.append({'type':'insertion',                # insert a predefined sequence 
                 'dependent':['number'],              # property of the generated text it depends on
                 'dependency': -1,                    # on which previous node it depends
                 'Plur':['sind','sind nicht'],        # use when the last node is Plural
                 'Sing':['ist','ist nicht'],          # use when teh last node is Singular
                 'next':3})                            

#######################################################
#3
patterns_1b.append({'type':'generate',
                 'any': False,
                 'num_samples':8,
                 'pos':[['ADJ'],['DET','NOUN']],
                 'case':[[''],['','Nom']],
                 'words':[[],[['der','die','das'],[]]],
                 'next':4})

#######################################################
#4
patterns_1b.append({'type':'insertion',
                 'text':['denn'],
                 'next':0})



In [11]:
patterns_2 = []

#######################################################
#0
patterns_2.append({'type':'insertion',
                 'text':['ich','du','er','sie','es','wir'], #'ihr' is not corrected labeled by Spacy
                 'next':1})

#######################################################
#1
patterns_2.append({'type':'insertion',
                 'dependent':['number','person'],
                 'dependency': -1,
                 'Plur 1':['dürfen', 'müssen', 'dürfen nicht', 'müssen nicht', 'können', 'können nicht'],
                 'Plur 2':['dürft', 'müsst', 'dürft nicht', 'müsst nicht', 'könnt', 'könnt nicht'],
                 'Plur 3':['dürfen', 'müssen', 'dürfen nicht', 'müssen nicht', 'können', 'können nicht'], 
                 'Sing 1':['darf','muss','darf nicht','muss nicht','kann','kann nicht'],
                 'Sing 2':['darfst','musst','darfst nicht','musst nicht','kannst','kannst nicht'],
                 'Sing 3':['darf','muss','darf nicht','muss nicht','kann','kann nicht'],  
                 'next':2})

#######################################################
#2
patterns_2.append({'type':'generate',
                 'any': False,
                 'num_samples':20,
                 'pos':[['VERB'],['AUX','VERB']],
                 'verbform':[['Inf'],['','Inf']],
                 'next':3})

#######################################################
#3
patterns_2.append({'type':'insertion',
                 'text':['denn'],
                 'next':0})

In [12]:
patterns_3 = []

#######################################################
#0
patterns_3.append({'type':'generate',
                 'any': True,
                 'num_samples':20,
                 'pos':[['ADJ','NOUN'],['NOUN']],
                 'case':[['','Nom'],['Nom']],
                 'add_det': [1,0],
                 'rand_next': [1,4]})               # list of next nodes to choose from randomly

#######################################################
#1
patterns_3.append({'type':'insertion',
                 'dependent':['number'],
                 'dependency': -1,
                 'Plur':['sind','sind nicht'],
                 'Sing':['ist','ist nicht'],
                 'rand_next': [2,5]})

#######################################################
#2
patterns_3.append({'type':'generate',
                 'any': False,
                 'num_samples':8,
                 'pos':[['ADJ'],['DET','NOUN']],
                 'case':[[''],['','Nom']],
                 'words':[[],[['der','die','das'],[]]],
                 'next':3})

#######################################################
#3
patterns_3.append({'type':'insertion',
                 'text':['denn'],
                 'next':0})

#######################################################
#4
patterns_3.append({'type':'generate',
                 'dependent':[[['number','person']],['',['number','person']]],
                 'dependency': -1,
                 'any': False,
                 'num_samples': 100,
                 'pos':[['VERB'],['AUX','VERB']],
                 #'num':[['Sing'],['','Sing']],
                 #'person':[[3],['',3]],
                 'verbform':[['Fin'],['','Fin']],
                 'next':3,
                 'if_failed':1})                       # where to continue if no option is found

#######################################################
#5
patterns_3.append({'type':'generate',
                 'any': False,
                 'num_samples':90,
                 'pos':[['ADJ'],['AUX']],
                 'next':3,
                 'if_failed':2})

In [13]:
patterns_3b = []

#######################################################
#0
patterns_3b.append({'type':'insertion',
                 'text':['der','die','das'],
                 'next':1})
#######################################################
#1
patterns_3b.append({'type':'generate',
                 'dependent':[['',['gender','number']],[['gender','number']]],
                 'dependency': -1,   
                 'any': False,
                 'num_samples':20,
                 'pos':[['ADJ','NOUN'],['NOUN']],
                 'case':[['','Nom'],['Nom']],
                 'rand_next': [2,5]})               # list of next nodes to choose from randomly
#######################################################
#2
patterns_3b.append({'type':'insertion',
                 'dependent':['number'],
                 'dependency': -1,
                 'Plur':['sind','sind nicht'],
                 'Sing':['ist','ist nicht'],
                 'rand_next': [3,6]})

#######################################################
#3
patterns_3b.append({'type':'generate',
                 'any': False,
                 'num_samples':8,
                 'pos':[['ADJ'],['DET','NOUN']],
                 'case':[[''],['','Nom']],
                 'words':[[],[['der','die','das'],[]]],
                 'next':4})

#######################################################
#4
patterns_3b.append({'type':'insertion',
                 'text':['denn'],
                 'next':0})

#######################################################
#5
patterns_3b.append({'type':'generate',
                 'dependent':[[['number','person']],['',['number','person']]],
                 'dependency': -1,
                 'any': False,
                 'num_samples': 100,
                 'pos':[['VERB'],['AUX','VERB']],
                 #'num':[['Sing'],['','Sing']],
                 #'person':[[3],['',3]],
                 'verbform':[['Fin'],['','Fin']],
                 'next':4,
                 'if_failed':2})                       # where to continue if no option is found

#######################################################
#6
patterns_3b.append({'type':'generate',
                 'any': False,
                 'num_samples':90,
                 'pos':[['ADJ']],
                 'next':4,
                 'if_failed':2})

In [79]:
def parse_dep(pattern,found_lst):
    translate_dict = {'number':'Number','person':'Person','gender':'Gender'}

    dependency = pattern['dependency']
    flatten = [item for sublist in pattern['dependent']for item in sublist if item]
    dependencies = list(set([item for sublist in flatten for item in sublist]))
    for dep in dependencies:
        pattern[dep] = []

    for possibility in pattern['dependent']:

        for dep in dependencies:
            pattern[dep].append([])
        for dep_lst in possibility:
            for key in dependencies:

                if key in dep_lst:                
                    try:
                        pattern[key][-1].append(found_lst[dependency]['morph'][-1][translate_dict[key]])
                    except:
                        if key == 'person':
                            pattern[key][-1].append(3)  # a noun has no person tag, but it is 3rd person
                        else:
                            print(key)
                            raise(Exception)
                else: 
                    pattern[key][-1].append('')



    return(pattern)



In [60]:
def check_gender(genders, idx,doc):
    if len(genders) <= idx:
        comp_genders = genders[0]
    else:
        comp_genders = genders[idx]

    if type(comp_genders) != list:
        comp_genders = [comp_genders]
    checked = True
    for i, gender in enumerate(comp_genders[:len(doc)]):

        if gender == '':
            pass
        elif doc[i].morph.to_dict().get('Gender') != gender:
            checked = False

    return checked

def check_pos(pos_tags, idx,doc):
    pos = [item.pos_ for item in doc]

    return pos == pos_tags[idx][:len(pos)]

def check_word(words, idx,doc):
    comp_words = words[idx]
    checked = True
    for i, word in enumerate(comp_words[:len(doc)]):
        if word == []:
            pass
        elif word == doc[i].text:
            checked = False
            
    return checked

def check_case(cases, idx,doc):
    if len(cases) <= idx:
        comp_cases = cases[0]
    else:
        comp_cases = cases[idx]
    checked = True
    for i, case in enumerate(comp_cases[:len(doc)]):

        if case == '':
            pass
        elif doc[i].morph.to_dict().get('Case') != case:
            checked = False

    return checked

def check_num(nums, idx,doc):
    if len(nums) <= idx:
        comp_nums = nums[0]
    else:
        comp_nums = nums[idx]
    checked = True
    for i, num in enumerate(comp_nums[:len(doc)]):

        if num == '':
            pass
        elif doc[i].morph.to_dict().get('Number') != num:
            checked = False
       
    return checked


def check_person(persons, idx,doc):
    if len(persons) <= idx:
        comp_persons = persons[0]
    else:
        comp_persons = persons[idx]
    checked = True
    for i, person in enumerate(comp_persons[:len(doc)]):

        if person == '':
            pass
        elif str(doc[i].morph.to_dict().get('Person')) != str(person):
            checked = False      
    return checked

def check_verbform(verbforms, idx,doc):
    forms = verbforms[idx]
    checked = True
    for i, form in enumerate(forms[:len(doc)]):
        if form == '':
            pass
        elif 'VerbForm' in doc[i].morph.to_dict().keys():
            if doc[i].morph.to_dict().get('VerbForm') != form:
            
                checked = False
        
    return checked

def check_all(conditions,max_idx,doc):
    
    for idx in range(max_idx):
        checked = True
        for condition in conditions:
            checked = checked and condition(idx,doc)
        if checked: 
            return True, idx
        
    return False, -1
        
    
def store_words(doc):
    text = ' '.join([item.text for item in doc])
    dct = {'text':' '.join([item.text for item in doc]),
           'pos':[],'morph':[],'dep':[]}
    
    for word in doc:
        dct['pos'].append(word.pos_)
        dct['dep'].append(word.dep_)
        dct['morph'].append(word.morph.to_dict())
        
    return dct
        

In [16]:
def get_pos_idx(pos_tags,n_gram):

    idx_lst = []
    for i, pos in enumerate(pos_tags):

        if n_gram['pos'] == pos:
            idx_lst.append(i)
            
    return idx_lst

            
def get_criteria_idx(criteria, n_gram):
    poss_idx = np.asarray(criteria[0](n_gram))
    if len(criteria) > 1: 
        for criterium in criteria[1:]: 
            poss_idx = np.settdif1d(poss_idx,criterium(n_gram))
    try:
        return poss_idx[0]
    except:
        return 0
        

In [62]:
def get_cond_len(pattern):
    condition_length = []
    for condition in pattern.values():
        if type(condition) == list:
            condition_length.append([len(cond) for cond in condition if type(cond) == list])
    return condition_length

def get_any(generated,prompt,pattern,check,lengths,min_return = 5):
    found_words = []
    condition_length = get_cond_len(pattern)
    for sent in generated:
            doc = nlp(prompt.strip() + ' ' + sent.strip())[-len(sent.strip().split()):]
            for length in lengths:            
                for j in range(len(doc)):   
                    n_gram = doc[j:j+length]
                    check_val, pattern_idx = check(n_gram)
                    if check_val:
                        if max([cond[pattern_idx] for cond in condition_length if len(cond) > pattern_idx]) == len(n_gram):
                            found_words.append(n_gram)
                        
                        if len(found_words) > min_return: 
                            return found_words
                        
    return found_words
    
def get_first(generated,prompt,pattern,check,lengths, min_return = 5):
    found_words = []
    condition_length = get_cond_len(pattern)
    for sent in generated:
            doc = nlp(prompt.strip() + ' ' + sent.strip())[-len(sent.strip().split()):]
            for length in lengths:            
                n_gram = doc[:length]
                check_val, pattern_idx = check(n_gram)
                if check_val:
                    if max([cond[pattern_idx] for cond in condition_length if len(cond) > pattern_idx]) == len(n_gram):
                        found_words.append(n_gram)
                        
                    if len(found_words) > min_return: 
                        return found_words
                    
    return found_words


def process_generative_pattern(pattern, prompt,found_lst):
    anywhere = False
    if 'any' in pattern.keys():
        if pattern['any'] == True:
            anywhere = True
            
    if 'num_samples' in pattern.keys():
        num_samples = pattern['num_samples']
    else:
        num_samples = 5
        
    num_gpt_samples = 30
    
    if num_samples > num_gpt_samples:
        num_gpt_samples = num_samples

    conditions = []
    criteria = []
    lengths = []

    ''' if 'dependent' in pattern.keys():
        dependency = pattern['dependency']
        if 'number' in pattern['dependent']:                
            conditions.append(functools.partial(check_num,[[found_lst[dependency]['morph'][-1]['Number']]]))

        if 'person' in pattern['dependent']:                
            conditions.append(functools.partial(check_num,[[found_lst[dependency]['morph'][-1]['Person']]]))

        if 'gender' in pattern['dependent']:                
            conditions.append(functools.partial(check_gender,[[found_lst[dependency]['morph'][-1]['Gender']]]))
    '''

    if 'pos' in pattern.keys():
        max_idx = len(pattern['pos'])
        conditions.append(functools.partial(check_pos,pattern['pos']))
        criteria.append(functools.partial(get_pos_idx,pattern['pos']))
        lengths += [len(item) for item in pattern['pos']]
        
    if 'words' in pattern.keys():
        conditions.append(functools.partial(check_word,pattern['words']))

    if 'number' in pattern.keys():
        conditions.append(functools.partial(check_num,pattern['number']))

    if 'person' in pattern.keys():                
        conditions.append(functools.partial(check_person,pattern['person']))

    if 'gender' in pattern.keys():                
        conditions.append(functools.partial(check_gender,pattern['gender']))

    if 'case' in pattern.keys():                
        conditions.append(functools.partial(check_case,pattern['case']))
        
    if 'verbform' in pattern.keys():
        conditions.append(functools.partial(check_verbform,pattern['verbform']))
        
   
    lengths = list(set(lengths))
    
    check = functools.partial(check_all,conditions,max_idx)
    get_idx = functools.partial(get_criteria_idx,criteria)
        
    if anywhere: 
        generated = gpt2_generate(prompt,num_return_sequences=num_gpt_samples)
        found_words = get_any(generated,prompt,pattern,check,lengths, min_return = num_samples)
        
        if not found_words:
            return '', 0
        word = random.choice(found_words)    
        
    else: 

        found_words = gpt_sample_systematic(prompt,check,pattern)
        if not found_words:
            return '', 0
        word = found_words[0]
        
        
    word = store_words(word)

    return word, get_idx(word)
    
        
                        

In [73]:
article_dict = {'Fem':'die','Masc':'der','Neut':'das'}

def generate_patterns(prompt,patterns,loops,stop_pattern,print_loops=False):
    next_pattern = 0
    found_lst = []
    count = 0
    while count < loops:
        
        pattern = patterns[next_pattern]
        before = False
        if 'position' in pattern.keys():
          if pattern['position'] == 'before':
            before = True

            
        if pattern['type'] == 'generate':
            if 'dependent' in pattern.keys():
                pattern = parse_dep(pattern,found_lst)

            found_word, index = process_generative_pattern(pattern,prompt,found_lst)
            
            if not found_word:
                success = False
                if not 'if_failed' in pattern.keys():
                    print(prompt)
                    
                    print('''no completion found, please try again by pressing the play button
                    of this cell''')

                    return prompt
            else:
                success = True

            if 'add_det' in pattern.keys() and success:
                dependency = pattern['add_det'][index]
                if pattern['add_det'][dependency] > - 1:
                    
                    gender = found_word['morph'][dependency]['Gender']
                    number = found_word['morph'][dependency]['Number']
                 
                    if number == 'Plur':
                        article = 'die'

                    else: 
                        article = article_dict[gender]

                    found_word['text'] = article + ' ' + found_word['text']
        else:
            success = True
            if 'dependent' in pattern.keys():
                query = ''
                dependency = pattern['dependency']
                if 'number' in pattern['dependent']:                
                    query += found_lst[dependency]['morph'][-1]['Number']
                if 'person' in pattern['dependent']:
                    query += ' ' + str(found_lst[dependency]['morph'][-1]['Person'])

                if 'gender' in pattern['dependent']:                
                    query += found_lst[dependency]['morph'][-1]['Gender']


                word = random.choice(pattern[query])

            else: 
                word = random.choice(pattern['text'])

            word_size = len(nlp(word))
            doc = nlp(prompt + ' ' + word)[-word_size:]
            found_word = store_words(doc)


        if success: 
            if found_lst:                                 # not the first loop
                if before:
                  prompt_lst = prompt.split()
                  prompt = ' '.join(prompt_lst[:-1] + [found_word['text'].strip()] + [prompt_lst[-1]])
                  found_lst = [found_word] + found_lst
                else: 
                  prompt += ' ' + found_word['text'].strip()
                  found_lst.append(found_word)
            else:
                prompt = found_word['text'].strip()
                found_lst.append(found_word)       

            if 'rand_next' in pattern.keys():
                next_pattern = random.choice(pattern['rand_next'])
            else: 
                next_pattern = pattern['next']
        else: 
            next_pattern = pattern['if_failed']
            
        if next_pattern == stop_pattern: 
            count +=1
        if print_loops:
          print(prompt)
    return prompt
    

In [21]:
prompt = '<|endoftext|>.'

## Generating the pattern

In [87]:
number_patterns = 4
print(generate_patterns(prompt, patterns_3,number_patterns,stop_pattern = 3,print_loops=True))

Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


die Gründe
die Gründe belegen
die Gründe belegen denn


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


die Gründe belegen denn die Erfahrung
die Gründe belegen denn die Erfahrung ruht
die Gründe belegen denn die Erfahrung ruht denn


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


die Gründe belegen denn die Erfahrung ruht denn das freyes
die Gründe belegen denn die Erfahrung ruht denn das freyes standhafte
die Gründe belegen denn die Erfahrung ruht denn das freyes standhafte denn


Setting `pad_token_id` to `eos_token_id`:50256 for open-end generation.


die Gründe belegen denn die Erfahrung ruht denn das freyes standhafte denn die Art
die Gründe belegen denn die Erfahrung ruht denn das freyes standhafte denn die Art neigt
die Gründe belegen denn die Erfahrung ruht denn das freyes standhafte denn die Art neigt


## The original scheme of Theo Lutz. 
Dependent on the corpus the model is trained on, the sequence of "ist" followed by an adjective is not too common. So you might need to run this cell a few times. Unfortunately the Kant model will almost always return "absolute" as adjective, if the model is required to follow the "ist". That is why the option "any" is set to True. Therefore the results are a bit random.

In [93]:
number_patterns = 5
print(generate_patterns(prompt, patterns_lutz_org,number_patterns,stop_pattern = 4,print_loops=True))

jede
jede Merkwürdigkeit
jede Merkwürdigkeit ist
jede Merkwürdigkeit ist besondere
jede Merkwürdigkeit ist besondere .
jede Merkwürdigkeit ist besondere . jede
jede Merkwürdigkeit ist besondere . jede Zueignung
jede Merkwürdigkeit ist besondere . jede Zueignung ist
jede Merkwürdigkeit ist besondere . jede Zueignung ist eigener
jede Merkwürdigkeit ist besondere . jede Zueignung ist eigener .
jede Merkwürdigkeit ist besondere . jede Zueignung ist eigener . nicht jede
jede Merkwürdigkeit ist besondere . jede Zueignung ist eigener . nicht jede Launigkeit
jede Merkwürdigkeit ist besondere . jede Zueignung ist eigener . nicht jede Launigkeit ist
jede Merkwürdigkeit ist besondere . jede Zueignung ist eigener . nicht jede Launigkeit ist begründeter
jede Merkwürdigkeit ist besondere . jede Zueignung ist eigener . nicht jede Launigkeit ist begründeter .
jede Merkwürdigkeit ist besondere . jede Zueignung ist eigener . nicht jede Launigkeit ist begründeter . nicht jede
jede Merkwürdigkeit ist beso

## Create multiple lines

In [None]:
for k in range(10):
    loop = random.choice([2,3])
    loop = 2
    print(generate_patterns(prompt,patterns_2,loop,stop_pattern = 3))