In [15]:
import numpy as np

import torch
import torch.nn as nn

from torch.utils.data import Dataset
from torch.utils.data import DataLoader
from torch.distributions.categorical import Categorical

In [2]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


In [3]:
with open('drive/MyDrive/Colab Notebooks/1268-0.txt', 'r', encoding="utf8") as fp:
    text = fp.read()

start_index = text.find('THE MYSTERIOUS ISLAND')
end_index = text.find('End of the Project Gutenberg')
text = text[start_index: end_index]
char_set = set(text)

len(text), len(char_set)

(1112350, 80)

In [4]:
chars_sorted = sorted(char_set)
char2int = {ch: i for i, ch in enumerate(chars_sorted)}
char_array = np.array(chars_sorted)
text_encoded = np.array([char2int[ch] for ch in text], dtype=np.int32)

text_encoded.shape

(1112350,)

In [5]:
text[:15], text_encoded[:15]

('THE MYSTERIOUS ',
 array([44, 32, 29,  1, 37, 48, 43, 44, 29, 42, 33, 39, 45, 43,  1],
       dtype=int32))

In [6]:
text_encoded[15: 21], char_array[text_encoded[15: 21]]

(array([33, 43, 36, 25, 38, 28], dtype=int32),
 array(['I', 'S', 'L', 'A', 'N', 'D'], dtype='<U1'))

In [7]:
for ex in text_encoded[: 5]:
    print('{} -> {}'.format(ex, char_array[ex]))

44 -> T
32 -> H
29 -> E
1 ->  
37 -> M


In [8]:
seq_length = 40
chunk_size = seq_length + 1
text_chunks = [text_encoded[i: i + chunk_size] for i in range(len(text_encoded) - chunk_size + 1)]

In [9]:
class TextDataset(Dataset):
    def __init__(self, text_chunks):
        self.text_chunks = text_chunks

    def __len__(self):
        return len(self.text_chunks)

    def __getitem__(self, index):
        text_chunk = self.text_chunks[index]
        return text_chunk[:-1].long(), text_chunk[1:].long()

In [10]:
seq_dataset = TextDataset(torch.tensor(text_chunks))

for i, (seq, target) in enumerate(seq_dataset):
    print("input:  ", repr(''.join(char_array[seq])))
    print("target: ", repr(''.join(char_array[target])))
    if i == 2:
        break

input:   'THE MYSTERIOUS ISLAND ***\n\n\n\n\nProduced b'
target:  'HE MYSTERIOUS ISLAND ***\n\n\n\n\nProduced by'
input:   'HE MYSTERIOUS ISLAND ***\n\n\n\n\nProduced by'
target:  'E MYSTERIOUS ISLAND ***\n\n\n\n\nProduced by '
input:   'E MYSTERIOUS ISLAND ***\n\n\n\n\nProduced by '
target:  ' MYSTERIOUS ISLAND ***\n\n\n\n\nProduced by A'


  seq_dataset = TextDataset(torch.tensor(text_chunks))


In [11]:
batch_size = 64
torch.manual_seed(1)
seq_dl = DataLoader(seq_dataset, batch_size=batch_size, shuffle=True, drop_last=True)

In [12]:
class RNN(nn.Module):
    def __init__(self, vocab_size, embed_dim, rnn_hidden_size):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.rnn_hidden_size = rnn_hidden_size
        self.rnn = nn.LSTM(embed_dim, rnn_hidden_size, batch_first=True)
        self.fc = nn.Linear(rnn_hidden_size, vocab_size)

    def forward(self, x, hidden, cell):
        out = self.embedding(x).unsqueeze(1)
        out, (hidden, cell) = self.rnn(out, (hidden, cell))
        out = self.fc(out).reshape(out.size(0), -1)
        return out, hidden, cell

    def init_hidden(self, batch_size):
        hidden = torch.zeros(1, batch_size, self.rnn_hidden_size)
        cell = torch.zeros(1, batch_size, self.rnn_hidden_size)
        return hidden, cell

In [13]:
vocab_size = len(char_array)
embed_dim = 256
rnn_hidden_size = 512
torch.manual_seed(1)
model = RNN(vocab_size, embed_dim, rnn_hidden_size)

model

RNN(
  (embedding): Embedding(80, 256)
  (rnn): LSTM(256, 512, batch_first=True)
  (fc): Linear(in_features=512, out_features=80, bias=True)
)

In [14]:
loss_fn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.005)

num_epochs = 10000
for epoch in range(num_epochs):
    hidden, cell = model.init_hidden(batch_size)
    seq_batch, target_batch = next(iter(seq_dl))
    optimizer.zero_grad()
    loss = 0
    for c in range(seq_length):
        pred, hidden, cell = model(seq_batch[:, c], hidden, cell)
        loss += loss_fn(pred, target_batch[:, c])

    loss.backward()
    optimizer.step()
    loss = loss.item() / seq_length
    if epoch % 500 == 0:
        print(f'Epoch {epoch} loss: {loss}')

