**Connecting attention and linear layers in a transformet block**

In [1]:
import torch
import torch.nn as nn

inputs = torch.tensor([
    [0.43, 0.15, 0.89],  # your        (x^1)
    [0.55, 0.87, 0.66],  # journey     (x^2)
    [0.57, 0.85, 0.64],  # starts     (x^3)
    [0.22, 0.58, 0.33],  # with       (x^4)
    [0.77, 0.25, 0.10],  # one        (x^5)
    [0.05, 0.80, 0.55]  # step        (X^6)
])


In [10]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,  # vocabsize from BPE tokenizer
    "context_length": 1024,  # context length
    "emb_dim": 768,  # embedding dimension
    "n_heads": 12,  # number of attention heads
    "n_layers": 12,  # number of layers
    "drop_rate": 0.1,  # dropout rate
    "qkv_bias": False  # query-key-value bias
}


**Multi-head attention**

In [2]:

class MultiHeadAttention(nn.Module):
    def __init__(self, d_in, d_out, context_length, dropout, num_heads, qkv_bias=False):
        super().__init__()

        assert d_out % num_heads == 0, "d_out must be divisible by num_heads"

        self.d_out = d_out
        self.num_heads = num_heads
        # reduce projection dim to match desired output dim
        self.head_dim = d_out//num_heads
        self.W_query = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_key = nn.Linear(d_in, d_out, bias=qkv_bias)
        self.W_value = nn.Linear(d_in, d_out, bias=qkv_bias)

        # use linear layer to combine head output
        self.out_proj = nn.Linear(d_out, d_in)
        self.dropout = nn.Dropout(dropout)
        self.register_buffer(
            'mask',
            torch.triu(torch.ones(context_length, context_length), diagonal=1)
        )

    def forward(self, x):
        b, num_tokens, d_in = x.shape
        keys = self.W_key(x)
        queries = self.W_query(x)
        values = self.W_value(x)

        # from (b, num_tokens, d_out) to (b, num_token, num_heads, head_dim) unrolling
        keys = keys.view(b, num_tokens, self.num_heads, self.head_dim)
        queries = queries.view(b, num_tokens, self.num_heads, self.head_dim)
        values = values.view(b, num_tokens, self.num_heads, self.head_dim)

        # (b, num_tokens, num_heads, head_dim) to (b, num_heads, num_tokens, head_dim)
        keys = keys.transpose(1, 2)
        queries = queries.transpose(1, 2)
        values = values.transpose(1, 2)

        # keys (b, num_heads, num_tokens, head_dim) to (b, num_heads, head_dim, num_tokens)
        attn_scores = queries@keys.transpose(2, 3)
        mask_bool = self.mask.bool()[:num_tokens, :num_tokens]
        attn_scores.masked_fill_(mask_bool, -torch.inf)

        attn_weights = torch.softmax(attn_scores/keys.shape[-1]**0.5, dim=-1)
        attn_weights = self.dropout(attn_weights)
        context_vec = (attn_weights@values).transpose(1, 2)

        # combining heads: self.d_out=self.num_heads*self.head_dim
        context_vec = context_vec.contiguous().view(b, num_tokens, self.d_out)
        context_vec = self.out_proj(context_vec)
        return context_vec
    
    

In [3]:
batch = torch.stack((inputs, inputs), dim=0)
print(batch)
print(batch.shape)


tensor([[[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]],

        [[0.4300, 0.1500, 0.8900],
         [0.5500, 0.8700, 0.6600],
         [0.5700, 0.8500, 0.6400],
         [0.2200, 0.5800, 0.3300],
         [0.7700, 0.2500, 0.1000],
         [0.0500, 0.8000, 0.5500]]])
torch.Size([2, 6, 3])


In [4]:
torch.manual_seed(123)
batch_size, context_length, d_in = batch.shape
d_out = 2
mha = MultiHeadAttention(d_in=d_in, d_out=d_out,
                         context_length=context_length, dropout=0.0, num_heads=2)
context_vec = mha(batch)
print(context_vec)
print(context_vec.shape)


tensor([[[-0.1933,  0.0272, -0.2507],
         [-0.2179, -0.0689, -0.4201],
         [-0.2267, -0.0993, -0.4760],
         [-0.2430, -0.0712, -0.4813],
         [-0.2484, -0.0658, -0.4875],
         [-0.2548, -0.0558, -0.4908]],

        [[-0.1933,  0.0272, -0.2507],
         [-0.2179, -0.0689, -0.4201],
         [-0.2267, -0.0993, -0.4760],
         [-0.2430, -0.0712, -0.4813],
         [-0.2484, -0.0658, -0.4875],
         [-0.2548, -0.0558, -0.4908]]], grad_fn=<ViewBackward0>)
torch.Size([2, 6, 3])


**Layer norm**

In [5]:

class LayerNorm(nn.Module):
    def __init__(self, emb_dim):
        super().__init__()
        self.eps = 1e-5
        self.scale = nn.Parameter(torch.ones(emb_dim))
        self.shift = nn.Parameter(torch.zeros(emb_dim))

    def forward(self, x):
        mean = x.mean(dim=-1, keepdim=True)
        var = x.var(dim=-1, keepdim=True, unbiased=False)
        norm_x = (x-mean)/torch.sqrt(var+self.eps)
        return self.scale*norm_x+self.shift
    
    

**GELU**

In [None]:
class GELU(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return 0.5*x*(1+torch.tanh(torch.sqrt(torch.tensor(2.0/torch.pi))*(x+0.044715*torch.pow(x, 3))))
    
    

**Feed Forward**

In [8]:

class FeedForward(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.layers = nn.Sequential(
            nn.Linear(cfg["emb_dim"], 4*cfg["emb_dim"]),
            GELU(),
            nn.Linear(4*cfg["emb_dim"], cfg["emb_dim"])

        )

    def forward(self, x):
        return self.layers(x)
    

**Transformer Block**

In [12]:
class TransformerBlock(nn.Module):
    def __init__(self, cfg):
        super().__init__()
        self.att=MultiHeadAttention(
            d_in=cfg["emb_dim"],
            d_out=cfg["emb_dim"],
            context_length=cfg["context_length"],
            dropout=cfg["drop_rate"],
            num_heads=cfg["n_heads"],
            qkv_bias=cfg["qkv_bias"]
        )

        self.ff=FeedForward(cfg)
        self.norm1=LayerNorm(cfg["emb_dim"])
        self.norm2=LayerNorm(cfg["emb_dim"])
        self.drop_shortcut=nn.Dropout(cfg["drop_rate"])


    def forward(self,x):
        shortcut=x
        x=self.norm1(x)
        x=self.att(x)
        x=self.drop_shortcut(x)
        x=x+shortcut

        shortcut=x
        x=self.norm2(x)
        x=self.ff(x)
        x=self.drop_shortcut(x)
        x=x+shortcut
        return x



In [14]:
torch.manual_seed(123)
x=torch.rand(2,4,768)
block=TransformerBlock(GPT_CONFIG_124M)
output=block(x)

print(x.shape)
print(output.shape)

torch.Size([2, 4, 768])
torch.Size([2, 4, 768])
