In [1]:
import gc

import torch
import numpy as np
from torch import nn
from torch.nn import functional as F
from torch import optim
from torchinfo import summary

from collections import namedtuple
import PyPDF3

In [2]:
DATA_DIR = os.getcwd().replace('notebooks', 'data')

with open(os.path.join(DATA_DIR, 'anna.txt'), 'r') as file:
    text = file.read()

In [3]:
text[:120]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverything was in confusion i'

In [4]:
unique_chars = list(set(text))

In [5]:
unique_chars

['3',
 'b',
 ':',
 'K',
 'c',
 '7',
 '*',
 'J',
 'X',
 'D',
 'n',
 '1',
 'E',
 's',
 'i',
 'o',
 'Z',
 '/',
 'S',
 'u',
 '`',
 'q',
 'T',
 '5',
 'j',
 '?',
 'm',
 'v',
 'I',
 '8',
 'A',
 'W',
 ',',
 '\n',
 'R',
 'l',
 't',
 'Y',
 'U',
 '"',
 '9',
 '$',
 'a',
 "'",
 'e',
 'O',
 'g',
 'y',
 'd',
 '6',
 'P',
 '0',
 '2',
 'M',
 'h',
 '&',
 'k',
 'f',
 'B',
 'N',
 'L',
 '_',
 '(',
 'Q',
 'F',
 'z',
 '!',
 'H',
 '%',
 'C',
 '-',
 '4',
 ')',
 'V',
 'p',
 'r',
 ';',
 'x',
 '.',
 '@',
 'G',
 'w',
 ' ']

In [6]:
len(unique_chars)

83

In [7]:
chars2int = {char : unique_chars.index(char) for char in unique_chars}
int2char = {v : k for (k, v) in chars2int.items()}

In [8]:
def encode_text(text, unique_chars = None):
    result_tuple = namedtuple('results', ['encoded_text', 'unique_char', 'int2char', 'char2int'])
    
    if unique_chars is None:
        unique_chars = list(set(text).union(set('#[]{}+-*=!')))
    else:
        unique_chars.extend(list('#[]{}+-*=!'))
        
    char2int = {char : unique_chars.index(char) for char in unique_chars}
    int2char = {v : k for (k, v) in char2int.items()}
    
    encoded_text = np.array(list(map(lambda x: char2int[x], list(text))))
    
    return result_tuple(encoded_text, unique_chars, int2char, char2int)

In [9]:
encoded_text, unique_chars, int2char, char2int = encode_text(text)

In [10]:
def one_hot_convert(arr, n_labels):
    nrows = arr.size
    array = np.zeros(shape = [nrows, n_labels])
    array[np.arange(array.shape[0]), arr.flatten()] = 1.
    
    return array

In [11]:
one_hot_convert(np.array([[1, 2, 3, 5]]), 10)

array([[0., 1., 0., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 1., 0., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 1., 0., 0., 0., 0., 0., 0.],
       [0., 0., 0., 0., 0., 1., 0., 0., 0., 0.]])

In [12]:
batch_size = 32
seq_length = 16

In [13]:
numel_seq = batch_size * seq_length

In [14]:
numel_seq

512

In [15]:
encoded_text.shape[0]/numel_seq

3877.388671875

In [16]:
def batch_sequence(arr, batch_size, seq_length):
    numel_seq = batch_size * seq_length
    num_batches = arr.size // numel_seq
    
    arr = arr[: num_batches * numel_seq].reshape(batch_size, -1)
    #print(arr.shape)
    
    batched_data = [(arr[:, n : n + seq_length], arr[:, n + 1 : n + 1 + seq_length])
                    for n in range(0, arr.shape[1], seq_length)]
    
    ### Finalize final array size
    batched_data[-1] = (batched_data[-1][0],
                        np.append(batched_data[-1][1], batched_data[0][1][:, 0].reshape(-1, 1), axis = 1))
    
    ###batched_arr = [arr[n : n + numel_seq].reshape(batch_size, seq_length) for n in range(num_batches)]
    return iter(batched_data), num_batches

In [17]:
batch, _ = batch_sequence(encoded_text, 32, 16)

In [18]:
X, y = next(batch)

In [19]:
def one_hot_encode(arr, n_labels):
    
    # Initialize the the encoded array
    one_hot = np.zeros((arr.size, n_labels), dtype=np.float32)
    
    # Fill the appropriate elements with ones
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    
    # Finally reshape it to get back to the original array
    one_hot = one_hot.reshape((*arr.shape, n_labels))
    
    return one_hot

