## Intro to language modeling: building makemore
### Part 2: MLP

From YouTube video: https://www.youtube.com/watch?v=TCH_1BHY58I by @AndrejKarpathy

#### Library imports

In [1]:
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
%matplotlib inline

#### Dataset import

In [2]:
words = open('names.txt', 'r').read().splitlines()
len(words)

32033

In [3]:
# ---------- Multilayer Perceptron ---------- #

In [4]:
# build the vocab of characters and mappings to int
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}
print(itos)

{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [5]:
# build the dataset

block_size = 3 # context len: 
# how many chars do we take to predict the next one?

X, Y = [], []
for w in words[:5]:
    
    print(w)
    context = [0] * block_size # padded
    for ch in w + '.':
        ix = stoi[ch]
        X.append(context)
        Y.append(ix)
        print(''.join(itos[i] for i in context), '---->', itos[ix])
        context = context[1:] + [ix] # crop & append
        # rolling window of context
        
X = torch.tensor(X)
Y = torch.tensor(Y)

emma
... ----> e
..e ----> m
.em ----> m
emm ----> a
mma ----> .
olivia
... ----> o
..o ----> l
.ol ----> i
oli ----> v
liv ----> i
ivi ----> a
via ----> .
ava
... ----> a
..a ----> v
.av ----> a
ava ----> .
isabella
... ----> i
..i ----> s
.is ----> a
isa ----> b
sab ----> e
abe ----> l
bel ----> l
ell ----> a
lla ----> .
sophia
... ----> s
..s ----> o
.so ----> p
sop ----> h
oph ----> i
phi ----> a
hia ----> .


In [6]:
X.shape, X.dtype, Y.shape, Y.dtype

(torch.Size([32, 3]), torch.int64, torch.Size([32]), torch.int64)

In [7]:
# look up table C
C = torch.randn((27,2))

In [8]:
# C[5]
# and 
# F.one_hot(torch.tensor(5), num_classes=27).float() @ C
# are the same

In [9]:
# embedding
emb = C[X]
emb.shape

torch.Size([32, 3, 2])

In [10]:
g = torch.Generator().manual_seed(2147483647) # for reproducibility
# first layer
W1 = torch.randn((6,100), generator=g)
b1 = torch.randn(100, generator=g)
# second layer
W2 = torch.randn((100,27), generator=g)
b2 = torch.randn(27, generator=g)

parameters = [C, W1, b1, W2, b2]

In [11]:
sum(p.nelement() for p in parameters) # num of para in total

3481

In [12]:
h = torch.tanh(emb.view(emb.shape[0],6) @ W1 + b1)
# use emb.shape[0] or -1

# broadcasting b1 to emb.view(emb.shape[0],6) @ W1
# -1~1

In [13]:
# # the size doesn't matter if I use cat & unbind
# # along 1st dimension
# torch.cat(torch.unbind(emb, 1), 1).shape 
# # new memory is created, not efficient enough

# arr = torch.arange(18)
# arr.view(3,3,2)

In [14]:
logits = h @ W2 + b2
# counts = logits.exp()
# prob = counts / counts.sum(1,keepdims=True)
# loss = -prob[torch.arange(32), Y].log().mean() # ideally the nums in prob[torch.arange(32), Y] will be 1
loss = F.cross_entropy(logits, Y)
loss

tensor(16.1090)

In [15]:
for p in parameters:
    p.requires_grad = True

In [18]:
for _ in range(1000):
    # forward pass
    emb = C[X] # (32,3,2)
    h = torch.tanh(emb.view(emb.shape[0],6) @ W1 + b1) # (32,100)
    logits = h @ W2 + b2
    loss = F.cross_entropy(logits, Y)
    
    # backward pass
    for p in parameters:
        p.grad = None
    loss.backward()
    # update
    for p in parameters:
        p.data += -0.1 * p.grad
        
print(loss.item())
# overfitting a single batch of the data and getting low loss

0.25461822748184204


In [19]:
logits.max(1)

torch.return_types.max(
values=tensor([16.2235, 18.9979, 18.1264, 17.2726, 15.8082, 16.2235, 18.3182, 18.4859,
        17.2660, 21.6965, 18.1526, 17.2265, 16.2235, 14.1414, 24.0474, 17.3209,
        16.2235, 21.7283, 22.4534, 14.7754, 16.1589, 16.5105, 16.1017, 16.5161,
        14.8172, 16.2235, 22.7258, 16.9645, 21.7419, 22.2422, 17.3828, 18.0312],
       grad_fn=<MaxBackward0>),
indices=tensor([ 9, 13, 13,  1,  0,  9, 12,  9, 22,  9,  1,  0,  9, 22,  1,  0,  9, 19,
         1,  2,  5, 12, 12,  1,  0,  9, 15, 16,  8,  9,  1,  0]))