In [64]:
import requests
import random
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

# Tiny Shakespeare

In [4]:
r = requests.get('https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt')
text = r.content.decode("utf-8")

In [7]:
print(f'length of dataset in characters: {len(text)}')

length of dataset in characters: 1115394


In [6]:
print(text[:1000])

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You are all resolved rather to die than to famish?

All:
Resolved. resolved.

First Citizen:
First, you know Caius Marcius is chief enemy to the people.

All:
We know't, we know't.

First Citizen:
Let us kill him, and we'll have corn at our own price.
Is't a verdict?

All:
No more talking on't; let it be done: away, away!

Second Citizen:
One word, good citizens.

First Citizen:
We are accounted poor citizens, the patricians good.
What authority surfeits on would relieve us: if they
would yield us but the superfluity, while it were
wholesome, we might guess they relieved us humanely;
but they think we are too dear: the leanness that
afflicts us, the object of our misery, is as an
inventory to particularise their abundance; our
sufferance is a gain to them Let us revenge this with
our pikes, ere we become rakes: for the gods know I
speak this in hunger for bread, not in thirst for revenge.



In [26]:
chars = sorted(set(''.join(text)))
stoi = {s: i for i, s in enumerate(char)}
itos = {i: s for i, s in enumerate(char)}

print(''.join(chars))
print(f'Num of char: {len(chars)}')


 !$&',-.3:;?ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz
Num of char: 65


In [32]:
encode = lambda s: [stoi[c] for c in s]
decode = lambda l: ''.join([itos[i] for i in l])

print(encode("Pattara"))
print(decode(encode("Pattara")))

[28, 39, 58, 58, 39, 56, 39]
Pattara


In [37]:
data = torch.tensor(encode(text), dtype=torch.long)
print(data.shape, data.type())

torch.Size([1115394]) torch.LongTensor


In [39]:
n = int(0.9*len(data))
train = data[:n]
test = data[n:]

# Name 

In [47]:
names = open('names.txt', 'r').read().split()
print(f'n names: {len(names)}')

n names: 32033


In [59]:
chars = sorted(set(''.join(names)))
stoi = {s:i+1 for i, s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}

vocab_size = len(stoi)

print(itos)
print(f'vocab size: {vocab_size}')

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}
vocab size: 27


In [173]:
block_size = 3
batch_size = 32

In [233]:
def build_dataset(names):
    X, y = [], []
    for n in names:
        context = [0] * block_size 
        for ch in n + '.':
            ix = stoi[ch]
            X.append(context)
            y.append(ix)
            context = context[1:] + [ix]
    return X, y

In [234]:
X, y = build_dataset(names)

for c, j in zip(X[:27], y[:27]):
    print(''.join(itos[i] for i in c), '--->', itos[j])
    
X, y = torch.tensor(X), torch.tensor(y)
X.shape, y.shape, X.dtype, y.dtype

... ---> j
..j ---> a
.ja ---> m
jam ---> e
ame ---> s
mes ---> e
ese ---> .
... ---> k
..k ---> e
.ke ---> i
kei ---> s
eis ---> y
isy ---> .
... ---> m
..m ---> e
.me ---> a
mea ---> d
ead ---> o
ado ---> w
dow ---> .
... ---> t
..t ---> a
.ta ---> e
tae ---> l
ael ---> y
ely ---> n
lyn ---> .


(torch.Size([228146, 3]), torch.Size([228146]), torch.int64, torch.int64)

In [235]:
# X, y = X[:32], y[:32]

In [396]:
n_embd = 10
hidden = 100
epochs = 10000
lr = 1e-2
n = batch_size

In [397]:
C = torch.randn((vocab_size, n_embd))
W1 = torch.randn((n_embd * block_size, hidden)) * 0.1
b1 = torch.randn(hidden) * 0.1
W2 = torch.randn((hidden, vocab_size)) * 0.1
b2 = torch.randn(vocab_size) * 0.1

parameters = [C, W1, b1, W2, b2]
print("Number of params: ", sum(p.nelement() for p in parameters))

# set require grad
for p in parameters:
    p.requires_grad = True

Number of params:  6097


