### E01. train a trigram language model

#### Building a dataset

In [356]:
# Loading a names.txt
words = open('names.txt', 'r').read().splitlines()
# Table from characters to integers
chars = sorted(list(set(''.join(words))))
# string to integer
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
itos = {i:s for s,i in stoi.items()}

In [357]:
words[:1]

['emma']

In [358]:
import torch
# Create the dataset
xs1, xs2, ys = [], [], []
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    # print("ch1 ", ch1)
    ix2 = stoi[ch2]
    # print("ch2 ", ch2)
    ix3 = stoi[ch3]
    # print("next symb ", ch3)
    # print(ch3)
    xs1.append(ix1)
    xs2.append(ix2)
    ys.append(ix3)
xs1 = torch.tensor(xs1)
xs2 = torch.tensor(xs2)
ys = torch.tensor(ys)
num = xs1.nelement()
print("Number of examples: ", num)
print(type(num))

Number of examples:  196113
<class 'int'>


In [359]:
print(xs1.shape)
print(xs2.shape)
print(ys.shape)

torch.Size([196113])
torch.Size([196113])
torch.Size([196113])


In [360]:
# initialize the network
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)

In [362]:
import torch.nn.functional as F
# Gradient descent
for k in range(100):
  # Forward pass
  xenc1 = F.one_hot(xs1, num_classes=27).float()
  xenc2 = F.one_hot(xs2, num_classes=27).float()
  # print(xenc1.shape)
  # print(xenc2.shape)
  xenc = xenc1 + xenc2

  # xenc = torch.cat((xenc1, xenc2), dim=0)
  # print(xenc.shape)
  # Need to: (4, 27) @ (27, 27) -> (4, 27), so let's use xenc1+xenc2
  logits = xenc @ W
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True)
  # + regularization 
  loss = -probs[torch.arange(num), ys].log().mean() 
  print(loss.item())
  # Backward pass
  W.grad = None
  loss.backward() 

  # Updating the W
  W.data += -50 * W.grad

4.161541938781738
3.263045072555542
3.085920810699463
2.8875062465667725
2.958226203918457
2.7500483989715576
2.8280277252197266
2.69932222366333
2.8332326412200928
2.631650924682617
2.71332049369812
2.658245086669922
2.8357033729553223
2.5668835639953613
2.5865964889526367
2.635568618774414
2.8357856273651123
2.542372941970825
2.5375895500183105
2.573140859603882
2.7344002723693848
2.5467212200164795
2.628833293914795
2.5983617305755615
2.791351079940796
2.514540195465088
2.5233278274536133
2.5691335201263428
2.752143621444702
2.515011787414551
2.554187774658203
2.5953571796417236
2.8041296005249023
2.503009080886841
2.4922454357147217
2.5180466175079346
2.6516964435577393
2.546321392059326
2.6906049251556396
2.517285108566284
2.6251401901245117
2.5542800426483154
2.723832368850708
2.4982352256774902
2.557117462158203
2.583812713623047
2.7925169467926025
2.4880549907684326
2.4787847995758057
2.507620334625244
2.6428043842315674
2.535244941711426
2.6776294708251953
2.5099241733551025
2

### E02: split up the dataset randomly into 80% train set, 10% dev set, 10% test set. 

#### dataset

In [220]:
# Creating an already encoded dataset
xs1, xs2, ys = [], [], []
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
    ix1 = stoi[ch1]
    # print("ch1 ", ch1)
    ix2 = stoi[ch2]
    # print("ch2 ", ch2)
    ix3 = stoi[ch3]
    # print("next symb ", ch3)
    # print(ch3)
    xs1.append(ix1)
    xs2.append(ix2)
    ys.append(ix3)
xs1 = torch.tensor(xs1)
xs1 = F.one_hot(xs1, num_classes=27).float()
xs2 = torch.tensor(xs2)
xs2 = F.one_hot(xs2, num_classes=27).float()
xs = xs1 + xs2
ys = torch.tensor(ys)

In [221]:
dataset = torch.cat([xs, ys.view(-1, 1)], dim=1)

# With view
idx = torch.randperm(dataset[..., -1].nelement())
dataset_shuffled = dataset[idx, :]


TRAIN_SIZE = int(0.8 * len(dataset_shuffled[..., -1]))
DEV_SIZE = int(0.1 * len(dataset_shuffled[..., -1]))
TEST_SIZE = len(dataset_shuffled[..., -1]) - TRAIN_SIZE - DEV_SIZE
labels = dataset_shuffled[..., -1]
features = dataset_shuffled[..., :-1]
features_train, features_dev, features_test = torch.split(features, [TRAIN_SIZE, DEV_SIZE, TEST_SIZE])
labels_train, labels_dev, labels_test = torch.split(labels, [TRAIN_SIZE, DEV_SIZE, TEST_SIZE])

