<a href="https://colab.research.google.com/github/ansfarooq7/l4-project/blob/main/prototypes/L4_Project_third.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Imports

In [45]:
!pip install transformers
!pip install pronouncing
!pip install wikipedia
!pip install syllables
!pip install aitextgen

from transformers import RobertaTokenizer, RobertaForMaskedLM, pipeline, GPT2Tokenizer, GPT2LMHeadModel
import torch
import pronouncing
import wikipedia
import re
import random
import nltk
import syllables
nltk.download('cmudict')

[nltk_data] Downloading package cmudict to /root/nltk_data...
[nltk_data]   Package cmudict is already up-to-date!


True

In [46]:
masked_tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
masked_model = RobertaForMaskedLM.from_pretrained('roberta-base')

causal_tokenizer = GPT2Tokenizer.from_pretrained("gpt2")
# add the EOS token as PAD token to avoid warnings
causal_model = GPT2LMHeadModel.from_pretrained("gpt2", pad_token_id=causal_tokenizer.eos_token_id)

loading file https://huggingface.co/roberta-base/resolve/main/vocab.json from cache at /root/.cache/huggingface/transformers/d3ccdbfeb9aaa747ef20432d4976c32ee3fa69663b379deb253ccfce2bb1fdc5.d67d6b367eb24ab43b08ad55e014cf254076934f71d832bbab9ad35644a375ab
loading file https://huggingface.co/roberta-base/resolve/main/merges.txt from cache at /root/.cache/huggingface/transformers/cafdecc90fcab17011e12ac813dd574b4b3fea39da6dd817813efa010262ff3f.5d12962c5ee615a4c803841266e9c3be9a691a924f72d395d3a6c6c81157788b
loading file https://huggingface.co/roberta-base/resolve/main/added_tokens.json from cache at None
loading file https://huggingface.co/roberta-base/resolve/main/special_tokens_map.json from cache at None
loading file https://huggingface.co/roberta-base/resolve/main/tokenizer_config.json from cache at None
loading file https://huggingface.co/roberta-base/resolve/main/tokenizer.json from cache at /root/.cache/huggingface/transformers/d53fc0fa09b8342651efd4073d75e19617b3e51287c2a535becda5

## Helper functions

In [47]:
frequent_words = set()

def set_seed(seed: int):
    """
    Helper function for reproducible behavior to set the seed in ``random``, ``numpy``, ``torch`` and/or ``tf`` (if
    installed).

    Args:
        seed (:obj:`int`): The seed to set.
    """
    #random.seed(seed)
    #np.random.seed(seed)
    #if is_torch_available():
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
        # ^^ safe to call this function even if cuda is not available
    #if is_tf_available():
        #tf.random.set_seed(seed)
        
with open("wordFrequency.txt", 'r') as f:
    line = f.readline()
    while line != '':  # The EOF char is an empty string
        frequent_words.add(line.strip())
        line = f.readline()

def filter_rhymes(word):
    filter_list = ['to', 'on', 'has', 'but', 'the', 'in', 'and', 'a', 'aitch', 'angst', 'arugula', 'beige', 'blitzed', 'boing', 'bombed', 'cairn', 'chaos', 'chocolate', 'circle', 'circus', 'cleansed', 'coif', 'cusp', 'doth', 'else', 'eth', 'fiends', 'film', 'flange', 'fourths', 'grilse', 'gulf', 'kiln', 'loge', 'midst', 'month', 'music', 'neutron', 'ninja', 'oblige', 'oink', 'opus', 'orange', 'pint', 'plagued', 'plankton', 'plinth', 'poem', 'poet', 'purple', 'quaich', 'rhythm', 'rouged', 'silver', 'siren', 'soldier', 'sylph', 'thesp', 'toilet', 'torsk', 'tufts', 'waltzed', 'wasp', 'wharves', 'width', 'woman', 'yttrium'] 
    if word in filter_list:
        return False
    else:
        return True

def remove_punctuation(text):
    text = re.sub(r'[^\w\s]', '', text)
    return text

