In [1]:
import torch
from transformers import GPT2Tokenizer, GPTNeoForCausalLM, GPTNeoModel
from transformers import GPT2Tokenizer, GPT2LMHeadModel
import pronouncing
from transformers import Trainer, TrainingArguments
from tqdm import tqdm
import random

import warnings
warnings.filterwarnings('ignore')

#Download Finetuned GPT-Neo
# Set the random seed to a fixed value to get reproducible results 
torch.manual_seed(42)
tokenizer = GPT2Tokenizer.from_pretrained("EleutherAI/gpt-neo-1.3B", 
                                          bos_token="<|startoftext|>",
                            eos_token="<|endoftext|>",
                            pad_token="<|pad|>")

# Download the pre-trained GPT-Neo model and transfer it to the GPU
model = GPTNeoForCausalLM.from_pretrained("FigoMe/news-gpt-neo-1.3B-keywords-line-by-line-reverse").cuda()
# Resize the token embeddings because we've just added 3 new tokens 
model.resize_token_embeddings(len(tokenizer))

def get_stress(phone):
    stress = []
    for s in phone.split():
        if s[-1].isdigit():
            if s[-1] == '2':
                stress.append(0)
            else:
                stress.append(int(s[-1]))
    return stress

def alternating(stress):
    #Check if the stress and unstress are alternating
    check1 = len(set(stress[::2])) <= 1 and (len(set(stress[1::2])) <= 1)
    check2 = len(set(stress)) == 2 if len(stress) >=2 else True
    return (check1 and check2)

def get_phones(rhyme_word):
    phone = pronouncing.phones_for_word(rhyme_word)[0]
    stress = get_stress(phone)
    p_state = stress[0]
    n_syllables = len(stress)
    return p_state, n_syllables

from torch import Tensor
from torch.nn import functional as F


def top_k_top_p_filtering(
    logits: Tensor,
    top_k: int = 0,
    top_p: float = 1.0,
    filter_value: float = -float("Inf"),
    min_tokens_to_keep: int = 1,
    return_index = False
) -> Tensor:
    """Filter a distribution of logits using top-k and/or nucleus (top-p) filtering
    Args:
        logits: logits distribution shape (batch size, vocabulary size)
        if top_k > 0: keep only top k tokens with highest probability (top-k filtering).
        if top_p < 1.0: keep the top tokens with cumulative probability >= top_p (nucleus filtering).
            Nucleus filtering is described in Holtzman et al. (http://arxiv.org/abs/1904.09751)
        Make sure we keep at least min_tokens_to_keep per batch example in the output
    From: https://gist.github.com/thomwolf/1a5a29f6962089e871b94cbd09daf317
    """
    if top_k > 0:
        top_k = min(max(top_k, min_tokens_to_keep), logits.size(-1))  # Safety check
        # Remove all tokens with a probability less than the last token of the top-k
        indices_to_remove = logits < torch.topk(logits, top_k)[0][..., -1, None]
        indices_keep = logits >= torch.topk(logits, top_k)[0][..., -1, None]
        indices_keep = indices_keep[0].tolist()
        indices_keep = [i for i,x in enumerate(indices_keep) if x == True]
        logits[indices_to_remove] = filter_value

    if top_p < 1.0:
        sorted_logits, sorted_indices = torch.sort(logits, descending=True)
        cumulative_probs = torch.cumsum(F.softmax(sorted_logits, dim=-1), dim=-1)

        # Remove tokens with cumulative probability above the threshold (token with 0 are kept)
        sorted_indices_to_remove = cumulative_probs > top_p
        if min_tokens_to_keep > 1:
            # Keep at least min_tokens_to_keep (set to min_tokens_to_keep-1 because we add the first one below)
            sorted_indices_to_remove[..., :min_tokens_to_keep] = 0
        # Shift the indices to the right to keep also the first token above the threshold
        sorted_indices_to_remove[..., 1:] = sorted_indices_to_remove[..., :-1].clone()
        sorted_indices_to_remove[..., 0] = 0

        # scatter sorted tensors to original indexing
        indices_to_remove = sorted_indices_to_remove.scatter(-1, sorted_indices, sorted_indices_to_remove)
        logits[indices_to_remove] = filter_value
    if return_index == True:
        return logits, indices_keep
    return logits