In [222]:
print(dataset_shuffled[0])
print(features_train[0])
print(labels_train.shape)
print(features_train.shape)
print(labels_train.int())

tensor([0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0., 9.])
tensor([0., 1., 0., 0., 1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
        0., 0., 0., 0., 0., 0., 0., 0., 0.])
torch.Size([156890])
torch.Size([156890, 27])
tensor([ 9, 14, 18,  ...,  5,  5, 26], dtype=torch.int32)


#### Training loop

In [245]:
# initialize the network
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)

In [249]:
# Gradient descent
for k in range(100):
  # Forward pass
  # xenc1 = F.one_hot(xs1, num_classes=27).float()
  # xenc2 = F.one_hot(xs2, num_classes=27).float()
  # # print(xenc1.shape)
  # # print(xenc2.shape)
  # xenc = xenc1 + xenc2
  xenc = features_train.float()
  # xenc = torch.cat((xenc1, xenc2), dim=0)
  # Need to: (4, 27) @ (27, 27) -> (4, 27), so let's use xenc1+xenc2
  logits = xenc @ W
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True)
  loss = -probs[torch.arange(TRAIN_SIZE), labels_train.int()].log().mean() 
  print(loss.item())
  # Backward pass
  W.grad = None
  loss.backward() 

  # Updating the W
  W.data += -20 * W.grad

3.533879518508911
3.2061102390289307
2.9651730060577393
2.778259754180908
2.6462581157684326
2.561558723449707
2.5088400840759277
2.472064733505249
2.444540023803711
2.42516827583313
2.413764476776123
2.408475637435913
2.406306505203247
2.4052703380584717
2.4046387672424316
2.4041941165924072
2.403857707977295
2.403592109680176
2.4033753871917725
2.4031944274902344
2.4030416011810303
2.402909517288208
2.40279483795166
2.402693271636963
2.4026036262512207
2.4025235176086426
2.402451276779175
2.402385950088501
2.4023261070251465
2.4022715091705322
2.402221202850342
2.402175188064575
2.402132511138916
2.402092695236206
2.4020557403564453
2.4020214080810547
2.401988983154297
2.40195894241333
2.401930570602417
2.4019038677215576
2.401878595352173
2.4018547534942627
2.4018325805664062
2.401811361312866
2.401791572570801
2.4017722606658936
2.4017539024353027
2.4017369747161865
2.4017205238342285
2.401705026626587
2.4016900062561035
2.4016759395599365
2.4016621112823486
2.401649236679077
2.401

### E03: use the dev set to tune the strength of smoothing (or regularization) for the trigram mode

In [252]:
# Gradient descent
for k in range(100):
  # Forward pass
  # xenc1 = F.one_hot(xs1, num_classes=27).float()
  # xenc2 = F.one_hot(xs2, num_classes=27).float()
  # # print(xenc1.shape)
  # # print(xenc2.shape)
  # xenc = xenc1 + xenc2
  xenc = features_dev.float()
  # xenc = torch.cat((xenc1, xenc2), dim=0)
  # Need to: (4, 27) @ (27, 27) -> (4, 27), so let's use xenc1+xenc2
  logits = xenc @ W
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True)
  loss = -probs[torch.arange(DEV_SIZE), labels_dev.int()].log().mean() + 0.001 * (W**2).mean()
  print(loss.item())
  # Backward pass
  W.grad = None
  loss.backward() 

  # Updating the W
  W.data += -20 * W.grad

2.3897366523742676
2.389726161956787
2.3897154331207275
2.389705181121826
2.389694929122925
2.3896844387054443
2.389673948287964
2.3896636962890625
2.3896536827087402
2.389643907546997
2.389633893966675
2.3896234035491943
2.3896141052246094
2.3896045684814453
2.389594793319702
2.389585018157959
2.389575242996216
2.389565944671631
2.389556646347046
2.389547109603882
2.389538288116455
2.389529228210449
2.3895199298858643
2.3895106315612793
2.3895018100738525
2.3894927501678467
2.389483690261841
2.389474868774414
2.3894665241241455
2.3894577026367188
2.389449119567871
2.3894402980804443
2.389431953430176
2.389423370361328
2.3894150257110596
2.389406442642212
2.3893985748291016
2.389390468597412
2.3893821239471436
2.389374017715454
2.3893659114837646
2.389357566833496
2.3893494606018066
2.3893418312072754
2.389334201812744
2.389326333999634
2.3893184661865234
2.389310836791992
2.389302968978882
2.3892955780029297
2.3892877101898193
2.389280319213867
2.389273166656494
2.389265775680542
2.38

