# Creating a Baseline for Training and Evaluating Models

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
from DoTLMViz.transformers import Embedding, PosEmbedding, LayerNorm, Attention, Unembedding, Config
from DoTLMViz.datamodules import Piles10k

from jaxtyping import Float

import torch
import torch.nn as nn
import torch.nn.functional as F

We will be creating a baseline for training and evaluating out custom attention only transformer language models in this notebook.


## The Configuration*

To start training our model, we will first need to assemble it. We have developed all the modules that will make the 1L-attn-only model to take a `Config` which contains various configuration for the model:

1. `d_model` - the size of the embedding
2. `layer_norm_eps` - a small constant that is added on the variance during normalization.
3. `d_vocab` - the size of the vocabulary
4. `init_range` - the range of values in the paramters
5. `n_ctx` - the context length
6. `d_head` - the dimension of each attention head
7. `d_mlp` - the dimension of mlp
8. `n_heads` - the number of attention heads
9. `n_layers` - the number of transformer blocks

For a 1L-attn-only model, we won't need the mlp and the number of layers will be only one, so the configuration will be:

In [3]:
class OneLayerAttnOnlyConfig(Config):
    d_model: int = 768
    layer_norm_eps: float = 1e-5
    d_vocab: int = 50257
    init_range: float = 0.02
    n_ctx: int = 128
    d_head: int = 64
    n_heads: int = 6

## The Model

Now that we have our configuration, we can assemble our 1L-attn-only model using the modules - `Embedding`, `PosEmbedding`, `LayerNorm`, `Attention`, `Unembedding` - that we have already developed.

In [4]:
class OneLayerAttnOnlyModel(nn.Module):
    """One Layer Attention Only Transformer Language Model."""

    def __init__(self, config: Config):
        super().__init__()
        self.embed = Embedding(config)
        self.pos_embed = PosEmbedding(config)
        self.ln1 = LayerNorm(config)
        self.attn = Attention(config)
        self.ln2 = LayerNorm(config)
        self.unembed = Unembedding(config)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """
        The forward method takes an input `x` on which the One Layer Attention Only
        Transformer Model operates to produce logits.
        """
        x = self.embed(x) + self.pos_embed(x)
        x = self.ln1(x)
        x = self.attn(x)
        x = self.ln2(x)
        x = self.unembed(x)
        return x

We now have our model, and we have also implemented how the input passes through it to form the logits in the `forward` method of the model. But, *what kind of data is our model operating on?*

In [5]:
model = OneLayerAttnOnlyModel(config=OneLayerAttnOnlyConfig).to("cuda")

## The Data

We will be using the Piles-10k dataset to train our model, which we have made available through the Piles10k dataset.

In [6]:
datamodule = Piles10k(batch_size=32, max_length=OneLayerAttnOnlyConfig.n_ctx)
datamodule.prepare_data()
datamodule.setup()

Let's look at the first batch of data

In [7]:
samples = next(iter(datamodule.train_dataloader()))
samples.shape

torch.Size([32, 128])

So each batch in our dataloaders consist of 32 sequences, each of length 128.

## Training and Evaluating Loop (For a single batch)

In the training loop, we need to do the following:

1. Perform a forward pass to obtain logits.
2. Use the logits to calculate the loss.
3. Perform a backward pass to backpropagate the loss.
4. Use an optimizer to optimize the parameters.

The above 4 steps are repeated for a certain number of steps/epochs or until some other condition is satisfied.

Let us instantiate a model first,

In [8]:
model = OneLayerAttnOnlyModel(config=OneLayerAttnOnlyConfig).to("cuda")

### For a Single Batch

##### 1. Perform a forward pass to obtain the logits

To perform an example of forward pass, let us obtain a sample from the dataloader first:

In [9]:
sample = next(iter(datamodule.train_dataloader()))
sample = sample.to("cuda")
sample.shape

torch.Size([32, 128])

Now, the forward pass can be performed by simply passing the data from the model. As a result of the forward pass, we will obtain logits.

In [10]:
logits = model(sample)

##### 2. Use the logits to calculate the loss

The loss we will be using to update our model will be the Negative Loglikelihood Loss. The Negative Loglikelihood Loss can be computed by using the logits as shown below:

In [11]:
# 1. Compute the log likelihood for all tokens in the vocabulary at each position in the sequence
log_probs = F.log_softmax(logits, dim=-1)
log_probs.shape