In [20]:
#X_ = one_hot_convert(X, 90)
X_ = one_hot_encode(X, len(unique_chars))

In [21]:
X.shape

(32, 16)

In [22]:
X_.shape

(32, 16, 90)

\begin{array}{ll} \\
        i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
        f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
        o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
        c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
        h_t = o_t \odot \tanh(c_t) \\
    \end{array}

In [23]:
print(help(nn.LSTM))

Help on class LSTM in module torch.nn.modules.rnn:

class LSTM(RNNBase)
 |  LSTM(*args, **kwargs)
 |  
 |  Applies a multi-layer long short-term memory (LSTM) RNN to an input
 |  sequence.
 |  
 |  
 |  For each element in the input sequence, each layer computes the following
 |  function:
 |  
 |  .. math::
 |      \begin{array}{ll} \\
 |          i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
 |          f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
 |          g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
 |          o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
 |          c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
 |          h_t = o_t \odot \tanh(c_t) \\
 |      \end{array}
 |  
 |  where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
 |  state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
 |  is the hidden state of the layer at time `t-1` or the initial hidden
 |  state a

In [24]:
def get_text(fpath):
    with open(fpath, "rb") as f:
        pdf = PyPDF3.PdfFileReader(f)
        text = str()
        for page_num in range(pdf.numPages):
            page = pdf.getPage(page_num)
            text = text + ' ' + page.extractText()
    return text

In [25]:
### Get text for validation data
val_text = get_text("The-Prince.pdf")

In [26]:
### Encode validation data
encoding_results = encode_text(val_text, unique_chars)
encoded_val = encoding_results.encoded_text

In [27]:
unique_char = encoding_results.unique_char
len(unique_char)

100

\begin{aligned}
 |              N ={} & \text{batch size} \\
 |              L ={} & \text{sequence length} \\
 |              D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
 |              H_{in} ={} & \text{input\_size} \\
 |              H_{cell} ={} & \text{hidden\_size} \\
 |              H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
 |          \end{aligned}