def reverse_order(line):
    line = line.replace(', ', ' , ')
    words = line.split()
    return ' '.join(reversed(words)).replace(' , ', ', ')


Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [239]:
loose_list = ['that','is','of','the','it','a','as','with','like','go','to','on','in','at','are','and']
def check_either_stress(stress, source_word, loose = True):
    if loose and source_word in loose_list:
        return True
    if len(stress) == 1 and len(pronouncing.phones_for_word(source_word))>1:
                    phone0 = pronouncing.phones_for_word(source_word)[0]
                    phone1 = pronouncing.phones_for_word(source_word)[1]
                    stress0 = [int(s[-1]) for s in phone0.split() if s[-1].isdigit()]
                    stress1 = [int(s[-1]) for s in phone1.split() if s[-1].isdigit()]
                    if stress0+stress1 ==1 and stress0*stress1 == 0:
                        return True

    return False

In [62]:
from transformers import AutoTokenizer, AutoModelForCausalLM
gpt2_tokenizer  = AutoTokenizer.from_pretrained('gpt2-large')
gpt2_model = AutoModelForCausalLM.from_pretrained('gpt2-large')
gpt2_model = gpt2_model.to(device)
gpt2_model.eval()

In [69]:
def regularBeamSearch(prompts):
	'''
	Beam search that considers the coherence by adding a new variable: previously_generated_lines
	'''
	BeamScorer = {}
	for sentence in prompts:
		loss = score_gpt2(sentence)
		BeamScorer[sentence] = [loss]
	answers = sorted(BeamScorer.items(), key=lambda x: x[1], reverse=False)
	new_prompts = [ans[0] for ans in answers]
	return new_prompts

In [67]:
def score_gpt2(sentence, normalize = True):
	'''
	Score a single sentence using the vanilla gpt2 model finetuned on lyrics
	The default setting is to normalize because we won't face the issue mentioned in function "score".
	'''
	tokens_tensor = gpt2_tokenizer.encode(sentence, add_special_tokens=False, return_tensors="pt")[0].cuda()
	with torch.no_grad():
		loss = gpt2_model(tokens_tensor, labels=tokens_tensor)[0]
	if normalize:
		return loss/len(tokens_tensor)
	else:
		return loss

In [30]:
def myBeamSearch(prompts, all_states, all_n_sys, all_keywords, beam_size = 5):
    BeamScorer = {}
    return_seq, return_stt, return_sys, return_key = [], [], [], []
    for sentence, p_state, n_sys, keywords in zip(prompts, all_states, all_n_sys, all_keywords):
        loss = score(sentence)
        BeamScorer[sentence] = [loss, p_state, n_sys, keywords]
    answers = sorted(BeamScorer.items(), key=lambda x: x[1], reverse=True)
    new_prompts = [ans[0] for ans in answers]
    new_p_states = [ans[1][1] for ans in answers]
    new_n_sys = [ans[1][2] for ans in answers]
    new_keywords = [ans[1][3] for ans in answers]
    l = len(new_prompts)
    if l > beam_size:
        return_seq += new_prompts[0:beam_size]
        return_stt += new_p_states[0:beam_size]
        return_sys += new_n_sys[0:beam_size]
        return_key += new_keywords[0:beam_size]
    else:
        return_seq +=new_prompts
        return_stt += new_p_states
        return_sys += new_n_sys
        return_key += new_keywords
    return return_seq,return_stt, return_sys, return_key


