# Modelos de lengua basados en transformers

Modifica el código del modelo de lengua basado en transformers para que use multi-query attention sin implementar todavía una caché KV.

Realiza un pequeño estudio sobre cómo afecta esto a la calidad del modelo; para
esto puedes medir qué probabilidad da el modelo a algunas frases similares a las del conjunto de entrenamiento.

Adicionalmente, estudia cómo afecta a la calidad del modelo el uso de una caché KV que tendrás que implementar. Aunque sería deseable poder medir el impacto de ambas cosas en los tiempos de ejecución, no es necesario que lo hagas, ya que probablemente no puedas medirlo con precisión suficiente salvo que incrementes el tamaño de los datos de entrenamiento y el número de parámetros del modelo.

Explica en tu respuesta las ideas básicas tanto de multi-query attention como de la caché KV.

## Modelo Original

In [1]:
%%capture
%pip install torch

### Preparación del mini batch

In [2]:
import torch
import random

def make_batch(tokenized_corpus, word_index, max_len, batch_size, device):

    token_indices = [word_index.get(token, word_index['[UNK]']) for token in tokenized_corpus]
    n_tokens = len(token_indices)  # number of tokens in the corpus
    assert n_tokens >= max_len, f'Short corpus ({n_tokens} tokens), must be at least {max_len} tokens long'

    while True:
        input_batch, output_batch = [], []

        for _ in range(batch_size):
            start_index = random.randint(0, n_tokens - 1)  # random start
            end_index = start_index + max_len
            input_seq = token_indices[start_index:end_index]
            if end_index > n_tokens:
                input_seq += token_indices[:end_index - n_tokens]

            # output is input shifted one token to the right:
            output_seq = input_seq[1:] + [token_indices[end_index % n_tokens]]

            input_batch.append(input_seq)
            output_batch.append(output_seq)

        yield torch.LongTensor(input_batch).to(device), torch.LongTensor(output_batch).to(device)
        pass  # this line will be executed next time the function is called

### Importacion del transformer

In [3]:
%%capture
import os
colab = bool(os.getenv("COLAB_RELEASE_TAG"))  # running in Google Colab?
if not os.path.isfile('transformer.ipynb') and colab:
    %pip install wget
    !wget https://raw.githubusercontent.com/jaspock/me/main/docs/materials/transformers/assets/notebooks/transformer.ipynb

%pip install nbformat
%run './transformer.ipynb'

set_seed(42)

### Preproceso del corpus

In [4]:
# download Tiny Shakespeare dataset:
import urllib.request
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
chars = 10000  # number of characters to keep
corpus = urllib.request.urlopen(url).read().decode("utf-8")[:chars]
print(corpus[:100])

word_list = list(set(corpus.split()))
word_index = {'[PAD]': 0, '[UNK]': 1}
special_tokens = len(word_index)
for i, w in enumerate(word_list):
    word_index[w] = i + special_tokens
index_word = {i: w for i, w in enumerate(word_index)}
vocab_size = len(word_index)
print(f"vocab_size = {vocab_size}")

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
vocab_size = 862


### Entrenamiento del modelo

In [5]:
n_layer = 2
n_head = 2
n_embd =  64
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
batch_size = 4
max_len = 32
training_steps = 1000
eval_steps = 100
lr = 0.001

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = DecoderTransformer(n_embd=n_embd, n_head=n_head, n_layer=n_layer, vocab_size=vocab_size,
                max_len=max_len, embd_pdrop=embd_pdrop, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop)
model.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)  # not needed here since we are not padding inputs
optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=training_steps)

model.train()
tokenized_corpus = corpus.split()
step = 0

for inputs, outputs in make_batch(tokenized_corpus, word_index, max_len, batch_size, device):
    optimizer.zero_grad()
    logits = model(inputs)
    loss = criterion(logits.view(-1,logits.size(-1)), outputs.view(-1))
    if step % eval_steps == 0:
        print(f'Step [{step}/{training_steps}], loss: {loss.item():.4f}, perplexity: {math.exp(loss.item()):.2f}')
    loss.backward()
    optimizer.step()
    scheduler.step()
    step = step + 1
    if (step==training_steps):
        break

print(f'Step [{step}/{training_steps}], loss: {loss.item():.4f}, perplexity: {math.exp(loss.item()):.2f}')

