[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/DalasNoin/arena/blob/main/w1/attention.ipynb)

### Create a Transfomer and test it on a toy example

In [1]:
# in colab 
# ! wget https://raw.githubusercontent.com/DalasNoin/arena/main/w1/shakespeare.py
# ! pip install transformers
# ! pip install wandb
import torch
from torch.nn.functional import softmax
from torch import nn
from dataclasses import dataclass
import wandb
import os

os.environ["WANDB_API_KEY"] = ""

In [2]:
Q = torch.ones((2,100,64))
K = torch.ones((2,90,64))
V = torch.ones((2,90,64))


def attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
    '''
    Should return the results of self-attention (see the "Self-Attention in Detail" section of the Illustrated Transformer).

    With this function, you can ignore masking.

    Q: shape (batch, target sequence length, embedding dim)
    K: shape (batch, source sequence length, embedding dim)
    V: shape (batch, source sequence length, embedding dim)
    softmax(Q KT/sqrt(d_k))V

    Return: shape (same as Q if embedding dim same. batch, target sequence length, output embedding dim)
    '''
    sqrt_d_k = torch.sqrt(torch.tensor(K.shape[-1]))
    query_key = torch.bmm(Q,torch.transpose(K,1,2))
    # print(f"{query_key.shape=} {sqrt_d_k=}")
    result =torch.bmm(softmax(query_key/sqrt_d_k,dim=2), V)
    return result

attention(Q, K, V).shape


torch.Size([2, 100, 64])

In [3]:




Q = torch.ones((2,20,64))
K = torch.ones((2,10,64))
V = torch.ones((2,10,64))

def masked_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor):
    '''
    Should return the results of self-attention.

    You should implement masking for this function. See "The Decoder Side" for an explanation of masking.

    Q: shape (batch, target sequence length, embedding dim)
    K: shape (batch, source sequence length, embedding dim)
    V: shape (batch, source sequence length, embedding dim)
    I = Q K.T
    I.shape = target_len x source_len
    softmax((I+mask)/sqrt(d_k))V

    Return: shape (same as Q if embedding dim same. batch, target sequence length, output embedding dim)
    '''
    sqrt_d_k = torch.sqrt(torch.tensor(K.shape[-1]))
    target_seq_len = torch.tensor(Q.shape[1])
    source_seq_len = torch.tensor(K.shape[1])
    triangular = torch.triu(torch.ones((target_seq_len, source_seq_len), dtype=torch.bool), diagonal=1)
    # print(triangular)

    query_key = torch.bmm(Q, torch.transpose(K,1,2))
    masked_query_key = torch.where(triangular, -torch.inf, query_key)
    # print(masked_query_key.shape, query_key.shape, triangular.shape)
    result =torch.bmm(softmax((masked_query_key)/sqrt_d_k,dim=2), V)
    return result



result = masked_attention(Q, K, V)
print(result.shape)

torch.Size([2, 20, 64])


In [4]:
# from matplotlib import pyplot as plt

In [5]:
Q = torch.ones((2,20,4*64))
K = torch.ones((2,10,4*64))
V = torch.ones((2,10,4*64))
num_heads = 4

def multihead_masked_attention(Q: torch.Tensor, K: torch.Tensor, V: torch.Tensor, num_heads: int, device:str="cpu"):
    '''
    Implements multihead masked attention on the matrices Q, K and V.

    Q: shape (batch, seq, nheads*headsize)
    K: shape (batch, seq, nheads*headsize)
    V: shape (batch, seq, nheads*headsize)
    '''
    # do the reshape
    
    batch, target_seq_len = Q.shape[0:2]
    source_seq_len = K.shape[1] 
    head_size = int(Q.shape[-1]/num_heads)
    sqrt_d_k = torch.sqrt(torch.tensor(head_size))
    # new_shape = (batch, target_seq_len, num_heads, head_size)
    Q = torch.reshape(Q, (batch, target_seq_len, num_heads, head_size))
    K = torch.reshape(K, (batch, source_seq_len, num_heads, head_size))
    V = torch.reshape(V, (batch, source_seq_len, num_heads, head_size))
    # generate mask
    triangular = torch.triu(torch.ones((target_seq_len, source_seq_len), dtype=torch.bool, device=device), diagonal=1)
    
    query_key = torch.einsum("abcd,aecd->acbe", Q, K)
    masked_query_key = torch.where(triangular, -torch.inf, query_key)
    masked_query_key = softmax((masked_query_key)/sqrt_d_k,dim=3)
    result = torch.einsum("abcd, adbe-> acbe", masked_query_key, V)
    result = torch.reshape(result, (batch, target_seq_len, num_heads * head_size))
    return result



