# Setting up a character level language model

In [1]:
import torch

from utils import estimate_loss, get_batch, get_training_corpus, decode, train_val_split
from language_model import SimpleLanguageModel

import matplotlib.pyplot as plt

# hyperparameters
batch_size = 16  # how many independent sequences will we process in parallel?
block_size = 32  # what is the maximum context length for predictions?
eval_interval = 100
learning_rate = 1e-3
device = "cuda" if torch.cuda.is_available() else "cpu"
eval_iters = 100
n_embd = 64
n_head = 4
n_layer = 4
dropout = 0.0
disable_kqv_weights = True

In [2]:
def train():
    loss_history = []
    for iter in range(max_iters):
        # every once in a while evaluate the loss on train and val sets
        if iter % eval_interval == 0 or iter == max_iters - 1:
            losses = estimate_loss(
                model, eval_iters, device, train_data, val_data, block_size, batch_size
            )
            print(
                f"step {iter}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
            )
            loss_history.append((iter,losses['val']))
        
        # sample a batch of data
        xb, yb = get_batch(
            "train", device, train_data, val_data, block_size, batch_size
        )
        # evaluate the loss
        logits, loss = model(xb, yb)
        optimizer.zero_grad(set_to_none=True)
        loss.backward()
        optimizer.step()
    return loss_history

def generate_text():
    # generate from the model
    context = torch.zeros((1, 1), dtype=torch.long, device=device)
    print(
        decode(
            m.generate(context, max_new_tokens=1000, block_size=block_size)[0].tolist()
        )
    )


def plot_loss(plot_titles=None, *data_arrays):
    plt.figure(figsize=(10, 6))
    colors = ['blue', 'red', 'green', 'purple', 'orange', 'black']
    if plot_titles is not None:
        assert len(plot_titles) == len(data_arrays)
        
    for i, data in enumerate(data_arrays):
        iterations, losses = zip(*data)
        label = f'Data {i+1}' if plot_titles is None else plot_titles[i]
        plt.plot(iterations, losses, marker='o', color=colors[i % len(colors)], label=label)

    plt.title('Loss over iterations')
    plt.xlabel('Iteration')
    plt.ylabel('Loss')

    plt.legend()
    plt.show()

In [3]:
text, vocab_size = get_training_corpus()
train_data, val_data = train_val_split(text)

model = SimpleLanguageModel(
    vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, disable_kqv_weights
    )
m = model.to(device)
print(sum(p.numel() for p in m.parameters()) / 1e6, "M parameters")
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

Download successful!
0.209729 M parameters


First, let's see that the model is producing gibberish when it is not properly trained.

# Generating on an untrained model

In [4]:
max_iters = 10

In [5]:
_ = train()

step 0: train loss 4.4026, val loss 4.3902
step 9: train loss 3.5175, val loss 3.5238


In [6]:
generate_text()


:SMed-n.vaN,BXFr dx-KabivnZr-am nKDtoW,&nshdI.oefhPMerWnfsdPxme'AnOWv:HZZATj-u. r iewlrlcd.ek;
pNecrDV, vohenJW;N-IiodnBRWVo eWo b'f 3

oimjOLNK,m
'u
asDvebhZasn;jSenfrP.ZC
smKGk'AP
HrenZ'CQtX mAlYr
a eFVK 
Yah TG&zsmd tef
rfXjTeFe t hskxX
n
doojR

dm
ai!nxn
 eXervNmdKwJiH
nfenHisO'n;nfo 
OlrHxcCheorIe&s  Eknxw.sn.!ZedrsCSbQsrtarSoHhr&
iAcnhdkcIoIWE
vG&snahvvqqdsy CRr RstLsm e
rxL-Ves.Tcrregdkq3 rLY$loiZ ,'NtCK;D:PLa J inokr!rVKLleCPn&PGo dnYee  Ykl-o Y& N&s Ee
xi? tedoOfcZeWb;eees lM
Oems,-rnO OaUehrK: pNzKsTGsndiL d!dLertZZGGXeEolsImtoerwE vD
o,w MiH

