In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [2]:
class Tokenizer():
    def __init__(self, file_path):
        self.sonnets = []
        self.vocab = self.create_vocab(file_path)
        self.stoi, self.itos = self.create_dict()
        self.file_path = file_path
        self.max_len = max([len(i) for i in self.sonnets])
        self.indexes = torch.arange(len(self.vocab))

    def create_vocab(self, path):
        f = open(path)
        vocab_set = {'<SOS>', '<EOS>'}
        sonnet = []
        for line in f.readlines():
            try:
                int(line.strip())
                if sonnet:
                    sonnet = ['<SOS>'] + sonnet + ['<EOS>']
                    self.sonnets.append(sonnet)
                sonnet = []
            except ValueError:
                words = [i.lower().strip() for i in line.split()]
                sonnet += (words)
                vocab_set.update(set(words))
        sonnet = ['<SOS>'] + sonnet + ['<EOS>']
        self.sonnets.append(sonnet)
        return vocab_set

    def create_dict(self):
        stoi = OrderedDict()
        itos = OrderedDict()
        for i, token in enumerate(self.vocab):
            stoi[token] = i
            itos[i] = token
        return stoi, itos

    def sonnet_to_tensor(self, sonnet):
        x = torch.zeros(self.max_len, len(self.vocab))
        for idx, word in enumerate(sonnet):
            x[idx, self.stoi[word]] = 1
        return x

    def get_tensor_data(self):
        data_tensor = torch.zeros(len(self.sonnets), self.max_len, len(self.vocab))
        for idx, sonnet in enumerate(self.sonnets):
            data_tensor[idx] = self.sonnet_to_tensor(sonnet)
        return data_tensor

    def word_to_one_hot(self, index, batch_size=1):
        x = torch.zeros(batch_size, 1, len(self.vocab))
        x[:, :, index] = 1
        return x

    def word_to_one_hot_batch(self, index):
        x = torch.zeros(index.shape[0], 1, len(self.vocab))
        for i, idx in enumerate(index):
            x[i] = self.word_to_one_hot(idx)  
        return x
        
    def decode(self, sonnet):
        words = ''
        try:
            for idx, row in enumerate(sonnet):
                words += self.itos[self.indexes[row.to(torch.bool)].item()] + ' '        
        except RuntimeError:
            pass
        return words

<strong> Recurrent Neural Networks</strong>
<p>  Hidden state update : </p> $$h_t = tanh(x_tW^T_{ih} + b_{ih} + h_{t-1}W^T_{hh} + b_{hh})$$
<p>$x_t$ : input at time step t</p>
<p>$W_{ih}$: input to hidden weight matrix</p>
<p>$b_{ih}$: input to hidden biases</p>
<p>$h_{t-1}$: previous hidden state</p>
<p>$W_{hh}$: weight matrix for previous hidden state to current hidden state</p>
<p>$h_t$: current hidden state</p>
<br>

<strong> Gated Recurrent Unit</strong>
\begin{align*}
r_t &= \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\
z_t &= \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\
n_t &= \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)} + b_{hn})) \\
h_t &= (1 - z_t) \odot n_t + z_t \odot h_{(t-1)}
\end{align*}

<p>$x_t$ : input at time step t</p>
<p>$h_{t-1}$: previous hidden state</p>
<p>$h_t$: current hidden state</p>
<p>$r_t$: reset gate</p>
<p>$z_t$: update gate</p>
<p>$n_t$: candidate for current hidden state</p>
<p>$h_t$: current hidden state</p>

<strong> Long short term memory</strong>
<p>Hidden state update :</p>

\begin{align*}
\Gamma_u &= \sigma(W_{iu} x_t + b_{iu} + W_{hu} h_{t-1} + b_{hu}) \\
\Gamma_f &= \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
\tilde{c_t} &= \tanh(W_{ic} x_t + b_{ic} + W_{hc} h_{t-1} + b_{hc}) \\
\Gamma_o &= \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
c_t &= \Gamma_f \odot c_{t-1} + \Gamma_u \odot \tilde{c_t} \\
h_t &= \Gamma_o \odot \tanh(c_t)
\end{align*}