In [28]:
class CharRNN(nn.Module):
    """
    Character-level LSTM.
    
    Parameters
    ----------
    hidden_size:
        Number of output features for LSTM.
    dropout:
        Dropout probability for LSTM.
    batch_size:
        Number of sequences in a batch.
    D:
        Number of directions: uni- or bidirectional architecture for LSTM.
    num_layers:
        Number of LSTM stacks.
    
    Returns
    -------
    output:
        Shape: [batch_size, sequence_length, num_features]
    hidden_state:
        Tuple containing:
        - Short-term hidden state
            Shape: [batch_size, sequence_length, num_features]
        - Cell state
            Shape: [batch_size, sequence_length, num_features]
    
    """
    def __init__(self, hidden_size = 128, dropout = 0.25,
                 batch_size = 32, D = 1, num_layers = 2):
        
        super(CharRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.dropout_rate = dropout
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.D = D
        
        self.lstm = nn.LSTM(input_size = len(unique_chars), hidden_size = self.hidden_size,
                            dropout = self.dropout_rate, batch_first = True,
                            bidirectional = True if self.D == 2 else False, bias = True,
                            num_layers = self.num_layers)
        
        self.fc = nn.Linear(self.D*self.hidden_size, len(unique_chars))
        
    def forward(self, x, hidden_state):
        outputs, hidden_state = self.lstm(x, hidden_state)
        outputs = outputs.contiguous().view(-1, self.D*self.hidden_size)
        outputs = self.fc(outputs)
        
        return outputs, hidden_state
    
    def init_hidden_state(self, mean, stddev):
        """
        Initialize hidden state and context tensors.
        """
        weights = next(self.parameters()).data
        h = torch.distributions.Normal(mean, stddev).sample((self.D*self.num_layers, self.batch_size, self.hidden_size))
        c = torch.distributions.Normal(mean, stddev).sample((self.D*self.num_layers, self.batch_size, self.hidden_size))
        
        return (h, c)
        

In [29]:
model = CharRNN(D = 1)

In [30]:
batch_size = 32
seq_length = 16

max_norm = 1.5
epochs = 20
lr = 1e-4

In [31]:
print(model)

CharRNN(
  (lstm): LSTM(100, 128, num_layers=2, batch_first=True, dropout=0.25)
  (fc): Linear(in_features=128, out_features=100, bias=True)
)


In [32]:
### Objective functions and optimizer
opt = optim.Adam(model.parameters(), lr = lr)
criterion = nn.CrossEntropyLoss()

In [33]:
gc.collect()

11627

In [34]:
### Train data
_, num_batches = batch_sequence(encoded_text, batch_size, seq_length)

In [35]:
### Validation data
_, num_batches_ = batch_sequence(encoded_val, batch_size, seq_length)

In [36]:
### Num of train batches
print(num_batches)

3877


In [37]:
### Num of valid batches
print(num_batches_)

551


In [38]:
### Outer training loop
for epoch in range(1, epochs + 1):
    h = model.init_hidden_state(mean = 0., stddev = .5)
    iteration = 0
    train_losses = list()
    
    ### Inner training loop
    for X, y in batch_sequence(encoded_text, batch_size, seq_length)[0]:
        X = one_hot_encode(X, len(unique_chars))
        X, y = torch.as_tensor(X), torch.as_tensor(y)
        
        model.train()
        iteration += 1
        
        h = tuple([each.data for each in h])
        opt.zero_grad()
        
        outputs, h = model(X, h)
        
        loss = criterion(outputs, y.reshape(-1,).long())
        
        loss.backward(retain_graph = True)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        opt.step()
        
        train_losses.append(loss.item())
        
        ### Outer validation loop
        if (not iteration % 20) or (iteration == num_batches):
            i = 0
            val_losses = list()
            model.eval()
            h_ = model.init_hidden_state(mean = 0., stddev = .5)
            
            ### Inner validation loop
            for X_, y_ in batch_sequence(encoded_val, batch_size, seq_length)[0]:
                i += 1
                
                h_ = tuple([each.data for each in h_])
                
                X_ = torch.as_tensor(one_hot_encode(X_, len(unique_chars)))
                y_ = torch.as_tensor(y_)
                
                outputs_, h_ = model(X_, h_)
                
                val_loss = criterion(outputs_, y_.reshape(-1,).long())
                val_losses.append(val_loss.item())

            ### Report training and validation losses
            val_loss = torch.Tensor(val_losses).mean().item()

            train_loss = torch.Tensor(train_losses).mean().item()
            
            print('='*80)
            print(f'Epoch: {epoch}/{epochs}, Iteration {iteration}/{num_batches},',
                  f'Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}')
        
    print('\n'+'='*80)
    print('='*80)
    #print('='*60)
    #print(f'Epoch: {epoch}/{epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}\n')
    #print('='*60)

Epoch: 1/20, Iteration 20/3877, Train Loss: 4.6072, Valid Loss: 4.5920


KeyboardInterrupt: 

In [39]:
X.shape

torch.Size([32, 16, 100])

In [40]:
y.shape

torch.Size([32, 16])

In [41]:
y

tensor([[38, 22, 18,  1, 37, 18, 77, 82, 88, 14, 29, 24, 88,  1, 37, 14],
        [ 0, 87, 14, 23, 32, 88, 37, 14, 68, 18, 88, 38, 29, 88,  1, 37],
        [32, 18, 18, 29, 88, 14, 29, 23,  1, 37, 33, 29, 22, 88, 18, 80],
        [87, 29, 82, 88, 32, 37, 38, 87, 33, 29, 22, 88, 37, 18, 77, 88],
        [ 0, 45, 88,  1, 37, 14,  1, 88, 32, 37, 18, 88, 18, 80, 76, 18],
        [18, 88, 44, 29, 45, 38, 77,  1, 44, 29, 14,  1, 18, 88, 65, 33],
        [38, 65, 18, 88, 87, 18, 88, 65, 44, 32,  1, 88,  0, 38, 38, 43],
        [ 1, 37, 18, 77, 88, 45, 77, 18, 18, 88, 45, 77, 38, 65, 88, 77],
        [44, 33, 11, 43,  0, 23, 88, 11, 14, 65, 18, 88,  7, 14, 11, 43],
        [11, 38, 65, 18, 88, 14, 29, 24, 88, 37, 14, 68, 18, 83, 11, 38],
        [38, 68, 33, 29, 22, 88, 38, 44,  1, 32, 33, 24, 18, 88,  7, 38],
        [65, 18, 29,  1, 88, 45, 38, 77, 88, 37, 18, 77, 88, 11, 77, 33],
        [14, 24, 83,  7, 18, 18, 29, 88,  1, 38, 38, 88,  7, 38,  0, 24],
        [88,  7, 23, 88, 33,  1, 81, 8

In [44]:
y_ = model(X, h)

In [46]:
y_[0].shape

torch.Size([512, 100])