torch.Size([32, 128, 50257])

In [12]:
# 2. From the computed log probabilities, find the predicted log probabilities for the next tokens for each position in the sequence.
pred_log_probs = torch.gather(log_probs[:, :-1], -1, sample[:, 1:, None])[..., 0]
pred_log_probs.shape

torch.Size([32, 127])

In [13]:
# 3. Compute the mean of the predicted log probability for the next token for each position in the sequence.
-pred_log_probs.mean()

tensor(10.9717, device='cuda:0', grad_fn=<NegBackward0>)

Wrapping all of this in a function:

In [14]:
def cross_entropy_loss(
    logits: Float[torch.Tensor, "batch seq d_vocab"], tokens: Float[torch.Tensor, "batch seq"]
) -> torch.Tensor:
    """
    Returns cross entropy loss for the given logits and tokens.
    """
    log_probs: Float[torch.Tensor, "batch seq d_vocab"] = F.log_softmax(logits, dim=-1)
    pred_log_probs: Float[torch.Tensor, "batch seq"] = torch.gather(log_probs[:, :-1], -1, tokens[:, 1:, None])[..., 0]
    return -pred_log_probs.mean()

##### 3. Perform a backward pass to backpropagate the loss

The loss can be backpropagated through the model by simply calling the `backward` method on it.

In [15]:
loss = cross_entropy_loss(logits, sample)
loss.backward()

But, we will be backpropagating the loss for each batch inside an epoch. To ensure that the gradient from one batch doesn't affect another batch, we should zero out all the grads. We do this by using an optimizer.

>**Aside: What about other Optimizers?**

In [16]:
optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

So, we should zero out the grads before backpropagating it. Thus, we have the following so far:

In [17]:
logits = model(sample)
loss = cross_entropy_loss(logits, sample)
optimizer.zero_grad()
loss.backward()

##### Use an optimizer to optimize the parameters

The parameters can be optimized by simply calling the `step` method of the optimizer.

In [18]:
optimizer.step()

Thus, the training loop performs the following for each batch inside each epoch:

In [19]:
logits = model(sample)
loss = cross_entropy_loss(logits, sample)
optimizer.zero_grad()
loss.backward()
optimizer.step()

### Through all Batches

We iterate through each batch ins the train dataloader and execute the code to train the model on a single batch for each batch.

In [None]:
for sample in datamodule.train_dataloader():
    logits = model(sample)
    loss = cross_entropy_loss(logits, sample)
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()

To evaluate the model, we can use accuracy as the metric.

>**Aside: Metric**
>
>Is accuracy a good metric for the task of language generation? If not, then which metric should we use, and can you justify its use?

In [None]:
model.eval()
with torch.no_grad():
    total_correct, total_sample = 0, 0
    for sample in datamodule.val_dataloader():
        logits = model(sample)
        predicted_tokens = logits[:, :-1].argmax(dim=-1)
        total_correct += (predicted_tokens == sample[:, 1:]).sum().item()
        total_sample += sample.size(0) * (sample.size(1) - 1)

    accuracy = total_correct / total_sample

## Through all Epochs

Now, we can iterate through each epoch and perform the above steps by logging the informations as:

>**Aside: Number of batches** 
>
>It can take forever to run this for the entire batches in the dataloader. Could you do it so that the training is only performed for a specific number of batches per epoch? **Will training the model like this be justified?**

In [None]:
epochs = 5
for epoch in range(epochs):
    losses, accuracies = [], []

    # training part

    model.train()
    total_loss = 0

    for sample in datamodule.train_dataloader():
        logits = model(sample)
        loss = cross_entropy_loss(logits, sample)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    mean_loss = total_loss / len(datamodule.train_dataloader())  # per epoch
    losses.append(mean_loss)

    # evaluation part

    model.eval()
    total_correct, total_sample = 0, 0

    for sample in datamodule.val_dataloader():
        logits = model(sample)
        predicted_tokens = logits[:, :-1].argmax(dim=-1)
        total_correct += (predicted_tokens == sample[:, 1:]).sum().item()
        total_sample += sample.size(0) * (sample.size(1) - 1)

    accuracy = total_correct / total_sample
    accuracies.append(accuracy)

    print(f"[{epoch + 1}/{epochs}]\tloss: {mean_loss}\tacc: {accuracy}")