# Generating Haiku with a Naive LSTM Network

In [None]:
%load_ext autoreload
%autoreload 2
%aimport haikulib.utils.data

%config InlineBackend.figure_format = 'svg'
%matplotlib inline

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

import torch
import torch.nn as nn

sns.set()

In [None]:
class WordLanguageModel(nn.Module):
    """Container module with an encoder, a recurrent module, and a decoder."""

    def __init__(
        self,
        rnn_type,
        ntokens,
        nembeddings,
        nhidden,
        nlayers,
        dropout=0.5,
        tie_weights=False,
    ):
        super().__init__()

        self.encoder = nn.Embedding(ntokens, nembeddings)
        if rnn_type in {"LSTM", "GRU"}:
            self.rnn = getattr(nn, rnn_type)(
                nembeddings, nhidden, nlayers, dropout=dropout
            )
        else:
            raise ValueError(f"RNN type '{rnn_type}' unknown.")
        self.dropout = nn.Dropout(dropout)
        self.decoder = nn.Linear(nhidden, ntokens)

        # Optionally tie weights as in:
        # "Using the Output Embedding to Improve Language Models" (Press & Wolf 2016)
        # https://arxiv.org/abs/1608.05859
        # and
        # "Tying Word Vectors and Word Classifiers: A Loss Framework for Language Modeling" (Inan et al. 2016)
        # https://arxiv.org/abs/1611.01462
        if tie_weights:
            if nhidden != nembeddings:
                raise ValueError(
                    "When using the tied flag, number of hidden units must be equal to the number of embeddings"
                )
            self.decoder.weight = self.encoder.weight

        self.rnn_type = rnn_type
        self.nhidden = nhidden
        self.nlayers = nlayers

        self.init_weights()

    def init_weights(self):
        initrange = 0.1
        self.encoder.weight.data.uniform_(-initrange, initrange)
        self.decoder.bias.data.zero_()
        self.decoder.weight.data.uniform_(-initrange, initrange)

    def forward(self, minibatch, hidden):
        embeddings = self.dropout(self.encoder(minibatch))
        output, hidden = self.rnn(embeddings, hidden)
        output = self.dropout(output)
        decoded = self.decoder(
            output.view(output.size(0) * output.size(1), output.size(2))
        )
        return decoded.view(output.size(0), output.size(1), decoded.size(1)), hidden

    # We need to reset the LSTM states at the beginning of every epoch.
    def init_hidden(self, bsz):
        weight = next(self.parameters())
        if self.rnn_type == "LSTM":
            return (
                weight.new_zeros(self.nlayers, bsz, self.nhidden),
                weight.new_zeros(self.nlayers, bsz, self.nhidden),
            )
        else:
            return weight.new_zeros(self.nlayers, bsz, self.nhidden)

In [None]:
# TODO: Split into training, test, and validation splits.
dataset = haikulib.utils.data.HaikuVocabIndexDataset(seq_len=3, method="words")
dataloader = torch.utils.data.DataLoader(dataset, batch_size=256, shuffle=True, num_workers=4)