### Test part

In [257]:
# Gradient descent
for k in range(1):  
  # Forward pass
  # xenc1 = F.one_hot(xs1, num_classes=27).float()
  # xenc2 = F.one_hot(xs2, num_classes=27).float()
  # # print(xenc1.shape)
  # # print(xenc2.shape)
  # xenc = xenc1 + xenc2
  xenc = features_test.float()
  # xenc = torch.cat((xenc1, xenc2), dim=0)
  # Need to: (4, 27) @ (27, 27) -> (4, 27), so let's use xenc1+xenc2
  logits = xenc @ W
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True)
  loss = -probs[torch.arange(TEST_SIZE), labels_test.int()].log().mean() + 0.001 * (W**2).mean()
  print(loss.item())


2.4140918254852295


## E04: delete our use of F.one_hot in favor of simply indexing into rows of W

In [363]:
import torch
# Create the dataset
xs, ys = [], []
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    # print("ch1 ", ch1)
    ix2 = stoi[ch2]
    # print("next symb ", ch3)
    # print(ch3)
    xs.append(ix1)
    ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = ys.nelement()
print("Number of examples: ", num)
print(type(num))

Number of examples:  228146
<class 'int'>


In [364]:
# initialize the network
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)

In [365]:
import torch.nn.functional as F
# Gradient descent
for k in range(100):
  # Need to: (4, 27) @ (27, 27) -> (4, 27), so let's use xenc1+xenc2
  logits = W[[x.item() for x in xs]]
  counts = logits.exp()
  probs = counts / counts.sum(1, keepdims=True)
  # # # + regularization 
  loss = -probs[torch.arange(num), ys].log().mean() 
  print(loss.item())
  # Backward pass
  W.grad = None
  loss.backward() 

  # Updating the W
  W.data += -50 * W.grad

3.758953332901001
3.371081829071045
3.1540329456329346
3.02036714553833
2.9277074337005615
2.8603997230529785
2.809727191925049
2.7701010704040527
2.738072633743286
2.711496353149414
2.6890029907226562
2.6696884632110596
2.6529300212860107
2.638277053833008
2.6253879070281982
2.613990545272827
2.60386323928833
2.5948216915130615
2.586712121963501
2.579403877258301
2.572788953781128
2.5667762756347656
2.5612878799438477
2.5562589168548584
2.551633834838867
2.547365665435791
2.543415069580078
2.5397486686706543
2.536336660385132
2.533154249191284
2.5301806926727295
2.5273969173431396
2.5247862339019775
2.5223348140716553
2.520029067993164
2.5178580284118652
2.515810489654541
2.513878345489502
2.512052059173584
2.510324001312256
2.5086867809295654
2.5071349143981934
2.5056614875793457
2.5042612552642822
2.502929210662842
2.5016608238220215
2.5004520416259766
2.4992988109588623
2.498197317123413
2.497144937515259
2.496137857437134
2.495173692703247
2.4942495822906494
2.493363380432129
2.49

In [267]:
print(xs.shape)
print(xenc.shape)
print(ys.shape)
print(probs.shape)

torch.Size([5])
torch.Size([5, 27])
torch.Size([5])
torch.Size([5, 27])


In [355]:
print([x.item() for x in xs])
print(xenc.shape)
print(xenc[2, :])
print(xenc[2] @ W)
print(logits[2])
# print(probs[1])
ind = [0, 5, 13, 13, 1]
logits = W[[x.item() for x in xs]]
logits[2]