number of parameters: 0.16M
Step [0/1000], loss: 6.7625, perplexity: 864.83
Step [100/1000], loss: 4.7342, perplexity: 113.78
Step [200/1000], loss: 3.5120, perplexity: 33.52
Step [300/1000], loss: 2.6646, perplexity: 14.36
Step [400/1000], loss: 2.4456, perplexity: 11.54
Step [500/1000], loss: 1.3665, perplexity: 3.92
Step [600/1000], loss: 0.9380, perplexity: 2.55
Step [700/1000], loss: 0.8752, perplexity: 2.40
Step [800/1000], loss: 0.6776, perplexity: 1.97
Step [900/1000], loss: 0.6829, perplexity: 1.98
Step [1000/1000], loss: 0.5262, perplexity: 1.69


### Evaluación del modelo

In [6]:

def generate_text(model, prompt, word_index, index_word, max_len, device):
    words = prompt.split()
    input_ids = [word_index.get(word, word_index['[UNK]']) for word in words]
    input = torch.LongTensor(input_ids).view(1, -1).to(device)  # add batch dimension

    with torch.no_grad():
        for _ in range(max_len - len(input_ids)):
            output = model(input)
            last_token_logits = output[0, -1, :]
            predicted_id = torch.argmax(last_token_logits, dim=-1).item()
            input = torch.cat([input, torch.LongTensor([predicted_id]).view(1,-1).to(device)], dim=1)
            predicted_word = index_word[predicted_id]
            words.append(predicted_word)

    return ' '.join(words)

model.eval()
prompt = "O God, that robot is out of control! I tell you, friends, "
generated_text = generate_text(model, prompt, word_index, index_word, max_len, device)
print(generated_text)


O God, that robot is out of control! I tell you, friends, most charitable care for you If you'll bestow a small--of what you have heard it; But, since it serves my


## Multi Query Attention sin caché KV.

### Preparación del mini batch

In [7]:
import torch
import random

def make_batch(tokenized_corpus, word_index, max_len, batch_size, device):

    token_indices = [word_index.get(token, word_index['[UNK]']) for token in tokenized_corpus]
    n_tokens = len(token_indices)  # number of tokens in the corpus
    assert n_tokens >= max_len, f'Short corpus ({n_tokens} tokens), must be at least {max_len} tokens long'

    while True:
        input_batch, output_batch = [], []

        for _ in range(batch_size):
            start_index = random.randint(0, n_tokens - 1)  # random start
            end_index = start_index + max_len
            input_seq = token_indices[start_index:end_index]
            if end_index > n_tokens:
                input_seq += token_indices[:end_index - n_tokens]

            # output is input shifted one token to the right:
            output_seq = input_seq[1:] + [token_indices[end_index % n_tokens]]

            input_batch.append(input_seq)
            output_batch.append(output_seq)

        yield torch.LongTensor(input_batch).to(device), torch.LongTensor(output_batch).to(device)
        pass  # this line will be executed next time the function is called

### Nuevos Metodos

In [8]:
class MultiQueryBlock(Block):
    def __init__(self, n_embd, n_head, attn_pdrop=0.1, resid_pdrop=0.1):
        super().__init__(n_embd, n_head, attn_pdrop, resid_pdrop)
        self.multiquery_attention = MultiQueryAttention(n_embd, n_head, attn_pdrop)

    def forward(self, x, mask):
        x = self.multiquery_attention(x, mask)
        x = self.feed_forward(x)
        return x


