In [5]:
from tqdm import tqdm
import torch.nn as nn
import numpy as np
import torch
from torch.utils.data import Dataset
from torch.utils.data import DataLoader

### Get the data and process
- This is the Mysterious island found in Project Gutenberg.

In [6]:
## Reading and processing text
from google.colab import drive
drive.mount('/content/drive')

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).


In [7]:
with open('/content/drive/MyDrive/Data/1268-0.txt', 'r', encoding="utf8") as fp:
    text=fp.read()

start_indx = text.find('THE MYSTERIOUS ISLAND')
end_indx = text.find('END OF THE PROJECT GUTENBERG')

text = text[start_indx:end_indx]
char_set = set(text)
print('Total Length:', len(text))
print('Unique Characters:', len(char_set))
assert(len(text) == 1112300)
assert(len(char_set) == 80)
assert(start_indx == 766)
assert(end_indx == 1113066)

Total Length: 1112300
Unique Characters: 80


### Tokenze and get other helpers
- We do this manually since everything is character based.

In [8]:
# The universe of words.
chars_sorted = sorted(char_set)

# Effectively, these maps are the tokenizer.
char2int = {ch:i for i,ch in enumerate(chars_sorted)}
int2char = np.array(chars_sorted)

# Tokenize the entire corpus.
text_encoded = np.array(
    [char2int[ch] for ch in text],
    dtype=np.int32)

print('Text encoded shape: ', text_encoded.shape)

print(text[:15], '     == Encoding ==> ', text_encoded[:15])
print(text_encoded[15:21], ' == Reverse  ==> ', ''.join(int2char[text_encoded[15:21]]))

Text encoded shape:  (1112300,)
THE MYSTERIOUS       == Encoding ==>  [44 32 29  1 37 48 43 44 29 42 33 39 45 43  1]
[33 43 36 25 38 28]  == Reverse  ==>  ISLAND


#### Examples

In [9]:
print('Text encoded shape: ', text_encoded.shape)
print(text[:15], '     == Encoding ==> ', text_encoded[:15])
print(text_encoded[15:21], ' == Reverse  ==> ', ''.join(int2char[text_encoded[15:21]]))

Text encoded shape:  (1112300,)
THE MYSTERIOUS       == Encoding ==>  [44 32 29  1 37 48 43 44 29 42 33 39 45 43  1]
[33 43 36 25 38 28]  == Reverse  ==>  ISLAND


### Process the data and get the data loader

In [11]:
seq_length = 40
chunk_size = seq_length + 1

text_chunks = [text_encoded[i:i+chunk_size]
               for i in range(len(text_encoded)-chunk_size+1)]

In [12]:
class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)

    def __getitem__(self, idx):
        text_chunk = self.text_chunks[idx]
        return text_chunk[:-1].long(), text_chunk[1:].long()

seq_dataset = TextDataset(torch.tensor(text_chunks))

  seq_dataset = TextDataset(torch.tensor(text_chunks))


In [13]:
for i, (seq, target) in enumerate(seq_dataset):
    # 40 characters for source and target ...
    print(seq.shape, target.shape)
    print('Input (x):', repr(''.join(int2char[seq])))
    print('Target (y):', repr(''.join(int2char[target])))
    print()
    if i == 1:
        break

torch.Size([40]) torch.Size([40])
Input (x): 'THE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTER'
Target (y): 'HE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERI'

torch.Size([40]) torch.Size([40])
Input (x): 'HE MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERI'
Target (y): 'E MYSTERIOUS ISLAND ***\n\n\n\n\nTHE MYSTERIO'



In [14]:
device = torch.device("cpu")

