In [43]:
import torch
import torch.nn as nn
import torch.optim as optim
from collections import OrderedDict
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm

In [2]:
class Tokenizer():
    def __init__(self, file_path):
        self.sonnets = []
        self.vocab = self.create_vocab(file_path)
        self.stoi, self.itos = self.create_dict()
        self.file_path = file_path
        self.max_len = max([len(i) for i in self.sonnets])
        self.indexes = torch.arange(len(self.vocab))

    def create_vocab(self, path):
        f = open(path)
        vocab_set = {'<SOS>', '<EOS>'}
        sonnet = []
        for line in f.readlines():
            try:
                int(line.strip())
                if sonnet:
                    sonnet = ['<SOS>'] + sonnet + ['<EOS>']
                    self.sonnets.append(sonnet)
                sonnet = []
            except ValueError:
                words = [i.lower().strip() for i in line.split()]
                sonnet += (words)
                vocab_set.update(set(words))
        sonnet = ['<SOS>'] + sonnet + ['<EOS>']
        self.sonnets.append(sonnet)
        return vocab_set

    def create_dict(self):
        stoi = OrderedDict()
        itos = OrderedDict()
        for i, token in enumerate(self.vocab):
            stoi[token] = i
            itos[i] = token
        return stoi, itos

    def sonnet_to_tensor(self, sonnet):
        x = torch.zeros(self.max_len, len(self.vocab))
        for idx, word in enumerate(sonnet):
            x[idx, self.stoi[word]] = 1
        return x

    def get_tensor_data(self):
        data_tensor = torch.zeros(len(self.sonnets), self.max_len, len(self.vocab))
        for idx, sonnet in enumerate(self.sonnets):
            data_tensor[idx] = self.sonnet_to_tensor(sonnet)
        return data_tensor

    def word_to_one_hot(self, index, batch_size=1):
        x = torch.zeros(batch_size, 1, len(self.vocab))
        x[:, :, index] = 1
        return x

    def word_to_one_hot_batch(self, index):
        x = torch.zeros(index.shape[0], 1, len(self.vocab))
        for i, idx in enumerate(index):
            x[i] = self.word_to_one_hot(idx)  
        return x
        
    def decode(self, sonnet):
        words = ''
        try:
            for idx, row in enumerate(sonnet):
                words += self.itos[self.indexes[row.to(torch.bool)].item()] + ' '        
        except RuntimeError:
            pass
        return words

In [3]:
tokenizer = Tokenizer('shakespeare.txt')

In [4]:
data_tensor = tokenizer.get_tensor_data()

In [5]:
class SonnetDataset(Dataset):
    def __init__(self, x):
        self.x = x
        self.samples = x.shape[0]

    def __getitem__(self, idx):
        return self.x[idx]

    def __len__(self):
        return self.samples

In [6]:
train_dataset = SonnetDataset(data_tensor)
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

In [7]:
class SeqGeneration(nn.Module):
    def __init__(self, input_size, hidden_size, num_cells, vocab_length, cell_init='rnn', bidirectional=False):
        super().__init__()
        self.input_size = input_size
        self.hidden_size = hidden_size
        self.num_cells = num_cells
        self.num_directions = 2 if bidirectional else 1
        self.cell_init = cell_init
        if cell_init == 'rnn':
            self.cell = nn.RNN(input_size, hidden_size, num_cells, batch_first=True, bidirectional=bidirectional)
        elif cell_init == 'gru':
            self.cell = nn.GRU(input_size, hidden_size, num_cells, batch_first=True, bidirectional=bidirectional)
        elif cell_init == 'lstm':
            self.cell = nn.LSTM(input_size, hidden_size, num_cells, batch_first=True, bidirectional=bidirectional)
        else:
            raise Exception('Invalid cell type')
        self.fc1 = nn.Linear(self.num_directions*hidden_size, vocab_length)
        self.softmax = nn.LogSoftmax(dim=-1)
        

    def forward(self, x, h0=None, c0=None, device='mps'):
        # data you pass will be the first time step only.
        if h0 is None:
            h0 = torch.zeros(self.num_directions*self.num_cells, x.shape[0], self.hidden_size, device=device)
        if self.cell_init == 'lstm':
            if c0 is None:
                c0 = torch.zeros(self.num_directions*self.num_cells, x.shape[0], self.hidden_size, device=device)
            hidden_states, (ht_all, c0) = self.cell(x, (h0, c0))
        else:
            hidden_states, ht_all = self.cell(x, h0)
            # ht will always be of the shape h0 
        ht = hidden_states[:, -1, :]
        logits = self.fc1(ht)
        scores = self.softmax(logits)
        return scores, ht_all, c0

