In [1]:
import gc
import os
import pickle

import torch
import numpy as np

from torch import nn
from torch.nn import functional as F
from torch import optim
from torchinfo import summary

from collections import namedtuple
import PyPDF3

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
print(device)

cpu


In [4]:
DATA_DIR = os.getcwd().replace('notebooks', 'data')

with open(os.path.join(DATA_DIR, 'anna.txt'), 'r') as file:
    text = file.read()

In [5]:
text[:120]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverything was in confusion i'

In [6]:
def encode_text(text, extend = True, unique_chars = None):
    result_tuple = namedtuple('results', ['encoded_text', 'unique_char', 'int2char', 'char2int'])
    
    if unique_chars is None:
        unique_chars = list(set(text).union(set('#[]{}+-*=!')))
    if extend:
        unique_chars.extend(list('#[]{}+-*=!'))
        
    char2int = {char : unique_chars.index(char) for char in unique_chars}
    int2char = {v : k for (k, v) in char2int.items()}
    
    encoded_text = np.array(list(map(lambda x: char2int[x], list(text))))
    
    return result_tuple(encoded_text, unique_chars, int2char, char2int)

In [7]:
batch_size = 32
seq_length = 64

In [8]:
numel_seq = batch_size * seq_length

In [9]:
numel_seq

2048

In [10]:
def batch_sequence(arr, batch_size, seq_length):
    numel_seq = batch_size * seq_length
    num_batches = arr.size // numel_seq
    
    arr = arr[: num_batches * numel_seq].reshape(batch_size, -1)
    #print(arr.shape)
    
    batched_data = [(arr[:, n : n + seq_length], arr[:, n + 1 : n + 1 + seq_length])
                    for n in range(0, arr.shape[1], seq_length)]
    
    ### Finalize final array size
    batched_data[-1] = (batched_data[-1][0],
                        np.append(batched_data[-1][1], batched_data[0][1][:, 0].reshape(-1, 1), axis = 1))
    
    ###batched_arr = [arr[n : n + numel_seq].reshape(batch_size, seq_length) for n in range(num_batches)]
    return iter(batched_data), num_batches

In [11]:
def one_hot_encode(arr, n_labels):
    
    # Initialize the the encoded array
    one_hot = np.zeros((arr.size, n_labels), dtype=np.float32)
    
    # Fill the appropriate elements with ones
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    
    # Finally reshape it to get back to the original array
    one_hot = one_hot.reshape((*arr.shape, n_labels))
    
    return one_hot

\begin{array}{ll} \\
        i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
        f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
        o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
        c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
        h_t = o_t \odot \tanh(c_t) \\
    \end{array}

\begin{aligned}
N ={} & \text{batch size} \\
L ={} & \text{sequence length} \\
D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
H_{in} ={} & \text{input_size} \\
H_{cell} ={} & \text{hidden_size} \\
H_{out} ={} & \text{proj_size if } \text{proj_size}>0 \text{ otherwise hidden_size} \\
\end{aligned}

In [12]:
print(help(nn.LSTM))

Help on class LSTM in module torch.nn.modules.rnn:

class LSTM(RNNBase)
 |  LSTM(*args, **kwargs)
 |  
 |  Applies a multi-layer long short-term memory (LSTM) RNN to an input
 |  sequence.
 |  
 |  
 |  For each element in the input sequence, each layer computes the following
 |  function:
 |  
 |  .. math::
 |      \begin{array}{ll} \\
 |          i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
 |          f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
 |          g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
 |          o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
 |          c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
 |          h_t = o_t \odot \tanh(c_t) \\
 |      \end{array}
 |  
 |  where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
 |  state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
 |  is the hidden state of the layer at time `t-1` or the initial hidden
 |  state a

In [13]:
def get_text(fpath):
    with open(fpath, "rb") as f:
        pdf = PyPDF3.PdfFileReader(f)
        text = str()
        for page_num in range(pdf.numPages):
            page = pdf.getPage(page_num)
            text = text + ' ' + page.extractText()
    return text

In [14]:
with open(os.path.join(DATA_DIR.replace('data', 'artefacts'), 'unique_char.pkl'), 'rb') as f:
    unique_chars = pickle.load(f)
    