In [398]:
for epoch in range(epochs+1):
    # mini batch
    ix = torch.randint(0, X.shape[0], (batch_size, ))
    
    # forward pass
    emb = C[X[ix]].view(-1, n_embd * block_size) # reduction from vocab_size dim to n_embd dim
    h = emb @ W1 + b1 
    logits = h @ W2 + b2

    # softmax
    # exps = logits.exp()
    # probs = exps / exps.sum(1, keepdims=True)
    # loss = -probs[torch.arange(n), y].log().mean()
    loss = F.cross_entropy(logits, y[ix])
    if epoch % 1000 == 0:
        print(f'Epoch: {epoch} | loss: {loss:.4f}')

    # backward pass
    for p in parameters: # set zero grad
        p.grad = None
    loss.backward() # call backward
    for p in parameters: # update
        p.data += -lr * p.grad

Epoch: 0 | loss: 3.6439
Epoch: 1000 | loss: 2.7153
Epoch: 2000 | loss: 2.6550
Epoch: 3000 | loss: 2.4809
Epoch: 4000 | loss: 2.4776
Epoch: 5000 | loss: 2.5530
Epoch: 6000 | loss: 2.2666
Epoch: 7000 | loss: 2.5536
Epoch: 8000 | loss: 2.4817
Epoch: 9000 | loss: 2.4477
Epoch: 10000 | loss: 2.4612


In [399]:
emb = C[X].view(-1, n_embd * block_size) # reduction from vocab_size dim to n_embd dim
h = emb @ W1 + b1 
logits = h @ W2 + b2
loss = F.cross_entropy(logits, y)
loss

tensor(2.4187, grad_fn=<NllLossBackward0>)

# Backprop

In [528]:
# Forward pass
ix = torch.randint(0, X.shape[0], (batch_size, ))
Xb, yb = X[ix], y[ix]

# embed vocab size to n_embd
emb = C[Xb]
embcat = emb.view(emb.shape[0], -1)

h = embcat @ W1 + b1
logits = h @ W2 + b2

# cross entropy loss
# norm logits
logit_maxes = logits.max(1, keepdims=True).values
norm_logits = logits - logit_maxes
# softmax
exps = norm_logits.exp()
sum_exps = exps.sum(1, keepdims=True)
sum_exps_inv = sum_exps**-1
probs = exps * sum_exps_inv
logprobs = probs.log()
loss = -logprobs[torch.arange(n), yb].mean()

for p in parameters:
    p.grad = None
for t in [logprobs, probs, exps, sum_exps, sum_exps_inv, norm_logits, logit_maxes, logits, h, embcat, emb]:
    t.retain_grad()
loss.backward()

print(f'Loss: {loss:.4f}')

Loss: 2.1437


In [614]:
# Backward pass
dlogprobs = torch.zeros_like(logprobs)
dlogprobs[torch.arange(n), yb] = 1 * -1.0/n
dprobs = dlogprobs * (1.0 / probs)
dexps = (torch.ones_like(exps) * sum_exps_inv) * dprobs
dsum_exps_inv = (exps * dprobs).sum(1, keepdims=True)
dsum_exps = (-1 / sum_exps**2) * dsum_exps_inv
dexps += torch.ones_like(exps) * dsum_exps
dnorm_logits = exps * dexps
dlogits = torch.ones_like(logits) * dnorm_logits
dlogit_maxes = (-dnorm_logits).sum(1, keepdims=True) 
logits_ones = torch.zeros_like(logits)
logits_ones[torch.arange(n), logits.max(1).indices] = 1
dlogits += logits_ones * dlogit_maxes
dW2 = h.T @ dlogits
dh = dlogits @ W2.T
db2 = dlogits.sum(0)
dW1 = embcat.T @ dh
dembcat = dh @ W1.T
db1 = dh.sum(0)
demb = dembcat.view(emb.shape)

dC = torch.zeros_like(C)
#for dembb, chs in zip(demb, Xb):
#    for dembt, ch in zip(dembb, chs):
#        dC[ch] += dembt
        
for k in range(Xb.shape[0]):
    for j in range(Xb.shape[1]):
        ix = Xb[k, j]
        dC[ix] += demb[k, j]
        

In [616]:
torch.all(dC == C.grad).item()

True

In [617]:
torch.allclose(dC, C.grad)

True