def get_rhymes(inp, level):
    entries = nltk.corpus.cmudict.entries()
    syllables = [(word, syl) for word, syl in entries if word == inp]
    rhymes = []
    filtered_rhymes = set()
    for (word, syllable) in syllables:
        rhymes += [word for word, pron in entries if pron[-level:] == syllable[-level:]]
    
    for word in rhymes:
        if (word in frequent_words) and (word != inp):
            filtered_rhymes.add(word)
    return filtered_rhymes

def get_inputs_length(input):
    input_ids = causal_tokenizer(input)['input_ids']
    return len(input_ids)

## RoBERTa

In [48]:
set_seed(0)
    
def get_prediction(sent):
    
    token_ids = masked_tokenizer.encode(sent, return_tensors='pt')
    masked_position = (token_ids.squeeze() == masked_tokenizer.mask_token_id).nonzero()
    masked_pos = [mask.item() for mask in masked_position ]

    with torch.no_grad():
        output = masked_model(token_ids)

    last_hidden_state = output[0].squeeze()

    list_of_list =[]
    for index,mask_index in enumerate(masked_pos):
        words = []
        while not words:
            mask_hidden_state = last_hidden_state[mask_index]
            idx = torch.topk(mask_hidden_state, k=5, dim=0)[1]
            for i in idx:
                word = masked_tokenizer.decode(i.item()).strip()
                if (remove_punctuation(word) != "") and (word != '</s>'):
                    words.append(word)
            #words = [masked_tokenizer.decode(i.item()).strip() for i in idx]
        list_of_list.append(words)
        print(f"Mask {index+1} Guesses: {words}")
    
    best_guess = ""
    for j in list_of_list:
        best_guess = best_guess+" "+j[0]
        
    return best_guess

In [49]:
sentence = "Manchester United ___ ___ ___ team"
print(f"Original Sentence: {sentence}")
if sentence[-1] != ".":
    sentence = sentence.replace("___","<mask>") + "."
else:
    sentence = sentence.replace("___","<mask>")
print(f"Original Sentence replaced with mask: {sentence}")
print("\n")

predicted_blanks = get_prediction(sentence)
print(f"\nBest guess for fill in the blanks: {predicted_blanks}")

Original Sentence: Manchester United ___ ___ ___ team
Original Sentence replaced with mask: Manchester United <mask> <mask> <mask> team.


Mask 1 Guesses: ["'s", 'vs', 'and', 'as']
Mask 2 Guesses: ['the', 'their', 'a', 'The']
Mask 3 Guesses: ['first', 'football', 'best', 'national', 'B']

Best guess for fill in the blanks:  's the first


## GPT-2 + RoBERTa

In [50]:
text_generation = pipeline("text-generation", model=causal_model, tokenizer=causal_tokenizer)
from aitextgen import aitextgen

# Without any parameters, aitextgen() will download, cache, and load the 124M GPT-2 "small" model
ai = aitextgen()


def get_line(prompt, inputs_len):
    line = ai.generate_one(prompt=prompt + ".", max_length=inputs_len + 7)[len(prompt)+2:]
    return line

def get_rhyming_line(prompt, rhyming_word, inputs_len):
    gpt2_sentence = ai.generate_one(prompt=prompt + ".", max_length=inputs_len + 4)[len(prompt)+2:]
    print(f"\nGetting rhyming line starting with '{gpt2_sentence}' and ending with rhyming word '{rhyming_word}'")
    sentence = gpt2_sentence + " ___ ___ ___ " + rhyming_word
    print(f"Original Sentence: {sentence}")
    if sentence[-1] != ".":
        sentence = sentence.replace("___","<mask>") + "."
    else:
        sentence = sentence.replace("___","<mask>")
    print(f"Original Sentence replaced with mask: {sentence}")
    print("\n")
 
    predicted_blanks = get_prediction(sentence)
    print(f"\nBest guess for fill in the blanks: {predicted_blanks}")
    final_sentence = gpt2_sentence + predicted_blanks + " " + rhyming_word
    print(f"Final Sentence: {final_sentence}")
    return final_sentence

