[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DalasNoin/arena/blob/main/w1/attention.ipynb)

### Create a Transfomer and test it on a toy example

In [5]:

import torch
from torch.nn.functional import softmax
from torch import nn
from dataclasses import dataclass

In [6]:
Q = torch.ones((2,100,64))
K = torch.ones((2,90,64))
V = torch.ones((2,90,64))


def attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (batch, target sequence length, embedding dim)
    K: shape (batch, source sequence length, embedding dim)
    V: shape (batch, source sequence length, embedding dim)
    softmax(Q KT/sqrt(d_k))V

    Return: shape (same as Q if embedding dim same. batch, target sequence length, output embedding dim)
    '''
    sqrt_d_k = torch.sqrt(torch.tensor(K.shape[-1]))
    query_key = torch.bmm(Q,torch.transpose(K,1,2))
    # print(f"{query_key.shape=} {sqrt_d_k=}")
    result =torch.bmm(softmax(query_key/sqrt_d_k,dim=2), V)
    return result

attention(Q, K, V).shape


torch.Size([2, 100, 64])

tensor([[[  1108.8000,   1198.4000,   1288.0000],
         [  4284.0000,   4625.6001,   4967.2002],
         [  7459.2002,   8052.7998,   8646.4004],
         [ 10634.4004,  11480.0000,  12325.5996],
         [ 13809.5996,  14907.2002,  16004.8008],
         [ 16984.8008,  18334.4004,  19684.0000],
         [ 20160.0000,  21761.5996,  23363.1992]],

        [[178684.7969, 184419.2031, 190153.5938],
         [203028.0000, 209543.5938, 216059.2031],
         [227371.2031, 234668.0000, 241964.7969],
         [251714.4062, 259792.4062, 267870.3750],
         [276057.6250, 284916.8125, 293776.0000],
         [300400.8125, 310041.1875, 319681.5938],
         [324744.0000, 335165.5938, 345587.1875]]])

In [7]:




Q = torch.ones((2,20,64))
K = torch.ones((2,10,64))
V = torch.ones((2,10,64))

def masked_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
    '''
    Should return the results of self-attention.

    You should implement masking for this function. See "The Decoder Side" for an explanation of masking.

    Q: shape (batch, target sequence length, embedding dim)
    K: shape (batch, source sequence length, embedding dim)
    V: shape (batch, source sequence length, embedding dim)
    I = Q K.T
    I.shape = target_len x source_len
    softmax((I+mask)/sqrt(d_k))V

    Return: shape (same as Q if embedding dim same. batch, target sequence length, output embedding dim)
    '''
    sqrt_d_k = torch.sqrt(torch.tensor(K.shape[-1]))
    target_seq_len = torch.tensor(Q.shape[1])
    source_seq_len = torch.tensor(K.shape[1])
    triangular = torch.triu(torch.ones((target_seq_len, source_seq_len), dtype=torch.bool), diagonal=1)
    # print(triangular)

    query_key = torch.bmm(Q, torch.transpose(K,1,2))
    masked_query_key = torch.where(triangular, -torch.inf, query_key)
    # print(masked_query_key.shape, query_key.shape, triangular.shape)
    result =torch.bmm(softmax((masked_query_key)/sqrt_d_k,dim=2), V)
    return result



result = masked_attention(Q, K, V)
print(result.shape)

torch.Size([2, 20, 64])


In [8]:
# from matplotlib import pyplot as plt

In [44]:
Q = torch.arange(2 * 7 * 4).reshape(2, 7, 4).type(torch.float32)
K = Q * 0.5
V = Q * 0.8
num_heads=2
multihead_masked_attention(Q, K, V, num_heads)

tensor([[[   616.0000,    652.4000,   3757.6001,   3962.0000],
         [  5409.6001,   5726.0000,   9738.4004,  10267.5996],
         [ 10203.2002,  10799.5996,  15719.2012,  16573.1992],
         [ 14996.8008,  15873.2002,  21700.0000,  22878.8008],
         [ 19790.4004,  20946.8008,  27680.8008,  29184.3984],
         [ 24584.0000,  26020.4004,  33661.6016,  35490.0000],
         [ 29377.6016,  31094.0000,  39642.4023,  41795.6016]],

        [[268822.4062, 275287.5938, 315868.0000, 323128.4062],
         [306544.0000, 313916.4062, 357285.5938, 365498.0000],
         [344265.5938, 352545.1875, 398703.1875, 407867.5938],
         [381987.2188, 391174.0000, 440120.8125, 450237.1875],
         [419708.8125, 429802.7812, 481538.4062, 492606.8125],
         [457430.4062, 468431.5938, 522956.0000, 534976.3750],
         [495152.0000, 507060.4062, 564373.6250, 577346.0000]]])

In [47]:
Q = torch.ones((2,20,4*64))
K = torch.ones((2,10,4*64))
V = torch.ones((2,10,4*64))
num_heads = 4

def multihead_masked_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_heads: int, device:str="cpu"):
    '''
    Implements multihead masked attention on the matrices Q, K and V.

    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)
    '''
    # do the reshape
    
    batch, target_seq_len = Q.shape[0:2]
    source_seq_len = K.shape[1] 
    head_size = int(Q.shape[-1]/num_heads)
    sqrt_d_k = torch.sqrt(torch.tensor(head_size))
    # new_shape = (batch, target_seq_len, num_heads, head_size)
    Q = torch.reshape(Q, (batch, target_seq_len, num_heads, head_size))
    K = torch.reshape(K, (batch, source_seq_len, num_heads, head_size))
    V = torch.reshape(V, (batch, source_seq_len, num_heads, head_size))
    # generate mask
    triangular = torch.triu(torch.ones((target_seq_len, source_seq_len), dtype=torch.bool, device=device), diagonal=1)
    
    query_key = torch.einsum("abcd,aecd->acbe", Q, K)
    masked_query_key = torch.where(triangular, -torch.inf, query_key)
    masked_query_key = softmax((masked_query_key)/sqrt_d_k,dim=1)
    result = torch.einsum("abcd, adbe-> acbe", query_key, V)
    result = torch.reshape(result, (batch, target_seq_len, num_heads * head_size))
    return result



result = multihead_masked_attention(Q, K, V, num_heads=num_heads)
print(result.shape)


torch.Size([2, 20, 256])


In [11]:
class MultiheadMaskedAttention(nn.Module):
    """
    head_size is not in this config, because in our implementation we're assuming num_heads * head_size = hidden_size.
hidden_size is also referred to as embedding_dim, or d_\text{model}d 
model
​
  in some material you might have read.
    """
    W_QKV: nn.Linear
    W_O: nn.Linear


    def __init__(self, hidden_size: int, num_heads: int, device:str="cpu"):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.device = device
        self.W_QKV = nn.Linear(hidden_size*3, num_heads*hidden_size*3)
        self.W_O = nn.Linear(num_heads*hidden_size, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        x = x.repeat((1,1,3)) # repeat trice along dim 2
        Q, K, V = torch.split(self.W_QKV(x), num_heads*self.hidden_size, 2)
        #print(f"{Q.shape=} {K.shape=} {V.shape=}")
        
        Z = multihead_masked_attention(Q, K, V, num_heads=self.num_heads, device=self.device)
        #print(f"{Z.shape=}")
        Z = self.W_O(Z)
        return Z

# num_heads=4
# x = torch.ones((2,10,hidden_size:=64))       
# mma = MultiheadMaskedAttention(hidden_size=hidden_size, num_heads=num_heads)
# mma(x).shape



In [12]:
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int
    # head_size is not in this config, because in our implementation we're assuming num_heads * head_size = hidden_size
    num_heads: int
    vocab_size: int
    # hidden_size is also referred to as embedding_dim, or d_\text{model}d model in some material you might have read.
    hidden_size: int
    # max_seq_len is used just to determine the size of the positional encoding matrix.
    max_seq_len: int 
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    device: str = "cpu"



In [13]:
from torch.nn import GELU

class MLP(nn.Module):
    def __init__(self, hidden_size: int, dropout: float):
        super().__init__()
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.mlp_block = nn.Sequential(
            nn.Linear(self.hidden_size, 4*self.hidden_size),
            GELU(),
            nn.Linear(4*self.hidden_size, self.hidden_size),
            nn.Dropout(self.dropout)
        )
    def forward(self, x: torch.Tensor):
        return self.mlp_block(x)

In [14]:
class DecoderBlock(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.device = config.device
        self.layernorm1 = nn.LayerNorm(normalized_shape=self.config.hidden_size,eps=self.config.layer_norm_epsilon)
        self.mma = MultiheadMaskedAttention(hidden_size=self.config.hidden_size,
                                            num_heads=self.config.num_heads,
                                            device=self.device)
        self.layernorm2 = nn.LayerNorm(normalized_shape=self.config.hidden_size,eps=self.config.layer_norm_epsilon)
        self.mlp = MLP(hidden_size=self.config.hidden_size,
                        dropout=self.config.dropout)
    
    def forward(self, x: torch.Tensor):
        """
        x: input tensor shape=(batch, seq_len, hidden_dim=embedding_dim)
        """
        x = self.layernorm1(x + self.mma(x))
        x = self.layernorm2(x + self.mlp(x))
        return x

    

    

In [15]:
import torch
from torch import nn, Tensor


# more efficient, buffer for pe version
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, device:str="cpu"):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        self.max_len = max_len
        self.device = device
        L = self.max_len
        partial_term = torch.outer(torch.arange(L),1/10_000**(torch.arange(torch.ceil(torch.tensor(self.d_model/2)))*2/self.d_model))
        positional_encoding = torch.zeros((L, self.d_model)).to(device)
        positional_encoding[:,::2] = torch.sin(partial_term.to(device))
        positional_encoding[:,1::2] = torch.cos(partial_term.to(device))
        self.register_buffer("positional_encoding", positional_encoding)


    def forward(self, x: Tensor) -> Tensor:
        '''
        x: Tensor, shape [batch, seq_len, embedding_dim]
        '''
        L = x.shape[1]
        # print(self.device)

        return self.dropout(x.to(self.device)+ self.positional_encoding[:L,:].to(self.device))

class DecoderOnlyTransformer(nn.Module):

    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.positional_encoding = PositionalEncoding(
            d_model=self.config.hidden_size,
            dropout=self.config.dropout,
            max_len=self.config.max_seq_len,
            device=self.config.device
        )
        list_decoder_blocks = [DecoderBlock(config = self.config) 
                                    for _ in range(self.config.num_layers)]
        self.decoder_blocks = nn.Sequential(*list_decoder_blocks)
        self.final_layer_norm = nn.LayerNorm(normalized_shape=self.config.hidden_size,eps=self.config.layer_norm_epsilon)


    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.positional_encoding(x)
        x = self.decoder_blocks(x)
        x = self.final_layer_norm(x)
        return x
    
config = TransformerConfig(
    num_layers=2,
    num_heads=4,
    vocab_size=1_000,
    hidden_size=64,
    max_seq_len=100,
    device="cpu"
)
# todo use the layernorm epsilon
test_input = torch.ones((2,20,config.hidden_size)).to(config.device)

transformer = DecoderOnlyTransformer(config=config)
transformer.to(config.device)
result = transformer.forward(test_input)



In [16]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss, MSELoss

In [17]:
class CustomTextDataset(Dataset):
    # def __init__(self, text, labels):
    #     self.labels = labels
    #     self.text = text
    def __init__(self, config: TransformerConfig):
        self.config = config
        self.seq_len = 25
        self.total_size = 100
        # self.text = torch.ones((self.total_size,
        #                         self.seq_len,
        #                         config.hidden_size)).to(config.device)
        # self.labels = torch.ones((self.total_size,
        #                         self.seq_len,
        #                         config.hidden_size)).to(config.device)
        self.text = torch.rand((self.seq_len,
                                config.hidden_size)).to(config.device).repeat(self.total_size,1,1)
        # self.labels = torch.rand((self.seq_len,
        #                         config.hidden_size)).to(config.device).repeat(self.total_size,1,1)
        

    def __len__(self):
            return self.total_size

    def __getitem__(self, idx):
            label = self.text[idx,1:]
            text = self.text[idx,:-1]
            sample = {"text": text, "label": label}
            return sample

In [18]:
def train(config: TransformerConfig):
    dataset = CustomTextDataset(config)
    model = DecoderOnlyTransformer(config).to(config.device)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    model.train()
    optimizer = Adam(params=model.parameters(), lr=0.001)
    # criterion = CrossEntropyLoss()
    criterion = MSELoss()
    for epoch_idx in range(50):
        for i, batch in enumerate(dataloader):
            # print(i, batch["text"].shape, batch["label"].shape)
            label = model.forward(batch["text"])
            target_label=batch["label"]
            loss = criterion(label, target_label)
           
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # if i %  == 0:
        # print(f"torch.mean(label)={torch.mean(label)} \t torch.mean(target_label)={torch.mean(target_label)}")
        print(loss.detach().cpu().numpy())

    return model



model = train(config)


1.2152405
1.1913815
1.1765445
1.1361688
1.0981227
1.0835243
1.0615007
1.0276015
1.007597
0.98298997
0.9608886
0.93912405
0.912196
0.89202505
0.8756698
0.85887486
0.8408169
0.8232762
0.80576414
0.7894452
0.7710479
0.76532936
0.75823
0.7349519
0.7195223
0.6991127
0.7018892
0.6864446
0.6691892
0.6633661
0.65107113
0.6473933
0.6359388
0.62032527
0.6142461
0.60213536
0.59114593
0.5818282
0.573105
0.55832285
0.54055184
0.5344665
0.5248639
0.51531905
0.5062099
0.49551716
0.4900469
0.4768146
0.47255838
0.47133112


In [19]:
import shakespeare

In [23]:
config = TransformerConfig(
    num_layers=2,
    num_heads=4,
    vocab_size=shakespeare.vocab_size,
    hidden_size=64,
    max_seq_len=100,
    device="cpu"
)
dataset = shakespeare.ShakespeareDataset(config)

In [22]:
from torch.nn.functional import one_hot
one_hot(torch.Tensor([1,4,34]).to(torch.int64), 5000)

tensor([[0, 1, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0],
        [0, 0, 0,  ..., 0, 0, 0]])