<p>$x_t$ : input at time step t</p>
<p>$h_{t-1}$: previous hidden state</p>
<p>$h_t$: current hidden state</p>
<p>$\Gamma_u$: update gate</p>
<p>$\Gamma_o$: output gate</p>
<p>$\Gamma_f$: forget gate</p>
<p>$\tilde{c_t} $: candidate for current memory state</p>
<p>$c_t$: current memory cell state</p>

In [3]:
tokenizer = Tokenizer('shakespeare.txt')

In [4]:
data_tensor = tokenizer.get_tensor_data()

In [5]:
class SonnetDataset(Dataset):
    def __init__(self, x):
        self.x = x
        self.samples = x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx]

    def __len__(self):
        return self.samples

In [6]:
train_dataset = SonnetDataset(data_tensor)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [7]:
class SeqGeneration(nn.Module):
    def __init__(self, input_size, hidden_size, num_cells, vocab_length, cell_init='rnn', bidirectional=False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_cells = num_cells
        self.num_directions = 2 if bidirectional else 1
        self.cell_init = cell_init
        if cell_init == 'rnn':
            self.cell = nn.RNN(input_size, hidden_size, num_cells, batch_first=True, bidirectional=bidirectional, dropout=0.2)
        elif cell_init == 'gru':
            self.cell = nn.GRU(input_size, hidden_size, num_cells, batch_first=True, bidirectional=bidirectional, dropout=0.2)
        elif cell_init == 'lstm':
            self.cell = nn.LSTM(input_size, hidden_size, num_cells, batch_first=True, bidirectional=bidirectional, dropout=0.2)
        else:
            raise Exception('Invalid cell type')
        self.fc1 = nn.Linear(self.num_directions*hidden_size, vocab_length)
        self.softmax = nn.LogSoftmax(dim=-1)
        

    def forward(self, x, h0=None, c0=None, device='mps'):
        # data you pass will be the first time step only.
        if h0 is None:
            h0 = torch.zeros(self.num_directions*self.num_cells, x.shape[0], self.hidden_size, device=device)
        if self.cell_init == 'lstm':
            if c0 is None:
                c0 = torch.zeros(self.num_directions*self.num_cells, x.shape[0], self.hidden_size, device=device)
            hidden_states, (ht_all, c0) = self.cell(x, (h0, c0))
        else:
            hidden_states, ht_all = self.cell(x, h0)
            # ht will always be of the shape h0 
        ht = hidden_states[:, -1, :]
        logits = self.fc1(ht)
        scores = self.softmax(logits)
        return scores, ht_all, c0

In [8]:
model = SeqGeneration(len(tokenizer.vocab), 128, 2, len(tokenizer.vocab), 'lstm', True)
# model.load_state_dict(torch.load('SeqGeneration.pth'))
device = 'mps'
model = model.to(device)
loss = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [9]:
def train(model, train_loader, train_params, tokenizer):
    min_loss = float('inf')
    epochs = train_params['epochs']
    loss = train_params['loss']
    optimizer = train_params['optimizer']
    for epoch in range(epochs):
        loss_at_batch = None
        for x in tqdm(train_loader, desc=f'Epoch: {epoch}/{epochs}, best loss: {min_loss}'):
            loss_at_batch = []
            ht = None
            c0 = None
            x = x.to(device)
            for t in range(x.shape[1] - 1):
                xt = x[:, t:t+1, :]
                scores, ht, c0 = model(xt, ht, c0)
                next_word_index = torch.argmax(scores, dim=1)
                labels = x[:, t+1:t+2, :]
                labels = torch.argmax(labels, dim=2).flatten()
                if labels.sum() == 0:
                    break
                loss_at_timestep = loss(scores, labels)
                loss_at_batch.append(loss_at_timestep)
                xt = tokenizer.word_to_one_hot_batch(next_word_index)
                xt = xt.to(device)
            loss_at_batch = sum(loss_at_batch)/len(loss_at_batch)
            optimizer.zero_grad()
            loss_at_batch.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), 0.5)
            optimizer.step()
        if loss_at_batch < min_loss:
            torch.save(model.state_dict(), 'SeqGeneration_lstm.pth')
            min_loss = loss_at_batch
            print('saving model')