In [188]:
def generate_next_word(input_ids1, temperature = 0.85, topk = 100, n_sample=10, device = 'cuda:0'):
    current_word = 0
    original = tokenizer.decode(input_ids1[0])
    for _ in range(1):
        outputs1 = model(input_ids1)
        #print(outputs1)
        next_token_logits1 = outputs1[0][:, -1, :]
        next_token_logits1 = top_k_top_p_filtering(next_token_logits1, top_k=topk)
        logit_zeros = torch.zeros(len(next_token_logits1)).cuda()
        #logit_zeros = torch.zeros(len(next_token_logits1), device=device)

        next_token_logits = next_token_logits1 * temperature
        probs = F.softmax(next_token_logits, dim=-1)
        next_tokens = torch.multinomial(probs, num_samples=n_sample).squeeze(1)
        #unfinished_sents = torch.ones(1, dtype=torch.long, device=device)
        unfinished_sents = torch.ones(1, dtype=torch.long).cuda()
        tokens_to_add = next_tokens * unfinished_sents + tokenizer.pad_token_id * (1 - unfinished_sents)

        temp = []
        for i in range(len(input_ids1)):
            temp +=[torch.cat([input_ids1[i].reshape(1,-1), token_to_add.reshape(1,-1)], dim=-1) for token_to_add in tokens_to_add[i]]
        input_ids1 = torch.stack(temp).view(len(temp),-1)
        # decode the generated token ids to natural words
        results = []
        input_ids1_l = []
        for input_id1 in input_ids1:
            gen = tokenizer.decode(input_id1).replace(original,'').strip(' ')
            if len(gen.split()) >0:
                gen = gen.split()[0]
                gen = gen.lower()
                if gen not in results:
                    results.append(gen)
        return results
        '''
        if tokenizer.decode(tokens_to_add[0])[0] == ' ':
            if current_word ==1:
                return tokenizer.decode(input_ids1[0]).split()[-1], False
            current_word += 1
        input_ids1 = torch.cat([input_ids1, tokens_to_add.unsqueeze(-1)], dim=-1)
        '''

In [3]:
def score(sentence, normalize = True):
	'''
	Score a single sentence using the plan-to-lyrics model.
	The recommended setting is to NOT normalize, because the input sentence is very long: it contains the title, planed keywords, and previously generated lines. 
	In addition, the candidate sentences contain the same prefix (i.e., the title, planed keywords, and previously generated lines) and only differ in the currently generated line.
	Normaling means dividing the loss by a large factor which may result in similarity accross different candidate sentences.
	'''
	tokens_tensor = tokenizer.encode(sentence, add_special_tokens=False, return_tensors="pt")[0].cuda()
	with torch.no_grad():
		loss = model(tokens_tensor, labels=tokens_tensor)[0]
	if normalize:
		return loss/len(tokens_tensor)
	else:
		return loss

In [244]:
single_character_word = ['i','a']
forbidden_words = ['dona','er','ira','ia',"'s","'m","hmm","mm"]
def get_valid_samples(prompt, p_state, n_syllables, keywords, n_sample=30, n_cands=5):
    #if n_syllables == 10 or n_syllables==11:
    if n_syllables == 10:
        return [prompt], [p_state], [n_syllables], [keywords]
    elif n_syllables > 10:
        return [], [], [],[]
    states = []
    all_n_syl = []
    
    prompts = []
    all_keywords= [] 
    #insert the keyword whenever possible
    for source_word in keywords:
        phone = pronouncing.phones_for_word(source_word)[0]
        stress = get_stress(phone)
        if not alternating(stress):
            continue

        #if the word is single syllable and can be either stressed or unstressed, flag = True
        flag = check_either_stress(stress, source_word)

        if stress[-1] == 1- p_state or flag:
            #print(source_word)
            states.append(stress[0])
            all_n_syl.append(n_syllables+len(stress))
            prompts.append(prompt+ ' ' + source_word )
            copy = keywords.copy()
            copy.remove(source_word)
            all_keywords.append(copy)    
    
    #The normal process of decoding
    input_ids = tokenizer(prompt, return_tensors='pt').input_ids.cuda()
    tokens = generate_next_word(input_ids, n_sample=n_sample)
    #print(tokens)
    for token in tokens:
        token = token.lower()
        if (len(token) == 1 and token not in single_character_word) or token in forbidden_words:
            continue
        if token not in prompt:
            try:
                phone = pronouncing.phones_for_word(token)[0]
                stress = get_stress(phone)
            except:
                continue
            if (not alternating(stress)) or (len(stress)==0):
                continue

            #if the word is single syllable and can be either stressed or unstressed, flag = True
            flag = check_either_stress(stress, token)

            if (stress[-1] == 1- p_state) or flag:
                tokens.append(token)
                if stress[-1] == 1- p_state:
                    states.append(stress[0])
                elif flag:
                    states.append(1- p_state)
                all_n_syl.append(n_syllables+len(stress))
                prompts.append(prompt+ ' ' + token )
                all_keywords.append(keywords)
                if len(prompts)>= n_cands:
                    return prompts, states, all_n_syl, all_keywords
    return prompts, states, all_n_syl, all_keywords