In [8]:
model = SeqGeneration(len(tokenizer.vocab), 256, 1, len(tokenizer.vocab), 'lstm', True)
# model.load_state_dict(torch.load('SeqGeneration.pth'))
device = 'mps'
model = model.to(device)
loss = nn.NLLLoss()
optimizer = optim.Adam(model.parameters(), lr=1e-3)

In [9]:
def train(model, train_loader, train_params, tokenizer):
    min_loss = float('inf')
    epochs = train_params['epochs']
    loss = train_params['loss']
    optimizer = train_params['optimizer']
    for epoch in range(epochs):
        loss_at_batch = None
        for x in tqdm(train_loader, desc=f'Epoch: {epoch}/{epochs}, best loss: {min_loss}'):
            loss_at_batch = []
            ht = None
            c0 = None
            x = x.to(device)
            for t in range(x.shape[1] - 1):
                xt = x[:, t:t+1, :]
                scores, ht, c0 = model(xt, ht, c0)
                next_word_index = torch.argmax(scores, dim=1)
                labels = x[:, t+1:t+2, :]
                labels = torch.argmax(labels, dim=2).flatten()
                loss_at_timestep = loss(scores, labels)
                loss_at_batch.append(loss_at_timestep)
                xt = tokenizer.word_to_one_hot_batch(next_word_index)
                xt = xt.to(device)
            loss_at_batch = sum(loss_at_batch)/len(loss_at_batch)
            optimizer.zero_grad()
            loss_at_batch.backward()
            torch.nn.utils.clip_grad_value_(model.parameters(), 0.5)
            optimizer.step()
        if loss_at_batch < min_loss:
            torch.save(model.state_dict(), 'SeqGeneration_gru.pth')
            min_loss = loss_at_batch
            print('saving model')

In [15]:
def generate_text(model, sos_tensor, tokenizer):
    curr_token = ''
    sonnet = ''
    pred = ''
    ht = None
    while curr_token != '<EOS>' and len(sonnet.split()) < tokenizer.max_len:
        sonnet += pred + ' '
        scores, ht, ct = model(sos_tensor, ht, None, 'cpu')
        # pred = tokenizer.itos[torch.argmax(scores).item()] always generates the same sequence
        # sample to generate different sequences each time.
        pred = tokenizer.itos[torch.argmax(scores).item()]
        random_word_index = torch.multinomial(torch.exp(scores).flatten(), num_samples=1).item()
        curr_token = pred
        sos_tensor = tokenizer.word_to_one_hot(tokenizer.stoi[tokenizer.itos[random_word_index]], 1)
    return sonnet

In [11]:
train_params = {'epochs': 200, 'loss': loss, 'optimizer': optimizer}
train(model, train_loader, train_params, tokenizer)

Epoch: 0/200, best loss: inf: 100%|█████████████| 10/10 [00:08<00:00,  1.21it/s]


saving model


