In [1]:
# Weird deprecation issues in torchtext means this cell needs to be run twice
import torch
import numpy as np
from torch import nn
from torchtext.datasets import IMDB
from torchtext.vocab import vocab
from torch.utils.data import Dataset, DataLoader, random_split
import matplotlib.pyplot as plt
import re
from collections import Counter, OrderedDict

torch.manual_seed(1)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(device)



cuda


In [2]:
with open('./1268-0.txt','r',encoding='utf-8') as f:
    text = f.read()
start_index = text.find('THE MYSTERIOUS ISLAND')
end_index = text.find('END OF THE PROJECT GUTENBERG')
text = text[start_index:end_index]
char_set = set(text)

print(f'Total length: {len(text)}')
print(f'Unique characters: {len(char_set)}')

Total length: 1112300
Unique characters: 80


In [3]:
chars_sorted = sorted(char_set)
char2int = {ch:i for i,ch in enumerate(chars_sorted)}
char_array = np.array(chars_sorted)
text_encoded = np.array([char2int[ch] for ch in text],dtype=np.int32)
print(f'Text encoded shape: {text_encoded.shape}')
print(f'{text[:15]} ==> {text_encoded[:15]}')
print(f'{text_encoded[15:21]} ==> {"".join(char_array[text_encoded[15:21]])}')

Text encoded shape: (1112300,)
THE MYSTERIOUS  ==> [44 32 29  1 37 48 43 44 29 42 33 39 45 43  1]
[33 43 36 25 38 28] ==> ISLAND


In [4]:
batch_size = 64
seq_length = 40
chunk_size = seq_length + 1
text_chunks = [text_encoded[i:i+chunk_size] for i in range(len(text_encoded)-chunk_size+1)]

class TextDataset(Dataset):
    def __init__(self,text_chunks):
        self.text_chunks = text_chunks
    
    def __len__(self):
        return len(self.text_chunks)
    
    def __getitem__(self,idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(),text_chunk[1:].long()

seq_dataset = TextDataset(torch.tensor(text_chunks).to(device))
seq_dl = DataLoader(seq_dataset,batch_size=batch_size,shuffle=True,drop_last=True)

  seq_dataset = TextDataset(torch.tensor(text_chunks).to(device))


In [5]:
for i,(seq,target) in enumerate(seq_dataset):
    print(f'Input: {repr("".join(char_array[seq.cpu()]))}')
    print(f'Target: {repr("".join(char_array[target.cpu()]))}')
    if i == 1: break

Input: 'THE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTER'
Target: 'HE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERI'
Input: 'HE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERI'
Target: 'E MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERIO'


In [6]:
vocab_size = len(char_array)
embed_dim = 256
rnn_hidden_size = 512

class RNN(nn.Module):
    def __init__(self,vocab_size,emebed_dim,rnn_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size,emebed_dim)
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn = nn.LSTM(emebed_dim,rnn_hidden_size,batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size,vocab_size)

    def forward(self,x,hidden,cell):
        out = self.embedding(x).unsqueeze(1)
        out,(hidden,cell) = self.rnn(out,(hidden,cell))
        out = self.fc(out).reshape(out.size(0),-1)
        return out,hidden,cell
    
    def init_hidden(self,batch_size):
        hidden = torch.zeros(1,batch_size,self.rnn_hidden_size).to(device)
        cell = torch.zeros(1,batch_size,self.rnn_hidden_size).to(device)
        return hidden,cell

model = RNN(vocab_size,embed_dim,rnn_hidden_size).to(device)

In [7]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(),lr=5e-3)
num_epochs = 10000

for epoch in range(num_epochs):
    hidden,cell = model.init_hidden(batch_size)
    seq_batch,target_batch = next(iter(seq_dl))
    optimizer.zero_grad()
    loss = 0
    for c in range(seq_length):
        preds,hidden,cell = model(seq_batch[:,c],hidden,cell)
        loss += loss_fn(preds,target_batch[:,c])
    loss.backward()
    optimizer.step()
    loss = loss.item()/seq_length
    if epoch % 500 == 0:
        print(f'Epoch: {epoch} loss {loss:.4f}')

Epoch: 0 loss 4.3723
Epoch: 500 loss 1.3515
Epoch: 1000 loss 1.2633
Epoch: 1500 loss 1.2541
Epoch: 2000 loss 1.2370
Epoch: 2500 loss 1.1626
Epoch: 3000 loss 1.1691
Epoch: 3500 loss 1.1978
Epoch: 4000 loss 1.1997
Epoch: 4500 loss 1.1420
Epoch: 5000 loss 1.1020
Epoch: 5500 loss 1.1045
Epoch: 6000 loss 1.1528
Epoch: 6500 loss 1.1148
Epoch: 7000 loss 1.1598
Epoch: 7500 loss 1.0989
Epoch: 8000 loss 1.1595
Epoch: 8500 loss 1.1759
Epoch: 9000 loss 1.1153
Epoch: 9500 loss 1.1612


In [8]:
from torch.distributions.categorical import Categorical

def sample(model,starting_str,len_generated_text=500,scale_factor=1.0):
    encoded_input = torch.tensor([char2int[s] for s in starting_str]).to(device)
    encoded_input = torch.reshape(encoded_input,(1,-1))
    generated_str = starting_str

    model.eval()
    hidden,cell = model.init_hidden(1)
    for c in range(len(starting_str)-1):
        _,hidden,cell = model(encoded_input[:,c].view(1),hidden,cell)
    
    last_char = encoded_input[:,-1]
    for i in range(len_generated_text):
        logits,hidden,cell = model(last_char.view(1),hidden,cell)
        logits = torch.squeeze(logits,0)
        scaled_logits = logits * scale_factor
        m = Categorical(logits=scaled_logits)
        last_char = m.sample()
        generated_str += str(char_array[last_char])
    
    return generated_str

In [10]:
print(sample(model,starting_str='The island',scale_factor=1.5))

The island, Herbert, document stowed it on fire, whether the difficient of the island,
exclaimed with extreme mass of the first steps in the cave.

The waters of the island, only eighthe again again at the corral. Pencroft and Gideon Spilett was very extent to make the Chimneys, and some day pass if they were allowed by his island of March, and since he was not been in some day, and the colonists had already five quadrupeds were carefully exposed towards the same time to land was produced by the corral,
w