result = multihead_masked_attention(Q, K, V, num_heads=num_heads)
print(result.shape)


torch.Size([2, 20, 256])


In [6]:
Q = torch.arange(2 * 7 * 4).reshape(2, 7, 4).type(torch.float32)
K = Q * 0.5
V = Q * 0.8
num_heads=2
multihead_masked_attention(Q, K, V, num_heads)

tensor([[[ 0.0000,  0.8000,  1.6000,  2.4000],
         [ 3.2000,  4.0000,  4.8000,  5.6000],
         [ 6.4000,  7.2000,  8.0000,  8.8000],
         [ 9.6000, 10.4000, 11.2000, 12.0000],
         [12.8000, 13.6000, 14.4000, 15.2000],
         [16.0000, 16.8000, 17.6000, 18.4000],
         [19.2000, 20.0000, 20.8000, 21.6000]],

        [[22.4000, 23.2000, 24.0000, 24.8000],
         [25.6000, 26.4000, 27.2000, 28.0000],
         [28.8000, 29.6000, 30.4000, 31.2000],
         [32.0000, 32.8000, 33.6000, 34.4000],
         [35.2000, 36.0000, 36.8000, 37.6000],
         [38.4000, 39.2000, 40.0000, 40.8000],
         [41.6000, 42.4000, 43.2000, 44.0000]]])