Epoch: 1/200, best loss: 6.950695037841797: 100%|█| 10/10 [00:07<00:00,  1.28it/


saving model


Epoch: 2/200, best loss: 6.654219627380371: 100%|█| 10/10 [00:07<00:00,  1.28it/


saving model


Epoch: 3/200, best loss: 6.40230655670166: 100%|█| 10/10 [00:07<00:00,  1.29it/s


saving model


Epoch: 4/200, best loss: 6.383993625640869: 100%|█| 10/10 [00:07<00:00,  1.28it/


saving model


Epoch: 5/200, best loss: 6.356734275817871: 100%|█| 10/10 [00:08<00:00,  1.22it/
Epoch: 6/200, best loss: 6.356734275817871: 100%|█| 10/10 [00:08<00:00,  1.22it/


saving model


Epoch: 7/200, best loss: 6.340819358825684: 100%|█| 10/10 [00:08<00:00,  1.20it/


saving model


Epoch: 8/200, best loss: 6.26102352142334: 100%|█| 10/10 [00:07<00:00,  1.28it/s


saving model


Epoch: 9/200, best loss: 6.187981605529785: 100%|█| 10/10 [00:07<00:00,  1.28it/
Epoch: 10/200, best loss: 6.187981605529785: 100%|█| 10/10 [00:07<00:00,  1.26it
Epoch: 11/200, best loss: 6.187981605529785: 100%|█| 10/10 [00:08<00:00,  1.19it


saving model


Epoch: 12/200, best loss: 6.1354522705078125: 100%|█| 10/10 [00:07<00:00,  1.27i
Epoch: 13/200, best loss: 6.1354522705078125: 100%|█| 10/10 [00:07<00:00,  1.27i
Epoch: 14/200, best loss: 6.1354522705078125: 100%|█| 10/10 [00:07<00:00,  1.27i
Epoch: 15/200, best loss: 6.1354522705078125: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 16/200, best loss: 6.075062274932861: 100%|█| 10/10 [00:07<00:00,  1.27it
Epoch: 17/200, best loss: 6.075062274932861: 100%|█| 10/10 [00:07<00:00,  1.26it
Epoch: 18/200, best loss: 6.075062274932861: 100%|█| 10/10 [00:07<00:00,  1.28it
Epoch: 19/200, best loss: 6.075062274932861: 100%|█| 10/10 [00:07<00:00,  1.27it


saving model


Epoch: 20/200, best loss: 6.068410396575928: 100%|█| 10/10 [00:07<00:00,  1.28it


saving model


Epoch: 21/200, best loss: 6.013192176818848: 100%|█| 10/10 [00:08<00:00,  1.23it
Epoch: 22/200, best loss: 6.013192176818848: 100%|█| 10/10 [00:08<00:00,  1.20it
Epoch: 23/200, best loss: 6.013192176818848: 100%|█| 10/10 [00:08<00:00,  1.24it


saving model


Epoch: 24/200, best loss: 5.893638610839844: 100%|█| 10/10 [00:07<00:00,  1.26it


saving model


Epoch: 25/200, best loss: 5.891335964202881: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 26/200, best loss: 5.826259613037109: 100%|█| 10/10 [00:07<00:00,  1.28it


saving model


Epoch: 27/200, best loss: 5.674631118774414: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 28/200, best loss: 5.653427600860596: 100%|█| 10/10 [00:07<00:00,  1.29it
Epoch: 29/200, best loss: 5.653427600860596: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 30/200, best loss: 5.608627796173096: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 31/200, best loss: 5.494716167449951: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 32/200, best loss: 5.478898525238037: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 33/200, best loss: 5.4262847900390625: 100%|█| 10/10 [00:07<00:00,  1.29i
Epoch: 34/200, best loss: 5.4262847900390625: 100%|█| 10/10 [00:07<00:00,  1.29i


saving model


Epoch: 35/200, best loss: 5.342709064483643: 100%|█| 10/10 [00:07<00:00,  1.28it
Epoch: 36/200, best loss: 5.342709064483643: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 37/200, best loss: 5.28896951675415: 100%|█| 10/10 [00:07<00:00,  1.28it/


saving model


Epoch: 38/200, best loss: 5.088871002197266: 100%|█| 10/10 [00:07<00:00,  1.28it
Epoch: 39/200, best loss: 5.088871002197266: 100%|█| 10/10 [00:07<00:00,  1.27it


saving model


Epoch: 40/200, best loss: 5.086554527282715: 100%|█| 10/10 [00:08<00:00,  1.23it


saving model


Epoch: 41/200, best loss: 4.8733296394348145: 100%|█| 10/10 [00:08<00:00,  1.21i
Epoch: 42/200, best loss: 4.8733296394348145: 100%|█| 10/10 [00:07<00:00,  1.26i


saving model


Epoch: 43/200, best loss: 4.830920219421387: 100%|█| 10/10 [00:07<00:00,  1.27it


saving model


Epoch: 44/200, best loss: 4.6931962966918945: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 45/200, best loss: 4.53828763961792: 100%|█| 10/10 [00:07<00:00,  1.27it/
Epoch: 46/200, best loss: 4.53828763961792: 100%|█| 10/10 [00:07<00:00,  1.26it/


saving model


Epoch: 47/200, best loss: 4.348770618438721: 100%|█| 10/10 [00:07<00:00,  1.26it


saving model


Epoch: 48/200, best loss: 4.146220684051514: 100%|█| 10/10 [00:07<00:00,  1.28it
Epoch: 49/200, best loss: 4.146220684051514: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 50/200, best loss: 3.999760627746582: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 51/200, best loss: 3.966996192932129: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 52/200, best loss: 3.962934970855713: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 53/200, best loss: 3.796117067337036: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 54/200, best loss: 3.748600482940674: 100%|█| 10/10 [00:07<00:00,  1.29it


saving model


Epoch: 55/200, best loss: 3.473721742630005: 100%|█| 10/10 [00:07<00:00,  1.27it
Epoch: 56/200, best loss: 3.473721742630005: 100%|█| 10/10 [00:07<00:00,  1.26it


saving model


Epoch: 57/200, best loss: 3.245720386505127: 100%|█| 10/10 [00:08<00:00,  1.25it


saving model


Epoch: 58/200, best loss: 3.146211624145508: 100%|█| 10/10 [00:08<00:00,  1.24it


saving model


Epoch: 59/200, best loss: 3.0613207817077637: 100%|█| 10/10 [00:08<00:00,  1.23i


saving model


Epoch: 60/200, best loss: 2.9243290424346924: 100%|█| 10/10 [00:07<00:00,  1.26i


saving model


Epoch: 61/200, best loss: 2.769000768661499: 100%|█| 10/10 [00:07<00:00,  1.27it
Epoch: 62/200, best loss: 2.769000768661499: 100%|█| 10/10 [00:08<00:00,  1.16it


saving model


Epoch: 63/200, best loss: 2.5599892139434814: 100%|█| 10/10 [00:08<00:00,  1.22i


saving model


Epoch: 64/200, best loss: 2.426161766052246: 100%|█| 10/10 [00:08<00:00,  1.19it


saving model


Epoch: 65/200, best loss: 2.3817026615142822: 100%|█| 10/10 [00:08<00:00,  1.23i


saving model


Epoch: 66/200, best loss: 2.162590265274048: 100%|█| 10/10 [00:07<00:00,  1.27it
Epoch: 67/200, best loss: 2.162590265274048: 100%|█| 10/10 [00:07<00:00,  1.27it


saving model


Epoch: 68/200, best loss: 1.9723268747329712: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 69/200, best loss: 1.9094074964523315: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 70/200, best loss: 1.7295185327529907: 100%|█| 10/10 [00:07<00:00,  1.29i
Epoch: 71/200, best loss: 1.7295185327529907: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 72/200, best loss: 1.679146409034729: 100%|█| 10/10 [00:07<00:00,  1.27it


saving model


Epoch: 73/200, best loss: 1.4014545679092407: 100%|█| 10/10 [00:07<00:00,  1.27i
Epoch: 74/200, best loss: 1.4014545679092407: 100%|█| 10/10 [00:07<00:00,  1.25i


saving model


Epoch: 75/200, best loss: 1.3043417930603027: 100%|█| 10/10 [00:07<00:00,  1.28i
Epoch: 76/200, best loss: 1.3043417930603027: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 77/200, best loss: 1.1688367128372192: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 78/200, best loss: 1.0335795879364014: 100%|█| 10/10 [00:08<00:00,  1.25i
Epoch: 79/200, best loss: 1.0335795879364014: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 80/200, best loss: 0.9890214800834656: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 81/200, best loss: 0.9470802545547485: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 82/200, best loss: 0.8834066390991211: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 83/200, best loss: 0.8569352030754089: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 84/200, best loss: 0.7317513823509216: 100%|█| 10/10 [00:07<00:00,  1.28i
Epoch: 85/200, best loss: 0.7317513823509216: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 86/200, best loss: 0.6732689142227173: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 87/200, best loss: 0.6570603847503662: 100%|█| 10/10 [00:08<00:00,  1.25i


saving model


Epoch: 88/200, best loss: 0.6239015460014343: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 89/200, best loss: 0.5515395998954773: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 90/200, best loss: 0.5243595242500305: 100%|█| 10/10 [00:07<00:00,  1.27i
Epoch: 91/200, best loss: 0.5243595242500305: 100%|█| 10/10 [00:07<00:00,  1.27i


saving model


Epoch: 92/200, best loss: 0.4837191700935364: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 93/200, best loss: 0.47021445631980896: 100%|█| 10/10 [00:07<00:00,  1.29
Epoch: 94/200, best loss: 0.47021445631980896: 100%|█| 10/10 [00:07<00:00,  1.28


saving model


Epoch: 95/200, best loss: 0.4264366924762726: 100%|█| 10/10 [00:07<00:00,  1.29i


saving model


Epoch: 96/200, best loss: 0.4147954285144806: 100%|█| 10/10 [00:08<00:00,  1.22i


saving model


Epoch: 97/200, best loss: 0.3745182454586029: 100%|█| 10/10 [00:07<00:00,  1.25i
Epoch: 98/200, best loss: 0.3745182454586029: 100%|█| 10/10 [00:08<00:00,  1.23i
Epoch: 99/200, best loss: 0.3745182454586029: 100%|█| 10/10 [00:07<00:00,  1.25i


saving model


Epoch: 100/200, best loss: 0.3397887945175171: 100%|█| 10/10 [00:07<00:00,  1.26
Epoch: 101/200, best loss: 0.3397887945175171: 100%|█| 10/10 [00:08<00:00,  1.23
Epoch: 102/200, best loss: 0.3397887945175171: 100%|█| 10/10 [00:07<00:00,  1.27


saving model


Epoch: 103/200, best loss: 0.3146854043006897: 100%|█| 10/10 [00:07<00:00,  1.27


saving model


Epoch: 104/200, best loss: 0.2845969498157501: 100%|█| 10/10 [00:07<00:00,  1.29
Epoch: 105/200, best loss: 0.2845969498157501: 100%|█| 10/10 [00:07<00:00,  1.28
Epoch: 106/200, best loss: 0.2845969498157501: 100%|█| 10/10 [00:08<00:00,  1.25


saving model


Epoch: 107/200, best loss: 0.2833605706691742: 100%|█| 10/10 [00:07<00:00,  1.26


saving model


Epoch: 108/200, best loss: 0.2633247673511505: 100%|█| 10/10 [00:07<00:00,  1.27


saving model


Epoch: 109/200, best loss: 0.25878435373306274: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 110/200, best loss: 0.25878435373306274: 100%|█| 10/10 [00:08<00:00,  1.2


saving model


Epoch: 111/200, best loss: 0.2415006458759308: 100%|█| 10/10 [00:07<00:00,  1.27


saving model


Epoch: 112/200, best loss: 0.2168470174074173: 100%|█| 10/10 [00:07<00:00,  1.28
Epoch: 113/200, best loss: 0.2168470174074173: 100%|█| 10/10 [00:07<00:00,  1.29


saving model


Epoch: 114/200, best loss: 0.20311322808265686: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 115/200, best loss: 0.20311322808265686: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 116/200, best loss: 0.19884194433689117: 100%|█| 10/10 [00:08<00:00,  1.2
Epoch: 117/200, best loss: 0.19884194433689117: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 118/200, best loss: 0.19006690382957458: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 119/200, best loss: 0.17266996204853058: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 120/200, best loss: 0.17266996204853058: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 121/200, best loss: 0.1698712408542633: 100%|█| 10/10 [00:07<00:00,  1.26
Epoch: 122/200, best loss: 0.1698712408542633: 100%|█| 10/10 [00:07<00:00,  1.25


saving model


Epoch: 123/200, best loss: 0.16769453883171082: 100%|█| 10/10 [00:08<00:00,  1.2


saving model


Epoch: 124/200, best loss: 0.16276489198207855: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 125/200, best loss: 0.16051678359508514: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 126/200, best loss: 0.1546589881181717: 100%|█| 10/10 [00:07<00:00,  1.26


saving model


Epoch: 127/200, best loss: 0.14229649305343628: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 128/200, best loss: 0.14229649305343628: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 129/200, best loss: 0.14102840423583984: 100%|█| 10/10 [00:08<00:00,  1.2
Epoch: 130/200, best loss: 0.14102840423583984: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 131/200, best loss: 0.14102840423583984: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 132/200, best loss: 0.14102840423583984: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 133/200, best loss: 0.1389019936323166: 100%|█| 10/10 [00:07<00:00,  1.27


saving model


Epoch: 134/200, best loss: 0.13383986055850983: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 135/200, best loss: 0.1261666715145111: 100%|█| 10/10 [00:07<00:00,  1.26
Epoch: 136/200, best loss: 0.1261666715145111: 100%|█| 10/10 [00:07<00:00,  1.27


saving model


Epoch: 137/200, best loss: 0.12049009650945663: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 138/200, best loss: 0.12049009650945663: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 139/200, best loss: 0.12033380568027496: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 140/200, best loss: 0.12033380568027496: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 141/200, best loss: 0.12033380568027496: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 142/200, best loss: 0.11324534565210342: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 143/200, best loss: 0.11324534565210342: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 144/200, best loss: 0.10973193496465683: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 145/200, best loss: 0.10847112536430359: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 146/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 147/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 148/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 149/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 150/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 151/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 152/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 153/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 154/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 155/200, best loss: 0.09986764937639236: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 156/200, best loss: 0.094920814037323: 100%|█| 10/10 [00:07<00:00,  1.27i
Epoch: 157/200, best loss: 0.094920814037323: 100%|█| 10/10 [00:07<00:00,  1.29i
Epoch: 158/200, best loss: 0.094920814037323: 100%|█| 10/10 [00:07<00:00,  1.28i


saving model


Epoch: 159/200, best loss: 0.09450068324804306: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 160/200, best loss: 0.09450068324804306: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 161/200, best loss: 0.09450068324804306: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 162/200, best loss: 0.09450068324804306: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 163/200, best loss: 0.09368669241666794: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 164/200, best loss: 0.09320423007011414: 100%|█| 10/10 [00:08<00:00,  1.2


saving model


Epoch: 165/200, best loss: 0.09187803417444229: 100%|█| 10/10 [00:08<00:00,  1.2
Epoch: 166/200, best loss: 0.09187803417444229: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 167/200, best loss: 0.0841362476348877: 100%|█| 10/10 [00:07<00:00,  1.25
Epoch: 168/200, best loss: 0.0841362476348877: 100%|█| 10/10 [00:07<00:00,  1.27


saving model


Epoch: 169/200, best loss: 0.08056255429983139: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 170/200, best loss: 0.08056255429983139: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 171/200, best loss: 0.08056255429983139: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 172/200, best loss: 0.07850842922925949: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 173/200, best loss: 0.07850842922925949: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 174/200, best loss: 0.07850842922925949: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 175/200, best loss: 0.07430317252874374: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 176/200, best loss: 0.07430317252874374: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 177/200, best loss: 0.07430317252874374: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 178/200, best loss: 0.07423420995473862: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 179/200, best loss: 0.07423420995473862: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 180/200, best loss: 0.07423420995473862: 100%|█| 10/10 [00:08<00:00,  1.2


saving model


Epoch: 181/200, best loss: 0.0730835497379303: 100%|█| 10/10 [00:07<00:00,  1.27
Epoch: 182/200, best loss: 0.0730835497379303: 100%|█| 10/10 [00:07<00:00,  1.27
Epoch: 183/200, best loss: 0.0730835497379303: 100%|█| 10/10 [00:07<00:00,  1.26


saving model


Epoch: 184/200, best loss: 0.0715344250202179: 100%|█| 10/10 [00:08<00:00,  1.24
Epoch: 185/200, best loss: 0.0715344250202179: 100%|█| 10/10 [00:09<00:00,  1.10


saving model


Epoch: 186/200, best loss: 0.06693289428949356: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 187/200, best loss: 0.06693289428949356: 100%|█| 10/10 [00:08<00:00,  1.2
Epoch: 188/200, best loss: 0.06693289428949356: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 189/200, best loss: 0.06693289428949356: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 190/200, best loss: 0.06693289428949356: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 191/200, best loss: 0.06693289428949356: 100%|█| 10/10 [00:08<00:00,  1.2


saving model


Epoch: 192/200, best loss: 0.06672032177448273: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 193/200, best loss: 0.06672032177448273: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 194/200, best loss: 0.06672032177448273: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 195/200, best loss: 0.06672032177448273: 100%|█| 10/10 [00:07<00:00,  1.2
Epoch: 196/200, best loss: 0.06672032177448273: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 197/200, best loss: 0.06570402532815933: 100%|█| 10/10 [00:07<00:00,  1.2


saving model


Epoch: 198/200, best loss: 0.0657012090086937: 100%|█| 10/10 [00:08<00:00,  1.25
Epoch: 199/200, best loss: 0.0657012090086937: 100%|█| 10/10 [00:07<00:00,  1.28

saving model





In [36]:
model.load_state_dict(torch.load('SeqGeneration_gru.pth'))  # with lstm
model.to('cpu') 
generate_text(model, tokenizer.word_to_one_hot(tokenizer.stoi['<SOS>'], 1), tokenizer)

' o me! what potions hand hath call world love-god lying now, as thou shalt so from the be that time the and i grant and made my with fortune and i better her glass and my renew mine an in that do not the and i shalt forsake thou art fool and and should nor i the eyes how can more and love my verse so where art of not her lips are at to old are be wise as thou grant thou face with as hand o but heart with substance, do so am i o in the old are in the old of thy glass and be thee how or glory in with i the world thus can my glass is my glass but she upon since grant thou art '

In [42]:
model = SeqGeneration(len(tokenizer.vocab), 128, 1, len(tokenizer.vocab), 'rnn', True)  #with rnns
model.load_state_dict(torch.load('SeqGeneration.pth'))
model.to('cpu')
generate_text(model, tokenizer.word_to_one_hot(tokenizer.stoi['<SOS>'], 1), tokenizer)

" hopes, loud tyranny, wherever sue idle wood sooner vanished husband's bound an tune drinks 'tis unlearned art, green, art, special-blest, groan never-resting slanderers beds' allow? idle unlettered ignorance precious sue ignorance ignorance dost tyranny, backward (inferior rondure wary, refined richly gone) unlearned sheaves never-resting mortality unlearned art sue mortality born, grew: flattered naked) never-resting if feeds sunk bound richly drinks tired, needy husband's ignorance steep bound 'will', self's sue borne self's weeds, bound constant sooner day ignorance beds' brief sue might unknown, poesy there poesy ignorance slanderers ignorance if prophetic self sue ignorance graves unlearned ignorance needy youthful tongue green, ignorance (dressed hopes, sue spirit if still unlearned ignorance rightly trial and spirit err, allow? tired, needy tired, found. crossed: censures never-resting make, tyranny, particulars hopes, allow? love's blunting heat, unlearned near, "