# Wslatts Miniproject 3: Poem Generation

### Download Data

In [None]:
import string
import requests
import random
import requests

import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt

In [None]:
url_dict = {
    'shakespeare.txt': 'https://caltech-cs155.s3.us-east-2.amazonaws.com/miniprojects/project3/data/shakespeare.txt',
    'spenser.txt': 'https://caltech-cs155.s3.us-east-2.amazonaws.com/miniprojects/project3/data/spenser.txt',
    'syllable_dict.txt' : 'https://caltech-cs155.s3.us-east-2.amazonaws.com/miniprojects/project3/data/Syllable_dictionary.txt',
    'about_syllable_dict.docx' : 'https://caltech-cs155.s3.us-east-2.amazonaws.com/miniprojects/project3/data/syllable_dict_explanation.docx'
}

def download_file(file_path):
    url = url_dict[file_path]
    print('Start downloading...')
    with requests.get(url, stream=True) as r:
        r.raise_for_status()
        with open(file_path, 'wb') as f:
            for chunk in r.iter_content(chunk_size=1024 * 1024 * 1024):
                f.write(chunk)
    print('Complete')

download_file('shakespeare.txt')
download_file('spenser.txt')
download_file('syllable_dict.txt')
download_file('about_syllable_dict.docx')

Start downloading...
Complete
Start downloading...
Complete
Start downloading...
Complete
Start downloading...
Complete


### Preprocessing

In [None]:
def load_sonnet_words(filename):
    with open(filename, "r", encoding="utf-8") as file:
        text = file.read()

    raw_sonnets = text.strip().split("\n\n")

    table = str.maketrans("", "", string.punctuation.replace("'", "").replace("-", ""))

    sonnets = []
    for sonnet in raw_sonnets:
      sonnet_lines = sonnet.split("\n")
      sonnet_words = []
      for line in sonnet_lines:
        line = line.lower().translate(table)
        words = line.split()
        if len(words) > 1:
          sonnet_words.append(words)

      if sonnet_words:
        sonnets.append(sonnet_words)
    return sonnets

def load_syllable_dict(filename):
    syllable_dict = {}
    with open(filename, "r", encoding="utf-8") as file:
      for line in file:
        line = line.strip()
        if not line:
          continue
        parts = line.split(" ", 1)
        if len(parts) == 2:
          word, syllable_count = parts
          syllable_dict[word.lower()] = (syllable_count)
    return syllable_dict

def load_syllable_dict_end_sep(filename):
    syllable_dict = {}
    with open(filename, "r", encoding="utf-8") as file:
      for line in file:
        line = line.strip()
        if not line:
          continue

        # parts is just one line in the file (ex "['acquainted', 'E2 3']")
        parts = line.split(" ", 1)

        # emily added this to better split up the E and 2 cases thing
        word = parts[0].lower()
        if word not in syllable_dict:
          syllable_dict[word] = {"normal": [], "end": []}
        syllable_counts = parts[1:]
        syllable_parts = syllable_counts[0].split()
        for syllable in syllable_parts:
          if syllable.startswith("E"):
            end_syllable = int(syllable[1:])
            syllable_dict[word]["end"].append(end_syllable)
          else:
            normal_syllable = int(syllable)
            syllable_dict[word]["normal"].append(normal_syllable)


        # syllable_dict[word] = {"normal":[], "end":[]}
        # if syllable_counts[0].startswith("E"):
        #   end_syllable = int(syllable_counts[0][1])
        #   normal_syllable = list(map(int, syllable_counts[1:]))
        #   syllable_dict[word]["end"].append(end_syllable)
        #   syllable_dict[word]["normal"].append(normal_syllable)
        # else:
        #   syllable_counts_sep = syllable_counts[0].split()
        #   normal_syllable = list(map(int, syllable_counts_sep))
        #   syllable_dict[word]["normal"].append(normal_syllable)

        # ashiria original code
        # if len(parts) == 2:
        #   word, syllable_count = parts
        #   syllable_dict[word.lower()] = (syllable_count)

    return syllable_dict

def convert_sonnets_to_syllables(sonnets, syllable_dict):
    sonnet_syllables = []
    for sonnet in sonnets:
      sonnet_syllables.append([
          [syllable_dict.get(word, 0) for word in line]
          for line in sonnet
      ])
    return sonnet_syllables

def convert_sonnets_to_syllables_smart(sonnets, syllable_dict):
    sonnet_syllables = []
    for sonnet in sonnets:
        sonnet_lines_syllables = []
        for line_idx, line in enumerate(sonnet):
            line_syllables = []
            current_syllable_cnt = 0

            for word_idx, word in enumerate(line):
                chosen_syllable = 0
                is_end_line = False
                syllables = syllable_dict.get(word, {"normal": [0], "end": [0]})
                if word_idx == len(line) - 1:
                    is_end_line = True
                if is_end_line:
                    if syllables["end"]:
                        chosen_syllable = syllables["end"][0]
                    else:
                        chosen_syllable = syllables["normal"][0]
                else:
                    if len(syllables["normal"]) > 1:
                        remaining = 10 - current_syllable_cnt
                        if remaining in syllables["normal"]:
                            chosen_syllable = remaining
                        else:
                            chosen_syllable = syllables["normal"][0]
                    else:
                        chosen_syllable = syllables["normal"][0]
                current_syllable_cnt += chosen_syllable
                line_syllables.append(chosen_syllable)
            sonnet_lines_syllables.append(line_syllables)
        sonnet_syllables.append(sonnet_lines_syllables)
    return sonnet_syllables

In [None]:
sonnets = load_sonnet_words("shakespeare.txt")
print('sonnets: ', sonnets)
print('number of sonnets: ', len(sonnets))

number of sonnets:  154


In [None]:
syllable_dict_with_end = load_syllable_dict_end_sep("syllable_dict.txt")
print('syllable dict: ', syllable_dict_with_end)
print('syllable dict words: ', syllable_dict_with_end.keys())
print('# of words in syllable dict: ', len(syllable_dict_with_end))

# of words in syllable dict:  3205


In [None]:
syllable_dict = load_syllable_dict("syllable_dict.txt")
sonnet_syllables = convert_sonnets_to_syllables_smart(sonnets, syllable_dict_with_end)
print('sonnet syllables: ', sonnet_syllables)
print('sonnet syllables len: ', len(sonnet_syllables))