## Limerick generation

In [51]:
def generate(topic):

    limericks = []

    topic_summary = remove_punctuation(wikipedia.summary(topic))
    word_list = topic_summary.split()
    topic_summary_len = len(topic_summary)
    no_of_words = len(word_list)
    inputs_len = get_inputs_length(topic_summary)
    print(f"Topic Summary: {topic_summary}")
    print(f"Topic Summary Length: {topic_summary_len}")
    print(f"No of Words in Summary: {no_of_words}")
    print(f"Length of Input IDs: {inputs_len}")           

    for i in range(1):
        print(f"\nGenerating limerick {i+1}")
        rhyming_words_125 = []
        while len(rhyming_words_125) < 3 or valid_rhyme == False or len(first_line) == 0:
            first_line = get_line(topic_summary, inputs_len)
            if first_line:
                end_word = remove_punctuation(first_line.split()[-1])
                valid_rhyme = filter_rhymes(end_word)
                if valid_rhyme:
                    print(f"\nFirst Line: {first_line}")
                    rhyming_words_125 = list(get_rhymes(end_word, 3))
                    print(f"Rhyming words for '{end_word}' are {rhyming_words_125}")
                    limerick = first_line + "\n"

        rhyming_word = rhyming_words_125[0]
        second_line = get_rhyming_line(topic_summary, rhyming_word, inputs_len)
        print(f"\nSecond Line: {second_line}")
        limerick += second_line + "\n"

        rhyming_words_34 = []
        while len(rhyming_words_34) < 2 or valid_rhyme == False or len(third_line) == 0:
            third_line = get_line(topic_summary, inputs_len)
            if third_line:
                print(f"\nThird Line: {third_line}")
                end_word = remove_punctuation(third_line.split()[-1])
                valid_rhyme = filter_rhymes(end_word)
                print(f"Does '{end_word}'' have valid rhymes: {valid_rhyme}")
                rhyming_words_34 = list(get_rhymes(end_word, 3))
                print(f"Rhyming words for '{end_word}' are {rhyming_words_34}")
                if valid_rhyme and len(rhyming_words_34) > 1:
                    limerick += third_line + "\n"

        rhyming_word = rhyming_words_34[0]
        fourth_line = get_rhyming_line(topic_summary, rhyming_word, inputs_len)
        print(f"\nFourth Line: {fourth_line}")
        limerick += fourth_line + "\n"

        rhyming_word = rhyming_words_125[1]
        fifth_line = get_rhyming_line(topic_summary, rhyming_word, inputs_len)
        print(f"\nFifth Line: {fifth_line}")
        limerick += fifth_line + "\n"

        limericks.append(limerick)

    print("\n")
    output = f"Generated {len(limericks)} limericks: \n"

    print(f"Generated {len(limericks)} limericks: \n")
    for limerick in limericks:
        print(limerick)
        output += "\n" + limerick

    return output

In [58]:
topic = input("Enter topic: ")
generate(topic)

Enter topic: manchester united
Topic Summary: Manchester United Football Club is a professional football club based in Old Trafford Greater Manchester England that competes in the Premier League the top flight of English football Nicknamed the Red Devils the club was founded as Newton Heath LYR Football Club in 1878 but changed its name to Manchester United in 1902 The club moved from Newton Heath to its current stadium Old Trafford in 1910
Manchester United have won the most trophies in English club football including a record 20 League titles 12 FA Cups five League Cups and a record 21 FA Community Shields They have won the European CupUEFA Champions League three times and the UEFA Europa League the UEFA Cup Winners Cup the UEFA Super Cup the Intercontinental Cup and the FIFA Club World Cup once each In 1968 under the management of Matt Busby 10 years after eight of the clubs players were killed in the Munich air disaster they became the first English club to win the European Cup Ale

"Generated 1 limericks: \n\nManchester United's current owners\nManchester United have released political prisoners\nThe Glazer family has been\nManchester United win a round robin\nFrom 2012 the government opened four centers\n"