In [57]:
import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt

In [58]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device_name = torch.cuda.get_device_name(torch.cuda.current_device())
print(device, 'with', device_name)

cuda with NVIDIA GeForce RTX 3070


In [59]:
with open('../Data/shakespeare.txt','r',encoding='utf8') as f:
    text = f.read()
print(text[:1000])


                     1
  From fairest creatures we desire increase,
  That thereby beauty's rose might never die,
  But as the riper should by time decease,
  His tender heir might bear his memory:
  But thou contracted to thine own bright eyes,
  Feed'st thy light's flame with self-substantial fuel,
  Making a famine where abundance lies,
  Thy self thy foe, to thy sweet self too cruel:
  Thou that art now the world's fresh ornament,
  And only herald to the gaudy spring,
  Within thine own bud buriest thy content,
  And tender churl mak'st waste in niggarding:
    Pity the world, or else this glutton be,
    To eat the world's due, by the grave and thee.


                     2
  When forty winters shall besiege thy brow,
  And dig deep trenches in thy beauty's field,
  Thy youth's proud livery so gazed on now,
  Will be a tattered weed of small worth held:  
  Then being asked, where all thy beauty lies,
  Where all the treasure of thy lusty days;
  To say within thine own deep su

In [60]:
all_characters = set(text)
print(len(all_characters))

84


In [61]:
decoder = dict(enumerate(all_characters))
encoder = {char: ind for ind, char in decoder.items()}

In [62]:
encoded_text = np.array([encoder[char] for char in text])
print(type(encoded_text))
print(encoded_text[120:130])

<class 'numpy.ndarray'>
[39 40 83 39 33 17  8 39 59 60]


In [63]:
def one_hot_encoder(encoded_text, num_uni_chars):

    one_hot = np.zeros((encoded_text.size, num_uni_chars))
    one_hot = one_hot.astype(np.float32)

    one_hot[np.arange(one_hot.shape[0]), encoded_text.flatten()] = 1.0

    one_hot = one_hot.reshape((*encoded_text.shape, num_uni_chars))

    return one_hot

In [64]:
arr = np.array([1,2,0])
print(one_hot_encoder(arr, 3))

[[0. 1. 0.]
 [0. 0. 1.]
 [1. 0. 0.]]


In [65]:
example_text = np.arange(10)
print(example_text)

[0 1 2 3 4 5 6 7 8 9]


In [66]:
def generate_batches(encoded_text, sample_per_batch=10, seq_len=50):
    char_per_batch = sample_per_batch * seq_len
    num_batches_available = len(encoded_text)//char_per_batch
    encoded_text = encoded_text[:num_batches_available*char_per_batch]

    encoded_text = encoded_text.reshape((sample_per_batch, -1))

    for n in range(encoded_text.shape[1], seq_len):
        x = encoded_text[:,n:n+seq_len]
        y = np.zeros_like(x)

        try:
            y[:,:-1] = x[:,1:]
            y[:,-1] = encoded_text[:, n+seq_len]

        except:
            y[:,:-1] = x[:,1:]
            y[:,-1] = encoded_text[:, 0]

        yield x,y

In [67]:
class CharModel(nn.Module):
    def __init__(self, all_chars, num_hidden=256, num_layers=4, drop_prob=0.5):
        super().__init__()
        self.drop_prob = drop_prob
        self.num_layers = num_layers
        self.num_hidden = num_hidden

        self.all_chars = all_chars
        self.decoder = dict(enumerate(all_chars))
        self.encoder = {char:ind for ind, char in decoder.items()}

        self.lstm = nn.LSTM(len(all_chars), num_hidden, num_layers, dropout=drop_prob, batch_first=True)
        self.dropout = nn.Dropout(drop_prob)

        self.fc_linear = nn.Linear(num_hidden, len(all_chars))

    def forward(self, x, hidden_state):

        lstm_output, hidden_state = self.lstm(x, hidden_state)

        drop_output = self.dropout(lstm_output)

        drop_output = drop_output.contiguous().view(-1, self.num_hidden)

        final_out = self.fc_linear(drop_output)

        return final_out, hidden_state

In [68]:
model = CharModel(all_chars=all_characters, num_hidden=512, num_layers=3, drop_prob=0.5).to(device)
total_params = []
for p in model.parameters():
    total_params.append(int(p.numel()))
print('Total:', sum(total_params))

learning_rate = 0.001
epochs = 50
batch_size = 100
seq_len = 100
tracker = 0
num_char = max(encoded_text)+1

Total: 5470292


In [69]:
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)
criterion = nn.CrossEntropyLoss()

In [71]:
train_percent = 0.9
train_index = int(len(encoded_text) * train_percent)
train_data = encoded_text[:train_index]
test_data = encoded_text[train_index:]

In [None]:
model.train()

for epoch in range(epochs):
    hidden_state = hidden = (torch.zeros(model.num_layers,batch_size,model.num_hidden).to(device),torch.zeros(model.num_layers,batch_size,model.num_hidden).to(device))

    for x, y in generate_batches(train_data, batch_size, seq_len):

        tracker += 1

        x = one_hot_encoder(x, num_char)

        inputs = torch.from_numpy(x).to(device)
        targets = torch.from_numpy(y).to(device)

        hidden = ([state.data])