#with open('weights.pt', 'rb') as f:
#    info = torch.load(f, map_location = torch.device('cpu') )

In [15]:
unique_chars

['2',
 '4',
 'u',
 'V',
 'Q',
 '@',
 'l',
 'F',
 '`',
 '6',
 'f',
 ' ',
 'C',
 'b',
 'x',
 '(',
 'H',
 'I',
 'X',
 'S',
 '=',
 '?',
 'L',
 'M',
 '{',
 'P',
 'k',
 'q',
 '*',
 ',',
 'y',
 'A',
 'B',
 'D',
 'W',
 'G',
 'Y',
 '9',
 '+',
 'J',
 'o',
 'z',
 ':',
 'N',
 ')',
 '3',
 'n',
 '-',
 'c',
 'g',
 '5',
 '[',
 ';',
 '7',
 '"',
 'a',
 'e',
 'R',
 'E',
 '\n',
 'h',
 ']',
 '1',
 '0',
 'm',
 'Z',
 'r',
 'j',
 '$',
 '8',
 '#',
 'p',
 'v',
 '}',
 'i',
 'd',
 '!',
 'U',
 't',
 'T',
 '_',
 's',
 '&',
 'O',
 '.',
 'K',
 "'",
 'w',
 '/',
 '%',
 '#',
 '[',
 ']',
 '{',
 '}',
 '+',
 '-',
 '*',
 '=',
 '!']

In [16]:
### Get text for validation data
val_text = get_text(os.path.join(DATA_DIR, "The-Prince.pdf"))

In [17]:
### Encode train data
encoded_text, _, _, _ = encode_text(text, False, unique_chars)

In [18]:
### Encode validation data
encoding_results = encode_text(val_text, False, unique_chars)
encoded_val = encoding_results.encoded_text

In [19]:
unique_char = encoding_results.unique_char
len(unique_char)

100

