In [1]:
import gc
import pickle

import torch
import numpy as np

from torch import nn
from torch.nn import functional as F
from torch import optim
from torchinfo import summary

from collections import namedtuple
import PyPDF3

In [2]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [3]:
print(device)

cpu


In [4]:
DATA_DIR = os.getcwd().replace('notebooks', 'data')

with open(os.path.join(DATA_DIR, 'anna.txt'), 'r') as file:
    text = file.read()

In [5]:
text[:120]

'Chapter 1\n\n\nHappy families are all alike; every unhappy family is unhappy in its own\nway.\n\nEverything was in confusion i'

In [6]:
def encode_text(text, extend = True, unique_chars = None):
    result_tuple = namedtuple('results', ['encoded_text', 'unique_char', 'int2char', 'char2int'])
    
    if unique_chars is None:
        unique_chars = list(set(text).union(set('#[]{}+-*=!')))
    if extend:
        unique_chars.extend(list('#[]{}+-*=!'))
        
    char2int = {char : unique_chars.index(char) for char in unique_chars}
    int2char = {v : k for (k, v) in char2int.items()}
    
    encoded_text = np.array(list(map(lambda x: char2int[x], list(text))))
    
    return result_tuple(encoded_text, unique_chars, int2char, char2int)

In [7]:
batch_size = 32
seq_length = 16

In [8]:
numel_seq = batch_size * seq_length

In [9]:
numel_seq

512

In [10]:
def batch_sequence(arr, batch_size, seq_length):
    numel_seq = batch_size * seq_length
    num_batches = arr.size // numel_seq
    
    arr = arr[: num_batches * numel_seq].reshape(batch_size, -1)
    #print(arr.shape)
    
    batched_data = [(arr[:, n : n + seq_length], arr[:, n + 1 : n + 1 + seq_length])
                    for n in range(0, arr.shape[1], seq_length)]
    
    ### Finalize final array size
    batched_data[-1] = (batched_data[-1][0],
                        np.append(batched_data[-1][1], batched_data[0][1][:, 0].reshape(-1, 1), axis = 1))
    
    ###batched_arr = [arr[n : n + numel_seq].reshape(batch_size, seq_length) for n in range(num_batches)]
    return iter(batched_data), num_batches

In [11]:
def one_hot_encode(arr, n_labels):
    
    # Initialize the the encoded array
    one_hot = np.zeros((arr.size, n_labels), dtype=np.float32)
    
    # Fill the appropriate elements with ones
    one_hot[np.arange(one_hot.shape[0]), arr.flatten()] = 1.
    
    # Finally reshape it to get back to the original array
    one_hot = one_hot.reshape((*arr.shape, n_labels))
    
    return one_hot

\begin{array}{ll} \\
        i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
        f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
        g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
        o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
        c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
        h_t = o_t \odot \tanh(c_t) \\
    \end{array}

In [12]:
print(help(nn.LSTM))

Help on class LSTM in module torch.nn.modules.rnn:

class LSTM(RNNBase)
 |  LSTM(*args, **kwargs)
 |  
 |  Applies a multi-layer long short-term memory (LSTM) RNN to an input
 |  sequence.
 |  
 |  
 |  For each element in the input sequence, each layer computes the following
 |  function:
 |  
 |  .. math::
 |      \begin{array}{ll} \\
 |          i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\
 |          f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\
 |          g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\
 |          o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\
 |          c_t = f_t \odot c_{t-1} + i_t \odot g_t \\
 |          h_t = o_t \odot \tanh(c_t) \\
 |      \end{array}
 |  
 |  where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell
 |  state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}`
 |  is the hidden state of the layer at time `t-1` or the initial hidden
 |  state a

In [13]:
def get_text(fpath):
    with open(fpath, "rb") as f:
        pdf = PyPDF3.PdfFileReader(f)
        text = str()
        for page_num in range(pdf.numPages):
            page = pdf.getPage(page_num)
            text = text + ' ' + page.extractText()
    return text

In [14]:
with open('unique_char.pkl', 'rb') as f:
    unique_chars = pickle.load(f)
    
with open('weights.pt', 'rb') as f:
    info = torch.load(f, map_location = torch.device('cpu') )

In [15]:
unique_chars

['2',
 '4',
 'u',
 'V',
 'Q',
 '@',
 'l',
 'F',
 '`',
 '6',
 'f',
 ' ',
 'C',
 'b',
 'x',
 '(',
 'H',
 'I',
 'X',
 'S',
 '=',
 '?',
 'L',
 'M',
 '{',
 'P',
 'k',
 'q',
 '*',
 ',',
 'y',
 'A',
 'B',
 'D',
 'W',
 'G',
 'Y',
 '9',
 '+',
 'J',
 'o',
 'z',
 ':',
 'N',
 ')',
 '3',
 'n',
 '-',
 'c',
 'g',
 '5',
 '[',
 ';',
 '7',
 '"',
 'a',
 'e',
 'R',
 'E',
 '\n',
 'h',
 ']',
 '1',
 '0',
 'm',
 'Z',
 'r',
 'j',
 '$',
 '8',
 '#',
 'p',
 'v',
 '}',
 'i',
 'd',
 '!',
 'U',
 't',
 'T',
 '_',
 's',
 '&',
 'O',
 '.',
 'K',
 "'",
 'w',
 '/',
 '%',
 '#',
 '[',
 ']',
 '{',
 '}',
 '+',
 '-',
 '*',
 '=',
 '!']

In [16]:
info

