In [1]:
import torch
from torch import nn
import torch.nn.functional as F

import numpy as np
import matplotlib.pyplot as plt
%matplotlib inline

In [2]:
with open('../Data/shakespeare.txt', 'r', encoding='utf8') as f:
    text = f.read()

In [3]:
type(text)

str

In [4]:
print(text[:100])


                     1
  From fairest creatures we desire increase,
  That thereby beauty's rose mi


In [5]:
len(text)

5445609

In [6]:
all_characters = set(text)

In [7]:
all_characters

{'\n',
 ' ',
 '!',
 '"',
 '&',
 "'",
 '(',
 ')',
 ',',
 '-',
 '.',
 '0',
 '1',
 '2',
 '3',
 '4',
 '5',
 '6',
 '7',
 '8',
 '9',
 ':',
 ';',
 '<',
 '>',
 '?',
 'A',
 'B',
 'C',
 'D',
 'E',
 'F',
 'G',
 'H',
 'I',
 'J',
 'K',
 'L',
 'M',
 'N',
 'O',
 'P',
 'Q',
 'R',
 'S',
 'T',
 'U',
 'V',
 'W',
 'X',
 'Y',
 'Z',
 '[',
 ']',
 '_',
 '`',
 'a',
 'b',
 'c',
 'd',
 'e',
 'f',
 'g',
 'h',
 'i',
 'j',
 'k',
 'l',
 'm',
 'n',
 'o',
 'p',
 'q',
 'r',
 's',
 't',
 'u',
 'v',
 'w',
 'x',
 'y',
 'z',
 '|',
 '}'}

In [8]:
len(all_characters)

84

In [9]:
# num --> letter
decoder = dict(enumerate(all_characters))
decoder

{0: 'D',
 1: 'e',
 2: 'O',
 3: 'f',
 4: '\n',
 5: 's',
 6: 'A',
 7: '6',
 8: 'u',
 9: ',',
 10: 'g',
 11: '&',
 12: ']',
 13: 'W',
 14: 'V',
 15: '-',
 16: '5',
 17: 'I',
 18: 'H',
 19: 'z',
 20: '8',
 21: 'Q',
 22: 'h',
 23: '[',
 24: 'd',
 25: 'i',
 26: 'b',
 27: 'c',
 28: 'C',
 29: 'r',
 30: 'l',
 31: 'o',
 32: '}',
 33: 'Z',
 34: '|',
 35: '<',
 36: '1',
 37: 'm',
 38: 'j',
 39: 'L',
 40: 'y',
 41: ' ',
 42: ')',
 43: 'q',
 44: 'p',
 45: '2',
 46: 'K',
 47: '!',
 48: 'U',
 49: "'",
 50: '"',
 51: 'B',
 52: 'X',
 53: '7',
 54: 'R',
 55: '(',
 56: 't',
 57: 'v',
 58: '0',
 59: 'a',
 60: ':',
 61: 'J',
 62: 'F',
 63: 'G',
 64: 'N',
 65: '>',
 66: '.',
 67: 'M',
 68: '4',
 69: 'Y',
 70: '_',
 71: 'w',
 72: '3',
 73: 'T',
 74: '9',
 75: 'P',
 76: '?',
 77: '`',
 78: 'n',
 79: ';',
 80: 'E',
 81: 'S',
 82: 'x',
 83: 'k'}

In [10]:
# letter --> num
encoder = {char: ind for ind, char in decoder.items()}

In [11]:
encoder

{'D': 0,
 'e': 1,
 'O': 2,
 'f': 3,
 '\n': 4,
 's': 5,
 'A': 6,
 '6': 7,
 'u': 8,
 ',': 9,
 'g': 10,
 '&': 11,
 ']': 12,
 'W': 13,
 'V': 14,
 '-': 15,
 '5': 16,
 'I': 17,
 'H': 18,
 'z': 19,
 '8': 20,
 'Q': 21,
 'h': 22,
 '[': 23,
 'd': 24,
 'i': 25,
 'b': 26,
 'c': 27,
 'C': 28,
 'r': 29,
 'l': 30,
 'o': 31,
 '}': 32,
 'Z': 33,
 '|': 34,
 '<': 35,
 '1': 36,
 'm': 37,
 'j': 38,
 'L': 39,
 'y': 40,
 ' ': 41,
 ')': 42,
 'q': 43,
 'p': 44,
 '2': 45,
 'K': 46,
 '!': 47,
 'U': 48,
 "'": 49,
 '"': 50,
 'B': 51,
 'X': 52,
 '7': 53,
 'R': 54,
 '(': 55,
 't': 56,
 'v': 57,
 '0': 58,
 'a': 59,
 ':': 60,
 'J': 61,
 'F': 62,
 'G': 63,
 'N': 64,
 '>': 65,
 '.': 66,
 'M': 67,
 '4': 68,
 'Y': 69,
 '_': 70,
 'w': 71,
 '3': 72,
 'T': 73,
 '9': 74,
 'P': 75,
 '?': 76,
 '`': 77,
 'n': 78,
 ';': 79,
 'E': 80,
 'S': 81,
 'x': 82,
 'k': 83}