class MultiQueryAttention(nn.Module):
    def __init__(self, n_embd, n_head, attn_pdrop=0.1):
        super().__init__()
        assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
        self.n_embd = n_embd
        self.n_head = n_head
        self.d_head = n_embd // n_head

        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(attn_pdrop)

    def forward(self, x, mask):
        B, T, C = x.size()
        qkv = self.c_attn(x).view(B, T, 3, self.n_head, self.d_head).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        q = q / (self.d_head ** 0.5)

        scores = torch.einsum("bhld,bhkd->bhlk", q, k)
        scores.masked_fill_(mask == 0, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum("bhlk,bhld->bhld", attn, v)
        x = x.permute(0, 3, 1, 2).contiguous().view(B, T, C)
        x = self.resid_drop(x)
        return x


class MultiQueryDecoderTransformer(AbstractTransformer):
    def __init__(self, n_embd, n_head, n_layer, vocab_size, max_len,
                 embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1):
        super().__init__(n_embd=n_embd, n_head=n_head, n_layer=n_layer, max_len=max_len, vocab_size=vocab_size,
                         embd_pdrop=embd_pdrop, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        self._init_weights()

    def _init_weights(self):
        super()._init_weights()

    def forward(self, inputs):
        B, T = inputs.size()
        device = inputs.device
        mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()
        mask = mask.view(1, T, T)
        x = super().forward(inputs, mask)
        logits = self.lm_head(x)

        return logits


### Preproceso del corpus

In [9]:
# download Tiny Shakespeare dataset:
import urllib.request
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
chars = 10000  # number of characters to keep
corpus = urllib.request.urlopen(url).read().decode("utf-8")[:chars]
print(corpus[:100])

word_list = list(set(corpus.split()))
word_index = {'[PAD]': 0, '[UNK]': 1}
special_tokens = len(word_index)
for i, w in enumerate(word_list):
    word_index[w] = i + special_tokens
index_word = {i: w for i, w in enumerate(word_index)}
vocab_size = len(word_index)
print(f"vocab_size = {vocab_size}")

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
vocab_size = 862


### Entrenamiento del modelo

In [10]:
n_layer = 2
n_head = 2
n_embd =  64
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
batch_size = 4
max_len = 32
training_steps = 1000
eval_steps = 100
lr = 0.001

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model2 = MultiQueryDecoderTransformer(n_embd=n_embd, n_head=n_head, n_layer=n_layer, vocab_size=vocab_size,
                max_len=max_len, embd_pdrop=embd_pdrop, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop)
model2.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)  # not needed here since we are not padding inputs
optimizer = optim.Adam(model2.parameters(), lr=lr)
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=training_steps)

model2.train()
tokenized_corpus = corpus.split()
step = 0

for inputs, outputs in make_batch(tokenized_corpus, word_index, max_len, batch_size, device):
    optimizer.zero_grad()
    logits = model2(inputs)
    loss = criterion(logits.view(-1,logits.size(-1)), outputs.view(-1))
    if step % eval_steps == 0:
        print(f'Step [{step}/{training_steps}], loss: {loss.item():.4f}, perplexity: {math.exp(loss.item()):.2f}')
    loss.backward()
    optimizer.step()
    scheduler.step()
    step = step + 1
    if (step==training_steps):
        break

print(f'Step [{step}/{training_steps}], loss: {loss.item():.4f}, perplexity: {math.exp(loss.item()):.2f}')

number of parameters: 0.16M
Step [0/1000], loss: 6.7847, perplexity: 884.21
Step [100/1000], loss: 5.1770, perplexity: 177.16
Step [200/1000], loss: 3.5851, perplexity: 36.06
Step [300/1000], loss: 2.5506, perplexity: 12.82
Step [400/1000], loss: 2.0883, perplexity: 8.07
Step [500/1000], loss: 1.3878, perplexity: 4.01
Step [600/1000], loss: 1.3365, perplexity: 3.81
Step [700/1000], loss: 0.9014, perplexity: 2.46
Step [800/1000], loss: 0.9228, perplexity: 2.52
Step [900/1000], loss: 0.5535, perplexity: 1.74
Step [1000/1000], loss: 0.6283, perplexity: 1.87


### Evaluación del modelo

In [11]:

def generate_text(model, prompt, word_index, index_word, max_len, device):
    words = prompt.split()
    input_ids = [word_index.get(word, word_index['[UNK]']) for word in words]
    input = torch.LongTensor(input_ids).view(1, -1).to(device)  # add batch dimension

    with torch.no_grad():
        for _ in range(max_len - len(input_ids)):
            output = model(input)
            last_token_logits = output[0, -1, :]
            predicted_id = torch.argmax(last_token_logits, dim=-1).item()
            input = torch.cat([input, torch.LongTensor([predicted_id]).view(1,-1).to(device)], dim=1)
            predicted_word = index_word[predicted_id]
            words.append(predicted_word)

    return ' '.join(words)

model2.eval()
prompt = "O God, that robot is out of control! I tell you, friends, "
generated_text = generate_text(model2, prompt, word_index, index_word, max_len, device)
print(generated_text)


O God, that robot is out of control! I tell you, friends, most grave belly be content to give good report fort, but that he should find you lions, finds you hares;


## Multi Query Attention con caché KV.

### Preparación del mini batch

In [12]:
import torch
import random

