In [None]:
import torch
import torch.nn.functional as F

# simple training set names based on 3 chars in dataset: a,b,c. This training set is missing the `ca` bigram
names = ['aa', 'ab', 'ac', 'bb', 'ba', 'bc', 'cc', 'cb']



['aa', 'ab', 'ac', 'bb', 'ba', 'bc', 'cc', 'cb']


In [None]:
SPECIAL_CH = '.'

chars = sorted(list(set(''.join(names)))) # unique chars in the training set
stoi = {s:i+1 for i, s in enumerate(chars)} # {'a':1, 'b':2, 'c':3, ..., 'z':26}
stoi[SPECIAL_CH] = 0
itos = {i:s for s, i in stoi.items()}

NUM_CHARS = len(chars + [SPECIAL_CH])

print(f'{NUM_CHARS=}')
print(f'{stoi=}')
print(f'{itos=}')


NUM_CHARS=4
stoi={'a': 1, 'b': 2, 'c': 3, '.': 0}
itos={1: 'a', 2: 'b', 3: 'c', 0: '.'}


In [None]:
xs, ys = [], [] # xs are first chars, ys are second chars found that follow the first char

for name in names:
    chs = [SPECIAL_CH] + list(name) + [SPECIAL_CH] # taking each name from training set and surrounding it with special start/end char
    for ch1, ch2 in zip(chs, chs[1:]): # loop through the pairs of chars in each name in the training set, collect pairs of chars that occur in the xs and ys
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        xs.append(ix1) # list of integer pairs found in the training set - x is first char (input), y is the target char (output,truth that follows first char)
        ys.append(ix2)

xs = torch.tensor(xs) # first chars of each pair from dataset - one dimensional array [0,5,13,13,1,...]
ys = torch.tensor(ys) # chars following the first char from the pairs

print(f'xs={xs.numpy()}') # convert to numpy array for easy printing. these are the int representations of first chars in a pair found in the training set
print(f'ys={ys.numpy()}') # second char corresponding to first xs char in a pair in the trainingset

num_pairs_found = xs.nelement()
print('number of char pairs in the training set: ',num_pairs_found) # how many char pairs to sample predictions for    



xs=[0 1 1 0 1 2 0 1 3 0 2 2 0 2 1 0 2 3 0 3 3 0 3 2]
ys=[1 1 0 1 2 0 1 3 0 2 2 0 2 1 0 2 3 0 3 3 0 3 2 0]
number of char pairs in the training set:  24


In [None]:
# Initialize Network
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((NUM_CHARS,NUM_CHARS), generator=g, requires_grad=True) # start with random weights - one column/row for each char (incl special)
# NUM_CHARS is the number of all unique characters found in the training set plus one special char used for denoting begin or end of a name
print('Weights (randomly initialized):')
print(W.detach().numpy()) # detach needed because of the require grad option on the tensor before converting to numpy



Weights (randomly initialized):
[[ 1.567362   -0.23729232 -0.02738461 -1.1007794 ]
 [ 0.28588146 -0.02964334 -1.5470592   0.60489196]
 [ 0.0791362   0.90462387 -0.4712532   0.786822  ]
 [-0.32843494 -0.43297017  1.3729309   2.9333673 ]]


In [None]:
# one hot encoded first chars of each member of the pairs that occur in the training set
# [0,1,0,0] => 'a'
print(xs.detach().numpy()) # int representation of the char
print(F.one_hot(xs, num_classes=NUM_CHARS).float()) # which "bit" of 4 possible bits (one per unique char) is turned on


[0 1 1 0 1 2 0 1 3 0 2 2 0 2 1 0 2 3 0 3 3 0 3 2]
tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.]])


In [24]:


############################################ 
#               GRADIENT DESCENT           # 
# ##########################################

for k in range(2):
    ############# FORWARD PASS #############
    xenc = F.one_hot(xs, num_classes=NUM_CHARS).float() # each row represents a char (one row per char in each name in training dataset) is 0s with the integer to str mapping idx set as 1
    print(f'one hot encoded: {F.one_hot(xs, num_classes=NUM_CHARS).float()}')
    print(f'W: {W.detach().numpy()}')
    logits = xenc @ W # logits is the appropriate row of W to find the counts/prob for that char/pair. W is the log counts. (the original bigram table with the counts would be W exponentiated - W.exp())
    # in matrix multiplication we retain the original number of rows of Matrix A (xenc), so we'll have 24 rows and 4 columns
    print(f'logits: {logits.detach().numpy()}')
    
    counts = logits.exp() # make all positive - e^x, vals close to zero will be close to 1
    probs = counts / counts.sum(1, keepdims=True) # probability for next char
    regularization_strength = 0.01 # can adjust this strength. the higher it is the more smooth it makes the distribution (more uniform). If higher it dominates the loss fn below and will make the weights (W) unable to grow because too much loss will be accumulated. everything will become uniform distribution equal predictions (?)
    loss = -probs[torch.arange(xs.nelement()), ys].log().mean() + regularization_strength*(W**2).mean() # regularizatization wants to push towards 0
    print(f'LOSS: {loss.item()}') # we should see the loss decreasing
    
    ######### BACKWARD PASS ###############
    W.grad = None # Zero the gradient
    loss.backward()

    ######### UPDATE THE WEIGHTS #############
    learning_rate = 50 # if slow loss reduction, increase the learning rate to bring it down faster
    W.data += -learning_rate * W.grad # go in reverse direction of gradient with the goal of reducing loss

one hot encoded: tensor([[1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 1., 0., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 1., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 1., 0., 0.],
        [1., 0., 0., 0.],
        [0., 0., 1., 0.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 0., 1.],
        [1., 0., 0., 0.],
        [0., 0., 0., 1.],
        [0., 0., 1., 0.]])
W: [[-5.667372    5.0118155  -1.1700845   1.9315311 ]
 [-1.7576466   0.4487791   0.49110973  0.45801425]
 [-1.7325449  -1.4437535   2.287304    1.5704437 ]
 [ 1.6058083  -1.0105455   2.8309557  -1.5670612 ]]
logits: [[-5.667372    5.0118155  -1.1700845   1.9315311 ]
 [-1.7576466   0.4487791   0.49110973  0.45801425]
 [-1.7576466   0.4487791   0.49110973  0