In [1]:
import torch
import torch.nn.functional as F

In [2]:
raw_names = open("./names.txt").read()
names = raw_names.split('\n')

In [3]:
X = []
Y = []
block_size = 3

letters_n = {v:i+1 for i,v in enumerate(sorted(list(set(raw_names.replace('\n','')))))}
letters_n['.'] = 0
n_letters = {v:i for i,v in letters_n.items()}

for name in names:
    context = [0] * block_size
    for char in name:
        n_char = letters_n[char]
        X.append(context)
        Y.append(n_char)
        
        context = context[1:] + [n_char]

X = torch.Tensor(X)
Y = torch.Tensor(Y)

idx1 = int(0.8*len(X))
idx2 = int(0.9*len(X))

Xtr = X[:idx1]
Ytr = Y[:idx1]
Xvl = X[idx1:idx2]
Yvl = Y[idx1:idx2]
Xts = X[idx2:]
Yts = Y[idx2:]

X.dtype, Y.dtype

(torch.float32, torch.float32)

In [4]:
Xtr.shape,Ytr.shape,Xvl.shape,Yvl.shape,Xts.shape,Yts.shape

(torch.Size([157192, 3]),
 torch.Size([157192]),
 torch.Size([19649, 3]),
 torch.Size([19649]),
 torch.Size([19649, 3]),
 torch.Size([19649]))

In [5]:
gen = torch.Generator().manual_seed(2147483647)
C = torch.randn((27,10),requires_grad=True, generator=gen)
W1 = torch.randn((30, 200), requires_grad=True, generator=gen)
b1 = torch.randn(200, requires_grad=True, generator=gen)
W2 = torch.randn((200, 100), requires_grad=True, generator=gen)
b2 = torch.randn(100, requires_grad=True, generator=gen)
W3 = torch.randn((100, 27), requires_grad=True, generator=gen)
b3 = torch.randn(27, requires_grad=True, generator=gen)
parameters = [C, W1, b1, W2, b2, W3, b3]
epsilon = 0.01

In [6]:
for __ in range(10000):

    mini_batch = torch.randint(0, Xtr.shape[0], (128,))

    embeddings = C[Xtr[mini_batch].type(dtype=torch.int64)]
    hypothesis = torch.tanh(embeddings.view(-1,30) @ W1 + b1)                                                # First Hidden Layer 
    hypothesis = torch.tanh(hypothesis @ W2 + b2)                                                    # Second Hidden Layer 
    logits = hypothesis @ W3 + b3                                                                           # Logits
    # counts = logits.exp()                                                                                   # Output Layer
    # prob = counts / counts.sum(1, keepdim=True)                                                             # Softmax
    # loss = -prob[torch.arange(embeddings.shape[0]).type(torch.int64), Y.type(torch.int64)].log().mean()     # Loss
    loss = F.cross_entropy(logits, Ytr[mini_batch].type(torch.int64))   
    for p in parameters:
        p.grad = None
    loss.backward()
    for p in parameters:
        p.data -= 0.01 * p.grad.type(torch.float32)

print(loss.item())

3.975724935531616


In [7]:
embeddings = C[Xvl.type(dtype=torch.int64)]
hypothesis = torch.tanh(embeddings.view(-1,30) @ W1 + b1)                                                # First Hidden Layer 
hypothesis = torch.tanh(hypothesis @ W2 + b2)                                   
logits = hypothesis @ W3 + b3                                                                      
loss = F.cross_entropy(logits, Yvl.type(torch.int64))   
print(loss)

tensor(4.5347, grad_fn=<NllLossBackward0>)


In [8]:
C

tensor([[ 1.6250, -0.3055,  0.1146, -1.2928,  0.3351, -0.1023, -1.6223,  0.6491,
          0.0079,  0.9782],
        [-0.6107,  0.8769, -0.1614, -0.6589,  1.5111,  3.1819,  1.6182, -1.6964,
          0.8306, -1.0097],
        [ 1.2450, -0.1344, -1.4076,  0.4298, -0.0413,  2.5197,  2.7214, -0.5226,
         -1.2087,  0.2822],
        [-1.0444,  0.6657, -0.0829,  1.0715, -0.9042, -0.8043, -1.8547,  1.3889,
         -1.0808,  1.3453],
        [-0.9298, -0.2919, -0.4072,  0.2365, -1.0983,  1.1543,  0.6226,  0.4737,
          0.1372, -1.5975],
        [ 0.3030,  0.7680,  0.4961, -1.3511,  0.0957, -1.1171,  0.4569,  0.4894,
         -1.1740, -0.6337],
        [ 0.1188, -0.1041, -1.0650, -1.4020,  0.8579, -0.6199,  1.8903,  1.9736,
         -0.7834,  0.5189],
        [ 0.6950,  1.5675, -1.1910, -0.8245, -0.7150,  1.8651,  0.1190,  1.2726,
         -0.7284,  0.2650],
        [-1.2191,  1.1412, -0.1764, -1.5431,  1.2509,  0.4777,  1.1769, -0.8420,
         -1.3493,  1.2023],
        [-1.3837,  