N efs hIO& on;p pf-KeKL.hLHu'mJEBGhhrluXd
aedrYPTfeHh&LerF
t tmRe
 eeZ,rNfSrmfufshhLNs bGtrrndsOlo:rw ?RgeIr
P;mwnCyaatooOnPegd&e
dsAKTsegiP erefs.&VsrtmtemtMf
e,erebjdTnDezn EXen'f,rorWetDs&rsEssasZgAMP?
-na;
uso mPQdRsimeeehyhe,rIo l?er3P sZ'nn;T m Z$iZpZhe DQeLoZedeKOOyIJ d 'V
nN;aexi cizm3s rTMCZ  stiehuvme.nmeX,o,o shhZeT;eoerref:Rhu-h
T'P Tija rr -
v,TQcoXHeLtoo,ZqG 
b&!:&eTeo
hOESte TuenITL hKh:spe-rErDKhrpyb mnd T
si-m?r.s  reh

As you can see, the model is producing gibberish

# Training a transformer model with non-trainable K,Q,V in self-attention mechanism

Here I have implemented an additional feature gate to either make key, query and value learnable or not. The purpose is demonstrate how much learning is impeded by disabling weights of key, value and query to learn during training

In [7]:
disable_kqv_weights = True

This essentially make key, value and query linear layers not to be updated during training. You can check line 16-19 attention.py
```
if disable_kqv_weights:
    for layer in [self.key, self.value, self.query]:
        for param in layer.parameters():
            param.requires_grad = False
```

Let's train the model for the same amount of iterations and make comparison on whether it attains good text generation ability.


In [8]:
max_iters = 4000

In [9]:
text, vocab_size = get_training_corpus()
train_data, val_data = train_val_split(text)

m = SimpleLanguageModel(
    vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, disable_kqv_weights
    )
model = m.to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

Download successful!
0.209729 M parameters


In [None]:
loss_hist_disable = train()

step 0: train loss 4.3152, val loss 4.3176
step 100: train loss 2.6636, val loss 2.6766
step 200: train loss 2.5557, val loss 2.5885
step 300: train loss 2.4972, val loss 2.5403
step 400: train loss 2.4669, val loss 2.5070
step 500: train loss 2.4417, val loss 2.4767
step 600: train loss 2.3983, val loss 2.4355
step 700: train loss 2.3809, val loss 2.4253
step 800: train loss 2.3670, val loss 2.4151
step 900: train loss 2.3552, val loss 2.3826
step 1000: train loss 2.3253, val loss 2.3819
step 1100: train loss 2.3085, val loss 2.3526
step 1200: train loss 2.3049, val loss 2.3506
step 1300: train loss 2.2858, val loss 2.3359
step 1400: train loss 2.2829, val loss 2.3241
step 1500: train loss 2.2470, val loss 2.3076
step 1600: train loss 2.2596, val loss 2.3069
step 1700: train loss 2.2293, val loss 2.2894
step 1800: train loss 2.2176, val loss 2.2833
step 1900: train loss 2.2204, val loss 2.2749
step 2000: train loss 2.2058, val loss 2.2606
step 2100: train loss 2.2061, val loss 2.2590


In [None]:
generate_text()

Now, let's make key, query and value weights to be learnable.

In [None]:
disable_kqv_weights = False

In [None]:
max_iters = 4000

In [None]:
text, vocab_size = get_training_corpus()
train_data, val_data = train_val_split(text)

m = SimpleLanguageModel(
    vocab_size, n_embd, block_size, n_head, n_layer, dropout, device, disable_kqv_weights
    )
model = m.to(device)
print(sum(p.numel() for p in model.parameters()) / 1e6, "M parameters")
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

In [None]:
loss_hist_enable = train()

In [None]:
generate_text()

In [None]:
plot_loss(['KQV Not Learnable', 'KQV Learnable'], loss_hist_disable, loss_hist_enable)

As you can see, Key, Query and Value weights are pretty important for training a transformer language model with self-attention mechanism. Self-attention mechanism's only learnable part is the K, Q, V weight matrices. Apart from there self-attention mechansism has no learnable components.

The importance of self-attention mechansim with learnable K, Q, V weight matrices will be more apparent if the language model is a word level or sub-word level but a huge vocab size.