In [524]:
import requests
import torch
from tqdm import tqdm_notebook
import matplotlib.pyplot as plt
%matplotlib inline

In [300]:
r = requests.get('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
data = r.content.decode("utf-8")

In [539]:
data = data[:1000]
print(data[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [540]:
chars = sorted(set(''.join(data)))
stoi = {s:i for i, s in enumerate(chars)}
itos = {i:s for s, i in stoi.items()}

vocab_size = len(stoi)

print(itos)
print(f'vocab size: {vocab_size}')

{0: '\n', 1: ' ', 2: '!', 3: "'", 4: ',', 5: '.', 6: ':', 7: ';', 8: '?', 9: 'A', 10: 'B', 11: 'C', 12: 'F', 13: 'I', 14: 'L', 15: 'M', 16: 'N', 17: 'O', 18: 'R', 19: 'S', 20: 'W', 21: 'Y', 22: 'a', 23: 'b', 24: 'c', 25: 'd', 26: 'e', 27: 'f', 28: 'g', 29: 'h', 30: 'i', 31: 'j', 32: 'k', 33: 'l', 34: 'm', 35: 'n', 36: 'o', 37: 'p', 38: 'r', 39: 's', 40: 't', 41: 'u', 42: 'v', 43: 'w', 44: 'y', 45: 'z'}
vocab size: 46


In [541]:
def build_dataset():
    X = [stoi[ch] for ch in data[:len(data)-1]]
    y = [stoi[ch] for ch in data[1:len(data)]]
    return X, y

In [542]:
def one_hot_encode(X):
    onehotx = torch.zeros((len(X), vocab_size))
    rows = torch.arange(len(X))
    onehotx[rows, X] = 1
    return onehotx

In [573]:
seq_length = 25
hidden_size = 256
lr = 1e-3
epochs = 1000

In [574]:
X, y = build_dataset()

In [575]:
Wxh = torch.randn(hidden_size, vocab_size)
Whh = torch.randn(hidden_size, hidden_size)
Why = torch.randn(vocab_size, hidden_size)
bh = torch.zeros(hidden_size, 1)
by = torch.zeros(vocab_size, 1)

parameters = [Wxh, Whh, Why, bh, by]
for param in parameters:
    param.requires_grad = True

In [592]:
lr = 1e-3

In [596]:
loss_epochs = []
for epoch in tqdm_notebook(range(epochs)):
    loss_i = []
    for i in range(len(X)):
        # forward pass
        loss = 0
        pi = i * seq_length
        xs, ys = X[pi:pi+seq_length], y[pi:pi+seq_length]
        xs = one_hot_encode(xs)
        hprev = torch.zeros(hidden_size, 1)
        if xs.nelement() == 0:
            break
        for xi, yi in zip(xs, ys):
            xi = xi.reshape(-1, 1)
            hprev = torch.tanh(Whh @ hprev + Wxh @ xi + bh)
            out = Why @ hprev + by

            z = out.exp()
            probs = z / z.sum()
            loss += -probs[yi].log()

        loss_i.append(loss.item())

        #backward pass
        for param in parameters:
            param.grad = None 

        loss.backward()

        # update step
        for param in parameters:
            param.data += -lr * param.grad
    loss_epoch = sum(loss_i) / len(loss_i)
    loss_epochs.append(loss_epoch)
    if epoch % 100 == 0:
        print(f'Epoch: {epoch} | Loss: {loss_epoch:.4f}')

Please use `tqdm.notebook.tqdm` instead of `tqdm.tqdm_notebook`
  for epoch in tqdm_notebook(range(epochs)):


  0%|          | 0/1000 [00:00<?, ?it/s]

Epoch: 0 | Loss: 74.9548
Epoch: 100 | Loss: 74.5317
Epoch: 200 | Loss: 74.2170
Epoch: 300 | Loss: 74.0321
Epoch: 400 | Loss: 73.9146
Epoch: 500 | Loss: 73.8340
Epoch: 600 | Loss: 73.7741
Epoch: 700 | Loss: 73.7261
Epoch: 800 | Loss: 73.6854
Epoch: 900 | Loss: 73.6500


In [599]:
# sample
n = 500
with torch.inference_mode():
    hprev = torch.zeros(hidden_size, 1)
    xi = torch.zeros(vocab_size, 1)
    
    for i in range(n):
        hprev = torch.tanh(Whh @ hprev + Wxh @ xi + bh)
        out = Why @ hprev + by
        z = out.exp()
        probs = z / z.sum()    

        ix = torch.multinomial(probs.reshape(1,-1), 1, replacement=True)

        xi = torch.zeros(vocab_size, 1)
        xi[ix] = 1
        print(itos[ix.item()], end='')

m r mee .rvoneeoihiegIrnsSien
etnrmrrnraurtar hr h,v oeFiuvsbe worpninbb etdukeuindbAeee n gr re tOortht Ss e,olftsn erybrf dngnds
ar mw abanui lrlit ogYrfgojm.twoevogta dhbnergin  elrFi
rthitW Csecw its te o . eIarn
eat loirsdk  ir oateea.sv ,nFirr nkehyrieor  o,npownterawn
oo
bt  dgsho
  bkehafghues
ntrwnrrhrrfhherorhieibrate ?fe,ieuw.soe e n?shs rsntnlrdSoee Fiefh
ve:  :rhosinib litiersah o imntiar  oa r:ihoa hri. trno  yn
tomt
sFi esnes
hc rrtm
arruauithv?
me diei  euked anihho
e
isegC?:o ro