In [10]:
def generate_text(model, sos_tensor, tokenizer):
    curr_token = ''
    sonnet = ''
    pred = ''
    ht = None
    model.eval()
    with torch.no_grad():  
        while curr_token != '<EOS>' and len(sonnet.split()) < tokenizer.max_len:
            sonnet += pred + ' '
            scores, ht, ct = model(sos_tensor, ht, None, 'cpu')
            # pred = tokenizer.itos[torch.argmax(scores).item()] always generates the same sequence
            # sample to generate different sequences each time.
            pred = tokenizer.itos[torch.argmax(scores).item()]
            random_word_index = torch.multinomial(torch.exp(scores).flatten(), num_samples=1).item()
            curr_token = pred
            sos_tensor = tokenizer.word_to_one_hot(tokenizer.stoi[tokenizer.itos[random_word_index]], 1)
        return sonnet

In [11]:
train_params = {'epochs': 200, 'loss': loss, 'optimizer': optimizer} # set epochs to train
train(model, train_loader, train_params, tokenizer)

Epoch: 0/200, best loss: inf: 100%|█████████████| 10/10 [00:07<00:00,  1.41it/s]


saving model


Epoch: 1/200, best loss: 7.480195045471191: 100%|█| 10/10 [00:06<00:00,  1.52it/


saving model


Epoch: 2/200, best loss: 6.760316848754883: 100%|█| 10/10 [00:06<00:00,  1.54it/


saving model


Epoch: 3/200, best loss: 6.553677558898926: 100%|█| 10/10 [00:06<00:00,  1.54it/
Epoch: 4/200, best loss: 6.553677558898926: 100%|█| 10/10 [00:06<00:00,  1.52it/


saving model


Epoch: 5/200, best loss: 6.544550895690918: 100%|█| 10/10 [00:06<00:00,  1.53it/


saving model


Epoch: 6/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.50it/
Epoch: 7/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.52it/
Epoch: 8/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.46it/
Epoch: 9/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.50it/
Epoch: 10/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.49it
Epoch: 11/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.50it
Epoch: 12/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 13/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.53it
Epoch: 14/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 15/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 16/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.53it
Epoch: 17/200, best loss: 6.442972183227539: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 18/200, best loss: 6.

saving model


Epoch: 20/200, best loss: 6.39254093170166: 100%|█| 10/10 [00:06<00:00,  1.54it/
Epoch: 21/200, best loss: 6.39254093170166: 100%|█| 10/10 [00:06<00:00,  1.54it/
Epoch: 22/200, best loss: 6.39254093170166: 100%|█| 10/10 [00:06<00:00,  1.53it/


saving model


Epoch: 23/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 24/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.53it
Epoch: 25/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 26/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 27/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.55it
Epoch: 28/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.55it
Epoch: 29/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.55it
Epoch: 30/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 31/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.53it
Epoch: 32/200, best loss: 6.244823932647705: 100%|█| 10/10 [00:06<00:00,  1.45it


saving model


Epoch: 33/200, best loss: 6.216181755065918: 100%|█| 10/10 [00:06<00:00,  1.50it


saving model


Epoch: 34/200, best loss: 6.1520538330078125: 100%|█| 10/10 [00:06<00:00,  1.51i


saving model


Epoch: 35/200, best loss: 6.10606575012207: 100%|█| 10/10 [00:06<00:00,  1.51it/
Epoch: 36/200, best loss: 6.10606575012207: 100%|█| 10/10 [00:06<00:00,  1.51it/


saving model


Epoch: 37/200, best loss: 5.804404258728027: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 38/200, best loss: 5.804404258728027: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 39/200, best loss: 5.804404258728027: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 40/200, best loss: 5.804404258728027: 100%|█| 10/10 [00:06<00:00,  1.52it


saving model


Epoch: 41/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.51it
Epoch: 42/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 43/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.53it
Epoch: 44/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.53it
Epoch: 45/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 46/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.51it
Epoch: 47/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.51it
Epoch: 48/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.51it
Epoch: 49/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.51it
Epoch: 50/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 51/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 52/200, best loss: 5.752346992492676: 100%|█| 10/10 [00:06<00:00,  1.51it
Epoch: 53/200, best loss: 5.

saving model


Epoch: 54/200, best loss: 5.5898356437683105: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 55/200, best loss: 5.5898356437683105: 100%|█| 10/10 [00:06<00:00,  1.52i
Epoch: 56/200, best loss: 5.5898356437683105: 100%|█| 10/10 [00:06<00:00,  1.54i
Epoch: 57/200, best loss: 5.5898356437683105: 100%|█| 10/10 [00:06<00:00,  1.55i


saving model


Epoch: 58/200, best loss: 5.418381214141846: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 59/200, best loss: 5.418381214141846: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 60/200, best loss: 5.418381214141846: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 61/200, best loss: 5.418381214141846: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 62/200, best loss: 5.418381214141846: 100%|█| 10/10 [00:06<00:00,  1.53it


saving model


Epoch: 63/200, best loss: 5.39573335647583: 100%|█| 10/10 [00:06<00:00,  1.53it/
Epoch: 64/200, best loss: 5.39573335647583: 100%|█| 10/10 [00:06<00:00,  1.53it/
Epoch: 65/200, best loss: 5.39573335647583: 100%|█| 10/10 [00:06<00:00,  1.52it/


saving model


Epoch: 66/200, best loss: 5.25770378112793: 100%|█| 10/10 [00:06<00:00,  1.52it/
Epoch: 67/200, best loss: 5.25770378112793: 100%|█| 10/10 [00:06<00:00,  1.51it/
Epoch: 68/200, best loss: 5.25770378112793: 100%|█| 10/10 [00:06<00:00,  1.52it/
Epoch: 69/200, best loss: 5.25770378112793: 100%|█| 10/10 [00:06<00:00,  1.50it/
Epoch: 70/200, best loss: 5.25770378112793: 100%|█| 10/10 [00:06<00:00,  1.51it/


saving model


Epoch: 71/200, best loss: 5.220528602600098: 100%|█| 10/10 [00:06<00:00,  1.51it
Epoch: 72/200, best loss: 5.220528602600098: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 73/200, best loss: 5.220528602600098: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 74/200, best loss: 5.220528602600098: 100%|█| 10/10 [00:06<00:00,  1.51it


saving model


Epoch: 75/200, best loss: 5.124361515045166: 100%|█| 10/10 [00:06<00:00,  1.53it


saving model


Epoch: 76/200, best loss: 5.10618257522583: 100%|█| 10/10 [00:06<00:00,  1.52it/
Epoch: 77/200, best loss: 5.10618257522583: 100%|█| 10/10 [00:06<00:00,  1.54it/


saving model


Epoch: 78/200, best loss: 5.063533782958984: 100%|█| 10/10 [00:06<00:00,  1.53it
Epoch: 79/200, best loss: 5.063533782958984: 100%|█| 10/10 [00:06<00:00,  1.53it
Epoch: 80/200, best loss: 5.063533782958984: 100%|█| 10/10 [00:06<00:00,  1.52it
Epoch: 81/200, best loss: 5.063533782958984: 100%|█| 10/10 [00:06<00:00,  1.51it


saving model


Epoch: 82/200, best loss: 5.003179550170898: 100%|█| 10/10 [00:06<00:00,  1.50it
Epoch: 83/200, best loss: 5.003179550170898: 100%|█| 10/10 [00:06<00:00,  1.51it


saving model


Epoch: 84/200, best loss: 4.873672008514404: 100%|█| 10/10 [00:06<00:00,  1.51it


saving model


Epoch: 85/200, best loss: 4.843679904937744: 100%|█| 10/10 [00:06<00:00,  1.51it


saving model


Epoch: 86/200, best loss: 4.788336753845215: 100%|█| 10/10 [00:06<00:00,  1.54it


saving model


Epoch: 87/200, best loss: 4.756350040435791: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 88/200, best loss: 4.756350040435791: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 89/200, best loss: 4.756350040435791: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 90/200, best loss: 4.756350040435791: 100%|█| 10/10 [00:07<00:00,  1.41it


saving model


Epoch: 91/200, best loss: 4.731354713439941: 100%|█| 10/10 [00:06<00:00,  1.45it


saving model


Epoch: 92/200, best loss: 4.65305233001709: 100%|█| 10/10 [00:07<00:00,  1.34it/


saving model


Epoch: 93/200, best loss: 4.652480125427246: 100%|█| 10/10 [00:07<00:00,  1.36it


saving model


Epoch: 94/200, best loss: 4.634428024291992: 100%|█| 10/10 [00:07<00:00,  1.43it


saving model


Epoch: 95/200, best loss: 4.5104289054870605: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 96/200, best loss: 4.484066009521484: 100%|█| 10/10 [00:06<00:00,  1.47it


saving model


Epoch: 97/200, best loss: 4.482104301452637: 100%|█| 10/10 [00:06<00:00,  1.54it
Epoch: 98/200, best loss: 4.482104301452637: 100%|█| 10/10 [00:06<00:00,  1.51it
Epoch: 99/200, best loss: 4.482104301452637: 100%|█| 10/10 [00:06<00:00,  1.53it


saving model


Epoch: 100/200, best loss: 4.335834980010986: 100%|█| 10/10 [00:06<00:00,  1.52i
Epoch: 101/200, best loss: 4.335834980010986: 100%|█| 10/10 [00:06<00:00,  1.53i
Epoch: 102/200, best loss: 4.335834980010986: 100%|█| 10/10 [00:06<00:00,  1.53i


saving model


Epoch: 103/200, best loss: 4.286771297454834: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 104/200, best loss: 4.240783214569092: 100%|█| 10/10 [00:07<00:00,  1.40i
Epoch: 105/200, best loss: 4.240783214569092: 100%|█| 10/10 [00:06<00:00,  1.43i
Epoch: 106/200, best loss: 4.240783214569092: 100%|█| 10/10 [00:07<00:00,  1.41i


saving model


Epoch: 107/200, best loss: 4.153110027313232: 100%|█| 10/10 [00:06<00:00,  1.49i
Epoch: 108/200, best loss: 4.153110027313232: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 109/200, best loss: 4.047351837158203: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 110/200, best loss: 4.047351837158203: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 111/200, best loss: 4.047351837158203: 100%|█| 10/10 [00:06<00:00,  1.51i


saving model


Epoch: 112/200, best loss: 4.004422187805176: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 113/200, best loss: 4.004422187805176: 100%|█| 10/10 [00:06<00:00,  1.51i


saving model


Epoch: 114/200, best loss: 3.9444375038146973: 100%|█| 10/10 [00:06<00:00,  1.52


saving model


Epoch: 115/200, best loss: 3.708221197128296: 100%|█| 10/10 [00:06<00:00,  1.52i
Epoch: 116/200, best loss: 3.708221197128296: 100%|█| 10/10 [00:06<00:00,  1.52i
Epoch: 117/200, best loss: 3.708221197128296: 100%|█| 10/10 [00:06<00:00,  1.53i
Epoch: 118/200, best loss: 3.708221197128296: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 119/200, best loss: 3.708221197128296: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 120/200, best loss: 3.652432918548584: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 121/200, best loss: 3.652432918548584: 100%|█| 10/10 [00:06<00:00,  1.53i
Epoch: 122/200, best loss: 3.652432918548584: 100%|█| 10/10 [00:06<00:00,  1.54i
Epoch: 123/200, best loss: 3.652432918548584: 100%|█| 10/10 [00:06<00:00,  1.48i
Epoch: 124/200, best loss: 3.652432918548584: 100%|█| 10/10 [00:06<00:00,  1.51i


saving model


Epoch: 125/200, best loss: 3.520200729370117: 100%|█| 10/10 [00:06<00:00,  1.50i
Epoch: 126/200, best loss: 3.520200729370117: 100%|█| 10/10 [00:06<00:00,  1.50i
Epoch: 127/200, best loss: 3.520200729370117: 100%|█| 10/10 [00:06<00:00,  1.50i


saving model


Epoch: 128/200, best loss: 3.488630771636963: 100%|█| 10/10 [00:06<00:00,  1.47i
Epoch: 129/200, best loss: 3.488630771636963: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 130/200, best loss: 3.488630771636963: 100%|█| 10/10 [00:06<00:00,  1.51i


saving model


Epoch: 131/200, best loss: 3.329408645629883: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 132/200, best loss: 3.329408645629883: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 133/200, best loss: 3.302119493484497: 100%|█| 10/10 [00:06<00:00,  1.52i
Epoch: 134/200, best loss: 3.302119493484497: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 135/200, best loss: 3.302119493484497: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 136/200, best loss: 3.2707035541534424: 100%|█| 10/10 [00:06<00:00,  1.50


saving model


Epoch: 137/200, best loss: 3.205918073654175: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 138/200, best loss: 3.205918073654175: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 139/200, best loss: 3.0541133880615234: 100%|█| 10/10 [00:06<00:00,  1.50
Epoch: 140/200, best loss: 3.0541133880615234: 100%|█| 10/10 [00:06<00:00,  1.51
Epoch: 141/200, best loss: 3.0541133880615234: 100%|█| 10/10 [00:06<00:00,  1.52
Epoch: 142/200, best loss: 3.0541133880615234: 100%|█| 10/10 [00:06<00:00,  1.53
Epoch: 143/200, best loss: 3.0541133880615234: 100%|█| 10/10 [00:06<00:00,  1.54


saving model


Epoch: 144/200, best loss: 3.036862373352051: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 145/200, best loss: 2.984029769897461: 100%|█| 10/10 [00:06<00:00,  1.49i
Epoch: 146/200, best loss: 2.984029769897461: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 147/200, best loss: 2.956364393234253: 100%|█| 10/10 [00:06<00:00,  1.53i
Epoch: 148/200, best loss: 2.956364393234253: 100%|█| 10/10 [00:06<00:00,  1.53i
Epoch: 149/200, best loss: 2.956364393234253: 100%|█| 10/10 [00:06<00:00,  1.53i


saving model


Epoch: 150/200, best loss: 2.879908323287964: 100%|█| 10/10 [00:06<00:00,  1.54i


saving model


Epoch: 151/200, best loss: 2.6977150440216064: 100%|█| 10/10 [00:06<00:00,  1.52
Epoch: 152/200, best loss: 2.6977150440216064: 100%|█| 10/10 [00:06<00:00,  1.53
Epoch: 153/200, best loss: 2.6977150440216064: 100%|█| 10/10 [00:06<00:00,  1.54
Epoch: 154/200, best loss: 2.6977150440216064: 100%|█| 10/10 [00:06<00:00,  1.52


saving model


Epoch: 155/200, best loss: 2.6623713970184326: 100%|█| 10/10 [00:06<00:00,  1.52
Epoch: 156/200, best loss: 2.6623713970184326: 100%|█| 10/10 [00:06<00:00,  1.53


saving model


Epoch: 157/200, best loss: 2.6055119037628174: 100%|█| 10/10 [00:06<00:00,  1.48
Epoch: 158/200, best loss: 2.6055119037628174: 100%|█| 10/10 [00:06<00:00,  1.53
Epoch: 159/200, best loss: 2.6055119037628174: 100%|█| 10/10 [00:06<00:00,  1.54
Epoch: 160/200, best loss: 2.6055119037628174: 100%|█| 10/10 [00:06<00:00,  1.55


saving model


Epoch: 161/200, best loss: 2.6040306091308594: 100%|█| 10/10 [00:06<00:00,  1.53


saving model


Epoch: 162/200, best loss: 2.5842671394348145: 100%|█| 10/10 [00:06<00:00,  1.53
Epoch: 163/200, best loss: 2.5842671394348145: 100%|█| 10/10 [00:06<00:00,  1.54


saving model


Epoch: 164/200, best loss: 2.479668617248535: 100%|█| 10/10 [00:06<00:00,  1.53i
Epoch: 165/200, best loss: 2.479668617248535: 100%|█| 10/10 [00:06<00:00,  1.54i


saving model


Epoch: 166/200, best loss: 2.4312191009521484: 100%|█| 10/10 [00:06<00:00,  1.53
Epoch: 167/200, best loss: 2.4312191009521484: 100%|█| 10/10 [00:06<00:00,  1.53


saving model


Epoch: 168/200, best loss: 2.3524136543273926: 100%|█| 10/10 [00:06<00:00,  1.53
Epoch: 169/200, best loss: 2.3524136543273926: 100%|█| 10/10 [00:06<00:00,  1.52


saving model


Epoch: 170/200, best loss: 2.321885824203491: 100%|█| 10/10 [00:06<00:00,  1.52i
Epoch: 171/200, best loss: 2.321885824203491: 100%|█| 10/10 [00:06<00:00,  1.53i


saving model


Epoch: 172/200, best loss: 2.273533344268799: 100%|█| 10/10 [00:06<00:00,  1.48i


saving model


Epoch: 173/200, best loss: 2.259803295135498: 100%|█| 10/10 [00:06<00:00,  1.47i
Epoch: 174/200, best loss: 2.259803295135498: 100%|█| 10/10 [00:06<00:00,  1.50i
Epoch: 175/200, best loss: 2.259803295135498: 100%|█| 10/10 [00:06<00:00,  1.51i


saving model


Epoch: 176/200, best loss: 2.2388885021209717: 100%|█| 10/10 [00:06<00:00,  1.48


saving model


Epoch: 177/200, best loss: 2.137296199798584: 100%|█| 10/10 [00:06<00:00,  1.47i
Epoch: 178/200, best loss: 2.137296199798584: 100%|█| 10/10 [00:06<00:00,  1.51i
Epoch: 179/200, best loss: 2.137296199798584: 100%|█| 10/10 [00:06<00:00,  1.52i


saving model


Epoch: 180/200, best loss: 2.123990058898926: 100%|█| 10/10 [00:06<00:00,  1.49i


saving model


Epoch: 181/200, best loss: 2.0945470333099365: 100%|█| 10/10 [00:06<00:00,  1.50
Epoch: 182/200, best loss: 2.0945470333099365: 100%|█| 10/10 [00:06<00:00,  1.49


saving model


Epoch: 183/200, best loss: 2.0410103797912598: 100%|█| 10/10 [00:06<00:00,  1.52
Epoch: 184/200, best loss: 2.0410103797912598: 100%|█| 10/10 [00:06<00:00,  1.54


saving model


Epoch: 185/200, best loss: 1.922293782234192: 100%|█| 10/10 [00:06<00:00,  1.54i
Epoch: 186/200, best loss: 1.922293782234192: 100%|█| 10/10 [00:06<00:00,  1.54i
Epoch: 187/200, best loss: 1.922293782234192: 100%|█| 10/10 [00:06<00:00,  1.54i
Epoch: 188/200, best loss: 1.922293782234192: 100%|█| 10/10 [00:06<00:00,  1.54i


saving model


Epoch: 189/200, best loss: 1.8946218490600586: 100%|█| 10/10 [00:06<00:00,  1.54


saving model


Epoch: 190/200, best loss: 1.8923827409744263: 100%|█| 10/10 [00:06<00:00,  1.53


saving model


Epoch: 191/200, best loss: 1.8654778003692627: 100%|█| 10/10 [00:06<00:00,  1.51
Epoch: 192/200, best loss: 1.8654778003692627: 100%|█| 10/10 [00:06<00:00,  1.44
Epoch: 193/200, best loss: 1.8654778003692627: 100%|█| 10/10 [00:06<00:00,  1.51
Epoch: 194/200, best loss: 1.8654778003692627: 100%|█| 10/10 [00:06<00:00,  1.51


saving model


Epoch: 195/200, best loss: 1.856586217880249: 100%|█| 10/10 [00:07<00:00,  1.40i
Epoch: 196/200, best loss: 1.856586217880249: 100%|█| 10/10 [00:06<00:00,  1.44i


saving model


Epoch: 197/200, best loss: 1.8069896697998047: 100%|█| 10/10 [00:07<00:00,  1.30


saving model


Epoch: 198/200, best loss: 1.7870066165924072: 100%|█| 10/10 [00:07<00:00,  1.40


saving model


Epoch: 199/200, best loss: 1.769547939300537: 100%|█| 10/10 [00:06<00:00,  1.47i

saving model





In [12]:
model.load_state_dict(torch.load('SeqGeneration_lstm.pth'))  # with lstm
model.to('cpu') 
generate_text(model, tokenizer.word_to_one_hot(tokenizer.stoi['<SOS>'], 1), tokenizer)

' so centre when nothing power not beauty want all power and the fever his be subject and thou fear a time lovely me my lovely what gentle thou nothing birth, to did supposing thee thou live, want me winter check thee will, be his the power once live, i will, thee live, and wish, the lovely not fear and fever of live, were me creatures of heart of time the power and live, the power live, thee thou centre thee creatures thee will, of fear lovely with will, blamed thee heart substance her beauty thee time of the old want that live, power journey to journey and force, hath show creatures thee as lovely to my lovely with heart power nothing me gentle of love, check not live, to time creatures me '

In [13]:
model = SeqGeneration(len(tokenizer.vocab), 128, 1, len(tokenizer.vocab), 'rnn', True)  #with rnns
model.load_state_dict(torch.load('SeqGeneration.pth'))
model.to('cpu')
generate_text(model, tokenizer.word_to_one_hot(tokenizer.stoi['<SOS>'], 1), tokenizer)



" devouring wond'ring prophecies hid wond'ring wond'ring well-contented unknown, very barrenly fire barrenly rough heaven: therefore endowed, lend least o'erpressed wont ceremony wond'ring stars moan prophecies out. deserts seeing moan crown devise. power disgrace. sweets: prophecies faster nor endowed, go, glad their noon: sullen stars budding contend. weed, fed, endowed, poor crown wooing celestial ward, wailing world, preserve preserve preserve preserve preserve of stop stop world, preserve preserve preserve preserve of light stop world, preserve preserve preserve preserve preserve preserve world, preserve preserve preserve preserve preserve world, preserve preserve preserve preserve preserve preserve preserve preserve preserve preserve preserve preserve preserve preserve world, preserve world, preserve preserve preserve world, world, preserve preserve preserve preserve preserve world, that world, preserve preserve hopes, preserve merit? world, world, preserve world, preserve preser

<p> LSTMs were able to generate a <EOS> token before the specifed max limit.</p>
<p> RNNs however could not do this and relatively bad generation.</p>
<p> In the future, we will use <strong>Attention</strong> to make these models better. Attention module helps network focus on specific parts of sentence. This context helps model generate better sentences.</p>