def make_batch(tokenized_corpus, word_index, max_len, batch_size, device):

    token_indices = [word_index.get(token, word_index['[UNK]']) for token in tokenized_corpus]
    n_tokens = len(token_indices)  # number of tokens in the corpus
    assert n_tokens >= max_len, f'Short corpus ({n_tokens} tokens), must be at least {max_len} tokens long'

    while True:
        input_batch, output_batch = [], []

        for _ in range(batch_size):
            start_index = random.randint(0, n_tokens - 1)  # random start
            end_index = start_index + max_len
            input_seq = token_indices[start_index:end_index]
            if end_index > n_tokens:
                input_seq += token_indices[:end_index - n_tokens]

            # output is input shifted one token to the right:
            output_seq = input_seq[1:] + [token_indices[end_index % n_tokens]]

            input_batch.append(input_seq)
            output_batch.append(output_seq)

        yield torch.LongTensor(input_batch).to(device), torch.LongTensor(output_batch).to(device)
        pass  # this line will be executed next time the function is called

### Nuevos Métodos

In [13]:
class MultiQueryAttention(nn.Module):
    def __init__(self, n_embd, n_head, attn_pdrop=0.1):
        super().__init__()
        assert n_embd % n_head == 0, "n_embd must be divisible by n_head"
        self.n_embd = n_embd
        self.n_head = n_head
        self.d_head = n_embd // n_head

        self.c_attn = nn.Linear(n_embd, 3 * n_embd)
        self.attn_drop = nn.Dropout(attn_pdrop)
        self.resid_drop = nn.Dropout(attn_pdrop)

        self.key_cache = None
        self.value_cache = None

    def forward(self, x, mask):
        B, T, C = x.size()
        qkv = self.c_attn(x).view(B, T, 3, self.n_head, self.d_head).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        if self.key_cache is None and self.value_cache is None:
            self.key_cache = k
            self.value_cache = v
        else:
            k = torch.cat([self.key_cache, k], dim=-2)
            v = torch.cat([self.value_cache, v], dim=-2)

            self.key_cache = k
            self.value_cache = v

        q = q / (self.d_head ** 0.5)

        scores = torch.einsum("bhld,bhkd->bhlk", q, k)
        scores.masked_fill_(mask == 0, float("-inf"))
        attn = F.softmax(scores, dim=-1)
        attn = self.attn_drop(attn)

        x = torch.einsum("bhlk,bhld->bhld", attn, v)
        x = x.permute(0, 3, 1, 2).contiguous().view(B, T, C)
        x = self.resid_drop(x)
        return x


class MultiQueryDecoderTransformer(AbstractTransformer):
    def __init__(self, n_embd, n_head, n_layer, vocab_size, max_len,
                 embd_pdrop=0.1, attn_pdrop=0.1, resid_pdrop=0.1):
        super().__init__(n_embd=n_embd, n_head=n_head, n_layer=n_layer, max_len=max_len, vocab_size=vocab_size,
                         embd_pdrop=embd_pdrop, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop)
        self.lm_head = nn.Linear(n_embd, vocab_size, bias=False)
        self._init_weights()

    def _init_weights(self):
        super()._init_weights()

    def forward(self, inputs):
        B, T = inputs.size()
        device = inputs.device
        mask = torch.triu(torch.ones(T, T, device=device), diagonal=1).bool()  # causal attention mask
        mask = mask.view(1, T, T)  # expand mask, (T, T) -> (1, T, T)
        x = super().forward(inputs, mask)
        logits = self.lm_head(x)

        return logits

### Preproceso del corpus

In [14]:
# download Tiny Shakespeare dataset:
import urllib.request
url = 'https://raw.githubusercontent.com/karpathy/char-rnn/master/data/tinyshakespeare/input.txt'
chars = 10000  # number of characters to keep
corpus = urllib.request.urlopen(url).read().decode("utf-8")[:chars]
print(corpus[:100])

word_list = list(set(corpus.split()))
word_index = {'[PAD]': 0, '[UNK]': 1}
special_tokens = len(word_index)
for i, w in enumerate(word_list):
    word_index[w] = i + special_tokens
index_word = {i: w for i, w in enumerate(word_index)}
vocab_size = len(word_index)
print(f"vocab_size = {vocab_size}")

First Citizen:
Before we proceed any further, hear me speak.

All:
Speak, speak.

First Citizen:
You
vocab_size = 862


### Entrenamiento del modelo

In [15]:
n_layer = 2
n_head = 2
n_embd =  64
embd_pdrop = 0.1
resid_pdrop = 0.1
attn_pdrop = 0.1
batch_size = 4
max_len = 32
training_steps = 1000
eval_steps = 100
lr = 0.001

