Exercises

- Train a trigram language model, i.e. take two characters as an input to predict the 3rd one. Feel free to use either counting or a neural net. Evaluate the loss; Did it improve over a bigram model?
- Split up the dataset randomly into 80% train set, 10% dev set, 10% test set. Train the bigram and trigram models only on the training set. Evaluate them on dev and test splits. What can you see?
- Use the dev set to tune the strength of smoothing (or regularization) for the trigram model - i.e. try many possibilities and see which one works best based on the dev set loss. What patterns can you see in the train and dev set loss as you tune this strength? Take the best setting of the smoothing and evaluate on the test set once and at the end. How good of a loss do you achieve?



In [97]:
import torch
import torch.nn.functional as F
from torch.utils.data import random_split
import matplotlib.pyplot as plt

trigram language model

In [98]:
words = open('names.txt', 'r').read().splitlines()
generator = torch.Generator().manual_seed(42)
splits = random_split(words, [0.8, 0.1, 0.1], generator=generator)
train_data, val_data, test_data = splits

In [99]:
# Create validation data
val_xs, val_ys = [], []
for w in val_data:
    chs = ['.','.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        combined_ix = ix1 * 27 + ix2
        val_xs.append(combined_ix)
        val_ys.append(stoi[ch3])

val_xs = torch.tensor(val_xs)
val_ys = torch.tensor(val_ys)

In [100]:
# Create test data (same way we created validation data)
test_xs, test_ys = [], []
for w in test_data:   
    chs = ['.','.'] + list(w) + ['.']
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        combined_ix = ix1 * 27 + ix2
        test_xs.append(combined_ix)
        test_ys.append(stoi[ch3])

test_xs = torch.tensor(test_xs)
test_ys = torch.tensor(test_ys)


In [101]:
chars = sorted(list(set(''.join(train_data))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0
print(stoi)
itos = {i:s for s,i in stoi.items()}
print(itos)


{'a': 1, 'b': 2, 'c': 3, 'd': 4, 'e': 5, 'f': 6, 'g': 7, 'h': 8, 'i': 9, 'j': 10, 'k': 11, 'l': 12, 'm': 13, 'n': 14, 'o': 15, 'p': 16, 'q': 17, 'r': 18, 's': 19, 't': 20, 'u': 21, 'v': 22, 'w': 23, 'x': 24, 'y': 25, 'z': 26, '.': 0}
{1: 'a', 2: 'b', 3: 'c', 4: 'd', 5: 'e', 6: 'f', 7: 'g', 8: 'h', 9: 'i', 10: 'j', 11: 'k', 12: 'l', 13: 'm', 14: 'n', 15: 'o', 16: 'p', 17: 'q', 18: 'r', 19: 's', 20: 't', 21: 'u', 22: 'v', 23: 'w', 24: 'x', 25: 'y', 26: 'z', 0: '.'}


In [102]:
# create the dataset
xs, ys = [], []
for w in train_data:
    chs = ['.', '.'] + list(w) + ['.'] # add the double padding at start 
    for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
        ix1 = stoi[ch1]
        ix2 = stoi[ch2]
        ix3 = stoi[ch3]

        combined_ix = ix1 * 27 + ix2
        
        # encode 2 letters into one unique number with no collisions
        #        Second Letter →
        #       .  a  b  c ...
        #F  .   0  1  2  3 ...
        #i  a  27 28 29 30 ...
        #r  b  54 55 56 57 ...
        #s  c  81 82 83 84 ...
        #t  .   .  .  .  . ...
        #↓
        
        xs.append(combined_ix)
        ys.append(ix3)
        
xs = torch.tensor(xs)
ys = torch.tensor(ys)
num = xs.nelement()
print('number of examples: ', num)
print('xs shape: ', xs.shape) 
print('ys shape: ', ys.shape)

number of examples:  182827
xs shape:  torch.Size([182827])
ys shape:  torch.Size([182827])


In [None]:
def train_model(reg_strength, num_epochs=800):
    g = torch.Generator().manual_seed(1478)
    W = torch.randn((27*27, 27), generator=g, requires_grad=True)
    
    # Lists to store losses
    train_losses = []
    val_losses = []
    
    # gradient descent
    for k in range(num_epochs):
        # forward pass
        xenc = F.one_hot(xs, num_classes=27*27).float() # input to the network: one-hot encoding
        logits = xenc @ W # predict log-counts
        counts = logits.exp() # counts, equivalent to N
        probs = counts / counts.sum(1, keepdims=True) # probabilities for next character
        loss = -probs[torch.arange(num), ys].log().mean() + 0.01*(W**2).mean()
        print(loss.item())
    
        # Store training loss
        train_losses.append(loss.item())
        
        # Every 10 epochs, calculate validation loss
        if k % 10 == 0:
            # Calculate validation loss
            with torch.no_grad():
                val_xenc = F.one_hot(val_xs, num_classes=27*27).float()
                val_logits = val_xenc @ W
                val_counts = val_logits.exp()
                val_probs = val_counts / val_counts.sum(1, keepdims=True)
                val_loss = -val_probs[torch.arange(len(val_ys)), val_ys].log().mean() + 0.01*(W**2).mean()
                print(f'step {k}: train loss {loss.item():.4f}, val loss {val_loss.item():.4f}')
                val_losses.append(val_loss.item())
                
        # backward pass
        W.grad = None # set to zero the gradient
        loss.backward()
        
        # update
        W.data += -50 * W.grad
        
    return train_losses[-1], val_losses[-1], W

reg_strengths = [0.0, 0.001, 0.1, 0.5, 5.0]
results = []

for reg in reg_strengths:
    train_loss, val_loss, W = train_model(reg)
    results.append({
        'reg_strength': reg,
        'train_loss': train_loss,
        'val_loss': val_loss
    })
    print(f'Reg strength: {reg:.3f}, Train loss: {train_loss:.4f}, Val loss: {val_loss:.4f}')

# Find best regularization strength based on validation loss
best_reg = min(results, key=lambda x: x['val_loss'])['reg_strength']
print(f'\nBest regularization strength: {best_reg}')

# Train final model with best regularization 
final_train_loss, final_val_loss, best_W = train_model(best_reg)

# Evaluate on test set
with torch.no_grad():
    test_xenc = F.one_hot(test_xs, num_classes=27*27).float()
    test_logits = test_xenc @ best_W
    test_counts = test_logits.exp()
    test_probs = test_counts / test_counts.sum(1, keepdims=True)
    test_loss = -test_probs[torch.arange(len(test_ys)), test_ys].log().mean() + best_reg*(best_W**2).mean()
    
    # Calculate accuracy
    predictions = test_probs.argmax(1)
    accuracy = (predictions == test_ys).float().mean()
    
print(f'\nFinal Results with best reg strength {best_reg}:')
print(f'Test loss: {test_loss:.4f}')
print(f'Test accuracy: {accuracy:.4f}')

3.748319149017334
step 0: train loss 3.7483, val loss 3.7543
3.6344263553619385
3.550145387649536
3.484727382659912
3.429079294204712
3.3790476322174072
3.3334155082702637
3.2915987968444824
3.253203868865967
3.2179064750671387
3.185405731201172
step 10: train loss 3.1854, val loss 3.1927
3.155413866043091
3.1276533603668213
3.1018667221069336
3.0778205394744873
3.0553109645843506
3.034165382385254
3.014239549636841
2.9954140186309814
2.9775891304016113
2.960679054260254
step 20: train loss 2.9607, val loss 2.9687
2.944612503051758
2.929326057434082
2.9147629737854004
2.9008727073669434
2.8876099586486816
2.8749325275421143
2.862802743911743
2.851184844970703
2.8400468826293945
2.8293585777282715
step 30: train loss 2.8294, val loss 2.8381
2.819092273712158
2.8092219829559326
2.7997238636016846
2.7905750274658203
2.781755208969116
2.7732458114624023
2.7650275230407715
2.757084846496582
2.7494020462036133
2.741964340209961
step 40: train loss 2.7420, val loss 2.7521
2.7347593307495117
2

In [96]:
# finally, sample from the 'neural net' model
num_chars=100
g = torch.Generator().manual_seed(2147483647)
out = ['.', '.']  # start with two dots
for _ in range(num_chars):
    # Get the last two characters
    ch1, ch2 = out[-2], out[-1]
    ix1, ix2 = stoi[ch1], stoi[ch2]
    
    # Create combined index
    combined_ix = ix1 * 27 + ix2
    
    # Convert to one-hot encoding
    xenc = F.one_hot(torch.tensor([combined_ix]), num_classes=27*27).float()
    
    # Get probabilities
    logits = xenc @ best_W
    counts = logits.exp()
    probs = counts / counts.sum(1, keepdims=True)
    
    # Sample from the distribution
    ix = torch.multinomial(probs[0], num_samples=1, replacement=True, generator=g).item()
    out.append(itos[ix])
print(''.join(out[2:]))
 

junide.tanasid.prelay.adin.fai.ritonian.free.udiania.zabileniassdbduinewimbressiyanayla.perinviumtsy