In [2]:
four_seasons_story_line = [
['snow', 'falling', 'future'],
['winter', 'is', 'coming'],
['gather', 'honest', 'humor'],
['spring', 'happy', 'blooming'],
['air', 'heat', 'warm'],
['little', 'birds', 'may'],
['flowers', 'leaves', 'storm'],
['summer','moon', 'day'],
['blue', 'sky', 'clouds'],
['sudden', 'rain', 'thunder'],
['Summer', 'fill', 'crowds'],
['Spring', 'no', 'wonder'],
['seasons','years', 'keep'],
['future', 'months', 'reap']]


In [245]:
example_title = 'Four Seasons'
beam_size=20
previous = ''
enforce_keywords = False
for kws in tqdm(four_seasons_story_line):
    success=False
    n_sample = 30
    while success != True:
        print(kws)
        rhyme_word = kws[-1]
        prefix =  '''Keywords: ''' + '; '.join(kws) +'. Sentence in reverse order: '
        prompt = '''<|startoftext|> Title: ''' + example_title + ' ' + previous + prefix + rhyme_word
        p_state, n_syllables = get_phones(rhyme_word)
        result_list = []
        i=0
        prompts, all_states, all_n_sys, all_keywords = get_valid_samples(prompt,p_state, n_syllables, keywords = kws[:2], n_sample=n_sample,n_cands=5)
        while i<7:
            print(i)
            new_prompts, new_states, new_n_sys, new_keywords = [], [], [], []
            for prompt, p_state, n_syllables, keyword in zip(prompts, all_states, all_n_sys, all_keywords):
                t_p, t_state, t_sys, t_keywords = get_valid_samples(prompt, p_state, n_syllables, keyword,n_sample=n_sample)
                new_prompts+=t_p
                new_states+=t_state
                new_n_sys+=t_sys
                new_keywords+=t_keywords
            prompts, all_states, all_n_sys, all_keywords = new_prompts, new_states, new_n_sys, new_keywords

            prompts, all_states, all_n_sys, all_keywords = myBeamSearch(prompts,all_states, all_n_sys, all_keywords, beam_size=beam_size)
            i += 1
        correct_prompts = [reverse_order(p.split('order: ')[1]) for p in prompts]
        result_list = regularBeamSearch(correct_prompts)
        print(result_list)
        if len(result_list)!=0:
            success=True
            found = False
            if enforce_keywords:
                for r in result_list:
                    if kws[0] in r and kws[1] in r:
                        previous = previous + r + ','
                        found = True
                        break
            if found == False:
                for r in result_list:
                    if kws[0] in r or kws[1] in r:
                        previous = previous + r + ','
                        found = True
                        break
            if found == False:
                previous = previous + result_list[0]+','
                n_sample = n_sample*3


  0%|          | 0/14 [00:00<?, ?it/s]

['snow', 'falling', 'future']
0
1
2
3
4
5
6


  7%|▋         | 1/14 [00:06<01:27,  6.69s/it]

