## Model

In [1]:
import torch.nn as nn

In [2]:
class LSTMModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(self, ntoken, ninp, nhid, nlayers, dropout=0.5):
        super(LSTMModel, self).__init__()
        self.drop = nn.Dropout(dropout)
        self.encoder = nn.Embedding(ntoken, ninp)
        self.rnn = nn.LSTM(ninp, nhid, nlayers, dropout=dropout)
        self.decoder = nn.Linear(nhid, ntoken)

        self.init_weights()

        self.nhid = nhid
        self.nlayers = nlayers

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, input, hidden):
        emb = self.drop(self.encoder(input))
        output, hidden = self.rnn(emb, hidden)
        output = self.drop(output)
        decoded = self.decoder(output.view(output.size(0)*output.size(1), output.size(2)))
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    def init_hidden(self, bsz):
        weight = next(self.parameters())
        return (weight.new_zeros(self.nlayers, bsz, self.nhid),
                weight.new_zeros(self.nlayers, bsz, self.nhid))

## Data preprocessing

In [3]:
import os
from io import open
import torch
from nltk.tokenize import wordpunct_tokenize
import json
import enchant
d_english = enchant.Dict("en_US")
valid_noword = ['end_of_poem','!','"',',','.',':',';','?']

class Dictionary(object):
    def __init__(self):
        self.word2idx = {}
        self.idx2word = []

    def add_word(self, word):
        if word not in self.word2idx:
            self.idx2word.append(word)
            self.word2idx[word] = len(self.idx2word) - 1
        return self.word2idx[word]

    def __len__(self):
        return len(self.idx2word)


class Corpus(object):
    def __init__(self, path):
        self.dictionary = Dictionary()
        self.train = self.tokenize(os.path.join(path, 'poems.json'))

    def tokenize(self, path):
        """Tokenizes a text file."""
        assert os.path.exists(path)
        # Add words to the dictionary
        with open(path, 'r', encoding="utf8") as f:
            tokens = 0
            for poem in f:
                poem_json = json.loads(poem)
                lines = poem_json['text'].split('\n') + ['end_of_poem']
                for words in lines:
                    for word in wordpunct_tokenize(words):
                        if d_english.check(word) or word in valid_noword:
                            self.dictionary.add_word(word.lower())
                            tokens += 1

        # Tokenize file content
        with open(path, 'r', encoding="utf8") as f:
            ids = torch.LongTensor(tokens)
            token = 0
            for poem in f:
                poem_json = json.loads(poem)
                lines = poem_json['text'].split('\n') + ['end_of_poem']
                for words in lines:
                    for word in wordpunct_tokenize(words):
                        if d_english.check(word) or word in valid_noword:
                            ids[token] = self.dictionary.word2idx[word.lower()]
                            token += 1

        return ids

## Build and train

In [4]:
import time
import math
import os
import torch
import torch.nn as nn

Set the hyperparameters

In [5]:
# parameters
seed = 1111
data = './'
batch_size = 32
emsize = 200
nhid = 500
nlayers = 2
dropout = 0.5
bptt = 6
log_interval = 100
lr = 30
epochs = 40
save = 'model.pt'

# Set the random seed manually for reproducibility.
torch.manual_seed(seed)

device = torch.device("cpu")

In [6]:
###############################################################################
# Load data
###############################################################################

corpus = Corpus(data)

# Starting from sequential data, batchify arranges the dataset into columns.
# For instance, with the alphabet as the sequence and batch size 4, we'd get
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
# │ e k q w │
# └ f l r x ┘.
# These columns are treated as independent by the model, which means that the
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
# batch processing.

def batchify(data, bsz):
    # Work out how cleanly we can divide the dataset into bsz parts.
    nbatch = data.size(0) // bsz
    # Trim off any extra elements that wouldn't cleanly fit (remainders).
    data = data.narrow(0, 0, nbatch * bsz)
    # Evenly divide the data across the bsz batches.
    data = data.view(bsz, -1).t().contiguous()
    return data.to(device)

train_data = batchify(corpus.train, batch_size)

###############################################################################
# Build the model
###############################################################################

