In [1]:
import torch 
import math
import torch.nn as nn
import torch.nn.functional as F

from arnn import *

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [2]:
text = open('lotr1.txt', 'r').read().lower()
text = text.replace('\xa0', ' ')
text = text.replace('/', '')
text = text.replace('*', '')
#text = text[:600000]
len(text)

938278

In [3]:
char_to_idx = {char: idx for (idx, char) in enumerate(list(set(text)))}
idx_to_char = {idx: char for (char, idx) in char_to_idx.items()}
one_hot_size = len(char_to_idx)
print(f'Number of unique characters: {one_hot_size}')

Number of unique characters: 57


In [4]:
import torch 
import math
import torch.nn as nn
import torch.nn.functional as F

class CharRNN(nn.Module):
    def __init__(self, input_size:int, hidden_size:int, vocab_size:int):
        super(CharRNN, self).__init__()
        
        self.rnn = nn.GRU(input_size=input_size, hidden_size=hidden_size)
        self.fc = nn.Linear(in_features=hidden_size, out_features=vocab_size)

    def forward(self, input:torch.Tensor, h_init:torch.Tensor):
        _, h = self.rnn(input, h_init)
        out = self.fc(h.squeeze(0))

        return out, h

In [5]:
class OneHotDataset:
    def __init__(self, seq_len, chr_step, text, char_to_idx):
        sequences = []
        next_chars = []

        for i in range(0, len(text) - seq_len, chr_step):
            sequences.append(text[i:i+seq_len]) # input sequence
            next_chars.append(text[i+seq_len]) # char to predict
        
        self.x = torch.zeros(len(sequences), seq_len, one_hot_size) # shape (L, N ,D)
        self.y = torch.zeros(len(sequences), one_hot_size)

        for i, sentence in enumerate(sequences):
            for t, char in enumerate(sentence):
                self.x[i, t, char_to_idx[char]] = 1
            self.y[i, char_to_idx[next_chars[i]]] = 1

        self.x = self.x.to(device)
        self.y = self.y.to(device)

    def __getitem__(self, idx):
        return self.x[idx, :, :], self.y[idx]

    def __len__(self):
        return self.y.shape[0]

data = OneHotDataset(60, 3, text, char_to_idx)

In [11]:
epochs = 30
hidden_size = 256
batch_size = 512
lr = 0.001
verbose = True

In [12]:
char_rnn = CharRNN(hidden_size=hidden_size, vocab_size=len(char_to_idx), input_size=one_hot_size).to(device)
loader = torch.utils.data.DataLoader(data, batch_size=batch_size)
optim = torch.optim.Adam(char_rnn.parameters(), lr=lr)
loss = torch.nn.CrossEntropyLoss()

for epoch in range(epochs):
    run_loss = 0.0
    for x, y in loader:
        optim.zero_grad()
        out, h = char_rnn(x.transpose(0, 1), None)
        l = loss(out, y)
        l.backward()
        optim.step()

        run_loss += l.item()
    if verbose:
        print(f'Epoch {epoch}, loss: {run_loss/len(loader)}')

Epoch 0, loss: 2.3623299411392837
Epoch 1, loss: 1.8960485228149864
Epoch 2, loss: 1.6935520607203782
Epoch 3, loss: 1.565801881339078
Epoch 4, loss: 1.4762569958566645
Epoch 5, loss: 1.4096127317306844
Epoch 6, loss: 1.3573111243021663
Epoch 7, loss: 1.3142419793992113
Epoch 8, loss: 1.2772307146200377
Epoch 9, loss: 1.2443517622112643
Epoch 10, loss: 1.214603691171312
Epoch 11, loss: 1.186813774163516
Epoch 12, loss: 1.160647168389709
Epoch 13, loss: 1.135663877532447
Epoch 14, loss: 1.1115712616135367
Epoch 15, loss: 1.088163818548237
Epoch 16, loss: 1.0655169738685246
Epoch 17, loss: 1.0444429727505935
Epoch 18, loss: 1.025461225189671
Epoch 19, loss: 1.0107575607182742
Epoch 20, loss: 1.0022248140918837
Epoch 21, loss: 0.9953767347062669
Epoch 22, loss: 0.9902115587126799
Epoch 23, loss: 0.9823064461667487
Epoch 24, loss: 0.9736600941604951
Epoch 25, loss: 0.9646115824946014
Epoch 26, loss: 0.956083917773882
Epoch 27, loss: 0.9513289601635035
Epoch 28, loss: 0.9469939417612728
Epo

In [13]:
# output shape (batch_size, vocab_size)
# input shape (seq_len, batch_size, input_size)

def char_to_onehot(char, char_to_idx):
    x = torch.zeros(1, 1, one_hot_size).to(device)
    x[0, 0, char_to_idx[char]] = 1
    
    return x

def seq_to_onehot(seq, char_to_idx):
    x = torch.zeros(len(seq), 1, one_hot_size)
    for t, char in enumerate(seq):
        x[t, 0, char_to_idx[char]] = 1

    return x.to(device)

In [39]:
def generate_text(model, length, start, char_to_idx, idx_to_char): # generate text from nothing
    with torch.no_grad():
        for temperature in [0.2, 0.5, 1.2]:
            temperature = 0.1
            generated = start # start from space
            input = seq_to_onehot(generated, char_to_idx)
            out, h = model(input, None)
            p = F.softmax(out / temperature, dim=1)
            idx = torch.distributions.Categorical(p[0]).sample().item()
            generated += idx_to_char[idx]
            last_char = generated[-1]

            for _ in range(length):
                out, h = model(char_to_onehot(last_char, char_to_idx), h)
                out = torch.divide(out, temperature)
                p = F.softmax(out, dim=1)
                idx = torch.distributions.Categorical(p[0]).sample().item()
                generated += idx_to_char[idx]
                last_char = generated[-1]
            print(len(generated))
            print(f'Temperature: {temperature}\n{generated}\n')

In [None]:
torch.save(char_rnn.state_dict(), 'char_rnn.pt')

In [41]:
start = text[:60]
print(start)
generate_text(char_rnn, 200, start, char_to_idx, idx_to_char)

when mr. bilbo baggins of bag end announced that he would sh
261
Temperature: 0.1
when mr. bilbo baggins of bag end announced that he would shotely up the gates of allownsed to the eastern soon. "we cannot answers and purpuis to see for a while to the eastern or the dwarves how the darkness. "or it will be there is not the land of aragorn. "

261
Temperature: 0.1
when mr. bilbo baggins of bag end announced that he would shotely up the stars were stone. it was so unouther to the shire that he was seen and come. and you mean to see what it was in the dark lord of great since to find the days. i am afready sail, as i was t

261
Temperature: 0.1
when mr. bilbo baggins of bag end announced that he would shotely up the stars were read out of the eastern shore. and we will not see no to the hollow between them and stone. it was standing stone. "then i cannot remome beast a count of mordor. "but i have see

