# Dataset

In [4]:
device = torch.device("mps")

with open('names.txt') as f:
    content = f.read()
    words = content.splitlines()

print("Dataset size: ", len(words))
print("Smallest length: ", min(len(w) for w in words))
print("Largest length: ", max(len(w) for w in words))
print("Examples: ", words[:10])

Dataset size:  32033
Smallest length:  2
Largest length:  15
Examples:  ['emma', 'olivia', 'ava', 'isabella', 'sophia', 'charlotte', 'mia', 'amelia', 'harper', 'evelyn']


In [7]:
import torch

# Use tensors to capture the frequency of bigrams. First lets determine the characters
# in the dataset used as offsets within the tensor.
chars = sorted(list(set(''.join(words))))
stoi = { s: i+1 for i, s in enumerate(chars) }
stoi['.'] = 0
itos = { i: s for s, i in stoi.items() }

NUM_1D = len(stoi)
NUM_CLASSES = NUM_1D * NUM_1D

def compute_token(ch1: str, ch2: str) -> str:
    return stoi[ch1] * NUM_1D + stoi[ch2]


N = torch.zeros(NUM_CLASSES, NUM_1D, dtype=torch.int32)
for w in words:
    chs = ['.']  + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        input_token = compute_token(ch1, ch2)
        N[input_token, stoi[ch3]] += 1

N = N.to(device=device)


In [8]:
NUM_CLASSES

729

# Training set of trigrams

In [10]:
# Initialize the weights
g = torch.Generator(device=device).manual_seed(2147483647)
W = torch.randn((NUM_CLASSES, NUM_1D), generator=g, requires_grad=True, device=device)

xs, ys = [], []
for w in words:
    chs = ['.', '.']  + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        input_token = compute_token(ch1, ch2)
        ix3 = stoi[ch3]
        xs.append(input_token)
        ys.append(ix3)

xs = torch.tensor(xs).to(device=device)
ys = torch.tensor(ys).to(device=device)
num = xs.nelement()
print("number of examples: ", num)

number of examples:  228146


In [11]:
W.shape

torch.Size([729, 27])

In [82]:
xs.shape


torch.Size([228146])

In [86]:
xs[6:0]

tensor([], device='mps:0', dtype=torch.int64)

In [97]:
import random
import torch.nn.functional as F
from torch.profiler import profile, record_function, ProfilerActivity

BATCH_SIZE = 128
num_records = xs.shape[0]
rang = torch.arange(num, device=device)

with profile(activities=[ProfilerActivity.CPU], record_shapes=True) as prof:

    for k in range(500):
        batch_start = int(xs.shape[0] * random.random())        
        batch = xs[batch_start:batch_start+BATCH_SIZE].to(device)
        xenc = F.one_hot(xs, num_classes=NUM_CLASSES).float().to(device)

        # Forward pass    
        logits = (xenc @ W) # log counts
        counts = logits.exp()  # equivalent N
        probs = counts.div(counts.sum(1, keepdim=True))
        # Add regularization loss
        loss = -probs[rang, ys].log().mean() + 0.01 * (W**2).mean()
        print(loss.item())

        # Backward pass
        W.grad = None
        loss.backward()

        # Update
        W.data += -50 * W.grad


STAGE:2024-08-04 23:15:00 38631:273768667 ActivityProfilerController.cpp:311] Completed Stage: Warm Up


2.2228968143463135
2.222891330718994
2.222885847091675
2.2228801250457764
2.222874641418457


STAGE:2024-08-04 23:15:02 38631:273768667 ActivityProfilerController.cpp:317] Completed Stage: Collection
STAGE:2024-08-04 23:15:02 38631:273768667 ActivityProfilerController.cpp:321] Completed Stage: Post Processing


In [98]:
print(prof.key_averages().table(sort_by="cpu_time_total", row_limit=10))


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                             aten::item         0.01%     199.000us        94.58%        2.548s      56.626ms            45  
                              aten::_local_scalar_dense        94.56%        2.548s        94.57%        2.548s      56.621ms            45  
                                               aten::mm         2.44%      65.719ms         2.44%      65.719ms       6.572ms            10  
                                           aten::matmul         0.01%     180.000us         2.38%      63.991ms      12.798ms             5  
      

# Inference

In [99]:
# This neural net is identical to the bigram model as above, but achieved through a different technique
g = torch.Generator(device=device).manual_seed(2147483647)
for i in range(5):
    out = []
    ix1, ix2 = '.', '.'
    while True:
        token = compute_token(ix1, ix2)
        xenc = F.one_hot(torch.tensor([token]), num_classes=NUM_CLASSES).float().to(device)
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(1, keepdim=True)
        ix = torch.multinomial(p, num_samples=1, replacement=True, generator=g).item()
        out.append(itos[ix])
        if ix == 0:
            break
        ix1 = ix2
        ix2 = itos[ix]
    print(''.join(out))

daima.
salon.
ods.
sy.
jayascovann.