ntokens = len(corpus.dictionary)
model = LSTMModel(ntokens, emsize, nhid, nlayers, dropout).to(device)

criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)

###############################################################################
# Training code
###############################################################################

def repackage_hidden(h):
    """Wraps hidden states in new Tensors, to detach them from their history."""
    if isinstance(h, torch.Tensor):
        return h.detach()
    else:
        return tuple(repackage_hidden(v) for v in h)


# get_batch subdivides the source data into chunks of length bptt.
# If source is equal to the example output of the batchify function, with
# a bptt-limit of 2, we'd get the following two Variables for i = 0:
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘
# Note that despite the name of the function, the subdivison of data is not
# done along the batch dimension (i.e. dimension 1), since that was handled
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.

def get_batch(source, i):
    seq_len = min(bptt, len(source) - 1 - i)
    data = source[i:i+seq_len]
    target = source[i+1:i+1+seq_len].view(-1)
    return data, target

def train():
    # Turn on training mode which enables dropout.
    model.train()
    start_time = time.time()
    ntokens = len(corpus.dictionary)
    hidden = model.init_hidden(batch_size)
    for batch, i in enumerate(range(0, train_data.size(0) - 1, bptt)):
        data, targets = get_batch(train_data, i)
        # Starting each batch, we detach the hidden state from how it was previously produced.
        # If we didn't, the model would try backpropagating all the way to start of the dataset.
        hidden = repackage_hidden(hidden)
        optimizer.zero_grad()
        output, hidden = model(data, hidden)
        loss = criterion(output.view(-1, ntokens), targets)
        loss.backward()                
        optimizer.step()
        # output to screen
        if batch % log_interval == 0 and batch > 0:
            elapsed = time.time() - start_time
            print('| epoch {:3d} | {:5d}/{:5d} batches | lr {:02.2f} | ms/batch {:5.2f} | '
                    'loss {:5.2f} '.format(
                epoch, batch, len(train_data) // bptt, lr,
                elapsed * 1000 / log_interval, loss.item()))
            start_time = time.time()

In [7]:
# Loop over epochs.
best_val_loss = None

# At any point you can hit Ctrl + C to break out of training early.
try:
    for epoch in range(1, epochs+1):
        epoch_start_time = time.time()
        train()
        print('-' * 89)
        print('| end of epoch {:3d} | time: {:5.2f}s '.format(epoch, (time.time() - epoch_start_time)))
        print('-' * 89)
        # Save the model
        with open(save, 'wb') as f:
            torch.save(model, f)
except KeyboardInterrupt:
    print('-' * 89)
    print('Exiting from training early')

| epoch   1 |   100/ 4550 batches | lr 30.00 | ms/batch 435.33 | loss  9.91 
| epoch   1 |   200/ 4550 batches | lr 30.00 | ms/batch 440.69 | loss  8.08 
| epoch   1 |   300/ 4550 batches | lr 30.00 | ms/batch 413.98 | loss  7.78 
| epoch   1 |   400/ 4550 batches | lr 30.00 | ms/batch 420.85 | loss  7.56 
| epoch   1 |   500/ 4550 batches | lr 30.00 | ms/batch 428.54 | loss  7.46 
| epoch   1 |   600/ 4550 batches | lr 30.00 | ms/batch 411.12 | loss  7.48 
| epoch   1 |   700/ 4550 batches | lr 30.00 | ms/batch 394.20 | loss  6.98 
| epoch   1 |   800/ 4550 batches | lr 30.00 | ms/batch 387.46 | loss  7.13 
| epoch   1 |   900/ 4550 batches | lr 30.00 | ms/batch 390.85 | loss  6.93 
| epoch   1 |  1000/ 4550 batches | lr 30.00 | ms/batch 390.45 | loss  6.96 
| epoch   1 |  1100/ 4550 batches | lr 30.00 | ms/batch 390.93 | loss  6.91 
| epoch   1 |  1200/ 4550 batches | lr 30.00 | ms/batch 390.63 | loss  6.91 
| epoch   1 |  1300/ 4550 batches | lr 30.00 | ms/batch 390.92 | loss  7.30 

  "type " + obj.__name__ + ". It won't be checked "


