<a href="https://colab.research.google.com/github/PacktPublishing/Modern-Computer-Vision-with-PyTorch-2E/blob/main/Chapter15/self-attention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
%%capture
%pip install torch-snippets lovely-tensors pysnooper

In [None]:
%reload_ext autoreload
%autoreload 2
from torch_snippets import *
from pysnooper import snoop
from builtins import print

In [18]:
class SelfAttention(nn.Module):
    def __init__(self, embed_size):
        super(SelfAttention, self).__init__()
        self.embed_size = embed_size

        # Query, Key, Value projections
        self.query = nn.Linear(embed_size, embed_size)
        self.key = nn.Linear(embed_size, embed_size)
        self.value = nn.Linear(embed_size, embed_size)

    @snoop()
    def forward(self, x):
        # x shape: (batch_size, seq_len, embed_size)
        query = self.query(x)  # shape: (batch_size, seq_len, embed_size)
        key = self.key(x)      # shape: (batch_size, seq_len, embed_size)
        value = self.value(x)  # shape: (batch_size, seq_len, embed_size)

        # Compute the attention scores
        # query shape: (batch_size, seq_len, embed_size)
        # key shape: (batch_size, seq_len, embed_size)
        # scores shape: (batch_size, seq_len, seq_len)
        scores = torch.bmm(query, key.transpose(1, 2)) / (self.embed_size ** 0.5)

        # Apply softmax to get the attention weights
        # dim=-1 ensures softmax is applied across the sequence length
        weights = F.softmax(scores, dim=-1)

        # Apply the attention weights to the values
        out = torch.bmm(weights, value)  # shape: (batch_size, seq_len, embed_size)
        return out

SA = SelfAttention(64)
x = torch.randn(5, 3, 64)
SA(x)

[33m[2mSource path:... [22m<ipython-input-18-b5025a279c6a>[0m
[32m[2mStarting var:.. [22mself = SelfAttention(  (query): Linear(in_features=64, ...near(in_features=64, out_features=64, bias=True))[0m
[32m[2mStarting var:.. [22mx = tensor[5, 3, 64] n=960 (3.8Kb) x∈[-3.357, 2.783] μ=-0.029 σ=0.997[0m
[2m10:44:59.383754 call        13[0m     def forward(self, x):
[2m10:44:59.386178 line        15[0m         query = self.query(x)  # shape: (batch_size, seq_len, embed_size)
[32m[2mNew var:....... [22mquery = tensor[5, 3, 64] n=960 (3.8Kb) x∈[-1.923, 1.820] μ=0.003 σ=0.608 grad ViewBackward0[0m
[2m10:44:59.387399 line        16[0m         key = self.key(x)      # shape: (batch_size, seq_len, embed_size)
[32m[2mNew var:....... [22mkey = tensor[5, 3, 64] n=960 (3.8Kb) x∈[-1.557, 1.585] μ=-0.019 σ=0.571 grad ViewBackward0[0m
[2m10:44:59.388865 line        17[0m         value = self.value(x)  # shape: (batch_size, seq_len, embed_size)
[32m[2mNew var:....... [22mval

tensor[5, 3, 64] n=960 (3.8Kb) x∈[-1.309, 1.075] μ=-0.013 σ=0.333 grad BmmBackward0

---

Actual Implementation in pytorch with multihead

In [6]:
# ??F.multi_head_attention_forward
# !ln -s /usr/local/lib/python3.10/dist-packages/torch/nn/functional.py .
# Add snoop() to F.multi_head_attention_forward in above mentioned python file and run the code below

In [3]:
import torch
import torch.nn as nn

class TransformerEncoderModule(nn.Module):
    def __init__(self, embed_size, num_heads, dropout_rate=0.1):
        super(TransformerEncoderModule, self).__init__()
        self.layer_norm = nn.LayerNorm(embed_size)
        self.multi_head_attention = nn.MultiheadAttention(embed_dim=embed_size, num_heads=num_heads)
        self.dropout = nn.Dropout(dropout_rate)
        self.feed_forward = nn.Sequential(
            nn.Linear(embed_size, embed_size * 4),
            nn.ReLU(),
            nn.Linear(embed_size * 4, embed_size),
            nn.Dropout(dropout_rate)
        )

    def forward(self, src):
        # Normalize and compute self-attention
        src = self.layer_norm(src)
        attention_output, _ = self.multi_head_attention(src, src, src)
        src = src + self.dropout(attention_output)

        # Apply feed-forward network
        src = self.layer_norm(src)
        feed_forward_output = self.feed_forward(src)
        src = src + self.dropout(feed_forward_output)
        return src

# Parameters
embed_size = 512  # Embedding size
num_heads = 8     # Number of attention heads (ensure embed_size % num_heads == 0)
dropout_rate = 0.1

# Create the transformer encoder module
transformer_encoder = TransformerEncoderModule(embed_size, num_heads, dropout_rate)

# Example input (Batch size x Time steps x Embedding size)
input_tensor = torch.randn(5, 3, 512)  # 1 batch, 3 time steps, 512 embeddings each

# Forward pass through the transformer encoder
output_tensor = transformer_encoder(input_tensor)

print(output_tensor)


[32m[2mStarting var:.. [22mquery = tensor[5, 3, 512] n=7680 (30Kb) x∈[-3.828, 3.956] μ=-9.934e-10 σ=1.000 grad NativeLayerNormBackward0[0m
[32m[2mStarting var:.. [22mkey = tensor[5, 3, 512] n=7680 (30Kb) x∈[-3.828, 3.956] μ=-9.934e-10 σ=1.000 grad NativeLayerNormBackward0[0m
[32m[2mStarting var:.. [22mvalue = tensor[5, 3, 512] n=7680 (30Kb) x∈[-3.828, 3.956] μ=-9.934e-10 σ=1.000 grad NativeLayerNormBackward0[0m
[32m[2mStarting var:.. [22membed_dim_to_check = 512[0m
[32m[2mStarting var:.. [22mnum_heads = 8[0m
[32m[2mStarting var:.. [22min_proj_weight = Parameter containing:Parameter[1536, 512] n=786432 (3Mb) x∈[-0.054, 0.054] μ=-1.519e-05 σ=0.031 grad[0m
[32m[2mStarting var:.. [22min_proj_bias = Parameter containing:Parameter[1536] 6Kb all_zeros grad[0m
[32m[2mStarting var:.. [22mbias_k = None[0m
[32m[2mStarting var:.. [22mbias_v = None[0m
[32m[2mStarting var:.. [22madd_zero_attn = False[0m
[32m[2mStarting var:.. [22mdropout_p = 0.0[0m
[32m[

tensor[5, 3, 512] n=7680 (30Kb) x∈[-3.756, 3.655] μ=-0.005 σ=1.034 grad AddBackward0


[32m[2mModified var:.. [22mattn_output_weights = tensor[3, 5, 5] n=75 x∈[0.132, 0.308] μ=0.200 σ=0.036 grad MeanBackward1[0m
[2m10:25:15.425229 line      5460[0m         if not is_batched:
[2m10:25:15.447279 line      5464[0m         return attn_output, attn_output_weights
[2m10:25:15.464976 return    5464[0m         return attn_output, attn_output_weights
[36m[2mReturn value:.. [22m(tensor[5, 3, 512] n=7680 (30Kb) x∈[-0.833, 0.78...0.132, 0.308] μ=0.200 σ=0.036 grad MeanBackward1)[0m
[33m[2mElapsed time: [22m00:00:01.104722[0m