Epoch 0 loss: 4.370949172973633
Epoch 500 loss: 1.3382611274719238
Epoch 1000 loss: 1.384279727935791
Epoch 1500 loss: 1.2310357093811035
Epoch 2000 loss: 1.1757749557495116
Epoch 2500 loss: 1.1572070121765137
Epoch 3000 loss: 1.1840678215026856
Epoch 3500 loss: 1.1450529098510742
Epoch 4000 loss: 1.1246798515319825
Epoch 4500 loss: 1.1645922660827637
Epoch 5000 loss: 1.119003677368164
Epoch 5500 loss: 1.1208457946777344
Epoch 6000 loss: 1.1137376785278321
Epoch 6500 loss: 1.1281139373779296
Epoch 7000 loss: 1.1290274620056153
Epoch 7500 loss: 1.1984622955322266
Epoch 8000 loss: 1.171825885772705
Epoch 8500 loss: 1.146688461303711
Epoch 9000 loss: 1.0994134902954102
Epoch 9500 loss: 1.11591796875


In [22]:
torch.manual_seed(1)
logits = torch.tensor([[1.0, 1.0, 3.0]])
nn.functional.softmax(logits, dim=1).numpy()[0]

array([0.10650698, 0.10650698, 0.78698605], dtype=float32)

In [19]:
m = Categorical(logits=logits)
samples = m.sample((10,))
samples.numpy()

array([[0],
       [2],
       [2],
       [1],
       [2],
       [1],
       [2],
       [2],
       [2],
       [2]])

In [23]:
nn.functional.softmax(logits, dim=1).numpy()[0]

array([0.10650698, 0.10650698, 0.78698605], dtype=float32)

In [24]:
nn.functional.softmax(logits * 0.5, dim=1).numpy()[0]

array([0.21194156, 0.21194156, 0.57611686], dtype=float32)

In [25]:
nn.functional.softmax(logits * 2.0, dim=1).numpy()[0]

array([0.01766842, 0.01766842, 0.9646632 ], dtype=float32)

In [20]:
from math import log
def sample(model, starting_str, len_generated_text=500, scale_factor=1.0):
    encoded_input = torch.tensor([char2int[s] for s in starting_str])
    encoded_input = torch.reshape(encoded_input, (1, -1))
    generated_str = starting_str

    model.eval()
    hidden, cell = model.init_hidden(1)
    for c in range(len(starting_str) - 1):
        _, hidden, cell = model(encoded_input[:, c].view(1), hidden, cell)

    last_char = encoded_input[:, -1]
    for i in range(len_generated_text):
        logits, hidden, cell = model(last_char.view(1), hidden, cell)
        logits = torch.squeeze(logits, 0)
        scaled_logits = logits * scale_factor
        m = Categorical(logits=scaled_logits)
        last_char = m.sample()
        generated_str += str(char_array[last_char])

    return generated_str

In [21]:
torch.manual_seed(1)
sample(model, starting_str='The island')

'The island was brought to hitten than this unperforate with all those which the beds, with\nrudictness, and there is enough to keep the arross the year of Capritally. Be yet reach Gideon!--to repast his delay or existence, and which was deed in his mastern part. The clothomes of the all the air of Harding; “you are comfort uncomplires of our desert time, for the palisade was lose, he thinks, that he were obliged to\nresult.\n\nOf a distance this Prospect Harding were carried and fifty men.\n\nThey were enclosed'

In [26]:
sample(model, starting_str='The island', scale_factor=2.0)

'The island was not more servant. The colonists were in the Chimneys, and\nthere was not in the interior of the rocks, which was the colonists were struck up and several known the mouth of the car which had already of the corral. He would be united that the engineer.\n\n“The seam, and only the animals had been contained by the rocks of the colonists the one who had not prevented. Like to reside that the body, which was not an immense wood to the Pomoutous work will be seen. They were greatly decided to return'

In [27]:
sample(model, starting_str='The island', scale_factor=0.5)

'The islandty time.\n\n“What,” an keartn paralmelered invided?” asome, Top if leftiply freed Equary.\nOa penoupes!”\n\nIt isnalful,” approently violentogn issue. Re. Nhice captern, violynibred roaman! y heard-gliflig cangollypressed bstated! /\ntamb\nseasEent attramapidl. 6ust, on fifs; it gravite ismale?d.\nDuring watchlyzar, posteds; you would, rubging; he kindo, whojoudned norfezabous giarles of, twelve sfelyed.”\n\nHersporning\nexty minaglen, blocsitiouslys, leapety Here told us to so, my onatco. He ramures onlig'