In [None]:
# Trigram Language Model Implementation

import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
from torch import nn
import random
import math

# Read in the names dataset
with open('names.txt', 'r') as f:
    words = f.read().splitlines()

# Let's first explore the dataset
print(f"Number of names: {len(words)}")
print(f"Some example names: {words[:10]}")

# Split the dataset into train, dev, and test sets (E02)
random.seed(42)
random.shuffle(words)
n = len(words)
train_words = words[:int(n*0.8)]
dev_words = words[int(n*0.8):int(n*0.9)]
test_words = words[int(n*0.9):]
print(f"Train set size: {len(train_words)}")
print(f"Dev set size: {len(dev_words)}")
print(f"Test set size: {len(test_words)}")

# Let's build our vocabulary
chars = sorted(list(set(''.join(words))))
stoi = {s:i+1 for i,s in enumerate(chars)}
stoi['.'] = 0 # add start/end token
itos = {i:s for s,i in stoi.items()}
vocab_size = len(stoi)
print(f"Vocabulary size: {vocab_size}")
print(f"Vocabulary: {itos}")

# -------------------------
# E01: Trigram model (counting-based approach)
# -------------------------

# Build a trigram model that takes two characters to predict the third
def build_trigram_model(words):
    # Count the occurrences of character trigrams
    trigram_counts = {}
    for w in words:
        chs = ['.'] + list(w) + ['.']
        for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
            if (ch1, ch2) not in trigram_counts:
                trigram_counts[(ch1, ch2)] = {}
            if ch3 not in trigram_counts[(ch1, ch2)]:
                trigram_counts[(ch1, ch2)][ch3] = 0
            trigram_counts[(ch1, ch2)][ch3] += 1
    return trigram_counts

def calculate_trigram_loss(trigram_counts, words, smoothing=1.0):
    # Calculate negative log likelihood loss
    total_log_likelihood = 0
    total_chars = 0
    
    for w in words:
        chs = ['.'] + list(w) + ['.']
        for ch1, ch2, ch3 in zip(chs, chs[1:], chs[2:]):
            # If we've seen this bigram context before
            if (ch1, ch2) in trigram_counts:
                # Get count of this specific trigram
                count = trigram_counts[(ch1, ch2)].get(ch3, 0)
                # Get total count of all trigrams with this context
                total = sum(trigram_counts[(ch1, ch2)].values())
                # Calculate probability with smoothing
                p = (count + smoothing) / (total + smoothing * vocab_size)
            else:
                # If we've never seen this context, use uniform distribution with smoothing
                p = smoothing / (smoothing * vocab_size)
            
            total_log_likelihood += -math.log(p)
            total_chars += 1
    
    return total_log_likelihood / total_chars

# Build the trigram model on train set
trigram_model = build_trigram_model(train_words)

# Also build a bigram model for comparison
def build_bigram_model(words):
    # Count the occurrences of character bigrams
    bigram_counts = {}
    for w in words:
        chs = ['.'] + list(w) + ['.']
        for ch1, ch2 in zip(chs, chs[1:]):
            if ch1 not in bigram_counts:
                bigram_counts[ch1] = {}
            if ch2 not in bigram_counts[ch1]:
                bigram_counts[ch1][ch2] = 0
            bigram_counts[ch1][ch2] += 1
    return bigram_counts

def calculate_bigram_loss(bigram_counts, words, smoothing=1.0):
    # Calculate negative log likelihood loss
    total_log_likelihood = 0
    total_chars = 0
    
    for w in words:
        chs = ['.'] + list(w) + ['.']
        for ch1, ch2 in zip(chs, chs[1:]):
            # If we've seen this character before
            if ch1 in bigram_counts:
                # Get count of this specific bigram
                count = bigram_counts[ch1].get(ch2, 0)
                # Get total count of all bigrams with this context
                total = sum(bigram_counts[ch1].values())
                # Calculate probability with smoothing
                p = (count + smoothing) / (total + smoothing * vocab_size)
            else:
                # If we've never seen this character, use uniform distribution with smoothing
                p = smoothing / (smoothing * vocab_size)
            
            total_log_likelihood += -math.log(p)
            total_chars += 1
    
    return total_log_likelihood / total_chars

# Build the bigram model on train set
bigram_model = build_bigram_model(train_words)

# Calculate losses on train, dev, and test sets
print("Counting-based models evaluation:")
print("--------------------------------")
print(f"Bigram train loss: {calculate_bigram_loss(bigram_model, train_words):.4f}")
print(f"Trigram train loss: {calculate_trigram_loss(trigram_model, train_words):.4f}")
print(f"Bigram dev loss: {calculate_bigram_loss(bigram_model, dev_words):.4f}")
print(f"Trigram dev loss: {calculate_trigram_loss(trigram_model, dev_words):.4f}")
print(f"Bigram test loss: {calculate_bigram_loss(bigram_model, test_words):.4f}")
print(f"Trigram test loss: {calculate_trigram_loss(trigram_model, test_words):.4f}")

# -------------------------
# E03: Tune smoothing strength for trigram model
# -------------------------

smoothing_values = [0.1, 0.5, 1.0, 2.0, 5.0, 10.0]
train_losses = []
dev_losses = []

print("\nTuning smoothing parameter for trigram model:")
print("--------------------------------------------")
for smoothing in smoothing_values:
    train_loss = calculate_trigram_loss(trigram_model, train_words, smoothing)
    dev_loss = calculate_trigram_loss(trigram_model, dev_words, smoothing)
    train_losses.append(train_loss)
    dev_losses.append(dev_loss)
    print(f"Smoothing: {smoothing}, Train Loss: {train_loss:.4f}, Dev Loss: {dev_loss:.4f}")

# Find the best smoothing value based on dev loss
best_smoothing_idx = np.argmin(dev_losses)
best_smoothing = smoothing_values[best_smoothing_idx]
print(f"\nBest smoothing value: {best_smoothing}")

# Evaluate on test set with best smoothing
test_loss = calculate_trigram_loss(trigram_model, test_words, best_smoothing)
print(f"Test loss with best smoothing: {test_loss:.4f}")

# Plot smoothing results
plt.figure(figsize=(10, 6))
plt.plot(smoothing_values, train_losses, 'b-o', label='Train Loss')
plt.plot(smoothing_values, dev_losses, 'r-o', label='Dev Loss')
plt.axvline(x=best_smoothing, color='g', linestyle='--', label=f'Best Smoothing: {best_smoothing}')
plt.xscale('log')
plt.xlabel('Smoothing Value')
plt.ylabel('Negative Log Likelihood Loss')
plt.title('Impact of Smoothing on Trigram Model Performance')
plt.legend()
plt.grid(True)
plt.show()

# -------------------------
# Neural Network-based Trigram Model Implementation
# -------------------------

# Function to build dataset for training
def build_dataset(words, context_size=2):
    X, Y = [], []
    for w in words:
        # Add start and end tokens
        context = [0] * context_size  # Start with '.' tokens
        for ch in w + '.':
            ix = stoi[ch]
            X.append(context.copy())  # Use the context as input
            Y.append(ix)              # Predict the next character
            # Update context by shifting and adding the new character
            context = context[1:] + [ix]
    return torch.tensor(X), torch.tensor(Y)

# Create datasets
X_train, Y_train = build_dataset(train_words, context_size=2)  # For trigram model
X_dev, Y_dev = build_dataset(dev_words, context_size=2)
X_test, Y_test = build_dataset(test_words, context_size=2)

# We'll now implement our neural network-based trigram model

# E04: Remove F.one_hot by directly indexing into rows of W
class TrigramModelDirect(nn.Module):
    def __init__(self, vocab_size, embedding_dim=10):
        super().__init__()
        self.vocab_size = vocab_size
        # Embedding tables for each position in the context
        self.embeddings0 = nn.Embedding(vocab_size, embedding_dim)
        self.embeddings1 = nn.Embedding(vocab_size, embedding_dim)
        # Linear layer to produce logits for next character prediction
        self.linear = nn.Linear(embedding_dim * 2, vocab_size)
        
    def forward(self, x):
        # x has shape (batch_size, 2) where each entry is a character index
        # Get embeddings for each position in the context
        emb0 = self.embeddings0(x[:, 0])  # (batch_size, embedding_dim)
        emb1 = self.embeddings1(x[:, 1])  # (batch_size, embedding_dim)
        
        # Concatenate the embeddings
        emb = torch.cat([emb0, emb1], dim=1)  # (batch_size, embedding_dim * 2)
        
        # Forward through the linear layer to get logits
        logits = self.linear(emb)  # (batch_size, vocab_size)
        return logits

# E05: Using F.cross_entropy
def train_model(model, X, Y, lr=0.1, epochs=100):
    optimizer = torch.optim.SGD(model.parameters(), lr=lr)
    losses = []
    
    for epoch in range(epochs):
        # Forward pass
        logits = model(X)
        # Calculate loss using cross_entropy
        loss = F.cross_entropy(logits, Y)
        losses.append(loss.item())
        
        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        if epoch % 10 == 0:
            print(f"Epoch {epoch}, Loss: {loss.item():.4f}")
    
    return losses

# Function to evaluate model
def evaluate(model, X, Y):
    with torch.no_grad():
        logits = model(X)
        loss = F.cross_entropy(logits, Y)
    return loss.item()

# Train and evaluate trigram neural model
print("\nNeural Network-based Trigram Model:")
print("--------------------------------")
trigram_nn = TrigramModelDirect(vocab_size)
train_losses = train_model(trigram_nn, X_train, Y_train)

# Evaluate on train, dev, and test sets
train_loss = evaluate(trigram_nn, X_train, Y_train)
dev_loss = evaluate(trigram_nn, X_dev, Y_dev)
test_loss = evaluate(trigram_nn, X_test, Y_test)

print(f"Final Train Loss: {train_loss:.4f}")
print(f"Dev Loss: {dev_loss:.4f}")
print(f"Test Loss: {test_loss:.4f}")

# Now let's also train a bigram neural network model for comparison
X_train_bigram, Y_train_bigram = build_dataset(train_words, context_size=1)
X_dev_bigram, Y_dev_bigram = build_dataset(dev_words, context_size=1)
X_test_bigram, Y_test_bigram = build_dataset(test_words, context_size=1)

class BigramModelDirect(nn.Module):
    def __init__(self, vocab_size, embedding_dim=10):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, embedding_dim)
        self.linear = nn.Linear(embedding_dim, vocab_size)
        
    def forward(self, x):
        # x has shape (batch_size, 1)
        emb = self.embedding(x[:, 0])  # (batch_size, embedding_dim)
        logits = self.linear(emb)  # (batch_size, vocab_size)
        return logits

print("\nNeural Network-based Bigram Model:")
print("--------------------------------")
bigram_nn = BigramModelDirect(vocab_size)
train_losses_bigram = train_model(bigram_nn, X_train_bigram, Y_train_bigram)

# Evaluate bigram model
bigram_train_loss = evaluate(bigram_nn, X_train_bigram, Y_train_bigram)
bigram_dev_loss = evaluate(bigram_nn, X_dev_bigram, Y_dev_bigram)
bigram_test_loss = evaluate(bigram_nn, X_test_bigram, Y_test_bigram)

print(f"Bigram Final Train Loss: {bigram_train_loss:.4f}")
print(f"Bigram Dev Loss: {bigram_dev_loss:.4f}")
print(f"Bigram Test Loss: {bigram_test_loss:.4f}")

# Compare bigram and trigram neural models
print("\nComparison of Neural Network Bigram vs Trigram models:")
print("---------------------------------------------------")
print(f"Bigram Train Loss: {bigram_train_loss:.4f}, Trigram Train Loss: {train_loss:.4f}")
print(f"Bigram Dev Loss: {bigram_dev_loss:.4f}, Trigram Dev Loss: {dev_loss:.4f}")
print(f"Bigram Test Loss: {bigram_test_loss:.4f}, Trigram Test Loss: {test_loss:.4f}")

# -------------------------
# E06: Generate names using our trigram model
# -------------------------

def generate_name(model, max_length=20):
    # Start with two '.' tokens
    context = torch.tensor([[0, 0]])  # Shape: (1, 2)
    name = ''
    
    while True:
        # Get model predictions
        with torch.no_grad():
            logits = model(context)
            probs = F.softmax(logits, dim=1)
            
        # Sample from the distribution
        ix = torch.multinomial(probs, num_samples=1).item()
        
        # If we sample the end token, we're done
        if ix == 0:
            break
            
        # Add the character to our name
        name += itos[ix]
        
        # Update context for next prediction
        context = context[:, 1:]  # Remove the first character
        context = torch.cat([context, torch.tensor([[ix]])], dim=1)  # Add the new character
        
        # Safety check for very long names
        if len(name) >= max_length:
            break
            
    return name

# Generate 10 names with our trigram model
print("\nGenerating names with our trigram model:")
print("--------------------------------------")
for i in range(10):
    print(generate_name(trigram_nn))

# Let's also try sampling with different "temperatures"
def generate_name_with_temperature(model, temperature=1.0, max_length=20):
    context = torch.tensor([[0, 0]])
    name = ''
    
    while True:
        with torch.no_grad():
            logits = model(context)
            # Apply temperature scaling to logits
            logits_temp = logits / temperature
            probs = F.softmax(logits_temp, dim=1)
            
        ix = torch.multinomial(probs, num_samples=1).item()
        
        if ix == 0:
            break
            
        name += itos[ix]
        context = context[:, 1:]
        context = torch.cat([context, torch.tensor([[ix]])], dim=1)
        
        if len(name) >= max_length:
            break
            
    return name

print("\nGenerating names with different temperatures:")
print("-------------------------------------------")
print("Low temperature (more conservative):")
for i in range(5):
    print(generate_name_with_temperature(trigram_nn, temperature=0.5))

print("\nHigh temperature (more creative):")
for i in range(5):
    print(generate_name_with_temperature(trigram_nn, temperature=2.0))

# Summary of what we've learned and conclusions
print("\nSummary and Conclusions:")
print("----------------------")
print("1. Trigram models generally perform better than bigram models because they capture more context.")
print("2. Proper smoothing is crucial for generalization to unseen data.")
print("3. Neural network-based models can learn more complex patterns than simple counting-based models.")
print("4. Using embeddings and direct indexing is more efficient than one-hot encoding.")
print("5. F.cross_entropy is a more numerically stable and efficient way to compute loss.")
print("6. Temperature in sampling allows control over the creativity/randomness of generated names.")