import torch
import torch.nn as nn
import torch.optim as optim
from torch.optim import lr_scheduler
import math

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model3 = MultiQueryDecoderTransformer(n_embd=n_embd, n_head=n_head, n_layer=n_layer, vocab_size=vocab_size,
                max_len=max_len, embd_pdrop=embd_pdrop, attn_pdrop=attn_pdrop, resid_pdrop=resid_pdrop)
model3.to(device)

criterion = nn.CrossEntropyLoss(ignore_index=0)  # not needed here since we are not padding inputs
optimizer = optim.Adam(model3.parameters(), lr=lr)
scheduler = lr_scheduler.LinearLR(optimizer, start_factor=1.0, end_factor=0.5, total_iters=training_steps)

model3.train()
tokenized_corpus = corpus.split()
step = 0

for inputs, outputs in make_batch(tokenized_corpus, word_index, max_len, batch_size, device):
    optimizer.zero_grad()
    logits = model3(inputs)
    loss = criterion(logits.view(-1,logits.size(-1)), outputs.view(-1))
    if step % eval_steps == 0:
        print(f'Step [{step}/{training_steps}], loss: {loss.item():.4f}, perplexity: {math.exp(loss.item()):.2f}')
    loss.backward()
    optimizer.step()
    scheduler.step()
    step = step + 1
    if (step==training_steps):
        break

print(f'Step [{step}/{training_steps}], loss: {loss.item():.4f}, perplexity: {math.exp(loss.item()):.2f}')

number of parameters: 0.16M
Step [0/1000], loss: 6.7531, perplexity: 856.75
Step [100/1000], loss: 4.9128, perplexity: 136.03
Step [200/1000], loss: 3.5589, perplexity: 35.12
Step [300/1000], loss: 2.5533, perplexity: 12.85
Step [400/1000], loss: 1.8998, perplexity: 6.68
Step [500/1000], loss: 1.3750, perplexity: 3.96
Step [600/1000], loss: 1.0287, perplexity: 2.80
Step [700/1000], loss: 0.9463, perplexity: 2.58
Step [800/1000], loss: 0.7983, perplexity: 2.22
Step [900/1000], loss: 0.7975, perplexity: 2.22
Step [1000/1000], loss: 0.6249, perplexity: 1.87


### Evaluación del modelo

In [16]:

def generate_text(model, prompt, word_index, index_word, max_len, device):
    words = prompt.split()
    input_ids = [word_index.get(word, word_index['[UNK]']) for word in words]
    input = torch.LongTensor(input_ids).view(1, -1).to(device)  # add batch dimension

    with torch.no_grad():
        for _ in range(max_len - len(input_ids)):
            output = model(input)
            last_token_logits = output[0, -1, :]
            predicted_id = torch.argmax(last_token_logits, dim=-1).item()
            input = torch.cat([input, torch.LongTensor([predicted_id]).view(1,-1).to(device)], dim=1)
            predicted_word = index_word[predicted_id]
            words.append(predicted_word)

    return ' '.join(words)

model3.eval()
prompt = "O God, that robot is out of control! I tell you, friends, "
generated_text = generate_text(model3, prompt, word_index, index_word, max_len, device)
print(generated_text)


O God, that robot is out of control! I tell you, friends, most grave belly was deliberate, Not rash like his accusers, and thus answer'd: 'True is it, my incorporate friends,' quoth


## Conclusiones

Las ideas básicas para multi-query attention y para caché kv son:

- Multi-query attention: es un algoritmo que es utilizado para mejorar la eficiencia del modelo sin reducir prácticamente su exactitud. Consiste en reducir o eliminar las cabezas h de los valores K y V. A cada cabeza del valor de consulta Q se le aplica la misma transformación K y V.

- Caché KV: es una técnica utilizada para mejorar la eficiencia de los modelos de atención. Se realiza almacenando los vectores de clave valor en una caché después de calcularse por primera vez, para así acceder a ellos posteriormente sin tener que volver a calcularlos.

Como se puede apreciar después del entrenamiento de cada modelo, los modelos multi-query attention sin y con caché KV no mejoran al original. Pero sí que mejoran su tiempo de entrenamiento. El modelo original tardó 25s, mientras que el modelo sin caché KV tardó 20s y el modelo con caché KV tardó 15s. Se puede apreciar una gran mejora en el tiempo sin llegar a afectar significativamente en el rendimiento del modelo.