sonnet syllables:  [[[1, 2, 2, 1, 2, 2], [1, 2, 2, 1, 1, 2, 1], [1, 1, 1, 2, 1, 1, 1, 2], [1, 2, 1, 1, 1, 1, 3], [1, 1, 3, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 4, 1], [2, 1, 2, 1, 3, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 3], [1, 2, 2, 1, 1, 2, 1], [2, 1, 1, 1, 2, 1, 2], [1, 2, 1, 1, 1, 1, 3], [2, 1, 1, 1, 1, 1, 2, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], [[1, 2, 2, 1, 2, 1, 1], [1, 1, 1, 2, 1, 1, 2, 1], [1, 1, 1, 3, 1, 1, 1, 1], [1, 1, 1, 2, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 2, 1], [1, 1, 1, 2, 1, 1, 2, 1], [1, 1, 2, 1, 1, 1, 2, 1], [1, 1, 3, 1, 1, 2, 1], [1, 1, 1, 1, 2, 1, 2, 1], [1, 1, 1, 2, 0, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 0], [2, 1, 2, 1, 3, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1]], [[1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 1, 1, 1, 1, 1, 1, 2], [1, 1, 2, 1, 1, 1, 1, 2], [1, 1, 2, 1, 1, 2, 1, 1], [1, 1, 1, 1, 1, 1, 1, 2, 1], [2, 1, 2, 1, 1, 3], [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], [1, 1, 2, 1, 1, 4], [1, 1, 1, 2, 1, 1, 1, 1, 1], [1, 1, 1, 2, 

In [None]:
def get_rhyme_scheme(sonnet):
  '''
  return string of alphabetical rhyme scheme, based on the last word of each line.
  '''
  last_words = [line[-1] for line in sonnet]

  rhyme_scheme = []
  rhyme_patterns = {}
  current_letter = 'a'

  for word in last_words:
      ending = word[-2:] if len(word) >= 2 else word

      if ending in rhyme_patterns:
          rhyme_scheme.append(rhyme_patterns[ending])
      else:
          rhyme_patterns[ending] = current_letter
          rhyme_scheme.append(current_letter)
          current_letter = chr(ord(current_letter) + 1)

  return rhyme_scheme

def get_syllable_counts(sonnet):
  '''
  return list of syllable counts for each line in a sonnet.
  - prints warning if word not in syllable dictionary
  '''
  syllable_counts = []
  for i, line in enumerate(sonnet):
      line_syllables = 0
      for j, word in enumerate(line):
          word = ''.join(word)

          if word in syllable_dict_with_end:
            if j == len(line) - 1 and syllable_dict_with_end[word]['end']:
                line_syllables += syllable_dict_with_end[word]['end'][0]
            else:
                line_syllables += syllable_dict_with_end[word]['normal'][0]
          else:
            line_syllables += 1
            print(f"Warning: '{word}' not found in syllable dictionary")

      syllable_counts.append(line_syllables)

  return syllable_counts

def sonnet_summary(sonnet):
  '''
  prints a summary of a given sonnet:
  - the sonnet itself
  - number of lines
  - number of words per line
  - syllable count per line
  - rhyme scheme

  args:
  - sonnet: a list of lists of words (str)
  '''
  # print sonnet
  print('Sonnet: ')
  for line in sonnet:
    print(" ".join(line))
  print()

  # number of lines
  print(f"Number of lines: {len(sonnet)}")

  # syllables per line
  syllable_counts = get_syllable_counts(sonnet)
  for i, count in enumerate(syllable_counts):
      print(f"Line {i+1}: number of words = {len(sonnet[i])}, number of syllables = {count}")

  # rhyme scheme
  rhyme_scheme = get_rhyme_scheme(sonnet)
  print("Rhyme scheme:", ''.join(rhyme_scheme))

In [None]:
sonnet_summary(sonnets[0])

Sonnet: 
from fairest creatures we desire increase
that thereby beauty's rose might never die
but as the riper should by time decease
his tender heir might bear his memory
but thou contracted to thine own bright eyes
feed'st thy light's flame with self-substantial fuel
making a famine where abundance lies
thy self thy foe to thy sweet self too cruel
thou that art now the world's fresh ornament
and only herald to the gaudy spring
within thine own bud buriest thy content
and tender churl mak'st waste in niggarding
pity the world or else this glutton be
to eat the world's due by the grave and thee

Number of lines: 14
Line 1: number of words = 6, number of syllables = 10
Line 2: number of words = 7, number of syllables = 10
Line 3: number of words = 8, number of syllables = 10
Line 4: number of words = 7, number of syllables = 10
Line 5: number of words = 8, number of syllables = 10
Line 6: number of words = 7, number of syllables = 10
Line 7: number of words = 6, number of syllables = 10

# Implementing Models

## HMM Code

In [None]:
class HiddenMarkovModel:
    '''
    Class implementation of Hidden Markov Models.
    '''

    def __init__(self, A, O):
        '''
        Initializes an HMM. Assumes the following:
            - States and observations are integers starting from 0.
            - There is a start state (see notes on A_start below). There
              is no integer associated with the start state, only
              probabilities in the vector A_start.
            - There is no end state.
        Arguments:
            A:          Transition matrix with dimensions L x L.
                        The (i, j)^th element is the probability of
                        transitioning from state i to state j. Note that
                        this does not include the starting probabilities.
            O:          Observation matrix with dimensions L x D.
                        The (i, j)^th element is the probability of
                        emitting observation j given state i.
        Parameters:
            L:          Number of states.

            D:          Number of observations.

            A:          The transition matrix.

            O:          The observation matrix.

            A_start:    Starting transition probabilities. The i^th element
                        is the probability of transitioning from the start
                        state to state i. For simplicity, we assume that
                        this distribution is uniform.
        '''

        self.L = len(A)
        self.D = len(O[0])
        self.A = A
        self.O = O
        self.A_start = [1. / self.L for _ in range(self.L)]

    def forward(self, x, normalize=False):
        '''
        Uses the forward algorithm to calculate the alpha probability
        vectors corresponding to a given input sequence.
        Arguments:
            x:          Input sequence in the form of a list of length M,
                        consisting of integers ranging from 0 to D - 1.
            normalize:  Whether to normalize each set of alpha_j(i) vectors
                        at each i. This is useful to avoid underflow in
                        unsupervised learning.
        Returns:
            alphas:     Vector of alphas.
                        The (i, j)^th element of alphas is alpha_j(i),
                        i.e. the probability of observing prefix x^1:i
                        and state y^i = j.
                        e.g. alphas[1][0] corresponds to the probability
                        of observing x^1:1, i.e. the first observation,
                        given that y^1 = 0, i.e. the first state is 0.
        '''

        M = len(x)      # Length of sequence.
        alphas = [[0. for _ in range(self.L)] for _ in range(M + 1)]

        for i in range(self.L):
          alphas[1][i] = self.A_start[i] * self.O[i][x[0]]

        for d in range(2, M + 1):

          for curr_state in range(self.L):
            prob = 0
            for prev_state in range(self.L):
              prob += (self.O[curr_state][x[d-1]] * (alphas[d-1][prev_state] * self.A[prev_state][curr_state]))

            alphas[d][curr_state] = prob

          if normalize:
            denom = np.sum(alphas[d])
            alphas[d] = [alpha/denom for alpha in alphas[d]]


        return alphas


    def backward(self, x, normalize=False):
        '''
        Uses the backward algorithm to calculate the beta probability
        vectors corresponding to a given input sequence.
        Arguments:
            x:          Input sequence in the form of a list of length M,
                        consisting of integers ranging from 0 to D - 1.
            normalize:  Whether to normalize each set of alpha_j(i) vectors
                        at each i. This is useful to avoid underflow in
                        unsupervised learning.
        Returns:
            betas:      Vector of betas.
                        The (i, j)^th element of betas is beta_j(i), i.e.
                        the probability of observing prefix x^(i+1):M and
                        state y^i = j.
                        e.g. betas[M][0] corresponds to the probability
                        of observing x^M+1:M, i.e. no observations,
                        given that y^M = 0, i.e. the last state is 0.
        '''

        M = len(x)      # Length of sequence.
        betas = [[0. for _ in range(self.L)] for _ in range(M + 1)]

        for i in range(self.L):
          betas[M][i] = 1

        for d in range(M - 1, -1, -1):

          for curr_state in range(self.L):
            prob = 0
            for next_state in range(self.L):
                if d == 0:
                  prob += (betas[d+1][next_state] * self.A_start[next_state] * self.O[next_state][x[d]])
                else:
                  prob += (betas[d+1][next_state] * self.A[curr_state][next_state] * self.O[next_state][x[d]])

            betas[d][curr_state] = prob

          if normalize:
            denom = np.sum(betas[d])
            betas[d] = [beta/denom for beta in betas[d]]

        return betas

    def unsupervised_learning(self, X, N_iters):
        '''
        Trains the HMM using the Baum-Welch algorithm on an unlabeled
        datset X. Note that this method does not return anything, but
        instead updates the attributes of the HMM object.
        Arguments:
            X:          A dataset consisting of input sequences in the form
                        of variable-length lists, consisting of integers
                        ranging from 0 to D - 1. In other words, a list of
                        lists.
            N_iters:    The number of iterations to train on.
        '''

        for i in range(N_iters):
          A_numer = np.zeros((self.L, self.L))
          A_denom = np.zeros((self.L, self.L))
          O_numer = np.zeros((self.L, self.D))
          O_denom = np.zeros((self.L, self.D))

          for x in X:
            alphas = self.forward(x, normalize=True)
            betas = self.backward(x, normalize=True)
            M = len(x)

            for d in range(1, M + 1):
              prob_OAd = np.array([alphas[d][curr_state] * betas[d][curr_state] for curr_state in range(self.L)])
              prob_OAd /= np.sum(prob_OAd)

              for curr_state in range(self.L):
                O_numer[curr_state][x[d-1]] += prob_OAd[curr_state]
                O_denom[curr_state] += prob_OAd[curr_state]
                if d != M:
                  A_denom[curr_state] += prob_OAd[curr_state]

            for d in range(1, M):
              prob_An = np.array([[alphas[d][curr_state] \
                                  * self.O[next_state][x[d]] \
                                  * self.A[curr_state][next_state] \
                                  * betas[d+1][next_state] \
                                  for next_state in range(self.L)] \
                                  for curr_state in range(self.L)])
              prob_An /= np.sum(prob_An)

              for curr_state in range(self.L):
                for next_state in range(self.L):
                  A_numer[curr_state][next_state] += prob_An[curr_state][next_state]

          self.A = A_numer / A_denom
          self.O = O_numer / O_denom

    def generate_emission(self, M, seed=None):
          '''
          Generates an emission of length M, assuming that the starting state
          is chosen uniformly at random.
          Arguments:
              M:          Length of the emission to generate.
          Returns:
              emission:   The randomly generated emission as a list.
              states:     The randomly generated states as a list.
          '''

          # (Re-)Initialize random number generator
          rng = np.random.default_rng(seed=seed)

          emission = []
          states = []

          # Initialize Random Start State
          state = np.random.randint(0, self.L)

          for d in range(M):
            emission.append(np.random.choice(list(range(self.D)), p = self.O[state]))
            states.append(state)
            state = np.random.choice(list(range(self.L)), p = self.A[state])

          return emission, states

    def generate_emission_reverse(self, M, last_observation, seed=None):
        '''
        Generates an emission of length M in reverse, starting with a given last observation.

        Arguments:
            M:              Length of the emission to generate (including the last observation).
            last_observation: The last observation (word) in the sequence.
            seed:           Random seed for reproducibility.

        Returns:
            emission:       The randomly generated emission as a list.
            states:         The randomly generated states as a list.
        '''
        # Initialize random number generator
        rng = np.random.default_rng(seed=seed)

        emission = [last_observation]
        states = []

        # Calculate reversed transition matrix - P(y_i-1 | y_i)
        A_reverse = np.zeros((self.L, self.L))
        for i in range(self.L):
            for j in range(self.L):
                A_reverse[i][j] = self.A[j][i] * self.A_start[j] / sum(self.A[k][i] * self.A_start[k] for k in range(self.L))

        # Choose initial state based on the last observation
        state_probs = np.array([self.O[i][last_observation] for i in range(self.L)])
        if sum(state_probs) > 0:
            state_probs = state_probs / sum(state_probs)
            state = rng.choice(self.L, p=state_probs)
        else:
            state = rng.integers(0, self.L)  # Random state if observation impossible

        states.append(state)

        # Generate the rest of the sequence in reverse
        for _ in range(1, M):
            # Get next (previous) state according to reversed transitions
            state = rng.choice(self.L, p=A_reverse[state])
            states.append(state)

            # Generate observation from this state
            obs = rng.choice(self.D, p=self.O[state])
            emission.append(obs)

        # Reverse the lists to get correct order
        emission.reverse()
        states.reverse()

        return emission, states

In [None]:
word_to_idx = {word: i for i, word in enumerate(sorted(set(word for sonnet in sonnets for line in sonnet for word in line)))}
idx_to_word = {i: word for word, i in word_to_idx.items()}

X = [[word_to_idx[word] for word in line] for sonnet in sonnets for line in sonnet]

def unsupervised_HMM(X, n_states, N_iters, seed=None):
    if seed is not None:
        np.random.seed(seed)

    D = max(max(x) for x in X) + 1

    A = np.random.random((n_states, n_states))
    A = A / A.sum(axis=1, keepdims=True)

    O = np.random.random((n_states, D))
    O = O / O.sum(axis=1, keepdims=True)

    hmm_model = HiddenMarkovModel(A, O)
    hmm_model.unsupervised_learning(X, N_iters)

    return hmm_model

In [None]:
def generate_sonnet(hmm_model, idx_to_word, syllable_dict, num_lines=14, seed=None):
    if seed is not None:
        np.random.seed(seed)

    M = num_lines * 10
    emission, _ = hmm_model.generate_emission(M, seed=seed)
    emission_words = [idx_to_word[idx] for idx in emission]

    def get_normal_count(word):
        if word in syllable_dict and syllable_dict[word].get('normal'):
            return syllable_dict[word]['normal'][0]
        return 1

    def get_end_count(word):
        if word in syllable_dict and syllable_dict[word].get('end'):
            if len(syllable_dict[word]['end']) > 0:
                return syllable_dict[word]['end'][0]
        return get_normal_count(word)

    sonnet = []
    current_line = []
    current_syllables = 0

    for word in emission_words:
      if len(sonnet) == num_lines:
          break

      current_line.append(word)

      syllable_total = sum(get_normal_count(w) for w in current_line[:-1])
      syllable_total += get_end_count(current_line[-1])

      if syllable_total >= 10:
          sonnet.append(current_line)
          current_line = []
    return sonnet

In [None]:
def generate_sonnet_per_line(hmm_model, idx_to_word, syllable_dict, num_lines=14, seed=None):
    if seed is not None:
        np.random.seed(seed)

    line_lengths = [len(line) for sonnet in sonnets for line in sonnet] # distribution of line lengths based on input sonnets

    sonnet_lines = []

    for line_idx in range(num_lines):
        words_per_line = np.random.choice(line_lengths)

        idx, _ = hmm_model.generate_emission(words_per_line)
        words = [idx_to_word[i] for i in idx]
        sonnet_lines.append(words)

    return sonnet_lines
    #   line_syllables = []
    #   current_line_sum = 0

    #   while len(line_syllables) < words_per_line:
    #       word_idx = hmm_model.generate_emission(words_per_line)
    #       word = idx_to_word[word_idx]

    #       syllables = syllable_dict.get(word, {"normal": [0], "end": [0]})

    #       is_end_of_line = (len(line_syllables) == words_per_line - 1)

    #       if is_end_of_line:
    #           if syllables["end"]:
    #               chosen_syllable = syllables["end"][0]
    #           else:
    #               chosen_syllable = syllables["normal"][0]
    #       else:
    #           if len(syllables["normal"]) > 1:
    #               remaining_syllables = words_per_line - current_line_sum
    #               if remaining_syllables in syllables["normal"]:
    #                   chosen_syllable = remaining_syllables
    #               else:
    #                   chosen_syllable = syllables["normal"][0]
    #           else:
    #               chosen_syllable = syllables["normal"][0]

    #       line_syllables.append(chosen_syllable)
    #       current_line_sum += chosen_syllable

    #       if current_line_sum > words_per_line:
    #           break

    #   sonnet_lines.append(" ".join(line_syllables))

    # return "\n".join(sonnet_lines)

In [None]:
n_states_list = [3, 7, 9, 11, 13, 15]
N_iters = 20

for n_state in n_states_list: # generated one line at a time
  print('n_state=', n_state)
  hmm_model = unsupervised_HMM(X, n_states=n_state, N_iters=N_iters)
  generated_sonnet = generate_sonnet_per_line(hmm_model, idx_to_word, syllable_dict_with_end, num_lines=14, seed=42)
  sonnet_summary(generated_sonnet)

n_state= 3
Sonnet: 
self do at not a the eternal i'll in
chide me stand of art doth will sun built is
deeds eve's in reigns thy in the most with
that holds children beauty's sleeping a rich
which use my me patience his idle thing
flatter that he to minutes deceivest mayst which rare some
to by true one even doth light his doth
for blanks the compounds fortune's in drink power my nine
that special man unless every shall all golden deepest
west dissuade which flowers thou made brave truth
ere true think bed-vow to advocate so and purple giving
and thee grown of fill would unlettered love's
and from silver where use a in with
for hateth proud read buds you look self

Number of lines: 14
Line 1: number of words = 9, number of syllables = 11
Line 2: number of words = 10, number of syllables = 10
Line 3: number of words = 9, number of syllables = 9
Line 4: number of words = 7, number of syllables = 10
Line 5: number of words = 8, number of syllables = 10
Line 6: number of words = 10, number 

In [None]:
n_states_list = [3, 7, 9, 11, 13, 15]
N_iters = 100

for n_state in n_states_list: # generated all words at once, then split into lines
  print('n_state=', n_state)
  hmm_model = unsupervised_HMM(X, n_states=n_state, N_iters=N_iters)
  generated_sonnet = generate_sonnet(hmm_model, idx_to_word, syllable_dict_with_end, num_lines=14, seed=42)
  sonnet_summary(generated_sonnet)

n_state= 3
Sonnet: 
the thou love make deceive be where you refigured
and it would but is me should absent of
a frown'st receiv'st dead desert live of thrall
issue summer's of with than oft orphans
eye's always is alters that put seem third
and scope on of womb such grew that love up
of dear not whether sweet pride bearing expiate
to best tyrants saw in corrupting mortal
full crowned oaths another fled spacious
your worst allow reap my not as fire merit
interim the my to for though but jewel thou
such new fair you from with tongues have to kindness
thy the of which every need yet an drugs
then hold my good wrongfully we i so

Number of lines: 14
Line 1: number of words = 9, number of syllables = 12
Line 2: number of words = 9, number of syllables = 10
Line 3: number of words = 8, number of syllables = 10
Line 4: number of words = 7, number of syllables = 10
Line 5: number of words = 8, number of syllables = 10
Line 6: number of words = 10, number of syllables = 10
Line 7: number of wor

In [None]:
n_state = 11 # best parameter
N_iters = 100

hmm_model = unsupervised_HMM(X, n_states=n_state, N_iters=N_iters)
generated_sonnet = generate_sonnet(hmm_model, idx_to_word, syllable_dict_with_end, num_lines=14, seed=42)
sonnet_summary(generated_sonnet)

Sonnet: 
tears thou in of choirs an were your pleasures
are i you call in my self against our
appetite have not good eyes for nothing thou
i thine own world's smell of proud hand am i
am that not storm-beaten thing be the parts
not you show me shadow of truth mine beated
me with self shall be hooks trim but to progress
giving do it eve's external o'erpressed
away came shines year women's and place no
mow and even lies my thee most to hour thee
birth his them so more by worth common which
threescore delight there his thing to mine whose
lion's me yet a dateless stand esteemed
me can worms to her prize doth that bed-vow

Number of lines: 14
Line 1: number of words = 9, number of syllables = 10
Line 2: number of words = 9, number of syllables = 10
Line 3: number of words = 8, number of syllables = 11
Line 4: number of words = 10, number of syllables = 10
Line 5: number of words = 8, number of syllables = 10
Line 6: number of words = 9, number of syllables = 11
Line 7: number of words = 10

## RNN Code

In [None]:
def get_char_mapping(text):
    chars = sorted(set(text))
    char_to_idx = {ch: i for i, ch in enumerate(chars)}
    idx_to_char = {i: ch for ch, i in char_to_idx.items()}
    return char_to_idx, idx_to_char

def get_data(text, char_to_idx, sequence_length=40, step=3):
    inputs, targets = [], []

    for i in range(0, len(text) - sequence_length, step):
        input_seq = text[i:i + sequence_length]
        target_seq = text[i + 1:i + sequence_length + 1]

        input_seq = [char_to_idx[ch] for ch in input_seq]
        target_seq = [char_to_idx[ch] for ch in target_seq]

        inputs.append(input_seq)
        targets.append(target_seq)
    inputs = torch.tensor(inputs, dtype=torch.long)   # Shape: (num_samples, seq_length)
    targets = torch.tensor(targets, dtype=torch.long) # Shape: (num_samples, seq_length)

    return inputs, targets

class LSTM(nn.Module):
    def __init__(self, vocab_size, hidden_size):
        super(LSTM, self).__init__()
        self.hidden_size = hidden_size
        self.embedding = nn.Embedding(vocab_size, hidden_size)
        self.lstm = nn.LSTM(hidden_size, hidden_size, batch_first=True)
        self.fc = nn.Linear(hidden_size, vocab_size)

    def forward(self, x):
        x = self.embedding(x)
        out, _ = self.lstm(x)
        out = self.fc(out)
        return out


def train_lstm(model, inputs, targets, epochs=20, lr=0.001):
    criterion = nn.CrossEntropyLoss()
    optimizer = optim.Adam(model.parameters(), lr=lr)

    dataset = torch.utils.data.TensorDataset(inputs, targets)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=64, shuffle=True)

    for epoch in range(epochs):
        total_loss = 0
        for batch_inputs, batch_targets in dataloader:
            optimizer.zero_grad()
            outputs = model(batch_inputs)  # Shape: (batch, seq_len, vocab_size)
            loss = criterion(outputs.view(-1, outputs.shape[-1]), batch_targets.view(-1))
            loss.backward()
            optimizer.step()
            total_loss += loss.item()

        print(f"Epoch {epoch+1}/{epochs}, Loss: {total_loss / len(dataloader):.4f}")

    return model

def generate_text(model, idx_to_char, char_to_idx, seed_text, length=200, temperature=1.0):
    model.eval()

    input_seq = [char_to_idx[ch] for ch in seed_text][-40:]
    input_tensor = torch.tensor([input_seq], dtype=torch.long)

    generated_text = seed_text

    with torch.no_grad():
        for _ in range(length):
            output = model(input_tensor)  # Shape: (1, seq_len, vocab_size)
            output = output[:, -1, :] #this is getting the last character

            probabilities = torch.softmax(output / temperature, dim=-1)
            next_char_idx = torch.multinomial(probabilities, 1).item()
            # next_char_idx = np.random.choice(range(vocab_size), p=probabilities)

            next_char = idx_to_char[next_char_idx]

            generated_text += next_char
            # the stuff below is shifting the sequence left to incorporate the each new character added
            # hence each newly predicted character is passed as input to the next character prediction!
            new_input = torch.cat((input_tensor[:, 1:], torch.tensor([[next_char_idx]])), dim=1)
            input_tensor = new_input

    return generated_text

In [None]:
seed = 49
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)

with open('shakespeare.txt', 'r', encoding='utf-8') as file:
    text = file.read().lower()
text = clean_text = re.sub(r'\[\d+\]|\(\d+\)|\d+', '', text)

char_to_idx, idx_to_char = get_char_mapping(text)
sequence_length = 40
step = 3
inputs, targets = get_data(text, char_to_idx, sequence_length, step)

vocab_size = len(char_to_idx)
hidden_size = 128
model = LSTM(vocab_size, hidden_size)

epochs = 50
learning_rate = 0.001
trained_model = train_lstm(model, inputs, targets, epochs, learning_rate)

Epoch 1/50, Loss: 1.9799
Epoch 2/50, Loss: 1.5982
Epoch 3/50, Loss: 1.4945
Epoch 4/50, Loss: 1.4276
Epoch 5/50, Loss: 1.3757
Epoch 6/50, Loss: 1.3321
Epoch 7/50, Loss: 1.2938
Epoch 8/50, Loss: 1.2597
Epoch 9/50, Loss: 1.2287
Epoch 10/50, Loss: 1.2003
Epoch 11/50, Loss: 1.1746
Epoch 12/50, Loss: 1.1512
Epoch 13/50, Loss: 1.1295
Epoch 14/50, Loss: 1.1101
Epoch 15/50, Loss: 1.0924
Epoch 16/50, Loss: 1.0762
Epoch 17/50, Loss: 1.0609
Epoch 18/50, Loss: 1.0478
Epoch 19/50, Loss: 1.0351
Epoch 20/50, Loss: 1.0233
Epoch 21/50, Loss: 1.0124
Epoch 22/50, Loss: 1.0028
Epoch 23/50, Loss: 0.9939
Epoch 24/50, Loss: 0.9850
Epoch 25/50, Loss: 0.9774
Epoch 26/50, Loss: 0.9696
Epoch 27/50, Loss: 0.9631
Epoch 28/50, Loss: 0.9566
Epoch 29/50, Loss: 0.9502
Epoch 30/50, Loss: 0.9450
Epoch 31/50, Loss: 0.9392
Epoch 32/50, Loss: 0.9342
Epoch 33/50, Loss: 0.9297
Epoch 34/50, Loss: 0.9249
Epoch 35/50, Loss: 0.9209
Epoch 36/50, Loss: 0.9165
Epoch 37/50, Loss: 0.9125
Epoch 38/50, Loss: 0.9085
Epoch 39/50, Loss: 0.

In [None]:
seed_text = "shall i compare thee to a summer's day"
temperatures = [1.5, 0.75,0.25]
generated_texts = []
for temperature in temperatures:
  generated_text = generate_text(trained_model, idx_to_char, char_to_idx, seed_text, length=450, temperature=temperature)
  print(f"Generated Text (Temperature = {temperature}):\n {generated_text}")
  generated_texts.append(generated_text)
  print()


Generated Text (Temperature = 1.5):
 shall i compare thee to a summer's day?
nour remove,
to may too,
me frught,ing toy,
i self infall-com'stannees despair injurious distainless life thou to me within be:
let my lay chast in one of say infecrieg,
a is it not did exchese:
so fron.
for from date:
  i being dombid mine,
rhy our most to uge,
when other in you art,
to wantanchoods which alters when it fil'st be by spirit in one, sip why wish a bate thy cure,
thy mind must graces, carting sing:
  in my love for summer since j

Generated Text (Temperature = 0.75):
 shall i compare thee to a summer's day,
when i all others worse alone.


                     
that thou dost lives my desire doth with lossess on your sight,
by thy parted with all my name refeit.
so art to the time do i ten with my love stol'n i do boundly pening,
and nature basome hath in my poor from when they 'tis long of your love,
to make me to god.
my that i must mild me, some vanicled,
that love in love to thee,
yet what d

In [None]:
for generated_text in generated_texts:
  generated_text_lines = generated_text.split('\n')
  generated_text_lines = [line.split(" ") for line in generated_text_lines]
  sonnet_summary(generated_text_lines)

Sonnet: 
shall i compare thee to a summer's day?
nour remove,
to may too,
me frught,ing toy,
i self infall-com'stannees despair injurious distainless life thou to me within be:
let my lay chast in one of say infecrieg,
a is it not did exchese:
so fron.
for from date:
  i being dombid mine,
rhy our most to uge,
when other in you art,
to wantanchoods which alters when it fil'st be by spirit in one, sip why wish a bate thy cure,
thy mind must graces, carting sing:
  in my love for summer since j

Number of lines: 15
Line 1: number of words = 8, number of syllables = 10
Line 2: number of words = 2, number of syllables = 2
Line 3: number of words = 3, number of syllables = 3
Line 4: number of words = 3, number of syllables = 3
Line 5: number of words = 12, number of syllables = 16
Line 6: number of words = 9, number of syllables = 9
Line 7: number of words = 6, number of syllables = 6
Line 8: number of words = 2, number of syllables = 2
Line 9: number of words = 3, number of syllables = 3
L

In [None]:
seed_text = "shall i compare thee to a summer's day"
generated_text = generate_text(trained_model, idx_to_char, char_to_idx, seed_text, length=300, temperature=0.9)

In [None]:
generated_text_lines = generated_text.split('\n')
generated_text_lines = [line.split(" ") for line in generated_text_lines]
sonnet_summary(generated_text_lines)

Sonnet: 
shall i compare thee to a summer's day?
but it feir most pride.
if hast his so shake thou bestow'st writ the treason,
i straight in hours may, to leave gift in me, thy return that is such a counters when i topt sea,
  brey i hand me alone:
but dear'st taket him, must beauty living faults whether i frond despite the time so breath o'er-s

Number of lines: 6
Line 1: number of words = 8, number of syllables = 10
Line 2: number of words = 5, number of syllables = 5
Line 3: number of words = 10, number of syllables = 11
Line 4: number of words = 21, number of syllables = 22
Line 5: number of words = 7, number of syllables = 7
Line 6: number of words = 17, number of syllables = 20
Rhyme scheme: abcdef


## Additional Goal

### Rhyme

In [None]:
def build_rhyming_dict():
  rhyming_dict = {}
  for sonnet in sonnets:
    for i in range(len(sonnet)):
      word = sonnet[i][-1]
      rhyme_ending = word[-2:]
      if rhyme_ending not in rhyming_dict:
        rhyming_dict[rhyme_ending] = []
      if word not in rhyming_dict[rhyme_ending]:
        rhyming_dict[rhyme_ending].append(word)
  return rhyming_dict

In [None]:
rhyming_dict = build_rhyming_dict()
print(rhyming_dict)

{'se': ['increase', 'decease', 'praise', 'use', 'abuse', 'cease', 'lease', 'decrease', 'muse', 'verse', 'rehearse', 'recompense', 'sense', 'cause', 'arise', 'suppose', 'those', 'choose', 'lose', 'disperse', 'devise', 'curse', 'worse', 'inhearse', 'horse', 'expense', 'rose', 'enclose', 'case', 'dispense', 'despise', 'chase', 'disease', 'please'], 'ie': ['die', 'lie'], 'ry': ['memory', 'husbandry', 'usury', 'wary', 'chary', 'injury', 'masonry', 'pry', 'cry', 'glory', 'story', 'history', 'idolatry', 'flattery'], 'es': ['eyes', 'lies', 'leaves', 'sheaves', 'graces', 'faces', 'shines', 'declines', 'cries', 'foes', 'roses', 'discloses', 'prophecies', 'subscribes', 'tribes', 'spies', 'subtleties', 'enemies', 'injuries'], 'el': ['fuel', 'cruel', 'excel', 'feel', 'steel', 'level', 'bevel', 'jewel'], 'nt': ['ornament', 'content', 'unprovident', 'evident', 'moment', 'comment', 'invent', 'excellent', 'account', 'surmount', 'argument', 'spent', 'monument', 'accident', 'discontent', 'rent', 'bent'],

In [None]:
def generate_sonnet_line_reverse(hmm_model, idx_to_word, syllable_dict, seed_word, seed=None):
    if seed is not None:
        np.random.seed(seed)

    M = 10
    emission, _ = hmm_model.generate_emission_reverse(M, word_to_idx[seed_word], seed=seed)
    emission_words = [idx_to_word[idx] for idx in emission]

    def get_normal_count(word):
        if word in syllable_dict and syllable_dict[word].get('normal'):
            return syllable_dict[word]['normal'][0]
        return 1

    def get_end_count(word):
        if word in syllable_dict and syllable_dict[word].get('end'):
            if len(syllable_dict[word]['end']) > 0:
                return syllable_dict[word]['end'][0]
        return get_normal_count(word)

    line = []
    syllables = 0

    emission_words.reverse()

    for word in emission_words:
      if syllables >= 10:
          break

      line.insert(0, word)

      syllables = 0
      syllables = sum(get_normal_count(w) for w in line[:-1])
      syllables += get_end_count(line[-1])

    return line

In [None]:
# sonnet with rhyme scheme ababcdcdefefgg
rhyme_scheme = ['a', 'b', 'a', 'b', 'c', 'd', 'c', 'd', 'e', 'f', 'e', 'f', 'g', 'g']  # ababcdcdefefgg
rhyme_endings = {}
for char in range(ord('a'), ord('g') + 1):
  rhyme_endings[char] = random.choice(list(rhyming_dict.keys()))
  while rhyme_endings[char][-1] == "'" or rhyme_endings[char] == "i" or len(rhyming_dict[rhyme_endings[char]]) <= 1:
    rhyme_endings[char] = random.choice(list(rhyming_dict.keys()))

sonnet = []
used_words = []
for letter in rhyme_scheme:
    ending = rhyme_endings[ord(letter)]
    seed_word = random.choice(rhyming_dict[ending])
    while len(rhyming_dict[ending]) > 1 and seed_word in used_words:
      seed_word = random.choice(rhyming_dict[ending])
    used_words.append(seed_word)
    line = generate_sonnet_line_reverse(hmm_model, idx_to_word, syllable_dict_with_end, seed_word)
    sonnet.append(line)

sonnet_summary(sonnet)

Sonnet: 
captain thorns o purest then eternity
painting beauty that black but sick my mine fault
things hide bower th' apparel incertainty
mother's beauty disgrace done your tyrant halt
up-locked bewailed me publish carcanet
insults truth nay you my own decays done
tied earth it in and thorns turn broke on set
truth to in doth heaven shall enclose to mine
appear ten self his woe prison the wide days
control place doth am hell doth wind comes respect
that riper the time a dull be decays
and do objects methinks thy to defect
beweep the night forgot refined dearly
entertain mourn self full thou she this fly

Number of lines: 14
Line 1: number of words = 6, number of syllables = 11
Line 2: number of words = 9, number of syllables = 11
Line 3: number of words = 6, number of syllables = 10
Line 4: number of words = 7, number of syllables = 11
Line 5: number of words = 5, number of syllables = 12
Line 6: number of words = 8, number of syllables = 10
Line 7: number of words = 10, number of syl

### Incorporating additional texts

In [None]:
spenser = load_sonnet_words('spenser.txt')
sonnets_incl_spenser = sonnets.copy()
sonnets_incl_spenser.extend(spenser)

print('number of sonnets: ', len(sonnets_incl_spenser))

number of sonnets:  244


In [None]:
n_state = [11, 13, 15]
N_iters = 100

word_to_idx = {word: i for i, word in enumerate(sorted(set(word for sonnet in sonnets_incl_spenser for line in sonnet for word in line)))}
idx_to_word = {i: word for word, i in word_to_idx.items()}
X = [[word_to_idx[word] for word in line] for sonnet in sonnets_incl_spenser for line in sonnet]

for state in n_state:
  print('n_state=', state)
  hmm_model = unsupervised_HMM(X, n_states=state, N_iters=N_iters)
  generated_sonnet = generate_sonnet(hmm_model, idx_to_word, syllable_dict_with_end, num_lines=14, seed=42)
  sonnet_summary(generated_sonnet)
  print()

n_state= 11
Sonnet: 
thence thine league life fair and with your pain are
it worthy declines i men pride a perish
alters do should dying gaze in never thee
it stop newer will that of respite i am
it am silence of the through boldness passed
of sweetest work shall else silence lead to
repair boast some will renew pluck deep dust
to every vengeance of i compare me doth
confounds mourners and do then your you and
quite or no cut do lead in this love to
his sweet bones in thou self of doth your eyes
with thy far thoughts in true th' not whose hate me
yet anew full thou fulfil posterity
eyes will thy grave passed doth that basest no

Number of lines: 14
Line 1: number of words = 10, number of syllables = 10
Line 2: number of words = 8, number of syllables = 10
Line 3: number of words = 8, number of syllables = 11
Line 4: number of words = 9, number of syllables = 10
Line 5: number of words = 8, number of syllables = 10
Line 6: number of words = 8, number of syllables = 10
Line 7: number of 

## Visualization

In [None]:
hmm_model = unsupervised_HMM(X, n_states=n_state, N_iters=N_iters)

NameError: name 'unsupervised_HMM' is not defined

In [None]:
import seaborn as sns
import nltk
nltk.download('averaged_perceptron_tagger_eng')
from nltk import pos_tag
from nltk.tokenize import word_tokenize
import pandas as pd
from collections import defaultdict


n_state = 11
N_iters = 100


def visualize_transition(hmm_model, X, n_states, N_iters, seed=None):

    transition_matrix = hmm_model.A

    emission_matrix = hmm_model.O

    plt.figure(figsize=(8, 6))
    sns.heatmap(transition_matrix, cmap="YlGnBu", annot=True, xticklabels=range(n_states), yticklabels=range(n_states))
    plt.title("Transition Matrix (A)")
    plt.xlabel("Next State")
    plt.ylabel("Current State")
    plt.show()


visualize_transition(hmm_model, X, n_states=11, N_iters=10)



def get_top_n_words_for_states(hmm_model, idx_to_word, top_n=10):

    O = hmm_model.O

    top_words_per_state = {}

    for state in range(O.shape[0]):
        top_indices = np.argsort(O[state])[-top_n:][::-1]
        top_words = [idx_to_word[idx] for idx in top_indices]
        top_words_per_state[state] = top_words

    return top_words_per_state

top_words = get_top_n_words_for_states(hmm_model, idx_to_word, top_n=10)

for state, words in top_words.items():
    print(f"State {state}: {words}")

def visualize_sparsities(hmm, O_max_cols=50, O_vmax=0.1):
  plt.close('all')
  plt.set_cmap('viridis')

  # Visualize sparsity of A.
  plt.imshow(hmm.A, vmax=1.0)
  plt.colorbar()
  plt.title('Sparsity of A matrix')
  plt.show()

  # Visualize parsity of O.
  plt.imshow(np.array(hmm.O)[:, :O_max_cols], vmax=O_vmax, aspect='auto')
  plt.colorbar()
  plt.title('Sparsity of O matrix')
  plt.show()

visualize_sparsities(hmm_model, O_max_cols=50)

desired_pos_tags = {'NN', 'VB', 'DT', 'ADJ'}

def get_pos_tag(word):
    return pos_tag([word])[0][1]


def count_pos_tags_for_states_nltk(hmm, idx_to_word, max_words=50):
    pos_tag_counts = defaultdict(lambda: defaultdict(int))

    emission, states = hmm.generate_emission(M=max_words)

    for state, obs_idx in zip(states, emission):
        word = idx_to_word[obs_idx]
        pos_tag = get_pos_tag(word)
        pos_tag_counts[state][pos_tag] += 1

    return pos_tag_counts

pos_tag_counts = count_pos_tags_for_states_nltk(hmm_model, idx_to_word, max_words=500)

def plot_pos_tag_distribution(pos_tag_counts):
    state_names = list(pos_tag_counts.keys())
    pos_tags = list(set(tag for state in pos_tag_counts for tag in pos_tag_counts[state].keys()))

    data = []
    for state in state_names:
        for pos_tag in pos_tags:
            count = pos_tag_counts[state].get(pos_tag, 0)
            data.append({'State': state, 'POS Tag': pos_tag, 'Count': count})

    df = pd.DataFrame(data)

    plt.figure(figsize=(10, 6))
    sns.barplot(data=df, x='State', y='Count', hue='POS Tag')
    plt.title('POS Tag Distribution in States')
    plt.ylabel('Count')
    plt.xlabel('State')
    plt.xticks(rotation=90)
    plt.show()

plot_pos_tag_distribution(pos_tag_counts)

In [None]:

####################
# WORDCLOUD FUNCTIONS
####################

def mask():
    # Parameters.
    r = 128
    d = 2 * r + 1

    # Get points in a circle.
    y, x = np.ogrid[-r:d-r, -r:d-r]
    circle = (x**2 + y**2 <= r**2)

    # Create mask.
    mask = 255 * np.ones((d, d), dtype=np.uint8)
    mask[circle] = 0

    return mask

def text_to_wordcloud(text, max_words=50, title='', show=True):
    plt.close('all')

    # Generate a wordcloud image.
    wordcloud = WordCloud(random_state=0,
                          max_words=max_words,
                          background_color='white',
                          mask=mask()).generate(text)

    # Show the image.
    if show:
        plt.imshow(wordcloud, interpolation='bilinear')
        plt.axis('off')
        plt.title(title, fontsize=24)
        plt.show()

    return wordcloud

def states_to_wordclouds(hmm, obs_map, max_words=50, show=True):
    # Initialize.
    M = 100000
    n_states = len(hmm.A)
    obs_map_r = obs_map_reverser(obs_map)
    wordclouds = []

    # Generate a large emission.
    emission, states = hmm.generate_emission(M)

    # For each state, get a list of observations that have been emitted from that state.
    obs_count = []
    for i in range(n_states):
        obs_lst = np.array(emission)[np.where(np.array(states) == i)[0]]
        obs_count.append(obs_lst)

    # For each state, convert it into a wordcloud.
    for i in range(n_states):
        obs_lst = obs_count[i]
        if len(obs_lst) == 0:
            print(f"Warning: State {i} has no observations.")
            continue

        sentence = [obs_map_r[j] for j in obs_lst]
        sentence_str = ' '.join(sentence)

        if not sentence_str.strip():
            print(f"Warning: State {i} has no valid words.")
            continue

        wordclouds.append(text_to_wordcloud(sentence_str, max_words=max_words, title=f'State {i}', show=show))

    return wordclouds



####################
# HMM FUNCTIONS
####################

def parse_observations(text):
    # Convert text to dataset, skipping blank lines
    lines = [line.split() for line in text.split('\n') if line.strip()]  # Skip blank lines

    obs_counter = 0
    obs = []
    obs_map = {}

    for line in lines:
        obs_elem = []

        for word in line:
            word = re.sub(r'[^\w]', '', word).lower()
            if word not in obs_map:
                # Add unique words to the observations map.
                obs_map[word] = obs_counter
                obs_counter += 1

            # Add the encoded word.
            obs_elem.append(obs_map[word])

        # Add the encoded sequence.
        obs.append(obs_elem)

    return obs, obs_map

def obs_map_reverser(obs_map):
    obs_map_r = {}

    for key in obs_map:
        obs_map_r[obs_map[key]] = key

    return obs_map_r

def sample_sentence(hmm, obs_map, n_words=100, seed=None):
    # Get reverse map.
    obs_map_r = obs_map_reverser(obs_map)

    # Sample and convert sentence.
    emission, states = hmm.generate_emission(n_words, seed=seed)
    sentence = [obs_map_r[i] for i in emission]

    return ' '.join(sentence).capitalize() + '...'

In [None]:
obs, obs_map = parse_observations(text)
wordclouds = states_to_wordclouds(hmm_model, obs_map)