[0, 5, 13, 13, 1, 0, 15, 12, 9, 22, 9, 1, 0, 1, 22, 1, 0, 9, 19, 1, 2, 5, 12, 12, 1, 0, 19, 15, 16, 8, 9, 1, 0, 3, 8, 1, 18, 12, 15, 20, 20, 5, 0, 13, 9, 1, 0, 1, 13, 5, 12, 9, 1, 0, 8, 1, 18, 16, 5, 18, 0, 5, 22, 5, 12, 25, 14, 0, 1, 2, 9, 7, 1, 9, 12, 0, 5, 13, 9, 12, 25, 0, 5, 12, 9, 26, 1, 2, 5, 20, 8, 0, 13, 9, 12, 1, 0, 5, 12, 12, 1, 0, 1, 22, 5, 18, 25, 0, 19, 15, 6, 9, 1, 0, 3, 1, 13, 9, 12, 1, 0, 1, 18, 9, 1, 0, 19, 3, 1, 18, 12, 5, 20, 20, 0, 22, 9, 3, 20, 15, 18, 9, 1, 0, 13, 1, 4, 9, 19, 15, 14, 0, 12, 21, 14, 1, 0, 7, 18, 1, 3, 5, 0, 3, 8, 12, 15, 5, 0, 16, 5, 14, 5, 12, 15, 16, 5, 0, 12, 1, 25, 12, 1, 0, 18, 9, 12, 5, 25, 0, 26, 15, 5, 25, 0, 14, 15, 18, 1, 0, 12, 9, 12, 25, 0, 5, 12, 5, 1, 14, 15, 18, 0, 8, 1, 14, 14, 1, 8, 0, 12, 9, 12, 12, 9, 1, 14, 0, 1, 4, 4, 9, 19, 15, 14, 0, 1, 21, 2, 18, 5, 25, 0, 5, 12, 12, 9, 5, 0, 19, 20, 5, 12, 12, 1, 0, 14, 1, 20, 1, 12, 9, 5, 0, 26, 15, 5, 0, 12, 5, 1, 8, 0, 8, 1, 26, 5, 12, 0, 22, 9, 15, 12, 5, 20, 0, 1, 21, 18, 15, 18, 1, 

tensor([ 1.5240,  3.3328,  0.2035, -0.2227, -0.1026,  2.0897, -0.5970, -0.6320,
        -0.0525,  2.5788, -1.4108, -0.1672, -1.0031,  0.2535, -0.8377,  1.1275,
        -0.0787, -1.0871,  0.2169, -1.4395, -0.7131, -0.5790, -1.0060, -0.7365,
        -0.0779,  0.8942, -0.9922], grad_fn=<SelectBackward0>)

### E05: look up and use F.cross_entropy instead.

In [368]:
# Example of target with class probabilities
input = torch.randn(3, 5, requires_grad=True)
target = torch.randint(5, (3,), dtype=torch.int64)
print(input)
print(target)
# loss = F.cross_entropy(input, target)
# loss.backward()

tensor([[ 0.0649,  0.0162, -0.6918, -0.9527,  0.0810],
        [ 0.8010, -0.4380, -1.3161, -1.0206,  0.0606],
        [ 1.1185,  0.7397,  1.3518,  1.5928, -0.4411]], requires_grad=True)
tensor([4, 2, 3])


In [370]:
# Create the dataset
xs, ys = [], []
for w in words:
  chs = ['.'] + list(w) + ['.']
  for ch1, ch2 in zip(chs, chs[1:]):
    ix1 = stoi[ch1]
    ix2 = stoi[ch2]
    xs.append(ix1)
    ys.append(ix2)
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print("Number of examples: ", num)

# initialize the network
g = torch.Generator().manual_seed(2147483647)
W = torch.randn((27, 27), generator=g, requires_grad=True)

Number of examples:  228146


In [374]:
print(xs.shape)
print(ys.shape)
print(xs[1])
print(ys[1])

torch.Size([228146])
torch.Size([228146])
tensor(5)
tensor(13)


In [377]:
xenc = F.one_hot(xs, num_classes=27).float()
print((xenc @ W).shape)

torch.Size([228146, 27])


In [378]:
# Gradient descent
for k in range(100):
  # Forward pass
  xenc = F.one_hot(xs, num_classes=27).float()
  logits = xenc @ W
  loss = F.cross_entropy(logits, ys)
  # counts = logits.exp()
  # probs = counts / counts.sum(1, keepdims=True)
  # # + regularization 
  # loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
  print(loss.item())
  # Backward pass
  W.grad = None
  loss.backward() 

  # Updating the W
  W.data += -50 * W.grad

3.758953332901001
3.371100902557373
3.154043197631836
3.020374059677124
2.927711248397827
2.860402822494507
2.8097290992736816
2.770102024078369
2.7380728721618652
2.711496591567993
2.6890034675598145
2.6696884632110596
2.6529300212860107
2.638277769088745
2.6253879070281982
2.6139907836914062
2.603863000869751
2.5948219299316406
2.586712121963501
2.579403877258301
2.572789192199707
2.5667762756347656
2.5612878799438477
2.5562586784362793
2.551633596420288
2.547366142272949
2.5434155464172363
2.5397486686706543
2.536336660385132
2.5331544876098633
2.5301806926727295
2.5273969173431396
2.5247862339019775
2.522334575653076
2.520029306411743
2.517857789993286
2.515810966491699
2.513878345489502
2.512052059173584
2.510324001312256
2.5086872577667236
2.5071349143981934
2.5056614875793457
2.5042612552642822
2.502929210662842
2.5016608238220215
2.5004522800445557
2.4992990493774414
2.4981977939605713
2.4971446990966797
2.496137857437134
2.495173215866089
2.4942495822906494
2.493363380432129
2