## Imports

In [1]:
import numpy as np
import string
import sonnetTools as tools
from hmmlearn import hmm
import warnings
from sklearn.preprocessing import LabelEncoder

# Scikit learn has a bug. Fixed in unreleased version
warnings.filterwarnings("ignore", category=DeprecationWarning)

## Helper Functions

In [2]:
def generateLine(HMModel, encoder, BiDict, syllableDict, lineSyllableCountTarget=10):
    currentLine = []
    currentPossibleSyllableCounts = [0]
    generatingLine = True
    while generatingLine:
        # Generate potential next word
        nextWordEncoded = HMModel.sample(1)[0].flatten()
        # Decode the word and convert to a string
        nextWordNumeric = encoder.inverse_transform(nextWordEncoded)[0]
        nextWordString = BiDict[nextWordNumeric]
        # Find possible syllable counts for the word
        possibleSyllables = syllableDict[nextWordString]
        # Given all previous syllable counts, check if the
        # new word produces a potentially valid line
        updatedPossibleSyllableCounts = []
        for lineSyllableCount in currentPossibleSyllableCounts:
            for wordSyllables in possibleSyllables:
                # If 'E' is in the possible word syllable count,
                # it indicates a syllable count which is only
                # valid if this is the final word. Check that
                # this word can end the line.
                if 'E' in wordSyllables:
                    count = int(wordSyllables.strip('E'))
                    if (lineSyllableCount + count == lineSyllableCountTarget):
                        # Can terminate the line in a valid way
                        updatedPossibleSyllableCounts.append(lineSyllableCount+count)
                else:
                    # Otherwise, syllable count can be used anywhere in the line
                    count = int(wordSyllables)
                    updatedPossibleSyllableCounts.append(lineSyllableCount+count)
        # We have now enumerated all possible syllable count up to this point
        if any([counts==lineSyllableCountTarget for counts in updatedPossibleSyllableCounts]):
            # If there is a possible way to match the target number of syllables, end the line
            currentLine.append(nextWordString)
            generatingLine = False
        else:
            # We cannot end the line with the current word
            # Remove all possible syllable counts which are greater than the target number
            # of syllables.
            updatedPossibleSyllableCounts = [count for count in updatedPossibleSyllableCounts \
                                                             if count <= lineSyllableCountTarget]
            # Check if there are any possible counts to build off of in the next step. If not,
            # the current word is a dead-end that pushed the line over the syllable limit.
            if updatedPossibleSyllableCounts != []:
                # Still have some possible counts to work with.
                # Add the current word to the line.
                currentLine.append(nextWordString)
                currentPossibleSyllableCounts = updatedPossibleSyllableCounts
            else:
                # The current word is a dead-end. 
                pass
    return currentLine

def generateSonnet(HMModel, encoder, BiDict, syllableDict):
    sonnet = []
    # Generate 14 sonnet lines
    for j in range(14):
        line = generateLine(HMModel, encoder, BiDict, syllableDict)
        sonnet.append(line)
    return sonnet

def writeSonnetsToFile(sonnets, filepath, header=None):
    f = open(filepath, 'w+')
    if header!=None:
        f.write(header)
        f.write('\n\n')
    for sonnet in sonnets:
        for lineInd in range(len(sonnet)):
            line = sonnet[lineInd]
            strLine = ' '.join(line)
            if lineInd%2 == 0:
                # Even index lines are the start of sentences
                strLine = strLine.capitalize() + ',\n'
            else:
                # Odd index lines are the end of sentences
                strLine = strLine + '.\n'
            f.write(strLine)
        f.write('\n')
    return
    
def formatSonnetToText(sonnet):
    textSonnet = ''
    for lineInd in range(len(sonnet)):
        line = sonnet[lineInd]
        strLine = ' '.join(line)
        if lineInd%2 == 0:
            # Even index lines are the start of sentences
            strLine = strLine.capitalize() + ',\n'
        else:
            # Odd index lines are the end of sentences
            strLine = strLine + '.\n'
        textSonnet = textSonnet + strLine
    return textSonnet

def trainHMM(preparedSonnets, sonnetLengths, n_states):
    model = hmm.MultinomialHMM(n_components=n_states)
    model.fit(preparedSonnets, lengths=sonnetLengths)
    return model

## Load Data

In [3]:
syllableDict = tools.readInSyllableCounts('data/Syllable_dictionary.txt')
BiDict = tools.readInWords('data/Syllable_dictionary.txt')
sonnets = tools.readInSonnets('data/shakespeare.txt', syllableDict)

## Convert sonnets to a form compatible with the HMM library

In [4]:
sonnets = [[[BiDict[word] for word in line] for line in sonnet] for sonnet in sonnets]

In [5]:
lineLengths = [[len(line) for line in sonnet] for sonnet in sonnets]
flattenedLengths = [length for sonnet in lineLengths for length in sonnet]
flattenedSonnets = [word for sonnet in sonnets for line in sonnet for word in line]

In [6]:
encoder = LabelEncoder()
encoder.fit(flattenedSonnets)
encodedSonnets = encoder.transform(flattenedSonnets)
hmmPreparedSonnets = encodedSonnets.reshape(-1,1)

## Train HMMs

In [7]:
stateCounts = [2,4,8,16,32,64]
trainedModels = []
for n_states in stateCounts:
    model = trainHMM(hmmPreparedSonnets, flattenedLengths, n_states)
    trainedModels.append(model)

## Create and Save Sonnets

In [8]:
n_sonnets = 10 # Sonnets to make per model
for ind in range(len(stateCounts)):
    n_states = stateCounts[ind]
    model = trainedModels[ind]
    sonnets = [generateSonnet(model, encoder, BiDict, syllableDict) for _ in range(n_sonnets)]
    title = "Naive Hidden Markov Model: %d hidden states" % n_states
    filetarget = "naiveHMM-%02d-states.txt" % n_states
    writeSonnetsToFile(sonnets, filetarget, header=title)