In [1]:
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader


In [2]:
# Load and preprocess the data
with open('data/1268-0.txt', 'r', encoding="utf8") as file:
    corpus = file.read()
    
start_index = corpus.find('THE MYSTERIOUS ISLAND')
end_index = corpus.find('End of the Project Gutenberg')

corpus = corpus[start_index:end_index]
char_set = set(corpus)
print('Total Length:', len(corpus))
print('Unique Characters:', len(char_set))
assert(len(corpus) == 1130711)
assert(len(char_set) == 85)

Total Length: 1130711
Unique Characters: 85


In [3]:
characters = sorted(char_set)

char_to_idx = {ch: i for i, ch in enumerate(characters)}
idx_to_char = np.array(characters)

# Tokenize the entire corpus.
encoded_corpus = np.array(
    [char_to_idx[ch] for ch in corpus],
    dtype=np.int32)

print('Text encoded shape: ', encoded_corpus.shape)
print(corpus[:15], '     == Encoding ==> ', encoded_corpus[:15])
print(encoded_corpus[15:21], ' == Reverse  ==> ', ''.join(idx_to_char[encoded_corpus[15:21]]))

Text encoded shape:  (1130711,)
THE MYSTERIOUS       == Encoding ==>  [48 36 33  1 41 53 47 48 33 46 37 43 49 47  1]
[37 47 40 29 42 32]  == Reverse  ==>  ISLAND


In [4]:
# Process the data and get the data loader

seq_length = 40
chunk_size = seq_length + 1

text_chunks = [encoded_corpus[i:i+chunk_size] for i in range(len(encoded_corpus)-chunk_size+1)]

