In [1]:
import string
import torch
import torch.nn as nn

## The LSTM model

In [2]:
class LSTMModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

## The data corpus

In [20]:
import os
from io import open
import torch
from nltk.tokenize import wordpunct_tokenize
import json
import enchant
d_english = enchant.Dict("en_US")
valid_noword = ['end_of_poem','!','"',',','.',':',';','?']

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'poems.json'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            tokens = 0
            for poem in f:
                poem_json = json.loads(poem)
                lines = poem_json['text'].split('\n') + ['end_of_poem']
                for words in lines:
                    for word in wordpunct_tokenize(words):
                        if d_english.check(word) or word in valid_noword:
                            self.dictionary.add_word(word.lower())
                            tokens += 1

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            ids = torch.LongTensor(tokens)
            token = 0
            for poem in f:
                poem_json = json.loads(poem)
                lines = poem_json['text'].split('\n') + ['end_of_poem']
                for words in lines:
                    for word in wordpunct_tokenize(words):
                        if d_english.check(word) or word in valid_noword:
                            ids[token] = self.dictionary.word2idx[word.lower()]
                            token += 1

        return ids

Some parameters including the random seed

In [24]:
checkpoint = './model.pt'
outf = 'generated.txt'
words = 500
data = './'
device = torch.device("cpu")

with open(checkpoint, 'rb') as f:
    model = torch.load(f).to(device)
model.eval()

# load the saved corpus
import pickle
corpus = pickle.load(open("corpus.p","rb"))

ntokens = len(corpus.dictionary)

## The prompt based generation
No prompt means randomly generate the first prompt and set hidden parameters to be random

In [11]:
seed = 926

# Set the random seed manually for reproducibility.
torch.manual_seed(seed)

# random generation
hidden = model.init_hidden(1)
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

# prompt based generation
#prompt_in = ['meow','meow','meow']
prompt_in = []
if prompt_in != []:
    prompt_index = corpus.dictionary.word2idx[prompt_in[0]]
    input.fill_(prompt_index)

In [12]:
with open(outf, 'w') as outfile:
    with torch.no_grad():  # no tracking history
# write the prompt inputs
        for prompt_word in prompt_in:
            prompt_index = corpus.dictionary.word2idx[prompt_word]
            prompt_word = corpus.dictionary.idx2word[prompt_index]
            if prompt_word in string.punctuation:
                write_down = prompt_word + '\n'
            else:
                write_down = prompt_word + ' '
            outfile.write(write_down)
            # continue calculating the model to get output and hidden variables
            output, hidden = model(input, hidden)
            input.fill_(prompt_index)
# generate poems after the prompt inputs
        for i in range(words):
            output, hidden = model(input, hidden)
            # use softmax to calculate the output weights
            word_weights = torch.nn.functional.softmax(output.squeeze(),0)
            word_idx = torch.multinomial(word_weights, 1)[0]
            input.fill_(word_idx)
            word = corpus.dictionary.idx2word[word_idx]
            
            print('Generated {}/{} words'.format(i, words))
            
            if word == 'end_of_poem':
                break
            
            if word in string.punctuation:
                write_down = word + '\n'
            else:
                write_down = word + ' '
            outfile.write(write_down)

Generated 0/500 words
Generated 1/500 words
Generated 2/500 words
Generated 3/500 words
Generated 4/500 words
Generated 5/500 words
Generated 6/500 words
Generated 7/500 words
Generated 8/500 words
Generated 9/500 words
Generated 10/500 words
Generated 11/500 words
Generated 12/500 words
Generated 13/500 words
Generated 14/500 words
Generated 15/500 words
Generated 16/500 words
Generated 17/500 words
Generated 18/500 words
Generated 19/500 words
Generated 20/500 words
Generated 21/500 words
Generated 22/500 words
Generated 23/500 words
Generated 24/500 words
Generated 25/500 words
Generated 26/500 words
Generated 27/500 words
Generated 28/500 words
Generated 29/500 words
Generated 30/500 words
Generated 31/500 words
Generated 32/500 words
Generated 33/500 words
Generated 34/500 words
Generated 35/500 words
Generated 36/500 words
Generated 37/500 words
Generated 38/500 words
Generated 39/500 words
Generated 40/500 words
Generated 41/500 words
Generated 42/500 words
Generated 43/500 word