# Bonus Track Assignment: Char RNN

## Book chosen: Lord Of The Rings first book (LOTR 1)

In [1]:
import torch 
import math
import torch.nn as nn
import torch.nn.functional as F

from learning import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
text = open('lotr1.txt', 'r').read().lower()
# remove some unwanted characters
text = text.replace('\xa0', ' ')
text = text.replace('/', '') 
text = text.replace('*', '')
len(text)

938278

In [3]:
char_to_idx = {char: idx for (idx, char) in enumerate(list(set(text)))}
idx_to_char = {idx: char for (char, idx) in char_to_idx.items()}
one_hot_size = len(char_to_idx)
print(f'Number of unique characters: {one_hot_size}')

Number of unique characters: 57


In [4]:
data = OneHotDataset(60, 3, text, char_to_idx, one_hot_size, device)

In [5]:
epochs = 30
hidden_size = 256
batch_size = 512
lr = 0.001
verbose = True

In [6]:
char_rnn = CharRNN(hidden_size=hidden_size, vocab_size=len(char_to_idx), input_size=one_hot_size).to(device)
loader = torch.utils.data.DataLoader(data, batch_size=batch_size)
optim = torch.optim.Adam(char_rnn.parameters(), lr=lr)
loss = torch.nn.CrossEntropyLoss()

for epoch in range(epochs):
    run_loss = 0.0
    for x, y in loader:
        optim.zero_grad()
        out, h = char_rnn(x.transpose(0, 1), None)
        l = loss(out, y)
        l.backward()
        optim.step()

        run_loss += l.item()
    if verbose:
        print(f'Epoch {epoch}, loss: {run_loss/len(loader)}')

Epoch 0, loss: 2.372396403755962
Epoch 1, loss: 1.8951143701415991
Epoch 2, loss: 1.689835849643339
Epoch 3, loss: 1.5628491493762027
Epoch 4, loss: 1.4750569395464883
Epoch 5, loss: 1.4100031522994345
Epoch 6, loss: 1.3586820596173625
Epoch 7, loss: 1.3162987154166164
Epoch 8, loss: 1.2797546778880242
Epoch 9, loss: 1.2471281524180586
Epoch 10, loss: 1.2172560764022426
Epoch 11, loss: 1.1893460916419272
Epoch 12, loss: 1.162863648367006
Epoch 13, loss: 1.1373860793105903
Epoch 14, loss: 1.11264721168824
Epoch 15, loss: 1.0898446897242933
Epoch 16, loss: 1.0664823427918304
Epoch 17, loss: 1.0460560579346752
Epoch 18, loss: 1.0290625857253317
Epoch 19, loss: 1.0189671581778705
Epoch 20, loss: 1.0160599636758407
Epoch 21, loss: 1.0106739944209835
Epoch 22, loss: 1.0058257832464728
Epoch 23, loss: 0.996593036940367
Epoch 24, loss: 0.9873158620539352
Epoch 25, loss: 0.9805808860424498
Epoch 26, loss: 0.9730567108784877
Epoch 27, loss: 0.9671263012870439
Epoch 28, loss: 0.9561812377210139
E

In [7]:
# comment to avoid overwriting the model
torch.save(char_rnn.state_dict(), 'char_rnn.pt')

In [12]:
char_rnn = CharRNN(hidden_size=hidden_size, vocab_size=len(char_to_idx), input_size=one_hot_size).to(device)
char_rnn.load_state_dict(torch.load('char_rnn.pt'))

<All keys matched successfully>

In [9]:
# output shape (batch_size, vocab_size)
# input shape (seq_len, batch_size, input_size)

def char_to_onehot(char, char_to_idx):
    x = torch.zeros(1, 1, one_hot_size).to(device)
    x[0, 0, char_to_idx[char]] = 1
    
    return x

def seq_to_onehot(seq, char_to_idx):
    x = torch.zeros(len(seq), 1, one_hot_size)
    for t, char in enumerate(seq):
        x[t, 0, char_to_idx[char]] = 1

    return x.to(device)

In [10]:
def generate_text(model, length, start, char_to_idx, idx_to_char): # generate text from nothing
    with torch.no_grad():
        for temperature in [0.2, 0.5, 1.2]:
            generated = start # start from space
            input = seq_to_onehot(generated, char_to_idx)
            out, h = model(input, None)
            p = F.softmax(out / temperature, dim=1)
            idx = torch.distributions.Categorical(p[0]).sample().item()
            generated += idx_to_char[idx]
            last_char = generated[-1]

            for _ in range(length):
                out, h = model(char_to_onehot(last_char, char_to_idx), h) 
                out = torch.divide(out, temperature)
                p = F.softmax(out, dim=1)
                idx = torch.distributions.Categorical(p[0]).sample().item()
                generated += idx_to_char[idx]
                last_char = generated[-1]
                
            print(f'Temperature: {temperature}\n{generated}\n')

In [17]:
start = 'frodo realized that a dark presence was following him '
print(start)
generate_text(char_rnn, 300, start, char_to_idx, idx_to_char)

frodo realized that a dark presence was following him 
Temperature: 0.2
frodo realized that a dark presence was following him a great beatir and both. the eastern shore and many land benowing on the eastern sky. there was a black shat speechess of letten brought them from the bound towers or stood as it seemed to be a still here the shadow of long years and delly. and if we cannot stain the world to see the ring of mordor. 

Temperature: 0.5
frodo realized that a dark presence was following him at last.  for he began to have set imily. it was hattered with him. the air were the dark with with a side north and hands, and some of the eastern followile the sare or heard of a white for a moment he gates and stirm. he spoke the mist or fally great smooth and the councing for it. then he came to 

Temperature: 1.2
frodo realized that a dark presence was following him leaks age now feel south:. but in this i compantor may brea, forward, holes! seeing they caught i have tons. he can," eanty qu