['means that falling parents distant future', 'falling melting snow uncertain future', 'clearly falling even better future', 'is and falling something better future', 'snow the falling parents distant future', 'old and falling parents distant future', 'falling soon remote uncertain future', 'game itself is falling better future', 'air itself is falling better future', 'night itself is falling better future', 'sun itself is falling better future', 'school and falling parents distant future', 'cold and falling parents distant future', 'from that falling parents distant future', 'saw that falling parents distant future', 'see that falling parents distant future', 'news and falling parents distant future', 'always falling something better future', 'likely falling parents distant future', 'never falling parents distant future']
['winter', 'is', 'coming']
0
1
2
3
4
5
6


 14%|█▍        | 2/14 [00:14<01:24,  7.07s/it]

['gather', 'honest', 'humor']
0
1
2
3
4
5
6


 21%|██▏       | 3/14 [00:18<01:06,  6.06s/it]

['others gather longer having humor', 'gather never longer having humor', 'honest never longer having humor', 'gather honest always having humor', 'gather seasons longer having humor', 'gather honest longer having humor', 'honest seasons longer having humor', 'honest gather longer having humor', 'many seasons longer having humor', 'being honest longer having humor', 'never honest longer having humor', 'children gather longer having humor', 'always gather honest having humor', 'often gather honest having humor', 'only gather honest having humor', 'better gather longer having humor', 'even gather honest rather humor', 'better gather honest rather humor']
['spring', 'happy', 'blooming']
0
1
2
3
4
5
6


 29%|██▊       | 4/14 [00:24<00:59,  5.94s/it]

['several thousand flowers happy blooming', 'cheerful yellow garden always blooming', 'buzzing head like flowers happy blooming', 'spring like really happy flower blooming', 'spring with happy lovely garden blooming', 'truly really happy flower blooming', 'met with thousand flowers happy blooming', 'colors really rather happy blooming', 'town with happy lovely garden blooming', 'lovely garden rather happy blooming', 'planting season rather happy blooming', 'seven thousand flowers happy blooming', 'always really rather happy blooming', 'happy country lovely garden blooming', 'into country lovely garden blooming', 'spring already lovely garden blooming', 'into happy lovely garden blooming', 'given season very happy blooming', 'season really rather happy blooming', 'empty garden rather happy blooming']
['air', 'heat', 'warm']
0
1
2
3
4
5
6


 36%|███▌      | 5/14 [00:29<00:51,  5.68s/it]

['especially after slowly getting warm', 'with children rated over keeping warm', 'with heat refreshing climate keeping warm', 'already bearing over getting warm', 'like heat eternal purpose keeping warm']
['little', 'birds', 'may']
0
1
2
3
4


 43%|████▎     | 6/14 [00:34<00:41,  5.25s/it]

5
6
['believe however little many may', 'however comfort little many may']
['flowers', 'leaves', 'storm']
0
1
2
3
4
5
6


 50%|█████     | 7/14 [00:39<00:37,  5.37s/it]

['recycled metal pillar cater storm', 'considered normal after flowers storm', 'like flowers metal pillar cater storm', 'like lonely flowers pillar cater storm']
['summer', 'moon', 'day']
0
1
2
3
4
5
6
[]
['summer', 'moon', 'day']
0
1
2
3
4
5
6
[]
['summer', 'moon', 'day']
0
1
2
3
4


 57%|█████▋    | 8/14 [00:50<00:43,  7.17s/it]

5
6
['extremely soothing summer mood like day', 'another humid summer summers day']
['blue', 'sky', 'clouds']
0
1
2
3


 64%|██████▍   | 9/14 [00:56<00:32,  6.55s/it]

4
5
6
['creating new expecting yellow clouds', 'expected blue expecting yellow clouds']
['sudden', 'rain', 'thunder']
0
1
2
3
4
5
6


 71%|███████▏  | 10/14 [00:57<00:20,  5.09s/it]

['lightning started rather sudden thunder', 'almost always pretty sudden thunder', 'maybe only really sudden thunder', 'seeing something really sudden thunder', 'notice something really sudden thunder', 'started only really sudden thunder', 'started almost very sudden thunder', 'rather something really sudden thunder', 'started rather pretty sudden thunder']
['Summer', 'fill', 'crowds']
0
1
2
3


 79%|███████▊  | 11/14 [01:03<00:15,  5.33s/it]

