# Setup and Imports

In [1]:
import torch
import torch.nn.functional as F
import random

In [2]:
words = open('names.txt').read().splitlines()

In [3]:
words

['emma',
 'olivia',
 'ava',
 'isabella',
 'sophia',
 'charlotte',
 'mia',
 'amelia',
 'harper',
 'evelyn',
 'abigail',
 'emily',
 'elizabeth',
 'mila',
 'ella',
 'avery',
 'sofia',
 'camila',
 'aria',
 'scarlett',
 'victoria',
 'madison',
 'luna',
 'grace',
 'chloe',
 'penelope',
 'layla',
 'riley',
 'zoey',
 'nora',
 'lily',
 'eleanor',
 'hannah',
 'lillian',
 'addison',
 'aubrey',
 'ellie',
 'stella',
 'natalie',
 'zoe',
 'leah',
 'hazel',
 'violet',
 'aurora',
 'savannah',
 'audrey',
 'brooklyn',
 'bella',
 'claire',
 'skylar',
 'lucy',
 'paisley',
 'everly',
 'anna',
 'caroline',
 'nova',
 'genesis',
 'emilia',
 'kennedy',
 'samantha',
 'maya',
 'willow',
 'kinsley',
 'naomi',
 'aaliyah',
 'elena',
 'sarah',
 'ariana',
 'allison',
 'gabriella',
 'alice',
 'madelyn',
 'cora',
 'ruby',
 'eva',
 'serenity',
 'autumn',
 'adeline',
 'hailey',
 'gianna',
 'valentina',
 'isla',
 'eliana',
 'quinn',
 'nevaeh',
 'ivy',
 'sadie',
 'piper',
 'lydia',
 'alexa',
 'josephine',
 'emery',
 'julia'

In [4]:
chars = sorted(list(set(''.join(words))))
stoi = {s:i for i, s in enumerate(chars, 1)}
stoi['.'] = 0
itos = {i:s for s, i in stoi.items()}

# E01:

train a trigram language model, i.e. take two characters as an input to predict the 3rd one. Feel free to use either counting or a neural net. Evaluate the loss; Did it improve over a bigram model?

### Counting Model

In [5]:
# creating the trigram counting model

N = torch.zeros(size=(27, 27, 27))

for word in words:
    chs = ['.'] + list(word) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1, ix2, ix3 = stoi[ch1], stoi[ch2], stoi[ch3]
        N[ix1, ix2, ix3] += 1

In [6]:
# sampling from the model

# first, sample the second letter as a bigram model (because you only have the special dot token as input)


P = N / N.sum(dim=2, keepdim=True)

for i in range(10):
    out = []
    ix1 = 0
    p = N[ix1].float()
    p = p.sum(axis=0) / p.sum() # look for the probability of the second letter

    ix2 = torch.multinomial(p, 1, replacement=True).item()

    out.append(itos[ix2])

# all letters from now on will depend on the previous two characters (trigram)
    
    while True:
        p = P[ix1, ix2]
        ix1 = ix2
        ix2 = torch.multinomial(p, 1, replacement=True).item()
        out.append(itos[ix2])
        if ix2 == 0:
            break

    print(''.join(out))



arlay.
is.
unalenioud.
laila.
aletum.
ron.
alearicha.
lyinsie.
alynoveilan.
he.


In [7]:
# calculating the negative log likliehood

nll = 0
n = 0
for word in words:
    chs = ['.'] + list(word) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1, ix2, ix3 = stoi[ch1], stoi[ch2], stoi[ch3]
        nll += -torch.log(P[ix1, ix2, ix3])
        n += 1
        
print(nll/n)

tensor(2.0620)


### Neural Network

In [9]:
# preparing the dataset

xs, ys = [], []
for word in words:
    chs = ['.'] + list(word) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1, ix2, ix3 = stoi[ch1], stoi[ch2], stoi[ch3]
        xs.append([ix1, ix2])
        ys.append(ix3)
        
xs, ys = torch.tensor(xs), torch.tensor(ys)

In [10]:
W = torch.randn(27*2, 27, requires_grad=True)

for k in range(200):
    xenc = F.one_hot(xs, num_classes=27).reshape(-1, 27*2).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(axis=1, keepdims=True)
    loss = -probs[torch.arange(xs.shape[0]), ys].log().mean()
    print(loss.item())
    
    W.grad = None
    loss.backward()
    
    W.data += -50*W.grad

4.389932155609131
3.4587507247924805
3.112456798553467
2.9231925010681152
2.8012492656707764
2.7153713703155518
2.6516430377960205
2.6029317378997803
2.5642783641815186
2.5330092906951904
2.5071616172790527
2.48553729057312
2.4671428203582764
2.4513232707977295
2.437530040740967
2.4253957271575928
2.4146127700805664
2.4049689769744873
2.396282434463501
2.3884217739105225
2.3812711238861084
2.3747427463531494
2.3687589168548584
2.3632571697235107
2.358182191848755
2.3534884452819824
2.3491344451904297
2.3450868129730225
2.3413140773773193
2.337789535522461
2.3344898223876953
2.3313937187194824
2.3284833431243896
2.325742483139038
2.3231558799743652
2.320711374282837
2.31839656829834
2.31620192527771
2.314117908477783
2.31213641166687
2.3102493286132812
2.308450937271118
2.306734323501587
2.3050942420959473
2.30352520942688
2.30202317237854
2.300583600997925
2.2992031574249268
2.297877311706543
2.2966034412384033
2.2953789234161377
2.2941999435424805
2.293064832687378
2.291970729827881
2

In [11]:
# sampling from the model

# first, sample the second letter as a bigram model (because you only have the special dot token as input)



for i in range(10):
    out = []
    
    ix1 = None
    ix2 = 0
    
    xenc = torch.concat([F.one_hot(torch.tensor([ix2]), num_classes=27), torch.zeros(1,27)], axis=1).float()
    logits = xenc @ W
    counts = logits.exp()
    p = counts / counts.sum(axis=1, keepdims=True)
    
    ix1 = ix2
    ix2 = torch.multinomial(p, 1, replacement=True).item()

    out.append(itos[ix2])

# all letters from now on will depend on the previous two characters (trigram)
    
    while True:
        xenc = torch.concat([F.one_hot(torch.tensor([ix1]), num_classes=27),
                             F.one_hot(torch.tensor([ix2]), num_classes=27)], axis=1).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(axis=1, keepdims=True)
        
        ix1 = ix2
        ix2 = torch.multinomial(p, 1, replacement=True).item()
        out.append(itos[ix2])
        if ix2 == 0:
            break

    print(''.join(out))



daxann.
son.
enaemzian.
augwiron.
datrialabiusa.
sh.
dale.
vaeli.
el.
wdwan.


# E02: 
split up the dataset randomly into 80% train set, 10% dev set, 10% test set. Train the bigram and trigram models only on the training set. Evaluate them on dev and test splits. What can you see?

In [12]:
# preparing the dataset

xs, ys = [], []
for word in words:
    chs = ['.'] + list(word) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1, ix2, ix3 = stoi[ch1], stoi[ch2], stoi[ch3]
        xs.append([ix1, ix2])
        ys.append(ix3)
        
xs, ys = torch.tensor(xs), torch.tensor(ys)

In [13]:
train_split, dev_split = int(0.8*len(xs)), int(0.9*len(xs))
print(train_split, dev_split)

156890 176501


Shuffling the dataset first

In [14]:
idxs = list(range(xs.shape[0]))
random.shuffle(idxs)

xs, ys = xs[idxs], ys[idxs]

In [15]:
xs_train, ys_train = xs[:train_split], ys[:train_split]

In [16]:
xs_dev, ys_dev = xs[train_split:dev_split], ys[train_split:dev_split]

In [17]:
xs_test, ys_test = xs[dev_split:], ys[dev_split:]

In [18]:
W = torch.randn(27*2, 27, requires_grad=True)

for k in range(200):
    xenc = F.one_hot(xs_train, num_classes=27).reshape(-1, 27*2).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(axis=1, keepdims=True)
    loss = -probs[torch.arange(xs_train.shape[0]), ys_train].log().mean()
    print(f'*train set loss = {loss.item()}')
    
    W.grad = None
    loss.backward()
    
    W.data += -50*W.grad
    
    
    ## eval on dev set
    with torch.no_grad():
        xenc = F.one_hot(xs_dev, num_classes=27).reshape(-1, 27*2).float()
        logits = xenc @ W
        counts = logits.exp()
        probs = counts / counts.sum(axis=1, keepdims=True)
        loss = -probs[torch.arange(xs_dev.shape[0]), ys_dev].log().mean()
        print(f'**dev set loss = {loss.item()}')

*train set loss = 4.258572578430176
**dev set loss = 3.4777674674987793
*train set loss = 3.4889206886291504
**dev set loss = 3.102198362350464
*train set loss = 3.1163411140441895
**dev set loss = 2.891940116882324
*train set loss = 2.905864715576172
**dev set loss = 2.7637836933135986
*train set loss = 2.7774341106414795
**dev set loss = 2.6776998043060303
*train set loss = 2.6912286281585693
**dev set loss = 2.6161861419677734
*train set loss = 2.629861831665039
**dev set loss = 2.5701406002044678
*train set loss = 2.5839269161224365
**dev set loss = 2.5343356132507324
*train set loss = 2.548229217529297
**dev set loss = 2.5057027339935303
*train set loss = 2.519646167755127
**dev set loss = 2.4822006225585938
*train set loss = 2.4961564540863037
**dev set loss = 2.462482452392578
*train set loss = 2.47641921043396
**dev set loss = 2.445636749267578
*train set loss = 2.459522247314453
**dev set loss = 2.431015729904175
*train set loss = 2.4448328018188477
**dev set loss = 2.41817712

*train set loss = 2.25931978225708
**dev set loss = 2.2520947456359863
*train set loss = 2.259103775024414
**dev set loss = 2.251908302307129
*train set loss = 2.2588918209075928
**dev set loss = 2.251725673675537
*train set loss = 2.258683443069458
**dev set loss = 2.251546621322632
*train set loss = 2.258478879928589
**dev set loss = 2.251370906829834
*train set loss = 2.258277654647827
**dev set loss = 2.2511980533599854
*train set loss = 2.258080005645752
**dev set loss = 2.251028299331665
*train set loss = 2.2578859329223633
**dev set loss = 2.2508623600006104
*train set loss = 2.257694959640503
**dev set loss = 2.2506983280181885
*train set loss = 2.25750732421875
**dev set loss = 2.2505381107330322
*train set loss = 2.257322311401367
**dev set loss = 2.250380039215088
*train set loss = 2.257140636444092
**dev set loss = 2.250225067138672
*train set loss = 2.256962299346924
**dev set loss = 2.250072717666626
*train set loss = 2.256786584854126
**dev set loss = 2.24992299079895
*t

In [19]:
# eval on test set

xenc = F.one_hot(xs_test, num_classes=27).reshape(-1, 27*2).float()
logits = xenc @ W
counts = logits.exp()
probs = counts / counts.sum(axis=1, keepdims=True)
loss = -probs[torch.arange(xs_test.shape[0]), ys_test].log().mean()
print(f'**test set loss = {loss.item()}')

**test set loss = 2.2460031509399414


# E03: 
use the dev set to tune the strength of smoothing (or regularization) for the trigram model - i.e. try many possibilities and see which one works best based on the dev set loss. What patterns can you see in the train and dev set loss as you tune this strength? Take the best setting of the smoothing and evaluate on the test set once and at the end. How good of a loss do you achieve?

In [20]:
W = torch.randn(27*2, 27, requires_grad=True)

for alpha in [0.0001, 0.001, 0.01, 0.1]:
    print(f'Testing with alpha={alpha}')
    for k in range(200):
        xenc = F.one_hot(xs_train, num_classes=27).reshape(-1, 27*2).float()
        logits = xenc @ W
        counts = logits.exp()
        probs = counts / counts.sum(axis=1, keepdims=True)
        loss = -probs[torch.arange(xs_train.shape[0]), ys_train].log().mean() + alpha * (W**2).mean()
#         print(f'*train set loss = {loss.item()}')

        W.grad = None
        loss.backward()

        W.data += -50*W.grad


        ## eval on dev set
        with torch.no_grad():
            xenc = F.one_hot(xs_dev, num_classes=27).reshape(-1, 27*2).float()
            logits = xenc @ W
            counts = logits.exp()
            probs = counts / counts.sum(axis=1, keepdims=True)
            loss = -probs[torch.arange(xs_dev.shape[0]), ys_dev].log().mean()
    print(f'**final dev set loss = {loss.item()}')

Testing with alpha=0.0001
**final dev set loss = 2.2440547943115234
Testing with alpha=0.001
**final dev set loss = 2.238290786743164
Testing with alpha=0.01
**final dev set loss = 2.2375736236572266
Testing with alpha=0.1
**final dev set loss = 2.2579474449157715


alpha of 0.01 seems to be the best choice

In [21]:
xenc = F.one_hot(xs_test, num_classes=27).reshape(-1, 27*2).float()
logits = xenc @ W
counts = logits.exp()
probs = counts / counts.sum(axis=1, keepdims=True)
loss = -probs[torch.arange(xs_test.shape[0]), ys_test].log().mean()
print(f'**test set loss = {loss.item()}')

**test set loss = 2.263033628463745


# E04:
we saw that our 1-hot vectors merely select a row of W, so producing these vectors explicitly feels wasteful. Can you delete our use of F.one_hot in favor of simply indexing into rows of W?

In [22]:
W = torch.randn(27*2, 27, requires_grad=True)

xs_train_offset = xs_train.clone()
xs_dev_offset = xs_dev.clone()

xs_train_offset[:, 1] = xs_train[:, 1] + 27
xs_dev_offset[:, 1] = xs_dev[:, 1] + 27

for k in range(200):
    logits = W[xs_train_offset].sum(axis=1)
    counts = logits.exp()
    probs = counts / counts.sum(axis=1, keepdims=True)
    loss = -probs[torch.arange(xs_train.shape[0]), ys_train].log().mean() + 0.01 * (W**2).mean()
    print(f'*train set loss = {loss.item()}')

    W.grad = None
    loss.backward()

    W.data += -50*W.grad


    ## eval on dev set
    with torch.no_grad():
        logits = W[xs_dev_offset].sum(axis=1)
        counts = logits.exp()
        probs = counts / counts.sum(axis=1, keepdims=True)
        loss = -probs[torch.arange(xs_dev.shape[0]), ys_dev].log().mean()
        print(f'**dev set loss = {loss.item()}')

*train set loss = 4.474884510040283
**dev set loss = 3.547820806503296
*train set loss = 3.5490829944610596
**dev set loss = 3.124401092529297
*train set loss = 3.131977081298828
**dev set loss = 2.92669415473938
*train set loss = 2.938084840774536
**dev set loss = 2.798306703567505
*train set loss = 2.8128182888031006
**dev set loss = 2.70963191986084
*train set loss = 2.7264153957366943
**dev set loss = 2.644716739654541
*train set loss = 2.662768840789795
**dev set loss = 2.595512866973877
*train set loss = 2.614443302154541
**dev set loss = 2.5573415756225586
*train set loss = 2.576620101928711
**dev set loss = 2.526750326156616
*train set loss = 2.546212673187256
**dev set loss = 2.501702308654785
*train set loss = 2.521108388900757
**dev set loss = 2.4806315898895264
*train set loss = 2.4999682903289795
**dev set loss = 2.462686538696289
*train set loss = 2.4818451404571533
**dev set loss = 2.4470856189727783
*train set loss = 2.4661011695861816
**dev set loss = 2.433446645736694

*train set loss = 2.2717292308807373
**dev set loss = 2.257068395614624
*train set loss = 2.271503210067749
**dev set loss = 2.2568418979644775
*train set loss = 2.2712812423706055
**dev set loss = 2.256619453430176
*train set loss = 2.2710630893707275
**dev set loss = 2.2564010620117188
*train set loss = 2.2708489894866943
**dev set loss = 2.2561862468719482
*train set loss = 2.270638942718506
**dev set loss = 2.2559752464294434
*train set loss = 2.270432233810425
**dev set loss = 2.255767345428467
*train set loss = 2.2702291011810303
**dev set loss = 2.255563259124756
*train set loss = 2.2700297832489014
**dev set loss = 2.2553629875183105
*train set loss = 2.269833564758301
**dev set loss = 2.2551653385162354
*train set loss = 2.269641160964966
**dev set loss = 2.2549710273742676
*train set loss = 2.269451856613159
**dev set loss = 2.2547802925109863
*train set loss = 2.2692654132843018
**dev set loss = 2.254592180252075
*train set loss = 2.269082546234131
**dev set loss = 2.2544074

# E05:
look up and use F.cross_entropy instead. You should achieve the same result. Can you think of why we'd prefer to use F.cross_entropy instead?

In [23]:
W = torch.randn(27*2, 27, requires_grad=True)

for k in range(200):
    xenc = F.one_hot(xs_train, num_classes=27).reshape(-1, 27*2).float()
    logits = xenc @ W
#     counts = logits.exp()
#     probs = counts / counts.sum(axis=1, keepdims=True)
    loss = F.cross_entropy(logits, ys_train)
    print(f'*train set loss = {loss.item()}')
    
    W.grad = None
    loss.backward()
    
    W.data += -50*W.grad
    
    
    ## eval on dev set
    xenc = F.one_hot(xs_dev, num_classes=27).reshape(-1, 27*2).float()
    logits = xenc @ W
#     counts = logits.exp()
#     probs = counts / counts.sum(axis=1, keepdims=True)
    loss = F.cross_entropy(logits, ys_dev)
    print(f'**dev set loss = {loss.item()}')

*train set loss = 4.2981672286987305
**dev set loss = 3.5289106369018555
*train set loss = 3.534932851791382
**dev set loss = 3.1508803367614746
*train set loss = 3.157095432281494
**dev set loss = 2.9201271533966064
*train set loss = 2.9275128841400146
**dev set loss = 2.7816343307495117
*train set loss = 2.789914131164551
**dev set loss = 2.690558433532715
*train set loss = 2.699413299560547
**dev set loss = 2.626573085784912
*train set loss = 2.6359446048736572
**dev set loss = 2.578982353210449
*train set loss = 2.5884926319122314
**dev set loss = 2.541457414627075
*train set loss = 2.551177978515625
**dev set loss = 2.511112689971924
*train set loss = 2.520751476287842
**dev set loss = 2.4857144355773926
*train set loss = 2.4954302310943604
**dev set loss = 2.4644439220428467
*train set loss = 2.4740352630615234
**dev set loss = 2.4461588859558105
*train set loss = 2.455779790878296
**dev set loss = 2.430565595626831
*train set loss = 2.440061092376709
**dev set loss = 2.416930437

*train set loss = 2.25945782661438
**dev set loss = 2.25349760055542
*train set loss = 2.2592458724975586
**dev set loss = 2.253302812576294
*train set loss = 2.259037494659424
**dev set loss = 2.2531116008758545
*train set loss = 2.2588324546813965
**dev set loss = 2.2529242038726807
*train set loss = 2.2586312294006348
**dev set loss = 2.252739667892456
*train set loss = 2.2584331035614014
**dev set loss = 2.252558708190918
*train set loss = 2.2582385540008545
**dev set loss = 2.252380609512329
*train set loss = 2.258047103881836
**dev set loss = 2.2522056102752686
*train set loss = 2.2578585147857666
**dev set loss = 2.2520334720611572
*train set loss = 2.2576732635498047
**dev set loss = 2.2518646717071533
*train set loss = 2.257491111755371
**dev set loss = 2.2516982555389404
*train set loss = 2.2573115825653076
**dev set loss = 2.2515347003936768
*train set loss = 2.2571351528167725
**dev set loss = 2.2513740062713623
*train set loss = 2.2569615840911865
**dev set loss = 2.251215

# E06: 
meta-exercise! Think of a fun/interesting exercise and complete it.


In [24]:
# combining input words through addition instead of concat


import torch.nn.functional as F

W = torch.randn(27, 27, requires_grad=True)

for k in range(200):
    xenc = F.one_hot(xs, num_classes=27).sum(dim=1).float()
    logits = xenc @ W
    counts = logits.exp()
    probs = counts / counts.sum(axis=1, keepdims=True)
    loss = -probs[torch.arange(xs.shape[0]), ys].log().mean()
    print(loss.item())
    
    W.grad = None
    loss.backward()
    
    W.data += -50*W.grad

4.392314910888672
3.2774765491485596
3.0092504024505615
2.904827833175659
3.014988422393799
2.71494197845459
2.715484619140625
2.7319507598876953
2.9249188899993896
2.607860803604126
2.591737985610962
2.610976219177246
2.754019021987915
2.6020848751068115
2.717416763305664
2.6011815071105957
2.751828908920288
2.5629937648773193
2.6494317054748535
2.6133389472961426
2.802103281021118
2.5276906490325928
2.5407650470733643
2.590041399002075
2.7804267406463623
2.521580457687378
2.5434582233428955
2.5929455757141113
2.793792963027954
2.5111072063446045
2.513646125793457
2.5554215908050537
2.7291479110717773
2.516855001449585
2.5774052143096924
2.5917978286743164
2.7979848384857178
2.498685359954834
2.4901933670043945
2.5184285640716553
2.657482385635376
2.5382745265960693
2.6728897094726562
2.523441791534424
2.6494579315185547
2.5343017578125
2.6761486530303955
2.5170063972473145
2.631889820098877
2.5414092540740967
2.7006607055664062
2.5007052421569824
2.5804762840270996
2.569719076156616


In [25]:
# sampling from the model

# first, sample the second letter as a bigram model (because you only have the special dot token as input)



for i in range(10):
    out = []
    
    ix1 = None
    ix2 = 0
    
    xenc = torch.concat([F.one_hot(torch.tensor([ix2]), num_classes=27)], axis=1).float()
    logits = xenc @ W
    counts = logits.exp()
    p = counts / counts.sum(axis=1, keepdims=True)
    
    ix1 = ix2
    ix2 = torch.multinomial(p, 1, replacement=True).item()

    out.append(itos[ix2])

# all letters from now on will depend on the previous two characters (trigram)
    
    while True:
        xenc = (F.one_hot(torch.tensor([ix1]), num_classes=27) + F.one_hot(torch.tensor([ix2]), num_classes=27)).float()
        logits = xenc @ W
        counts = logits.exp()
        p = counts / counts.sum(axis=1, keepdims=True)
        
        ix1 = ix2
        ix2 = torch.multinomial(p, 1, replacement=True).item()
        out.append(itos[ix2])
        if ix2 == 0:
            break

    print(''.join(out))



aeriallyeamuahiasnerenannadnynahandaemseriagnsyshekniaannydedylenrentnnedlaeraliady.
alwertihtnaabseisaadnyo.
ahajnedliyh.
elenlelnaiatynannakali.
ejaydevrisyenuaniandajarjicheyleivniijozdaahleadianr.
oh.
alnianlielieliy.
roy.
reenezyndionaomnazlizariantaaadyaevielnial.
criadartteoriighrweynarchaneemaelsahnax.