In [12]:
encoded_text = np.array([encoder[char] for char in text])

In [13]:
encoded_text[:500]

array([ 4, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41,
       41, 41, 41, 41, 41, 36,  4, 41, 41, 62, 29, 31, 37, 41,  3, 59, 25,
       29,  1,  5, 56, 41, 27, 29,  1, 59, 56,  8, 29,  1,  5, 41, 71,  1,
       41, 24,  1,  5, 25, 29,  1, 41, 25, 78, 27, 29,  1, 59,  5,  1,  9,
        4, 41, 41, 73, 22, 59, 56, 41, 56, 22,  1, 29,  1, 26, 40, 41, 26,
        1, 59,  8, 56, 40, 49,  5, 41, 29, 31,  5,  1, 41, 37, 25, 10, 22,
       56, 41, 78,  1, 57,  1, 29, 41, 24, 25,  1,  9,  4, 41, 41, 51,  8,
       56, 41, 59,  5, 41, 56, 22,  1, 41, 29, 25, 44,  1, 29, 41,  5, 22,
       31,  8, 30, 24, 41, 26, 40, 41, 56, 25, 37,  1, 41, 24,  1, 27,  1,
       59,  5,  1,  9,  4, 41, 41, 18, 25,  5, 41, 56,  1, 78, 24,  1, 29,
       41, 22,  1, 25, 29, 41, 37, 25, 10, 22, 56, 41, 26,  1, 59, 29, 41,
       22, 25,  5, 41, 37,  1, 37, 31, 29, 40, 60,  4, 41, 41, 51,  8, 56,
       41, 56, 22, 31,  8, 41, 27, 31, 78, 56, 29, 59, 27, 56,  1, 24, 41,
       56, 31, 41, 56, 22

In [14]:
decoder[27]

'c'

In [15]:
def one_hot_encoder(encoded_text, num_uni_chars):
    
    # encoded_text --> batch of encoded text
    # num_uni_chars --> len(set(text))
    
    one_hot = np.zeros((encoded_text.size, num_uni_chars))
    
    one_hot = one_hot.astype(np.float32)
    
    one_hot[np.arange(one_hot.shape[0]), encoded_text.flatten()]  = 1.0
    
    one_hot = one_hot.reshape((*encoded_text.shape, num_uni_chars))
    
    return one_hot

In [16]:
arr = np.array([1, 2, 0])
arr

array([1, 2, 0])

In [17]:
one_hot_encoder(arr, 3)

array([[0., 1., 0.],
       [0., 0., 1.],
       [1., 0., 0.]], dtype=float32)

In [18]:
example_text = np.arange(10)

In [19]:
example_text

array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

In [20]:
example_text.reshape((5,-1))

array([[0, 1],
       [2, 3],
       [4, 5],
       [6, 7],
       [8, 9]])

In [21]:
def generate_batches(encoded_text, samp_per_batch=10, seq_len=50):
    
    # X : encoded text of legth seq_len
    # Y : encoded text shifted by one
    
    # how many chars per batch?
    char_per_batch = samp_per_batch * seq_len
    
    # how many batches can we make, given the len of the encoded text?
    num_batches_avail = int(len(encoded_text)/char_per_batch)
    
    # Cut off the end of the encoded text, that won't fit evenly into a batch
    encoded_text = encoded_text[:num_batches_avail * char_per_batch]

    encoded_text = encoded_text.reshape((samp_per_batch, -1))
    
    for n in range(0, encoded_text.shape[1], seq_len):
        
        x = encoded_text[:, n:n+seq_len]
        
        # zeros array to the same shape as x
        y = np.zeros_like(x)
        
        try:
            y[:, :-1] = x[:, 1:]
            y[:, -1] = encoded_text[:, n+seq_len]
        except:
            y[:, :-1] = x[:, 1:]
            y[:,-1] = encoded_text[:, 0]
        
        yield x,y

In [22]:
sample_text = encoded_text[:20]
sample_text

array([ 4, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41, 41,
       41, 41, 41])

In [23]:
batch_generator = generate_batches(sample_text, samp_per_batch=2, seq_len=5)

In [24]:
x,y = next(batch_generator)

In [25]:
x

array([[ 4, 41, 41, 41, 41],
       [41, 41, 41, 41, 41]])

In [26]:
y

array([[41, 41, 41, 41, 41],
       [41, 41, 41, 41, 41]])

In [27]:
class CharModel(nn.Module):
    
    def __init__(self, all_chars, num_hidden=256, num_layers=4, drop_prob=0.5, use_gpu=False):
        
        super().__init__()
        
        self.drop_prob = drop_prob
        self.num_layers = num_layers
        self.num_hidden = num_hidden
        self.use_gpu = use_gpu
        
        self.all_chars = all_chars
        self.decoder = dict(enumerate(all_chars))
        self.encoder = {char: ind for ind, char in decoder.items()}
        
        self.lstm = nn.LSTM(len(self.all_chars), num_hidden, num_layers, dropout=drop_prob, batch_first=True)
        
        self.dropout = nn.Dropout(drop_prob)
        
        self.fc_linear = nn.Linear(num_hidden, len(self.all_chars))
    
    def forward(self, x, hidden):
        
        lstm_output, hidden = self.lstm(x, hidden)
        
        drop_output = self.dropout(lstm_output)
        
        drop_output = drop_output.contiguous().view(-1, self.num_hidden)
        
        final_out = self.fc_linear(drop_output)
        
        return final_out, hidden
    
    def hidden_state(self, batch_size):
        
        if self.use_gpu:
            
            hidden = (torch.zeros(self.num_layers, batch_size, self.num_hidden).cuda(), torch.zeros(self.num_layers, batch_size, self.num_hidden).cuda())
        else:
            hidden = (torch.zeros(self.num_layers, batch_size, self.num_hidden), torch.zeros(self.num_layers, batch_size, self.num_hidden))
        
        return hidden
    

In [28]:
model = CharModel(all_chars=all_characters,
                  num_hidden=512,
                  num_layers=3,
                  drop_prob=0.5,
                  use_gpu=True)

In [29]:
total_param = []

for p in model.parameters():
    total_param.append(int(p.numel()))

sum(total_param)

5470292

In [30]:
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
criterion = nn.CrossEntropyLoss()

In [31]:
train_percent = 0.9
train_ind = int(len(encoded_text)*train_percent)

train_data = encoded_text[:train_ind]
val_data = encoded_text[train_ind:]

len(train_data), len(val_data)

(4901048, 544561)

In [34]:
# VARIABLES

epochs = 2
batch_size = 100

seq_len = 100

tracker = 0
num_char = max(encoded_text) + 1


In [36]:
model.train()

if model.use_gpu:
    model.cuda()
    
for i in range(epochs):
    
    hidden = model.hidden_state(batch_size)
    
    for x,y in generate_batches(train_data, batch_size, seq_len):
    
        tracker += 1
        
        x = one_hot_encoder(x, num_char)
        
        inputs = torch.from_numpy(x)
        targets = torch.from_numpy(y)
        
        if model.use_gpu:
            inputs = inputs.cuda()
            targets = targets.cuda()
        
        hidden = tuple([state.data for state in hidden])
        
        model.zero_grad()
            
        lstm_output, hidden = model.forward(inputs, hidden)
        loss = criterion(lstm_output, targets.view(batch_size*seq_len).long())
        loss.backward()
        
        nn.utils.clip_grad_norm_(model.parameters(), max_norm=5)
        
        optimizer.step()
        
        if tracker % 25 == 0:
            
            val_hidden = model.hidden_state(batch_size)
            val_losses = []
            model.eval()
            
            for x,y in generate_batches(val_data, batch_size, seq_len):
                
                x = one_hot_encoder(x, num_char)
        
                inputs = torch.from_numpy(x)
                targets = torch.from_numpy(y)

                if model.use_gpu:
                    inputs = inputs.cuda()
                    targets = targets.cuda()
                
                val_hidden = tuple([state.data for state in val_hidden])
                
                lstm_output, val_hidden = model.forward(inputs, val_hidden)
                val_loss = criterion(lstm_output, targets.view(batch_size*seq_len).long())
                val_losses.append(val_loss.item())
            
            model.train()
            
            print(f"EPOCH: {i} step: {tracker} VAL LOSS: {val_loss.item()}")
    

EPOCH: 0 step: 225 VAL LOSS: 1.6385951042175293
EPOCH: 0 step: 250 VAL LOSS: 1.6289540529251099
EPOCH: 0 step: 275 VAL LOSS: 1.6253702640533447
EPOCH: 0 step: 300 VAL LOSS: 1.6129294633865356
EPOCH: 0 step: 325 VAL LOSS: 1.6040384769439697
EPOCH: 0 step: 350 VAL LOSS: 1.6027451753616333
EPOCH: 0 step: 375 VAL LOSS: 1.5917539596557617
EPOCH: 0 step: 400 VAL LOSS: 1.581817626953125
EPOCH: 0 step: 425 VAL LOSS: 1.5739343166351318
EPOCH: 0 step: 450 VAL LOSS: 1.5720685720443726
EPOCH: 0 step: 475 VAL LOSS: 1.5666931867599487
EPOCH: 0 step: 500 VAL LOSS: 1.5568277835845947
EPOCH: 0 step: 525 VAL LOSS: 1.5578359365463257
EPOCH: 0 step: 550 VAL LOSS: 1.547835111618042
EPOCH: 0 step: 575 VAL LOSS: 1.540914535522461
EPOCH: 0 step: 600 VAL LOSS: 1.5369573831558228
EPOCH: 0 step: 625 VAL LOSS: 1.5342597961425781
EPOCH: 0 step: 650 VAL LOSS: 1.5311095714569092
EPOCH: 0 step: 675 VAL LOSS: 1.5279278755187988
EPOCH: 1 step: 700 VAL LOSS: 1.527523159980774
EPOCH: 1 step: 725 VAL LOSS: 1.5133775472640

In [37]:
model_name = "hidden512_layers3_shakes_saul.net"

In [39]:
torch.save(model.state_dict(), model_name)