4
5
6
['already closing seasons Summer crowds']
['Spring', 'no', 'wonder']
0
1
2
3
4
5
6


 86%|████████▌ | 12/14 [01:05<00:08,  4.33s/it]

['nearly always very peaceful wonder', 'truly no beginnings only wonder', 'weather always very peaceful wonder', 'very oddly even peaceful wonder', 'always no beginnings only wonder', 'ended oddly even peaceful wonder', 'into Spring beginnings only wonder']
['seasons', 'years', 'keep']
0
1
2
3


 93%|█████████▎| 13/14 [01:10<00:04,  4.37s/it]

4
5
6
['remaining years conditions therefore keep', 'although exact conditions therefore keep']
['future', 'months', 'reap']
0
1
2
3


100%|██████████| 14/14 [01:14<00:00,  5.29s/it]

4
5
6
['intentions threaten even children reap']





In [236]:
print('Not Enforce keywords:')

print(previous.replace(',',',\n'))

Not Enforce keywords:
warmer sun and we falling better future,
clearly mean winter weather always coming,
trying being honest see is humor,
very happy was already blooming,
humid air amazing power staying warm,
sneaky little squirrel maybe even may,
refreshing autumn flowers chilly storm,
summer moon enjoying any special day,
citrus orange sky remember seeing clouds,
notice something really sudden thunder,
peaceful Summer county over quiet crowds,
tourist market farmers also wonder,
seasons old exciting novel title keep,
thrilling news awaited future only reap,



In [246]:
print('Not Enforce keywords:')

print(previous.replace(',',',\n'))

Not Enforce keywords:
means that falling parents distant future,
winter colder air and also coming,
others gather longer having humor,
several thousand flowers happy blooming,
with heat refreshing climate keeping warm,
believe however little many may,
considered normal after flowers storm,
extremely soothing summer mood like day,
expected blue expecting yellow clouds,
lightning started rather sudden thunder,
already closing seasons Summer crowds,
truly no beginnings only wonder,
remaining years conditions therefore keep,
intentions threaten even children reap,



In [220]:
print('Enforce keywords:')

print(previous.replace(',',',\n'))

Enforce keywords:
snow the falling new uncertain future,
there already is and always coming,
gather being honest any humor,
spring amazing looking happy blooming,
brutal winter heat surprises rather warm,
singing many other merry little may,
flowers leaves discovered how unwanted storm,
refreshing sunny early summer day,
colors sky expected fluffy yellow clouds,
rain another sudden rolling thunder,
follows Summer fill expecting festive crowds,
really no whats truly peaceful wonder,
busy your entire seasons parcel keep,
future months potential harvest only reap,



In [197]:
print(previous.replace(',',',\n'))

snow and own falling dark endless future,
follows is been something always coming,
elders never gather honest humor,
sunny spring already happy blooming,



In [189]:
print(previous.replace(',',',\n'))

snow the never falling better future,
winters truly is and really coming,
gather honest very little humor,
happy early spring already blooming,
hunting season almost always rather warm,
lonely fellow other pretty little may,
purple flowers yellow leaves impending storm,
summer harvest moon another special day,
crimson thunder bearing over angry clouds,
sudden heavy rain approaching thunder,
colors Summer fill erupted into crowds,
excitement yet beginnings only wonder,
seasons recent years successful title keep,
many months exciting looking future reap,



In [177]:
print(previous.replace(',',',\n'))

surely snow the falling better future,
winters really is and always coming,
helping people gather honest humor,
happy pretty yellow flowers blooming,
hunting season also slowly getting warm,
started singing even fluffy little may,
purple flowers autumn leaves impending storm,
summer moon another very special day,
sweater only barely hiding any clouds,
sudden heavy rain approaching thunder,
Summer fill already vibrant city crowds,
lonely local corner never wonder,

