In [254]:
import torch
import torch.nn as nn
import torch.optim as optim
import random

In [265]:
# text = open("onegin.txt", "r").read()
text = "hello world" * 10
chars = sorted(set(text))
vocab_size = len(chars)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

print(vocab_size, len(text))

8 110


In [266]:
def encode(s):
    return [stoi[s] for c in s]

def decode(indices):
    return "".join([itos[i] for i in indices])

In [267]:
class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.Wx = nn.Linear(input_size, hidden_size)
        self.Wh = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Wy = nn.Linear(hidden_size, output_size)

    def forward(self, inputs, h_prev):
        hs = []
        h = h_prev
        for x in inputs:
            h = torch.relu(self.Wx(x) + self.Wh(h))
            hs.append(h)
        logits = self.Wy(hs[-1])
        return logits, h

In [268]:
# HyperParameters
input_size = vocab_size
hidden_size = 32
output_size = vocab_size
seq_len = 7  # how many characters to read before predicting the next

In [269]:
model = VanillaRNN(input_size, hidden_size, output_size)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def one_hot(index, vocab_size):
    vec = torch.zeros(1, vocab_size)
    vec[0, index] = 1.0
    return vec

In [270]:
def generate(start_text=" ", length=20):
    model.eval()
    chars = [stoi[start_text]]
    h = torch.zeros(1, hidden_size)
    for _ in range(length):
        x = one_hot(chars[-1], vocab_size)
        logits, h = model([x], h)
        probs = torch.softmax(logits, dim=-1)
        idx = torch.multinomial(probs, num_samples=1).item()
        chars.append(idx)
    return decode(chars)

In [271]:
n_epochs = 300
for epoch in range(n_epochs):
    total_loss = 0

    for i in range(len(text)-seq_len):
        seq_in = text[i:i+seq_len]
        seq_out = text[i+seq_len]

        x_seq = [one_hot(stoi[ch], vocab_size) for ch in seq_in]
        y_target = torch.tensor([stoi[seq_out]])

        h_prev = torch.zeros(1, hidden_size)
        logits, h = model(x_seq, h_prev)

        loss = loss_fn(logits, y_target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    if (epoch+1) % 50 == 0:
        print(f"Epoch: {epoch+1} | Loss: {total_loss:.4f}")
        # print(generate(random.choice(chars)))
        print(generate("h", 11))

Epoch: 50 | Loss: 0.0006
hodo wooldhe
Epoch: 100 | Loss: 0.0000
hword worldh
Epoch: 150 | Loss: 99.6665
hlll w rrdh 
Epoch: 200 | Loss: 0.0019
hddlorl  wor
Epoch: 250 | Loss: 0.0001
hlolowworldh
Epoch: 300 | Loss: 0.0000
hhlllo world


In [307]:
print(generate("w", 4))

world


In [289]:
print(generate("h", 4))

hwhll


In [330]:
i_s = []
for _ in range(100):
    i = 1
    while True:
        pred = generate("h", 10)
        if pred == "hello world":
            i_s.append(float(i))
            break
        i += 1

print(torch.tensor(i_s).mean().item())

6.28000020980835


Alright! We've had enough fun with simple RNN.
Time to scale up, optimize and learn more complex data!

In [331]:
text = open("onegin.txt", "r").read()
chars = sorted(set(text))
vocab_size = len(chars)

stoi = {ch: i for i, ch in enumerate(chars)}
itos = {i: ch for ch, i in stoi.items()}

print(vocab_size, len(text))

145 165519


In [332]:
def encode(s):
    return [stoi[s] for c in s]

def decode(indices):
    return "".join([itos[i] for i in indices])

In [333]:
class VanillaRNN(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super().__init__()
        self.Wx = nn.Linear(input_size, hidden_size)
        self.Wh = nn.Linear(hidden_size, hidden_size, bias=False)
        self.Wy = nn.Linear(hidden_size, output_size)

    def forward(self, inputs, h_prev):
        hs = []
        h = h_prev
        for x in inputs:
            h = torch.relu(self.Wx(x) + self.Wh(h))
            hs.append(h)
        logits = self.Wy(hs[-1])
        return logits, h

In [334]:
# HyperParameters
input_size = vocab_size
hidden_size = 32
output_size = vocab_size
seq_len = 7  # how many characters to read before predicting the next

model = VanillaRNN(input_size, hidden_size, output_size)
loss_fn = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)

def one_hot(index, vocab_size):
    vec = torch.zeros(1, vocab_size)
    vec[0, index] = 1.0
    return vec

In [335]:
def generate(start_text=" ", length=20):
    model.eval()
    chars = [stoi[start_text]]
    h = torch.zeros(1, hidden_size)
    for _ in range(length):
        x = one_hot(chars[-1], vocab_size)
        logits, h = model([x], h)
        probs = torch.softmax(logits, dim=-1)
        idx = torch.multinomial(probs, num_samples=1).item()
        chars.append(idx)
    return decode(chars)

In [337]:
n_epochs = 5
for epoch in range(n_epochs):
    total_loss = 0

    for i in range(len(text)-seq_len):
        seq_in = text[i:i+seq_len]
        seq_out = text[i+seq_len]

        x_seq = [one_hot(stoi[ch], vocab_size) for ch in seq_in]
        y_target = torch.tensor([stoi[seq_out]])

        h_prev = torch.zeros(1, hidden_size)
        logits, h = model(x_seq, h_prev)

        loss = loss_fn(logits, y_target)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    if (epoch+1) % 1 == 0:
        print(f"Epoch: {epoch+1} | Loss: {total_loss:.4f}")
        print(generate(random.choice(chars)))

KeyboardInterrupt: 