In [20]:
class CharRNN(nn.Module):
    """
    Character-level LSTM.
    
    Parameters
    ----------
    hidden_size:
        Number of output features for LSTM.
    dropout:
        Dropout probabilityfor LSTM.
    batch_size:
        Number of sequences in a batch.
    D:
        Number of directions: uni- or bidirectional architecture for LSTM.
    num_layers:
        Number of LSTM stacks.
    
    Returns
    -------
    output:
        Shape: [batch_size, sequence_length, num_features]
    hidden_state:
        Tuple containing:
        - Short-term hidden state
            Shape: [batch_size, sequence_length, num_features]
        - Cell state
            Shape: [batch_size, sequence_length, num_features]
    
    """
    def __init__(self, hidden_size = 128, dropout = 0.25,
                 batch_size = 32, D = 1, num_layers = 2):
        
        super(CharRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.dropout_rate = dropout
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.D = D
        
        self.lstm = nn.LSTM(input_size = len(unique_chars), hidden_size = self.hidden_size,
                            dropout = self.dropout_rate, batch_first = True,
                            bidirectional = True if self.D == 2 else False, bias = True,
                            num_layers = self.num_layers)
        
        self.batch_norm = nn.BatchNorm2d(self.batch_size)
        self.fc = nn.Linear(self.D*self.hidden_size, len(unique_chars))
        
    def forward(self, x, hidden_state):
        outputs, hidden_state = self.lstm(x, hidden_state)
        outputs = outputs.contiguous().view(-1, self.D*self.hidden_size)
        outputs = self.fc(outputs)
        
        return outputs, hidden_state
    
    def init_hidden_state(self, mean, stddev):
        """
        Initialize hidden state and context tensors.
        """
        
        h = torch.distributions.Normal(mean, stddev).sample((self.D*self.num_layers, self.batch_size, self.hidden_size))
        c = torch.distributions.Normal(mean, stddev).sample((self.D*self.num_layers, self.batch_size, self.hidden_size))
        
        return (h, c)
        

In [21]:
model = CharRNN(D = 1, hidden_size = 512, dropout = 0.25,
                batch_size = 32, num_layers = 2)

In [22]:
batch_size = 32
seq_length = 64

max_norm = 1.5
epochs = 20
lr = 4e-2

In [23]:
print(model)

CharRNN(
  (lstm): LSTM(100, 512, num_layers=2, batch_first=True, dropout=0.25)
  (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=512, out_features=100, bias=True)
)


model.load_state_dict(info['model_state_dict'])

In [24]:
model.to(device)

CharRNN(
  (lstm): LSTM(100, 512, num_layers=2, batch_first=True, dropout=0.25)
  (batch_norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (fc): Linear(in_features=512, out_features=100, bias=True)
)

In [25]:
### Objective functions and optimizer
opt = optim.Adam(model.parameters(), lr = lr)
criterion = nn.CrossEntropyLoss()

opt.load_state_dict(info['optimizer_state_dict'])

In [26]:
dir(opt)

['__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__getstate__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__setstate__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_hook_for_profile',
 '_zero_grad_profile_name',
 'add_param_group',
 'defaults',
 'load_state_dict',
 'param_groups',
 'state',
 'state_dict',
 'step',
 'zero_grad']

opt.param_groups[0]['lr'] = 1e-5

In [27]:
gc.collect()

11659

In [28]:
type(encoded_text)

numpy.ndarray

In [29]:
### Train data
_, num_batches = batch_sequence(encoded_text, batch_size, seq_length)

In [30]:
### Validation data
_, num_batches_ = batch_sequence(encoded_val, batch_size, seq_length)

In [31]:
### Num of train batches
print(num_batches)

969


In [32]:
### Num of valid batches
print(num_batches_)

137


In [None]:
### Outer training loop
for epoch in range(1, epochs + 1):
    h = model.init_hidden_state(mean = 0., stddev = .5)
    iteration = 0
    train_losses = list()
    
    ### Inner training loop
    for X, y in batch_sequence(encoded_text, batch_size, seq_length)[0]:
        X = one_hot_encode(X, len(unique_chars))
        X, y = torch.as_tensor(X).to(device), torch.as_tensor(y).to(device)
        
        model.train()
        iteration += 1
        
        h = tuple([each.data.to(device) for each in h])
        opt.zero_grad()
        
        outputs, h = model(X, h)
        
        loss = criterion(outputs, y.reshape(-1,).long())
        
        loss.backward(retain_graph = True)
        nn.utils.clip_grad_norm_(model.parameters(), max_norm)
        opt.step()
        
        train_losses.append(loss.item())
        
        ### Outer validation loop
        if (not iteration % 20) or (iteration == num_batches):
            i = 0
            val_losses = list()
            model.eval()
            h_ = model.init_hidden_state(mean = 0., stddev = .5)
            
            ### Inner validation loop
            for X_, y_ in batch_sequence(encoded_val, batch_size, seq_length)[0]:
                i += 1
                
                h_ = tuple([each.data.to(device) for each in h_])
                
                X_ = torch.as_tensor(one_hot_encode(X_, len(unique_chars))).to(device)
                y_ = torch.as_tensor(y_).to(device)
                
                outputs_, h_ = model(X_, h_)
                
                val_loss = criterion(outputs_, y_.reshape(-1,).long())
                val_losses.append(val_loss.item())

            ### Report training and validation losses
            val_loss = torch.Tensor(val_losses).mean().item()

            train_loss = torch.Tensor(train_losses).mean().item()
            
            print('='*80)
            print(f'Epoch: {epoch}/{epochs}, Iteration {iteration}/{num_batches},',
                  f'Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}')
        
    print('\n'+'='*80)
    print('='*80)
    #print('='*60)
    #print(f'Epoch: {epoch}/{epochs}, Train Loss: {train_loss:.4f}, Valid Loss: {val_loss:.4f}\n')
    #print('='*60)

Epoch: 1/20, Iteration 20/969, Train Loss: 5.2870, Valid Loss: 8.0907
Epoch: 1/20, Iteration 40/969, Train Loss: 4.3389, Valid Loss: 8.3135
Epoch: 1/20, Iteration 60/969, Train Loss: 3.9462, Valid Loss: 6.9592
Epoch: 1/20, Iteration 80/969, Train Loss: 3.7479, Valid Loss: 12.6951
Epoch: 1/20, Iteration 100/969, Train Loss: 3.6272, Valid Loss: 9.8642


In [None]:
with open('new-weights.net', 'wb') as f:
    torch.save({'model_state_dict' : model.state_dict(),
                'optimizer_state_dict' : opt.state_dict(),
               }, f)