In [5]:
class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)
    
    def __getitem__(self, idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()

In [6]:
dataset = TextDataset(torch.tensor(text_chunks))

  dataset = TextDataset(torch.tensor(text_chunks))


In [7]:
device = torch.device("cpu")
batch_size = 64
torch.manual_seed(1)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [8]:
# Define the models

class RNNModel(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn = nn.LSTM(
            embed_dim,
            rnn_hidden_size,
            batch_first=True
        )
        self.fc = nn.Linear(rnn_hidden_size, vocab_size)

    def forward(self, text, hidden, cell):
        embedded_text = self.embedding(text)
        output, (hidden, cell) = self.rnn(embedded_text, (hidden, cell))
        output = self.fc(output)
        return output, (hidden, cell)

    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden.to(device), cell.to(device)

In [9]:
vocab_size = len(idx_to_char)
embed_dim = 256
rnn_hidden_size = 512

torch.manual_seed(1)
model = RNNModel(vocab_size, embed_dim, rnn_hidden_size)
model = model.to(device)

model

RNNModel(
  (embedding): Embedding(85, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=85, bias=True)
)

In [12]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

num_epochs = 2000 

In [13]:
for epoch in range(num_epochs):
    hidden, cell = model.init_hidden(batch_size)
    
    seq_batch, target_batch = next(iter(dataloader))
        
    seq_batch = seq_batch.to(device)
    target_batch = target_batch.to(device)
    
    optimizer.zero_grad()
    
    loss = 0
    
    hidden, cell = model.init_hidden(batch_size)

    logits, _ = model(seq_batch, hidden, cell)
    
    loss += criterion(logits.view(logits.size(0) * logits.size(1), -1), target_batch.view(-1))
    
    loss.backward()
    
    optimizer.step()
    
    loss = loss.item()
    
    if epoch % 100 == 0:
        print(f'Epoch {epoch} loss: {loss:.4f}')

Epoch 0 loss: 1.2037
Epoch 100 loss: 1.2078
Epoch 200 loss: 1.1990
Epoch 300 loss: 1.2514
Epoch 400 loss: 1.2090
Epoch 500 loss: 1.2288
Epoch 600 loss: 1.1358
Epoch 700 loss: 1.1678
Epoch 800 loss: 1.2074
Epoch 900 loss: 1.1567
Epoch 1000 loss: 1.1091
Epoch 1100 loss: 1.1502
Epoch 1200 loss: 1.1952
Epoch 1300 loss: 1.1738
Epoch 1400 loss: 1.1984
Epoch 1500 loss: 1.1307
Epoch 1600 loss: 1.1628
Epoch 1700 loss: 1.1798
Epoch 1800 loss: 1.2049
Epoch 1900 loss: 1.1969


### Random Decoding

In [14]:
# Random Decoding
from torch.distributions.categorical import Categorical

def random_sample(model, starting_str, len_generated_text=500, temperature=1.0):

    encoded_input = torch.tensor([char_to_idx[s] for s in starting_str])
    encoded_input = torch.reshape(encoded_input, (1, -1))

    generated_str = starting_str
    model.eval()
    
    hidden, cell = model.init_hidden(1)
    hidden = hidden.to(device)
    cell = cell.to(device)
        
    for c in range(len(starting_str)-1):
        out = encoded_input[:, c].reshape(1, 1)
        _, (hidden, cell) = model(out, hidden, cell)
    
    last_char = encoded_input[:, -1]
    for i in range(len_generated_text):
        logits, (hidden, cell) = model(last_char.reshape(1, 1), hidden, cell)
        logits = torch.squeeze(logits, 0)
        
        logits /= temperature
        m = Categorical(logits=logits)
        last_char = m.sample()
        generated_str += str(idx_to_char[last_char])
        
    return generated_str


In [15]:
torch.manual_seed(1)
model.to(device)
print(random_sample(model, starting_str='The island'))

The island?
Superb, which had crossed one open.
House was a gairs on the river, and you.

“There is that which was so as still retreat.

This was to leave nothing. It was a swell where the reporter could see this unfortunate for
be recognized at possessioners, middles almost savory, and round the channel one. But the Serpending by
viterary thrown with earthquestion; the boat, Herbert had kept succeeded, fermed his line. Ayrton
did not mode a leaf anything hust to the summit of convictions. It was there, a


### Beam search

In [16]:
def init_hidden(model, batch_size):
    return (
        torch.zeros(1, batch_size, model.rnn_hidden_size),
        torch.zeros(1, batch_size, model.rnn_hidden_size)
    )

In [17]:
def beam_search_decoding(model, starting_str, len_generated_text=500, beams=5, print_paths=True):
    assert(len(starting_str) != 0)

    encoded_input = torch.tensor([char_to_idx[s] for s in starting_str])
    encoded_input = torch.reshape(encoded_input, (1, -1))

    model.eval()
    
    hidden, cell = init_hidden(model, 1)
    hidden = hidden.to(device)
    cell = cell.to(device)
    
    generated_log_prob = 0
    generated_str = starting_str[0]
        
    for i in range(len(starting_str)-1):
        out = encoded_input[:, i].reshape(1, 1)
        logits, (hidden, cell) = model(out, hidden, cell)
        
        generated_str += starting_str[i+1]
        probs = nn.Softmax(dim=1)(logits.squeeze(1)).squeeze()
        generated_log_prob += np.log(probs[char_to_idx[generated_str[i+1]]].item())
        
    last_char_int = encoded_input[:, -1].reshape(1,1)
    
    logits, (hidden, cell) = model(last_char_int, hidden, cell)
                        
    probs = nn.Softmax(dim=1)(logits.squeeze(1)).squeeze()
    
    new_beams = []
    
    for j, prob in enumerate(probs):
        new_beams.append(
            (
                hidden,
                cell,
                generated_str + idx_to_char[j],
                generated_log_prob + np.log(prob.item())
            )
        )
        
    new_beams = sorted(new_beams, key = lambda beam_data: -beam_data[-1])
    
    beam_to_beam_data = {}
    
    for beam in range(beams):
        beam_to_beam_data[beam] = new_beams[beam]
    
    print('The number of beams is', len(beam_to_beam_data))
        
    for i in range(len_generated_text):
        new_beams = []
        
        for beam in range(beams):
            
            (hidden, cell, generated_str, generated_log_prob) = beam_to_beam_data[beam]
                        
            last_char_int = torch.tensor(char_to_idx[generated_str[-1]]).reshape(1, 1)
            
            logits, (hidden, cell) = model(last_char_int, hidden, cell)
            
            probs = nn.Softmax(dim=1)(logits.squeeze(1)).squeeze()
                                                
            for j, prob in enumerate(probs):
                new_beams.append(
                    (
                        hidden,
                        cell,
                        generated_str + idx_to_char[j],
                        generated_log_prob + np.log(prob.item())
                    )
                )
        
        new_beams = sorted(new_beams, key = lambda beam_data: -beam_data[-1])
                
        assert(len(new_beams) == beams * len(char_to_idx))
        
        if print_paths:
            print("The first 5 paths beam paths and the associated data for them: ")
            for beam in range(5):
                generated_str, generated_log_prob = new_beams[beam][2:]
                print("Text: \"{}\" Prob {:0.30f}".format(
                        generated_str, np.exp(generated_log_prob)
                ))
            _ = input("Insert anything to continue ...")
            print("\n")
                
        for beam in range(beams):
            beam_to_beam_data[beam] = new_beams[beam]
            
    generated_strs = []
    generated_log_probs = []
        
    for beam in range(beams):
        (_, _, generated_str, generated_log_prob) = beam_to_beam_data[beam]
        generated_strs.append(generated_str)
        generated_log_probs.append(generated_log_prob)        
                
    return generated_strs, [np.exp(_) for _ in generated_log_probs]


In [18]:
model.to(device)
beams = 5
len_generated_text = 500

generated_strs, generated_probs = beam_search_decoding(
    model,
    starting_str="The island",
    len_generated_text=len_generated_text,
    beams=beams
)

for beam in range(beams):
    print(f"Beam {beam} information: ")
    print(generated_strs[beam])
    print(generated_probs[beam])

The number of beams is 5
The first 5 paths beam paths and the associated data for them: 
Text: "The island, " Prob 0.001224591919526377701080144256
Text: "The island w" Prob 0.000742109791642601751733565596
Text: "The island?”" Prob 0.000375831776073952416889617512
Text: "The island i" Prob 0.000372457092760443994692437508
Text: "The island o" Prob 0.000252795300040132653405372531




The first 5 paths beam paths and the associated data for them: 
Text: "The island wa" Prob 0.000478577996922212583075922909
Text: "The island, a" Prob 0.000441835097894723862702731632
Text: "The island is" Prob 0.000336136278199693867493880184
Text: "The island?” " Prob 0.000189226141215010463473428226
Text: "The island?”
" Prob 0.000186338308712591193107635523


The first 5 paths beam paths and the associated data for them: 
Text: "The island was" Prob 0.000475132975728690461637154785
Text: "The island, an" Prob 0.000356716189444643489488834254
Text: "The island is " Prob 0.000289652316871872813914406963
Text: "The island?”

" Prob 0.000169878007315161648656254290
Text: "The island?” a" Prob 0.000160107018773822710154541848


The first 5 paths beam paths and the associated data for them: 
Text: "The island was " Prob 0.000411057288850606998738268505
Text: "The island, and" Prob 0.000348683172705367011318805526
Text: "The island?” as" Prob 0.000152425931668267685727136129
Text: "The 