In [7]:
class MultiheadMaskedAttention(nn.Module):
    """
    head_size is not in this config, because in our implementation we're assuming num_heads * head_size = hidden_size.
hidden_size is also referred to as embedding_dim, or d_\text{model}d 
model
​
  in some material you might have read.
    """
    W_QKV: nn.Linear
    W_O: nn.Linear


    def __init__(self, hidden_size: int, num_heads: int, device:str="cpu"):
        super().__init__()
        self.num_heads = num_heads
        self.hidden_size = hidden_size
        self.device = device
        self.W_QKV = nn.Linear(hidden_size*3, num_heads*hidden_size*3)
        self.W_O = nn.Linear(num_heads*hidden_size, hidden_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        '''
        x: shape (batch, seq, hidden_size)

        Return: shape (batch, seq, hidden_size)
        '''
        x = x.repeat((1,1,3)) # repeat trice along dim 2
        x = self.W_QKV(x)
        #print(f"{x.shape=} {num_heads=} {self.hidden_size=}")
        Q, K, V = torch.split(x, num_heads*self.hidden_size, 2)
        #print(f"{Q.shape=} {K.shape=} {V.shape=}")
        
        Z = multihead_masked_attention(Q, K, V, num_heads=self.num_heads, device=self.device)
        #print(f"{Z.shape=}")
        Z = self.W_O(Z)
        return Z

num_heads=4
hidden_size=64
x = torch.ones((2,10,hidden_size))       
mma = MultiheadMaskedAttention(hidden_size=hidden_size, num_heads=num_heads)
mma(x).shape



torch.Size([2, 10, 64])

In [8]:
@dataclass(frozen=True)
class TransformerConfig:
    '''Constants used throughout your decoder-only transformer model.'''

    num_layers: int
    # head_size is not in this config, because in our implementation we're assuming num_heads * head_size = hidden_size
    num_heads: int
    vocab_size: int
    # hidden_size is also referred to as embedding_dim, or d_\text{model}d model in some material you might have read.
    hidden_size: int
    # max_seq_len is used just to determine the size of the positional encoding matrix.
    max_seq_len: int 
    dropout: float = 0.1
    layer_norm_epsilon: float = 1e-05
    device: str = "cpu"



In [9]:
from torch.nn import GELU

class MLP(nn.Module):
    def __init__(self, hidden_size: int, dropout: float):
        super().__init__()
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.mlp_block = nn.Sequential(
            nn.Linear(self.hidden_size, 4*self.hidden_size),
            GELU(),
            nn.Linear(4*self.hidden_size, self.hidden_size),
            nn.Dropout(self.dropout)
        )
    def forward(self, x: torch.Tensor):
        return self.mlp_block(x)

In [10]:
class DecoderBlock(nn.Module):
    def __init__(self, config: TransformerConfig):
        super().__init__()
        self.config = config
        self.device = config.device
        self.layernorm1 = nn.LayerNorm(normalized_shape=self.config.hidden_size,eps=self.config.layer_norm_epsilon)
        self.mma = MultiheadMaskedAttention(hidden_size=self.config.hidden_size,
                                            num_heads=self.config.num_heads,
                                            device=self.device)
        self.layernorm2 = nn.LayerNorm(normalized_shape=self.config.hidden_size,eps=self.config.layer_norm_epsilon)
        self.mlp = MLP(hidden_size=self.config.hidden_size,
                        dropout=self.config.dropout)
    
    def forward(self, x: torch.Tensor):
        """
        x: input tensor shape=(batch, seq_len, hidden_dim=embedding_dim)
        """
        x = self.layernorm1(x + self.mma(x))
        x = self.layernorm2(x + self.mlp(x))
        return x

    

    

In [11]:
import torch
from torch import nn, Tensor


# more efficient, buffer for pe version
class PositionalEncoding(nn.Module):

    def __init__(self, d_model: int, dropout: float = 0.1, max_len: int = 5000, device:str="cpu"):
        super().__init__()
        self.d_model = d_model
        self.dropout = nn.Dropout(p=dropout)
        self.max_len = max_len
        self.device = device
        L = self.max_len
        partial_term = torch.outer(torch.arange(L),1/10_000**(torch.arange(torch.ceil(torch.tensor(self.d_model/2)))*2/self.d_model))
        positional_encoding = torch.zeros((L, self.d_model)).to(device)
        positional_encoding[:,::2] = torch.sin(partial_term.to(device))
        positional_encoding[:,1::2] = torch.cos(partial_term.to(device))
        self.register_buffer("positional_encoding", positional_encoding)


    def forward(self, x: Tensor) -> Tensor:
        '''
        x: Tensor, shape [batch, seq_len, embedding_dim]
        '''
        L = x.shape[1]
        # print(self.device)

        return self.dropout(x.to(self.device)+ self.positional_encoding[:L,:].to(self.device))

class DecoderOnlyTransformer(nn.Module):

    def __init__(self, config: TransformerConfig, vocab_size:int):
        super().__init__()
        self.config = config
        
        self.embedding = nn.Embedding(
            num_embeddings=vocab_size,
            embedding_dim=self.config.hidden_size)
        self.positional_encoding = PositionalEncoding(
            d_model=self.config.hidden_size,
            dropout=self.config.dropout,
            max_len=self.config.max_seq_len,
            device=self.config.device
        )
        list_decoder_blocks = [DecoderBlock(config = self.config) 
                                    for _ in range(self.config.num_layers)]
        self.decoder_blocks = nn.Sequential(*list_decoder_blocks)
        self.final_layer_norm = nn.LayerNorm(normalized_shape=self.config.hidden_size,eps=self.config.layer_norm_epsilon)
        self.unembed = nn.Linear(self.config.hidden_size, vocab_size)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.embedding(x)
        # print(f"x.shape={x.shape}")
        x = self.positional_encoding(x)
        x = self.decoder_blocks(x)
        x = self.final_layer_norm(x)
        x = softmax(self.unembed(x),dim=2)
        return x
    
config = TransformerConfig(
    num_layers=2,
    num_heads=4,
    vocab_size=1_000,
    hidden_size=64,
    max_seq_len=100,
    device="cpu"
)
# todo use the layernorm epsilon
test_input = torch.ones((2,20)).int().to(config.device)

transformer = DecoderOnlyTransformer(config=config, vocab_size=100)
transformer.to(config.device)
result = transformer.forward(test_input)
result.shape

torch.Size([2, 20, 100])

In [12]:
from torch.utils.data import Dataset, DataLoader
from torch.optim import Adam
from torch.nn import CrossEntropyLoss, MSELoss

In [13]:
class CustomTextDataset(Dataset):
    # def __init__(self, text, labels):
    #     self.labels = labels
    #     self.text = text
    def __init__(self, config: TransformerConfig):
        self.config = config
        self.seq_len = 25
        self.total_size = 100
        self.vocab_size=config.hidden_size*2
        # self.text = torch.ones((self.total_size,
        #                         self.seq_len,
        #                         config.hidden_size)).to(config.device)
        # self.labels = torch.ones((self.total_size,
        #                         self.seq_len,
        #                         config.hidden_size)).to(config.device)
        self.text = torch.rand((self.seq_len)).int().to(config.device).repeat(self.total_size,1)
        # self.labels = torch.rand((self.seq_len,
        #                         config.hidden_size)).to(config.device).repeat(self.total_size,1,1)
        

    def __len__(self):
            return self.total_size

    def __getitem__(self, idx):
            label = self.text[idx,1:]
            text = self.text[idx,:-1]
            sample = {"text": text, "label": label}
            return text, label

In [14]:
def train(config: TransformerConfig):
    dataset = CustomTextDataset(config)
    model = DecoderOnlyTransformer(config, vocab_size=dataset.vocab_size).to(config.device)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
    model.train()
    optimizer = Adam(params=model.parameters(), lr=0.001)
    # criterion = CrossEntropyLoss()
    criterion = CrossEntropyLoss()
    for epoch_idx in range(5):
        for i, (text, target_label) in enumerate(dataloader):
            # print(i, batch["text"].shape, batch["label"].shape)
            label = model.forward(text)
            # target_label=batch["label"]
            # print(f"{label.shape=} {target_label.shape=}")
            loss = criterion(label.transpose(1,2), target_label.long())
           
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            # if i %  == 0:
        # print(f"torch.mean(label)={torch.mean(label)} \t torch.mean(target_label)={torch.mean(target_label)}")
        print(loss.detach().cpu().numpy())

    return model



model = train(config)


4.442834
4.2563744
4.1350603
4.0464044
3.9840705


In [15]:
import shakespeare

In [16]:
config = TransformerConfig(
    num_layers=2,
    num_heads=4,
    vocab_size=shakespeare.vocab_size,
    hidden_size=64,
    max_seq_len=100,
    device="cpu"
)
# dataset = shakespeare.ShakespeareDataset(config)

In [17]:
dataset = shakespeare.ShakespeareDataset(config)

In [18]:
def collate(batch: list):
    print(batch)
    max_len = max([len(text) for (text, label) in batch])
    batch_size = len(batch)
    new_batch = list()
    for i, (text, label) in enumerate(batch):
        padded_text = torch.zeros(max_len).long()
        padded_text[:len(text)]=text
        padded_label = torch.zeros(max_len).long()
        padded_label[:len(label)] = label
        new_batch.append((padded_text, padded_label))
    return new_batch


In [19]:
def collate(batch: list):
    # print(batch)
    max_len = max([len(text) for (text, label) in batch])
    batch_size = len(batch)
    new_text = torch.zeros((batch_size, max_len)).long()
    new_label = torch.zeros((batch_size, max_len)).long()
    for i, (text, label) in enumerate(batch):
        new_text[i,:len(text)]=text
        new_label[i,:len(label)] = label
    return new_text, new_label


In [20]:
wandb.init()
def train(config: TransformerConfig, dataset: Dataset):
    
    model = DecoderOnlyTransformer(config, vocab_size=dataset.vocab_size).to(config.device)
    wandb.watch(model,log_freq=100)
    dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate)
    model.train()
    optimizer = Adam(params=model.parameters(), lr=0.001)
    # criterion = CrossEntropyLoss()
    criterion = CrossEntropyLoss()
    for epoch_idx in range(50):
        for i, (text, target_label) in enumerate(dataloader):
            # print(i, batch["text"].shape, batch["label"].shape)
            label = model.forward(text)
            # target_label=batch["label"]
            # print(f"{label.shape=} {target_label.shape=}")
            loss = criterion(label.transpose(1,2), target_label)
           
            loss.backward()
            optimizer.step()
            optimizer.zero_grad()
            if i % 100 == 0:
               wandb.log({"loss": loss})
        # print(f"torch.mean(label)={torch.mean(label)} \t torch.mean(target_label)={torch.mean(target_label)}")
        print(loss.detach().cpu().numpy())

    return model



model = train(config, dataset)


[34m[1mwandb[0m: Currently logged in as: [33mdalasnoin[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [None]:
dataset = shakespeare.ShakespeareDataset(config)

In [None]:
dataset[3]["label"].shape

torch.Size([19])