In [1]:
GPT_CONFIG_124M = {
    "vocab_size": 50257,
    "context_length": 1024,
    "emb_dim": 768,
    "n_heads": 12,
    "n_layers": 12,
    "drop_rate": 0.1,
    "qkv_bias": False
}

# ![transformer_block](transformer_block.png)


In [8]:
import torch
import torch.nn as nn
from utils.attention import MultiHeadAttention
from utils.layers import FeedForward, LayerNorm

In [9]:
class TransformerBlock(nn.Module):
    def __init__(self,config):
        super().__init__()
        self.att = MultiHeadAttention(
            d_in=config["emb_dim"],
            d_out=config["emb_dim"],
            context_length=config["context_length"],
            dropout=config["drop_rate"],
            num_heads=config["n_heads"],
            qkv_bias=config["qkv_bias"]
        )
        self.ff = FeedForward(config)
        self.norm1 = LayerNorm(config["emb_dim"])
        self.norm2 = LayerNorm(config["emb_dim"])
        self.drop_skip_layer = nn.Dropout(config["drop_rate"])

    def forward(self,x: torch.Tensor) -> torch.Tensor:
        shortcut = x
        x = self.norm1(x)
        x = self.att(x)
        x = self.drop_skip_layer(x)
        x = x + shortcut
        shortcut = x
        x = self.norm2(x)
        x = self.ff(x)
        x = self.drop_skip_layer(x)
        x = x + shortcut
        return x

In [10]:
torch.manual_seed(123)
x = torch.rand(2, 4, 768)

transformer_block = TransformerBlock(GPT_CONFIG_124M)
transformer_block(x)

tensor([[[ 0.0620,  0.5158, -0.0550,  ...,  1.2767,  0.1946,  0.6363],
         [-0.0909,  0.1276,  0.2290,  ...,  0.6513,  0.5151,  0.7462],
         [ 0.4857,  0.4642, -0.0245,  ...,  1.2237,  0.1653,  0.7574],
         [ 0.0807,  0.8251,  0.8783,  ...,  0.4391,  0.6565,  0.7821]],

        [[ 0.3090,  1.2133,  0.5646,  ...,  0.2144, -0.0115, -0.4829],
         [-0.1076,  0.7464,  0.2806,  ...,  0.2268,  0.5715,  0.0618],
         [ 0.8299,  0.6553,  0.3559,  ...,  0.4327,  0.7580, -0.0314],
         [ 0.5128,  0.6655,  0.1042,  ...,  1.2630,  1.3332,  0.2348]]],
       grad_fn=<AddBackward0>)

In [11]:
print(x.shape)

torch.Size([2, 4, 768])
