In [1]:
import numpy as np
import matplotlib.pyplot as plt
import itertools
import nltk
from nltk.corpus import cmudict
nltk.download('cmudict')
import pickle
from HMM_sol import *

[nltk_data] Error loading cmudict: <urlopen error [SSL:
[nltk_data]     CERTIFICATE_VERIFY_FAILED] certificate verify failed
[nltk_data]     (_ssl.c:749)>


In [2]:
# Load in pickled preprocessed data

file = open('../data/sonnets.pkl', 'rb')
sonnets = pickle.load(file)
file.close()

file = open('../data/shakespeare_dics.pkl', 'rb')
dic_to_ids, dic_to_syl, ids_to_dic, syl_to_dic = pickle.load(file)
file.close()

file = open('../data/dic_to_meter.pkl', 'rb')
dic_to_meter = pickle.load(file)
file.close()

file = open('../data/meter_to_dic.pkl', 'rb')
meter_to_dic = pickle.load(file)
file.close()

file = open('../data/shakespeare_rhymes.pkl', 'rb')
rhyme_data_raw = pickle.load(file)
file.close()

# Update rhyming word set to only include words that rhyme, but also end on a stressed syllable
# (these will go at the end of lines)
temp_rhymes = []

for rhyme_set in rhyme_data_raw:
    temp_rhyme_set = []
    for word in rhyme_set:
        if int(dic_to_meter[word][-1]) == 1:
            temp_rhyme_set.append(word)
    temp_rhymes.append(temp_rhyme_set)

rhyme_data_raw = temp_rhymes


In [3]:
def generate_emission_sequential(num_syl, stress_pattern, dic_to_syl, dic_to_ids, dic_to_meter, A, O):
    temp = np.array(O)
    
    emission = []
    states = []
    syllable_count = 0

    # choose starting state
    y_i = np.random.choice(range(len(O)), p=[1 / len(O) for i in range(len(O))])
    
    while syllable_count < num_syl:
        states.append(y_i)
        
        # choose a word index
        array = range(len(O[y_i]))
        observation_index = np.random.choice(array, p=O[y_i])
    
        num_syl_ok = syllable_count + dic_to_syl[ids_to_dic[observation_index]] <= num_syl
        stress_ok = num_syl_ok
        if (len(stress_pattern) != 0) and num_syl_ok:
            actual_stress = dic_to_meter[ids_to_dic[observation_index]].split(', ')
            if type(actual_stress) != list:
                actual_stress = [actual_stress]
            for i in range(len(actual_stress)):
                stress_ok = stress_ok and (int(dic_to_meter[ids_to_dic[observation_index]].split(', ')[i]) == int(stress_pattern[i]))
            
        # Make sure we don't start off with too many syllables, or don't ruin the stress
        while (not num_syl_ok or not stress_ok):
            observation_index = np.random.choice(array, p=O[y_i])

            num_syl_ok = syllable_count + dic_to_syl[ids_to_dic[observation_index]] <= num_syl
            stress_ok = num_syl_ok
            if (len(stress_pattern) != 0) and num_syl_ok:
                actual_stress = dic_to_meter[ids_to_dic[observation_index]].split(', ')
                if type(actual_stress) != list:
                    actual_stress = [actual_stress]
                for i in range(len(actual_stress)):
                    stress_ok = stress_ok and (int(dic_to_meter[ids_to_dic[observation_index]].split(', ')[i]) == int(stress_pattern[i]))

                
        word = ids_to_dic[observation_index]

        syllable_count += dic_to_syl[word]
        emission.append(dic_to_ids[word])
        stress_pattern = stress_pattern[dic_to_syl[word]:]
        
        # change states
        array = range(len(A[y_i]))
        y_i = np.random.choice(array, p=A[y_i])

    return emission, states


In [4]:
def generate_haiku_test(test_HMM, stress_pattern, syllable_pattern, hmm_name):
    poem = []
    
    assert len(stress_pattern) == len(syllable_pattern)
    
    for index in range(len(stress_pattern)):
        temp_syl = syllable_pattern[index]
        temp_stress = stress_pattern[index]
        
        assert (len(temp_stress) == 0) or (temp_syl == len(temp_stress))
        
        emission, states = generate_emission_sequential(temp_syl, temp_stress, dic_to_syl, dic_to_ids, dic_to_meter, test_HMM.A, test_HMM.O)
        x = ' '.join([str(ids_to_dic[i]) for i in emission])
        poem.append(x)
    
    with open('../models/general/{}_sample.txt'.format(hmm_name), 'w+') as f:
        for line in poem:
            f.write(line)
            f.write("\n")

        f.write("\n")

        for line in poem:
            line_syl = []

            line_list = line.split()

            for word in line_list:
                line_syl.append(dic_to_syl[word])
            f.write("{} {}".format(line_syl, sum(line_syl)))
            f.write("\n")

        f.write("\n")

        for line in poem:
            line_meter = []
            line_list = line.split()

            for word in line_list:
                line_meter.append(dic_to_meter[word])
            f.write(str(line_meter))
            f.write("\n")


In [5]:
hidden_state_args = [30]
num_iter_args = [200]

# Haiku
# stress_pattern = [[], [], []]
# syllable_pattern = [5, 7, 5]
# name = 'haiku'

# Iambic Haiku
stress_pattern = [[0, 1, 0, 1, 0], [0, 1, 0, 1, 0, 1, 0], [0, 1, 0, 1, 0]]
syllable_pattern = [5, 7, 5]
name = 'iambichaiku'

for hidden_state in hidden_state_args:
    for num_iter in num_iter_args:
        with open('../models/hmm_hpt_{}_{}.hmm'.format(hidden_state, num_iter), 'rb') as f:
            test_HMM = pickle.load(f)
            generate_haiku_test(test_HMM, stress_pattern, syllable_pattern, '{}_hidden_{}_iter_{}'.format(name, hidden_state, num_iter))