In [1]:
import gc

import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch import optim
from torchinfo import summary

from collections import namedtuple
import PyPDF3

In [2]:
with open('anna.txt', 'r') as file:
    text = file.read()

In [3]:
text[:120]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverything was in confusion i'

In [4]:
unique_chars = list(set(text))

In [5]:
unique_chars

['\n',
 'v',
 ' ',
 '9',
 '!',
 'w',
 '2',
 '4',
 'E',
 ')',
 'R',
 '/',
 '.',
 'd',
 '%',
 ';',
 ':',
 'S',
 'a',
 '@',
 'b',
 'g',
 'P',
 '7',
 '0',
 'h',
 'C',
 't',
 'F',
 '5',
 'p',
 'M',
 '(',
 'J',
 '$',
 '8',
 '3',
 '?',
 '1',
 'u',
 'x',
 'O',
 'y',
 'q',
 'K',
 'U',
 'X',
 ',',
 'Y',
 '-',
 'Z',
 'r',
 'e',
 '_',
 '&',
 'I',
 'W',
 'D',
 '*',
 'L',
 '`',
 'H',
 'm',
 'c',
 'B',
 'n',
 'j',
 'o',
 'N',
 'G',
 'i',
 '"',
 'l',
 'k',
 "'",
 'f',
 'T',
 'A',
 'z',
 '6',
 'V',
 'Q',
 's']

In [6]:
len(unique_chars)

83

In [7]:
chars2int = {char : unique_chars.index(char) for char in unique_chars}
int2char = {v : k for (k, v) in chars2int.items()}

In [8]:
def encode_text(text, unique_chars = None):
    result_tuple = namedtuple('results', ['encoded_text', 'unique_char', 'int2char', 'char2int'])
    
    if unique_chars is None:
        unique_chars = list(set(text).union(set('#[]{}+-*=!')))
    else:
        unique_chars.extend(list('#[]{}+-*=!'))
        
    char2int = {char : unique_chars.index(char) for char in unique_chars}
    int2char = {v : k for (k, v) in char2int.items()}
    
    encoded_text = np.array(list(map(lambda x: char2int[x], list(text))))
    
    return result_tuple(encoded_text, unique_chars, int2char, char2int)

In [9]:
encoded_text, unique_chars, int2char, char2int = encode_text(text)

In [10]:
def one_hot_convert(arr, n_labels):
    nrows = arr.size
    array = np.zeros(shape = [nrows, n_labels])
    array[np.arange(array.shape[0]), arr.flatten()] = 1.
    
    return array

In [11]:
one_hot_convert(np.array([[1, 2, 3, 5]]), 10)

array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

In [12]:
batch_size = 32
seq_length = 16

In [13]:
numel_seq = batch_size * seq_length

In [14]:
numel_seq

512

In [15]:
encoded_text.shape[0]/numel_seq

3877.388671875

In [16]:
def batch_sequence(arr, batch_size, seq_length):
    numel_seq = batch_size * seq_length
    num_batches = arr.size // numel_seq
    
    arr = arr[: num_batches * numel_seq].reshape(batch_size, -1)
    #print(arr.shape)
    
    batched_data = [(arr[:, n : n + seq_length], arr[:, n + 1 : n + 1 + seq_length])
                    for n in range(0, arr.shape[1], seq_length)]
    
    ### Finalize final array size
    batched_data[-1] = (batched_data[-1][0],
                        np.append(batched_data[-1][1], batched_data[0][1][:, 0].reshape(-1, 1), axis = 1))
    
    ###batched_arr = [arr[n : n + numel_seq].reshape(batch_size, seq_length) for n in range(num_batches)]
    return iter(batched_data), num_batches

In [17]:
batch, _ = batch_sequence(encoded_text, 32, 16)

In [18]:
X, y = next(batch)

In [19]:
def one_hot_encode(arr, n_labels):
    
    # Initialize the the encoded array
    one_hot = np.zeros((arr.size, n_labels), dtype=np.float32)
    
    # Fill the appropriate elements with ones
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    
    # Finally reshape it to get back to the original array
    one_hot = one_hot.reshape((*arr.shape, n_labels))
    
    return one_hot

In [20]:
#X_ = one_hot_convert(X, 90)
X_ = one_hot_encode(X, len(unique_chars))

In [21]:
X.shape

(32, 16)

In [22]:
X_.shape

(32, 16, 90)

\begin{array}{ll} \\
        i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
        f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
        o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
        c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
        h_t = o_t \odot \tanh(c_t) \\
    \end{array}

In [23]:
print(help(nn.LSTM))

Help on class LSTM in module torch.nn.modules.rnn:

class LSTM(RNNBase)
 |  LSTM(*args, **kwargs)
 |  
 |  Applies a multi-layer long short-term memory (LSTM) RNN to an input
 |  sequence.
 |  
 |  
 |  For each element in the input sequence, each layer computes the following
 |  function:
 |  
 |  .. math::
 |      \begin{array}{ll} \\
 |          i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
 |          f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
 |          g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
 |          o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
 |          c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
 |          h_t = o_t \odot \tanh(c_t) \\
 |      \end{array}
 |  
 |  where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
 |  state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
 |  is the hidden state of the layer at time `t-1` or the initial hidden
 |  state a

In [24]:
def get_text(fpath):
    with open(fpath, "rb") as f:
        pdf = PyPDF3.PdfFileReader(f)
        text = str()
        for page_num in range(pdf.numPages):
            page = pdf.getPage(page_num)
            text = text + ' ' + page.extractText()
    return text

In [25]:
### Get text for validation data
val_text = get_text("The-Prince.pdf")

In [26]:
### Encode validation data
encoding_results = encode_text(val_text, unique_chars)
encoded_val = encoding_results.encoded_text

In [27]:
unique_char = encoding_results.unique_char
len(unique_char)

100

\begin{aligned}
 |              N ={} & \text{batch size} \\
 |              L ={} & \text{sequence length} \\
 |              D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
 |              H_{in} ={} & \text{input\_size} \\
 |              H_{cell} ={} & \text{hidden\_size} \\
 |              H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
 |          \end{aligned}