| epoch   2 |   100/ 4550 batches | lr 30.00 | ms/batch 393.83 | loss  6.69 
| epoch   2 |   200/ 4550 batches | lr 30.00 | ms/batch 387.86 | loss  6.86 
| epoch   2 |   300/ 4550 batches | lr 30.00 | ms/batch 387.70 | loss  6.92 
| epoch   2 |   400/ 4550 batches | lr 30.00 | ms/batch 388.09 | loss  6.91 
| epoch   2 |   500/ 4550 batches | lr 30.00 | ms/batch 387.72 | loss  7.02 
| epoch   2 |   600/ 4550 batches | lr 30.00 | ms/batch 386.88 | loss  6.96 
| epoch   2 |   700/ 4550 batches | lr 30.00 | ms/batch 387.04 | loss  6.72 
| epoch   2 |   800/ 4550 batches | lr 30.00 | ms/batch 384.91 | loss  6.94 
| epoch   2 |   900/ 4550 batches | lr 30.00 | ms/batch 387.31 | loss  6.66 
| epoch   2 |  1000/ 4550 batches | lr 30.00 | ms/batch 389.24 | loss  6.63 
| epoch   2 |  1100/ 4550 batches | lr 30.00 | ms/batch 389.57 | loss  6.61 
| epoch   2 |  1200/ 4550 batches | lr 30.00 | ms/batch 392.23 | loss  6.72 
| epoch   2 |  1300/ 4550 batches | lr 30.00 | ms/batch 394.38 | loss  7.09 

## Generate poems

In [8]:
import string

###############################################################################
# Language Modeling
#
# This file generates new sentences sampled from the language model
#
###############################################################################

checkpoint = './model.pt'
outf = 'generated.txt'
words = 500

# Set the random seed manually for reproducibility.
torch.manual_seed(0)

with open(checkpoint, 'rb') as f:
    model = torch.load(f).to(device)
model.eval()

corpus = Corpus(data)
ntokens = len(corpus.dictionary)

hidden = model.init_hidden(1)
input = torch.randint(ntokens, (1, 1), dtype=torch.long).to(device)

with open(outf, 'w') as outf:
    with torch.no_grad():  # no tracking history
        # write the initial inputs
        word_idx = input.item()
        word = corpus.dictionary.idx2word[word_idx]
        if word in string.punctuation:
            write_down = word + '\n'
        else:
            write_down = word + ' '
        outf.write(write_down)
        for i in range(words):
            output, hidden = model(input, hidden)
            # use softmax to calculate the output weights
            word_weights = torch.nn.functional.softmax(output.squeeze(),0)
            word_idx = torch.multinomial(word_weights, 1)[0]
            input.fill_(word_idx)
            word = corpus.dictionary.idx2word[word_idx]
            
            print('Generated {}/{} words'.format(i, words))
            
            if word == 'end_of_poem':
                break
            
            if word in string.punctuation:
                write_down = word + '\n'
            else:
                write_down = word + ' '
            outf.write(write_down)

Generated 0/500 words
Generated 1/500 words
Generated 2/500 words
Generated 3/500 words
Generated 4/500 words
Generated 5/500 words
Generated 6/500 words
Generated 7/500 words
Generated 8/500 words
Generated 9/500 words
Generated 10/500 words
Generated 11/500 words
Generated 12/500 words
Generated 13/500 words
Generated 14/500 words
Generated 15/500 words
Generated 16/500 words
Generated 17/500 words
Generated 18/500 words
Generated 19/500 words
Generated 20/500 words
Generated 21/500 words
Generated 22/500 words
Generated 23/500 words
Generated 24/500 words
Generated 25/500 words
Generated 26/500 words
Generated 27/500 words
Generated 28/500 words
Generated 29/500 words
Generated 30/500 words
Generated 31/500 words
Generated 32/500 words
Generated 33/500 words
Generated 34/500 words
Generated 35/500 words
Generated 36/500 words
Generated 37/500 words
Generated 38/500 words
Generated 39/500 words
Generated 40/500 words
Generated 41/500 words
Generated 42/500 words
Generated 43/500 word