In [1]:
import torch
import torch.nn as nn
import math
from dataclasses import dataclass
from torch.nn import functional as F

In [2]:
@dataclass
class GPTConfig:
    block_size: int = 1024
    vocab_size: int = 50304 # GPT-2 vocab_size of 50257, padded up to nearest multiple of 64 for efficiency
    n_layer: int = 12
    n_head: int = 12
    n_embd: int = 768
    bias: bool = False

In [3]:
class CausalSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        # key, query, value projections for all heads, but in a batch
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd, bias=config.bias)
        # output projection
        self.c_proj = nn.Linear(config.n_embd, config.n_embd, bias=config.bias)
        # causal mask to ensure that attention is only applied to the left in the input sequence
        self.register_buffer("decoder_mask", torch.tril(torch.ones(config.block_size, config.block_size))
                                        .view(1, 1, config.block_size, config.block_size))
        #Reshaping buffer as (1 ~~> batch_size, 1 ~~~> num_heads, block_size, block_size)
        self.n_head = config.n_head
        self.n_embd = config.n_embd

    def forward(self, x, attention_mask=None):
        
        B, T, C = x.size() # batch size, sequence length, embedding dimensionality (n_embd)

        # calculate query, key, values for all heads in batch and move head forward to be the batch dim
        qkv_transformations= self.c_attn(x) # batch_size, sequence_length, 3 * n_embd
        q, k, v  = qkv_transformations.split(self.n_embd, dim=2)

        q_rearranged= q.view(B, T, self.n_head, C // self.n_head) # (B,T,n_heads,partitioned_head_size)
        k_rearranged= k.view(B, T, self.n_head, C // self.n_head) # (B,T,n_heads,partitioned_head_size)
        v_rearranged= v.view(B, T, self.n_head, C // self.n_head) # (B,T,n_heads,partitioned_head_size)

        q = q_rearranged.transpose(1, 2) # (B, nh, T, hs)
        k = k_rearranged.transpose(1, 2) # (B, nh, T, hs)
        v = v_rearranged.transpose(1, 2) # (B, nh, T, hs)

        att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))) #(B,nh,T,T)
        # Without Multi-Head Attention ~~> Attention Weights would have been just (B,T,T)
        # However, Remember that the hidden_dimension or inner_dimenion has been reduce from head_size to partitioned_head_size
        # for computing each of the B*nh Attention Weight Matrices now!
        print_b,print_head,print_time,print_dims=1,0,2,T
        print(" Sample#: {print_b} , Head#: {print_head}, Token# Under Attention Investigation:{print_time}, "
        "Total Attention Weights for Ref Token: {print_dims}".format(
            print_b=print_b,print_head=print_head,print_time=print_time,print_dims=print_dims))
        print("Attention Weight Logits Before Self-Maksing ", att[print_b,print_head,print_time,:print_dims])
        att = att.masked_fill(self.decoder_mask[:,:,:T,:T] == 0, float('-inf'))
        print("Attention Weight Logits After Decoder-Self Masking ", att[print_b,print_head,print_time,:print_dims])
        ############################
        if attention_mask is not None:
            # Expand mask to match scores shape
            #attention_mask.shape ~> (B,T)
            attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)  # [B, 1, 1, T]
            # mask = (1 - attention_mask) * -1e9  # Invert: 1 -> 0, 0 -> -1e9
            # att = att + mask
            att = att.masked_fill(attention_mask == 0, float('-inf'))
            print("Attention Weight Logits After External Attention-Masking ", att[print_b,print_head,print_time,:print_dims])
        #############################
        #(B,nh,T,T)
        att = F.softmax(att, dim=-1)
        print("Attention Weight Logits After Softmax ", att[print_b,print_head,print_time,:print_dims]*100)
        #att: (B,nh,T,T)
        y = att @ v # (B, nh, T, T) x (B, nh, T, hs) -> (B, nh, T, hs)
        # Note Here; Values are Weighted Separately for Every head for Each Position
        # Hence, Total Number of Weighted Values are (B*nh*T); Hoever, the dimension for every Value here has become smaller
        # by a factor of num_heads
        y_transposed= y.transpose(1,2) #(B, nh, T, hs) ~~> (B, T, nh, hs) 
        #y_transposed has become non-contiguous because of the transpose operation
        y = y_transposed.contiguous().view(B, T, C) # re-assemble all head outputs side by side
        # y is back to (B, T, n_embed) dimension <Same as Input>
        # output projection
        y = self.c_proj(y)
        return att,attention_mask,y

In [4]:
class MockEmbeddingModel(nn.Module):
    def __init__(self, gptconf, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.wte= nn.Embedding(gptconf.vocab_size,gptconf.n_embd)
    def forward(self,x):
        return self.wte(x)

In [6]:
# Example usage
hidden_size = 4  # Small for demo
num_heads = 2
head_dim = hidden_size // num_heads  # 2
gptconf = GPTConfig()
mock_emb=MockEmbeddingModel(gptconf)
B,T=2,4
x= torch.Tensor(
    [
        [12, 1, 43,44],
        [gptconf.vocab_size-1,23,2,173]
    ]
).int()
x = mock_emb(x)
print(x.shape)
# Attention mask (Prompt 1: all real, Prompt 2: 1 padding)
attention_mask = torch.tensor([
    [1, 1, 1,1],  # Prompt 1: 4 real tokens
    [0, 1, 1,1]   # Prompt 2: 1 padding, 3 real tokens
], dtype=torch.float32)

attention_model=CausalSelfAttention(gptconf)
att,attention_mask,y = attention_model(x,attention_mask)

torch.Size([2, 4, 768])
 Sample#: 1 , Head#: 0, Token# Under Attention Investigation:2, Total Attention Weights for Ref Token: 4
Attention Weight Logits Before Self-Maksing  tensor([-0.0410,  0.0649, -0.2804, -0.3085], grad_fn=<SelectBackward0>)
Attention Weight Logits After Decoder-Self Masking  tensor([-0.0410,  0.0649, -0.2804,    -inf], grad_fn=<SelectBackward0>)
Attention Weight Logits After External Attention-Masking  tensor([   -inf,  0.0649, -0.2804,    -inf], grad_fn=<SelectBackward0>)
Attention Weight Logits After Softmax  tensor([ 0.0000, 58.5487, 41.4513,  0.0000], grad_fn=<MulBackward0>)