In [28]:
class CharRNN(nn.Module):
    """
    Character-level LSTM.
    
    Parameters
    ----------
    hidden_size:
        Number of output features for LSTM.
    dropout:
        Dropout probabilityfor LSTM.
    batch_size:
        Number of sequences in a batch.
    D:
        Number of directions: uni- or bidirectional architecture for LSTM.
    num_layers:
        Number of LSTM stacks.
    
    Returns
    -------
    output:
        Shape: [batch_size, sequence_length, num_features]
    hidden_state:
        Tuple containing:
        - Short-term hidden state
            Shape: [batch_size, sequence_length, num_features]
        - Cell state
            Shape: [batch_size, sequence_length, num_features]
    
    """
    def __init__(self, hidden_size = 128, dropout = 0.25,
                 batch_size = 32, D = 1, num_layers = 2):
        
        super(CharRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.dropout_rate = dropout
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.D = D
        
        self.lstm = nn.LSTM(input_size = len(unique_chars), hidden_size = self.hidden_size,
                            dropout = self.dropout_rate, batch_first = True,
                            bidirectional = True if self.D == 2 else False, bias = True,
                            num_layers = self.num_layers)
        
        self.fc = nn.Linear(self.D*self.hidden_size, len(unique_chars))
        
    def forward(self, x, hidden_state):
        outputs, hidden_state = self.lstm(x, hidden_state)
        outputs = outputs.contiguous().view(-1, self.D*self.hidden_size)
        outputs = self.fc(outputs)
        
        return outputs, hidden_state
    
    def init_hidden_state(self, mean, stddev):
        """
        Initialize hidden state and context tensors.
        """
        weights = next(self.parameters()).data
        h = torch.distributions.Normal(mean, stddev).sample((self.D*self.num_layers, self.batch_size, self.hidden_size))
        c = torch.distributions.Normal(mean, stddev).sample((self.D*self.num_layers, self.batch_size, self.hidden_size))
        
        return (h, c)
        

In [29]:
model = CharRNN(D = 1)

In [30]:
batch_size = 32
seq_length = 16

max_norm = 1.5
epochs = 20
lr = 1e-4

In [31]:
print(model)

CharRNN(
  (lstm): LSTM(100, 128, num_layers=2, batch_first=True, dropout=0.25)
  (fc): Linear(in_features=128, out_features=100, bias=True)
)


In [32]:
### Objective functions and optimizer
opt = optim.Adam(model.parameters(), lr = lr)
criterion = nn.CrossEntropyLoss()

In [33]:
gc.collect()

11627

In [34]:
### Train data
_, num_batches = batch_sequence(encoded_text, batch_size, seq_length)

In [35]:
### Validation data
_, num_batches_ = batch_sequence(encoded_val, batch_size, seq_length)

In [36]:
### Num of train batches
print(num_batches)

3877


In [37]:
### Num of valid batches
print(num_batches_)

551


In [None]:
### Outer training loop
for epoch in range(1, epochs + 1):
    h = model.init_hidden_state(mean = 0., stddev = .5)
    iteration = 0
    train_losses = list()
    
    ### Inner training loop
    for X, y in batch_sequence(encoded_text, batch_size, seq_length)[0]:
        X = one_hot_encode(X, len(unique_chars))
        X, y = torch.as_tensor(X), torch.as_tensor(y)
        
        model.train()
        iteration += 1
        
        h = tuple([each.data for each in h])
        opt.zero_grad()
        
        outputs, h = model(X, h)
        
        loss = criterion(outputs, y.reshape(-1,).long())
        
        loss.backward(retain_graph = True)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        opt.step()
        
        train_losses.append(loss.item())
        
        ### Outer validation loop
        if (not iteration % 20) or (iteration == num_batches):
            i = 0
            val_losses = list()
            model.eval()
            h_ = model.init_hidden_state(mean = 0., stddev = .5)
            
            ### Inner validation loop
            for X_, y_ in batch_sequence(encoded_val, batch_size, seq_length)[0]:
                i += 1
                
                h_ = tuple([each.data for each in h_])
                
                X_ = torch.as_tensor(one_hot_encode(X_, len(unique_chars)))
                y_ = torch.as_tensor(y_)
                
                outputs_, h_ = model(X_, h_)
                
                val_loss = criterion(outputs_, y_.reshape(-1,).long())
                val_losses.append(val_loss.item())

            ### Report training and validation losses
            val_loss = torch.Tensor(val_losses).mean().item()

            train_loss = torch.Tensor(train_losses).mean().item()
            
            print('='*80)
            print(f'Epoch: {epoch}/{epochs}, Iteration {iteration}/{num_batches},',
                  f'Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}')
        
    print('\n'+'='*80)
    print('='*80)
    #print('='*60)
    #print(f'Epoch: {epoch}/{epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}\n')
    #print('='*60)

Epoch: 1/20, Iteration 20/3877, Train Loss: 4.6059, Valid Loss: 4.5866
Epoch: 1/20, Iteration 40/3877, Train Loss: 4.5874, Valid Loss: 4.5384
Epoch: 1/20, Iteration 60/3877, Train Loss: 4.5558, Valid Loss: 4.3946
Epoch: 1/20, Iteration 80/3877, Train Loss: 4.4362, Valid Loss: 3.5124
Epoch: 1/20, Iteration 100/3877, Train Loss: 4.2252, Valid Loss: 3.2408
Epoch: 1/20, Iteration 120/3877, Train Loss: 4.0584, Valid Loss: 3.1710
Epoch: 1/20, Iteration 140/3877, Train Loss: 3.9301, Valid Loss: 3.1434
Epoch: 1/20, Iteration 160/3877, Train Loss: 3.8287, Valid Loss: 3.1301
Epoch: 1/20, Iteration 180/3877, Train Loss: 3.7477, Valid Loss: 3.1231
Epoch: 1/20, Iteration 200/3877, Train Loss: 3.6839, Valid Loss: 3.1208
Epoch: 1/20, Iteration 220/3877, Train Loss: 3.6354, Valid Loss: 3.1206
Epoch: 1/20, Iteration 240/3877, Train Loss: 3.5926, Valid Loss: 3.1201
Epoch: 1/20, Iteration 260/3877, Train Loss: 3.5572, Valid Loss: 3.1205
Epoch: 1/20, Iteration 280/3877, Train Loss: 3.5254, Valid Loss: 3.1

Epoch: 1/20, Iteration 1100/3877, Train Loss: 3.2142, Valid Loss: 3.1257
Epoch: 1/20, Iteration 1120/3877, Train Loss: 3.2123, Valid Loss: 3.1243
Epoch: 1/20, Iteration 1140/3877, Train Loss: 3.2104, Valid Loss: 3.1242
Epoch: 1/20, Iteration 1160/3877, Train Loss: 3.2089, Valid Loss: 3.1247
Epoch: 1/20, Iteration 1180/3877, Train Loss: 3.2069, Valid Loss: 3.1241
Epoch: 1/20, Iteration 1200/3877, Train Loss: 3.2053, Valid Loss: 3.1255
Epoch: 1/20, Iteration 1220/3877, Train Loss: 3.2036, Valid Loss: 3.1258
Epoch: 1/20, Iteration 1240/3877, Train Loss: 3.2017, Valid Loss: 3.1262
Epoch: 1/20, Iteration 1260/3877, Train Loss: 3.2001, Valid Loss: 3.1249
Epoch: 1/20, Iteration 1280/3877, Train Loss: 3.1987, Valid Loss: 3.1268
Epoch: 1/20, Iteration 1300/3877, Train Loss: 3.1972, Valid Loss: 3.1287
Epoch: 1/20, Iteration 1320/3877, Train Loss: 3.1960, Valid Loss: 3.1298
Epoch: 1/20, Iteration 1340/3877, Train Loss: 3.1946, Valid Loss: 3.1296
Epoch: 1/20, Iteration 1360/3877, Train Loss: 3.193

Epoch: 1/20, Iteration 2180/3877, Train Loss: 3.1577, Valid Loss: 3.0878
Epoch: 1/20, Iteration 2200/3877, Train Loss: 3.1571, Valid Loss: 3.0854
Epoch: 1/20, Iteration 2220/3877, Train Loss: 3.1564, Valid Loss: 3.0841
Epoch: 1/20, Iteration 2240/3877, Train Loss: 3.1555, Valid Loss: 3.0807
Epoch: 1/20, Iteration 2260/3877, Train Loss: 3.1546, Valid Loss: 3.0756
Epoch: 1/20, Iteration 2280/3877, Train Loss: 3.1538, Valid Loss: 3.0717
Epoch: 1/20, Iteration 2300/3877, Train Loss: 3.1529, Valid Loss: 3.0665
Epoch: 1/20, Iteration 2320/3877, Train Loss: 3.1519, Valid Loss: 3.0605
Epoch: 1/20, Iteration 2340/3877, Train Loss: 3.1510, Valid Loss: 3.0576
Epoch: 1/20, Iteration 2360/3877, Train Loss: 3.1500, Valid Loss: 3.0512
Epoch: 1/20, Iteration 2380/3877, Train Loss: 3.1490, Valid Loss: 3.0445
Epoch: 1/20, Iteration 2400/3877, Train Loss: 3.1481, Valid Loss: 3.0384
Epoch: 1/20, Iteration 2420/3877, Train Loss: 3.1470, Valid Loss: 3.0337
Epoch: 1/20, Iteration 2440/3877, Train Loss: 3.145

Epoch: 1/20, Iteration 3260/3877, Train Loss: 3.0805, Valid Loss: 2.8311
Epoch: 1/20, Iteration 3280/3877, Train Loss: 3.0787, Valid Loss: 2.8311
Epoch: 1/20, Iteration 3300/3877, Train Loss: 3.0769, Valid Loss: 2.8287
Epoch: 1/20, Iteration 3320/3877, Train Loss: 3.0751, Valid Loss: 2.8187
Epoch: 1/20, Iteration 3340/3877, Train Loss: 3.0735, Valid Loss: 2.8163
Epoch: 1/20, Iteration 3360/3877, Train Loss: 3.0716, Valid Loss: 2.8116
Epoch: 1/20, Iteration 3380/3877, Train Loss: 3.0698, Valid Loss: 2.8083
Epoch: 1/20, Iteration 3400/3877, Train Loss: 3.0683, Valid Loss: 2.8031
Epoch: 1/20, Iteration 3420/3877, Train Loss: 3.0671, Valid Loss: 2.7969
Epoch: 1/20, Iteration 3440/3877, Train Loss: 3.0653, Valid Loss: 2.7908
Epoch: 1/20, Iteration 3460/3877, Train Loss: 3.0635, Valid Loss: 2.7891
Epoch: 1/20, Iteration 3480/3877, Train Loss: 3.0619, Valid Loss: 2.7835
Epoch: 1/20, Iteration 3500/3877, Train Loss: 3.0603, Valid Loss: 2.7779
Epoch: 1/20, Iteration 3520/3877, Train Loss: 3.058

Epoch: 2/20, Iteration 440/3877, Train Loss: 2.5865, Valid Loss: 2.6249
Epoch: 2/20, Iteration 460/3877, Train Loss: 2.5856, Valid Loss: 2.6206
Epoch: 2/20, Iteration 480/3877, Train Loss: 2.5831, Valid Loss: 2.6182
Epoch: 2/20, Iteration 500/3877, Train Loss: 2.5814, Valid Loss: 2.6155
Epoch: 2/20, Iteration 520/3877, Train Loss: 2.5797, Valid Loss: 2.6145
Epoch: 2/20, Iteration 540/3877, Train Loss: 2.5777, Valid Loss: 2.6127
Epoch: 2/20, Iteration 560/3877, Train Loss: 2.5774, Valid Loss: 2.6079
Epoch: 2/20, Iteration 580/3877, Train Loss: 2.5746, Valid Loss: 2.6063
Epoch: 2/20, Iteration 600/3877, Train Loss: 2.5728, Valid Loss: 2.6014
Epoch: 2/20, Iteration 620/3877, Train Loss: 2.5704, Valid Loss: 2.5996
Epoch: 2/20, Iteration 640/3877, Train Loss: 2.5688, Valid Loss: 2.5998
Epoch: 2/20, Iteration 660/3877, Train Loss: 2.5685, Valid Loss: 2.5960
Epoch: 2/20, Iteration 680/3877, Train Loss: 2.5680, Valid Loss: 2.5958
Epoch: 2/20, Iteration 700/3877, Train Loss: 2.5673, Valid Loss: