# Assignment 4

## Importing Libraries

In [None]:
import os
import matplotlib.pyplot as plt
from dataclasses import dataclass
from sklearn.model_selection import train_test_split
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import Dataset, DataLoader
from src.utils import load_text, set_seed, configure_device
%matplotlib inline

## Configuration

In [None]:
@dataclass
class MLPConfig:
    root_dir: str = os.getcwd() + "/../../"
    dataset_path: str = "data/names.txt"
    device: torch.device = torch.device('cpu')  # Automatic device configuration

    # Tokenizer
    vocab_size: int = 0  # Set later
    
    # Model
    context_size: int = 3
    d_embed: int = 16
    d_hidden: int = 64
    
    # Training
    val_size: float = 0.1
    val_interval: int = 1000
    batch_size: int = 32
    lr: float = 2e-3
    max_steps: int = 10000

    seed: int = 101

## Reproducibility

In [None]:
set_seed(MLPConfig.seed)

## Device

In [None]:
MLPConfig.device = configure_device()

## Tokenizer

In [None]:
chars = [chr(i) for i in range(97, 123)]  # all alphabet characters
chars.insert(0, ".")  # Add special token
MLPConfig.vocab_size = len(chars)
str2idx = {char: idx for idx, char in enumerate(chars)}
idx2str = {idx: char for char, idx in str2idx.items()}

## Dataset

In [None]:
names = load_text(MLPConfig.root_dir + MLPConfig.dataset_path).splitlines()

## Preprocessing

In [None]:
# Train-Val Split
train_names, val_names = train_test_split(names, test_size=MLPConfig.val_size, random_state=MLPConfig.seed)

In [None]:
def prepare_dataset(_names):
    _inputs, _targets = [], []

    for name in _names:
        _context = [0] * MLPConfig.context_size

        for char in name + ".":
            idx = str2idx[char]
            _inputs.append(_context)
            _targets.append(idx)
            _context = _context[1:] + [idx]  # Shift the context by 1 character

    _inputs = torch.tensor(_inputs)
    _targets = torch.tensor(_targets)

    return _inputs, _targets

### Task 1: PyTorch DataLoader

We have been using plain Python lists to and then converted them to PyTorch tensors. This is not efficient since it is loading the entire dataset into memory.

PyTorch provides `Dataset` and `DataLoader` class to load the data in memory on the fly. [PyTorch Documentation](https://pytorch.org/tutorials/beginner/basics/data_tutorial.html)

Refactor the `prepare_dataset` function into a PyTorch `Dataset` class and use the `DataLoader` to efficiently load the data in batches.

In [None]:
# Dataset
class NamesDataset(Dataset):
    def __init__(self, _names, context_size):
        self.names = _names
        self.context_size = context_size
        
    def __len__(self):
        return len(self.names)
    
    def __getitem__(self, idx):
        name = self.names[idx]
        return name

In [None]:
# Initialize the dataset
train_dataset = NamesDataset(train_names, MLPConfig.context_size)
val_dataset = NamesDataset(val_names, MLPConfig.context_size)

In [None]:
# DataLoader
train_loader = DataLoader(train_dataset, batch_size=MLPConfig.batch_size, shuffle=True)
val_loader = DataLoader(val_dataset, batch_size=MLPConfig.batch_size, shuffle=False)

## Model

### Task 2: MLP Model

Initialize the weights of the model using the `Kaiming` initialization.

In [None]:
class MLP(nn.Module):
    ################################################################################
    # TODO:                                                                        #
    # Define the __init__ and forward methods for the MLP model.                   #
    ################################################################################
    # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
    def __init__(self, vocab_size, context_size, d_embed, d_hidden):
        super().__init__()
        self.embedding = nn.Embedding(vocab_size, d_embed)
        self.linear1 = nn.Linear(context_size * d_embed, d_hidden, bias=True)
        self.linear2 = nn.Linear(d_hidden, vocab_size, bias=True)
        
    def forward(self, x):  # x: (batch_size, context_size)
        x_embed = self.embedding(x)  # (batch_size, context_size, d_embed)
        x_embed = x_embed.view(x_embed.size(0), -1)  # (batch_size, context_size * d_embed)
        x = F.relu(self.linear1(x_embed))  # (batch_size, d_hidden)
        x = self.linear2(x)  # (batch_size, vocab_size)
        return x
    # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****

In [None]:
# Initialize the model
mlp = MLP(MLPConfig.vocab_size, MLPConfig.context_size, MLPConfig.d_embed, MLPConfig.d_hidden)
mlp.to(MLPConfig.device) # Move the model to the device
print(mlp)
print("Number of parameters:", sum(p.numel() for p in mlp.parameters()))

## Training

In [None]:
def train(model, max_steps=MLPConfig.max_steps):
    steps = []
    train_losses = []
    val_losses = []
    
    # Define the optimizer
    optimizer = torch.optim.SGD(model.parameters(), lr=MLPConfig.lr)
    
    for step in range(1, max_steps + 1):
        # Training
        # Sample batch
        idx = torch.randperm(len(train_inputs))[:MLPConfig.batch_size]
        x, y = train_inputs[idx], train_targets[idx]
        x, y = x.to(MLPConfig.device), y.to(MLPConfig.device)  # Move the data to the device
        
        ################################################################################
        # TODO:                                                                        #
        # Implement the forward pass and the backward pass                             #
        ################################################################################
        # *****START OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        # Forward pass
        optimizer.zero_grad()
        logits = model(x)
        # print(logits[0])
        loss = F.cross_entropy(logits, y)
        
        # Backward pass
        loss.backward()
        optimizer.step()
        # *****END OF YOUR CODE (DO NOT DELETE/MODIFY THIS LINE)*****
        
        # Validation
        if step % MLPConfig.val_interval == 0:
            # Validation loss
            with torch.no_grad():
                val_logits = model(val_inputs.to(MLPConfig.device))
                val_loss = F.cross_entropy(val_logits, val_targets.to(MLPConfig.device)).item()
                val_losses.append(val_loss)
            
        # Logging
        steps.append(step)
        train_losses.append(loss.item())
        if step == 1:
            print(f"Initial Train Loss = {loss.item():.4f}")
        if step % MLPConfig.val_interval == 0:
            print(f"Step {step}: Train Loss = {loss.item():.4f}, Val Loss = {val_loss:.4f}")
    
    final_val_logits = model(val_inputs.to(MLPConfig.device))
    final_val_loss = F.cross_entropy(final_val_logits, val_targets.to(MLPConfig.device)).item()
    print(f"Final Validation Loss = {final_val_loss:.4f}")

    # Plot the loss
    if max_steps > MLPConfig.val_interval:
        plt.figure()
        plt.plot(steps, train_losses, label="Train")
        plt.plot(steps[::MLPConfig.val_interval], val_losses, label="Validation")
        plt.xlabel("Steps")
        plt.ylabel("Loss")
        plt.legend()
        plt.show()

In [None]:
train(model=mlp, max_steps=MLPConfig.max_steps)

## Inference

## Evaluation