{'model_state_dict': OrderedDict([('lstm.weight_ih_l0',
               tensor([[ 0.0153, -0.0876,  0.1679,  ...,  0.0518, -0.0367, -0.0339],
                       [-0.1801,  0.0243,  0.0953,  ..., -0.0628,  0.0793,  0.0684],
                       [-0.2093,  0.0910,  0.6245,  ..., -0.0290,  0.0812, -0.0207],
                       ...,
                       [ 0.1932, -0.1144, -0.2806,  ...,  0.0126,  0.0307,  0.0137],
                       [ 0.2034,  0.1398,  0.4949,  ...,  0.0112,  0.0776,  0.0411],
                       [ 0.0148,  0.1126,  0.0408,  ..., -0.0366,  0.0138,  0.0754]])),
              ('lstm.weight_hh_l0',
               tensor([[ 0.2525,  0.0819, -0.1941,  ..., -0.1613,  0.0800, -0.1786],
                       [ 0.4471,  0.2502, -0.4032,  ..., -0.0685,  0.2477, -0.5673],
                       [ 0.1754, -0.2310, -0.3944,  ..., -0.2395, -0.0897,  0.0155],
                       ...,
                       [ 0.2115, -0.0121, -0.0146,  ..., -0.3699,  0.0284, -0.2815],

In [17]:
info.keys()

dict_keys(['model_state_dict', 'optimizer_state_dict', 'epoch', 'train_loss', 'test_loss'])

In [18]:
### Get text for validation data
val_text = get_text(os.path.join(DATA_DIR, "The-Prince.pdf"))

In [19]:
### Encode train data
encoded_text, _, _, _ = encode_text(text, False, unique_chars)

In [20]:
### Encode validation data
encoding_results = encode_text(val_text, False, unique_chars)
encoded_val = encoding_results.encoded_text

In [21]:
unique_char = encoding_results.unique_char
len(unique_char)

100

\begin{aligned}
 |              N ={} & \text{batch size} \\
 |              L ={} & \text{sequence length} \\
 |              D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\
 |              H_{in} ={} & \text{input\_size} \\
 |              H_{cell} ={} & \text{hidden\_size} \\
 |              H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\
 |          \end{aligned}

In [22]:
class CharRNN(nn.Module):
    """
    Character-level LSTM.
    
    Parameters
    ----------
    hidden_size:
        Number of output features for LSTM.
    dropout:
        Dropout probabilityfor LSTM.
    batch_size:
        Number of sequences in a batch.
    D:
        Number of directions: uni- or bidirectional architecture for LSTM.
    num_layers:
        Number of LSTM stacks.
    
    Returns
    -------
    output:
        Shape: [batch_size, sequence_length, num_features]
    hidden_state:
        Tuple containing:
        - Short-term hidden state
            Shape: [batch_size, sequence_length, num_features]
        - Cell state
            Shape: [batch_size, sequence_length, num_features]
    
    """
    def __init__(self, hidden_size = 128, dropout = 0.25,
                 batch_size = 32, D = 1, num_layers = 2):
        
        super(CharRNN, self).__init__()
        
        self.hidden_size = hidden_size
        self.dropout_rate = dropout
        self.num_layers = num_layers
        self.batch_size = batch_size
        self.D = D
        
        self.lstm = nn.LSTM(input_size = len(unique_chars), hidden_size = self.hidden_size,
                            dropout = self.dropout_rate, batch_first = True,
                            bidirectional = True if self.D == 2 else False, bias = True,
                            num_layers = self.num_layers)
        
        self.fc = nn.Linear(self.D*self.hidden_size, len(unique_chars))
        
    def forward(self, x, hidden_state):
        outputs, hidden_state = self.lstm(x, hidden_state)
        outputs = outputs.contiguous().view(-1, self.D*self.hidden_size)
        outputs = self.fc(outputs)
        
        return outputs, hidden_state
    
    def init_hidden_state(self, mean, stddev):
        """
        Initialize hidden state and context tensors.
        """
        
        h = torch.distributions.Normal(mean, stddev).sample((self.D*self.num_layers, self.batch_size, self.hidden_size))
        c = torch.distributions.Normal(mean, stddev).sample((self.D*self.num_layers, self.batch_size, self.hidden_size))
        
        return (h, c)
        

In [23]:
model = CharRNN(D = 1)

In [24]:
batch_size = 32
seq_length = 16

max_norm = 1.5
epochs = 5
lr = 1e-4

In [25]:
print(model)

CharRNN(
  (lstm): LSTM(100, 128, num_layers=2, batch_first=True, dropout=0.25)
  (fc): Linear(in_features=128, out_features=100, bias=True)
)


In [26]:
model.load_state_dict(info['model_state_dict'])

<All keys matched successfully>

In [27]:
model.to(device)

CharRNN(
  (lstm): LSTM(100, 128, num_layers=2, batch_first=True, dropout=0.25)
  (fc): Linear(in_features=128, out_features=100, bias=True)
)

In [28]:
model.requires_grad_ = False

In [29]:
int2char = encoding_results.int2char
char2int = encoding_results.char2int

In [48]:
seed = 'David'
k = 10

In [49]:
seed_list = list(seed)
a = torch.distributions.Normal(scale = 0.5, loc = 0.).sample((2, 1, 128))
b = torch.distributions.Normal(scale = 0.5, loc = 0.).sample((2, 1, 128))
h = (a, b)
for char in seed_list:
    h = tuple([each.data for each in h])
    out, h = model(torch.tensor(one_hot_encode(np.array(char2int[char]).reshape(1, -1), 100)), h)
    p = F.softmax(out, dim=-1)
    _, chars = p.topk(k = k, dim=-1)
    next_char = chars[0,torch.randint(low = 0, high = k, size = (1,))].item()

seed_list.append(int2char[next_char])

In [50]:
seed_list

['D', 'a', 'v', 'i', 'd', 'y']

In [51]:
p.topk(k = 5, dim=-1)

torch.return_types.topk(
values=tensor([[0.2452, 0.2362, 0.1562, 0.0549, 0.0481]], grad_fn=<TopkBackward0>),
indices=tensor([[30, 56, 74, 40, 11]]))

In [52]:
for ix in range(1000):
    h = tuple([each.data for each in h])
    out, h = model(torch.tensor(one_hot_encode(np.array(char2int[seed_list[-1]]).reshape(1, -1), 100)), h)
    p = F.softmax(out, dim=-1)
    _, chars = p.topk(k = k, dim=-1)
    next_char = chars[0,torch.distributions.Uniform(low = 0, high = k).sample().to(torch.int32).item()].item()
    seed_list.append(int2char[next_char])
    

In [53]:
chars.shape

torch.Size([1, 10])

In [54]:
pred_text = ''.join(seed_list)

In [55]:
print(pred_text)

Davidya
or easll hllwifomesibracy:-'All,:'-_
Lvvr'
medltiny-lrationslibfhs-wounstss ollrine elblr styamsine,
swnolivua say hrope
pelly;".
Ardiabtaces;;,)e-Lvyu.'s
attwyatieft._S
ato" an efphing,).)
I'd,";",'m to
cieft,'g,_,,. Lots.-"Teen.n.''.f)-
hrwsoon't,; ad so chty.,
Agoanizcley'ls;- t erf id.',
alame
tulisf hrilkey'-.-Bnuvy.-Betrucus.,'s.-Wimhlods mmacpor fucceribll tmea twuert. Ad
oniade sailiess,, bvet'ostey fmignaif,'lyeldse._n-.I"ss
Spife.-"T'houblused"-_Lvru
knand
ime bvencily ng dose-daypiny-bowaictsed's- holfaucy
shit.
"Be,;'r.,,'h.?,? Haspevtins.,ns,'"--is exe thry-to
lugkl bleam hyapinaciefunl,'g-hanchs-wott adonivuler,;.;.n
I, wilkiglil feerisiozie mucabeem
bousht wyirk,'c tuimiel,),;)st,
ifenistor,, smilf-or iln.s."'-"Ivrry
trmmougen?_'Ald-rweodibye
cutnanly-dhad'f?e""'-_Syokn'g'k
hhavcoece._' hons
ago,
ene
oboskayialy' wlhent
slulf,.'
Hh surdeemeyth,s twalcsoryshs;-saiche
conannsko timoig.'"".n-How,,;;_'s);)
Kesyit,'m..;;,'t,;-.
Age,-'flompuem,ss wavy;.-Befut,.' alsoan

In [36]:
print(pred_text)

Davide,
treinily wonle.. "Tely,
way a poisule,.

I'ld,...
"All, aled,,
trepill to
many, answey,
thoogs. I'r
samficatinc tham....".
Ans trivalialy
breathiom, thind trings tarth a leatisher,"
haves
a sencioul faringet,
takay
any
anspisely,
steivn on
shoold howered,
she
he wosddevenced. "All a lanctiomss." Ho tounglliens. "When'her
same..
The clovoube homs, we so wife
atticuted
im a lanct...

A his
mothach.
"I
al at
im deciling.. To
you'se. Havies." 
Shavarevs whateder,""" sinicatiagay,
at
tiself.
"
Tinky
had. At an toubted
trouttlien
tay strildialy at
hid work whis,,
sain
too,
weriged.
 "Yay tell.
I
asterst. In
takiness
as
is
wiflire.
Tusko he hoors, tak it. Ithored thints oven
had,"
ser iflenstallite, we as hereles,,
attonts at
its anst imsisfed, takant
at her, brand a morow-dreat of
it any
tensed,"."s"."

Show have
herstichs,"
he'ce
hig stoob took her
facurang on tak anowithiods,,. Alderadlive,"" taled," soid, stations
wher, tho town, a stroumly
sourd,
went tira is and.."
"And soletely

In [37]:
p.shape

torch.Size([1, 100])

In [38]:
out.shape

torch.Size([1, 100])

In [40]:
(torch.tensor(one_hot_encode(np.array(char2int[char]).reshape(1, -1), 100))).shape

torch.Size([1, 1, 100])