In [15]:
batch_size = 64
torch.manual_seed(1)
seq_dl = DataLoader(seq_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

### Write the models

In [16]:
class RNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn = nn.LSTM(
            embed_dim,
            rnn_hidden_size,
            batch_first=True
        )
        self.fc = nn.Linear(rnn_hidden_size, vocab_size)

    def forward(self, text, hidden=None, cell=None):
        out = self.embedding(text)

        if hidden is not None:
            out, (hidden, cell) = self.rnn(out, (hidden, cell))
        else:
            out, (hidden, cell) = self.rnn(out)

        out = self.fc(out)

        return out, (hidden, cell)

    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden.to(device), cell.to(device)

### Do this right way - across all data all at once!

In [17]:
vocab_size = len(int2char)
embed_dim = 256
rnn_hidden_size = 512

torch.manual_seed(1)
model = RNN(vocab_size, embed_dim, rnn_hidden_size)
model = model.to(device)
model

RNN(
  (embedding): Embedding(80, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=80, bias=True)
)

In [None]:
criterion = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

num_epochs = 5000

torch.manual_seed(1)

for epoch in range(num_epochs):
    hidden, cell = model.init_hidden(batch_size)

    seq_batch, target_batch = next(iter(seq_dl))

    seq_batch = seq_batch.to(device)
    target_batch = target_batch.to(device)

    optimizer.zero_grad()

    loss = 0

    # Specify these. This is the same as the first RNN, but this one you can trace.
    hidden, cell = model.init_hidden(batch_size)

    logits, _ = model(seq_batch, hidden, cell)

    loss += criterion(logits.view(logits.size(0) * logits.size(1), -1), target_batch.view(-1))

    loss.backward()

    optimizer.step()

    loss = loss.item()

    if epoch % 10 == 0:
        print(f'Epoch {epoch} loss: {loss:.4f}')

Epoch 0 loss: 4.3720
Epoch 10 loss: 2.5053
Epoch 20 loss: 2.2534
Epoch 30 loss: 2.0950
Epoch 40 loss: 2.0144
Epoch 50 loss: 1.8707
Epoch 60 loss: 1.8163
Epoch 70 loss: 1.7993
Epoch 80 loss: 1.7736
Epoch 90 loss: 1.7415
Epoch 100 loss: 1.6978
Epoch 110 loss: 1.7108
Epoch 120 loss: 1.6233
Epoch 130 loss: 1.6114
Epoch 140 loss: 1.5823
Epoch 150 loss: 1.5776
Epoch 160 loss: 1.6101
Epoch 170 loss: 1.5327
Epoch 180 loss: 1.4949
Epoch 190 loss: 1.5807
Epoch 200 loss: 1.5527
Epoch 210 loss: 1.5177
Epoch 220 loss: 1.4924
Epoch 230 loss: 1.5106
Epoch 240 loss: 1.5058
Epoch 250 loss: 1.4681
Epoch 260 loss: 1.5263
Epoch 270 loss: 1.5178
Epoch 280 loss: 1.4628
Epoch 290 loss: 1.4473


In [None]:
from torch.distributions.categorical import Categorical

torch.manual_seed(1)

logits = torch.tensor([[-1.0, 1.0, 3.0]])

print('Probabilities:', nn.Softmax(dim=1)(logits).numpy())

m = Categorical(logits=logits)
samples = m.sample((10,))

print(samples.numpy())

### Random decoding.
- This compounds problems: once you make a mistake, you can't undo it.

In [None]:
def random_sample(
    model,
    starting_str,
    len_generated_text=500,
    T = 1.0
):

    encoded_input = torch.tensor([char2int[s] for s in starting_str])

    encoded_input = torch.reshape(encoded_input, (1, -1))

    generated_str = starting_str

    model.eval()

    hidden, cell = model.init_hidden(1)

    hidden = hidden.to(device)

    cell = cell.to(device)

    # Build up the starting hidden and cell states.
    # You can do this all in one go?
    for c in range(len(starting_str)-1):
        # Feed each letter 1 by 1 and then get the final hidden state.
        out = encoded_input[:, c].reshape(1, 1)
        _, (hidden, cell) = model(out, hidden, cell)

    last_char = encoded_input[:, -1]
    for i in range(len_generated_text):

        logits, (hidden, cell) = model(last_char.reshape(1, 1), hidden, cell)

        logits = torch.squeeze(logits, 0)

        # Use temperature scaling here. For the HW, just do T = 1
        m = Categorical(logits=logits / T)

        last_char = m.sample()

        generated_str += str(int2char[last_char])

    return generated_str

torch.manual_seed(1)
model.to(device)
print(random_sample(model, starting_str='The island'))

### Beam search algorithm.
- Good article: https://towardsdatascience.com/foundations-of-nlp-explained-visually-beam-search-how-it-works-1586b9849a24

In [None]:
def beam_search_sample(
    model,
    starting_str,
    len_generated_text=500,
    beams=5,
    print_paths=True
):

    encoded_input = torch.tensor([char2int[s] for s in starting_str])

    encoded_input = torch.reshape(encoded_input, (1, -1))

    generated_str = starting_str

    model.eval()

    hidden, cell = model.init_hidden(1)

    hidden = hidden.to(device)

    cell = cell.to(device)

    # Build up the starting hidden and cell states.
    # You can do this all in one go?
    for i in range(len(starting_str)-1):
        # Feed each letter 1 by 1 and then get the final hidden state.
        out = encoded_input[:, i].reshape(1, 1)
        _, (hidden, cell) = model(out, hidden, cell)

    beam_to_beam_data = {}
    for beam in range(beams):
        beam_to_beam_data[beam] = (hidden, cell, [char2int[generated_str[-1]]], generated_str, 0.0)

    for i in range(len_generated_text):
        new_beams = []

        for beam in range(beams):
            (hidden, cell, generated_char, generated_str, generated_log_prob) = beam_to_beam_data[beam]

            last_char_int = torch.tensor(generated_char[-1]).reshape(1, 1)

            logits, (hidden, cell) = model(last_char_int, hidden, cell)

            probs = nn.Softmax(dim=1)(logits.squeeze(1)).squeeze()

            for j, prob in enumerate(probs):
                new_beams.append(
                    (
                        hidden,
                        cell,
                        generated_char + [j],
                        generated_str + int2char[j],
                        generated_log_prob + np.log(prob.item())
                    )
                )

        new_beams_to_hidden_cell = {}

        for data in new_beams:
            hidden, cell, generated_char, generated_str, generated_log_prob = data
            new_beams_to_hidden_cell[
                (generated_str, generated_log_prob)
            ] = (hidden, cell, generated_char)

        new_beams = []

        for generated_str_generated_prob, hidden_cell_generated_char in new_beams_to_hidden_cell.items():
            generated_str, generated_log_prob = generated_str_generated_prob
            hidden, cell, generated_char = hidden_cell_generated_char
            new_beams.append(
                (hidden, cell, generated_char, generated_str, generated_log_prob)
            )

        # Sort the beams from most proable to least. Use -log(p).
        new_beams = sorted(new_beams, key = lambda beam_data: -beam_data[-1])

        # The number of beams considered should always satisfy this.
        # Except for the first iteration.
        print(
            "The number of beams is {}, the number of expected beams {} ".format(
                len(new_beams), beams * len(char2int))
        )

        if print_paths:
            print("The first 5 paths beam paths and the associated data for them: ")
            for beam in range(5):
                generated_char, generated_str, generated_log_prob = new_beams[beam][2:]
                print("Generated char indices: {} Generated Text: \"{}\" Generated Prob {:0.10f}".format(
                        generated_char[-7:], generated_str, np.exp(generated_log_prob)
                ))
            _ = input("Insert anything to continue ...")


        for beam in range(beams):
            beam_to_beam_data[beam] = new_beams[beam]

        if print_paths:
            print("Current beams considered: ")
            for beam, beam_data in beam_to_beam_data.items():
                print(beam, beam_data[-2])
            print("\n")

    generated_strs = []
    generated_chars = []
    generated_log_probs = []

    for beam in range(beams):
        (_, _, generated_char, generated_str, generated_log_prob) = beam_to_beam_data[beam]
        generated_strs.append(generated_str)
        generated_log_probs.append(generated_log_prob)
        generated_chars.append(generated_char)


    return generated_strs, generated_chars, [np.exp(_) for _ in generated_log_probs]

In [None]:
torch.manual_seed(1)
model.to('cpu')
beams = 2

generated_strs, generated_chars, generated_probs = beam_search_sample(model, starting_str="The island", beams=5)

for beam in range(beams):
    print(f"Beam {beam} information: ")
    print(generated_strs[beam])
    print(generated_probs[beam])
