In [1]:
import torch 

In [3]:
torch.tril(torch.ones(4,4))

tensor([[1., 0., 0., 0.],
        [1., 1., 0., 0.],
        [1., 1., 1., 0.],
        [1., 1., 1., 1.]])

In [4]:
torch.tril(torch.ones(4,4)).view(1,1,4,4)

tensor([[[[1., 0., 0., 0.],
          [1., 1., 0., 0.],
          [1., 1., 1., 0.],
          [1., 1., 1., 1.]]]])

In [19]:
import torch
import torch.nn as nn
import math
import torch.nn.functional as F

FLASH=0

class CasualSelfAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        assert config.n_embd % config.n_head == 0
        self.c_attn = nn.Linear(config.n_embd, 3 * config.n_embd)
        self.c_proj = nn.Linear(config.n_embd, config.n_embd)
        self.c_proj.LLMC_RESIDUAL_SCALE_FLAG = 1
        self.n_head = config.n_head
        self.n_embd = config.n_embd
        self.register_buffer("bias", torch.tril(torch.ones(config.block_size, config.block_size))
                             .view(1, 1, config.block_size, config.block_size))
        
    def forward(self, x):
        B, T, C = x.size()
        print(f"Input size: B={B}, T={T}, C={C}")

        qkv = self.c_attn(x)
        print(f"QKV size: {qkv.size()}")

        q, k, v = qkv.split(self.n_embd, dim=2)
        print(f"Q size: {q.size()}, K size: {k.size()}, V size: {v.size()}")

        k = k.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        q = q.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        v = v.view(B, T, self.n_head, C // self.n_head).transpose(1, 2)
        print(f"Reshaped Q size: {q.size()}, K size: {k.size()}, V size: {v.size()}")

        if FLASH:
            y = F.scaled_dot_product_attention(q, k, v, is_causal=True)
        else:
            att = (q @ k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1)))
            print(f"Attention scores size: {att.size()}")

            att = att.masked_fill(self.bias[:, :, :T, :T] == 0, float('-inf'))
            att = F.softmax(att, dim=-1)
            print(f"Softmaxed attention scores size: {att.size()}")

            y = att @ v
            print(f"Attention output size: {y.size()}")

        y = y.transpose(1, 2).contiguous().view(B, T, C)
        print(f"Reshaped output size: {y.size()}")

        y = self.c_proj(y)
        print(f"Final output size: {y.size()}")
        return y

In [20]:


# Define a simple configuration class
class Config:
    def __init__(self, n_embd, n_head, block_size, attn_pdrop, resid_pdrop):
        self.n_embd = n_embd
        self.n_head = n_head
        self.block_size = block_size
        self.attn_pdrop = attn_pdrop
        self.resid_pdrop = resid_pdrop

# Create a configuration object
config = Config(
    n_embd=64,      # Embedding dimension
    n_head=8,       # Number of attention heads
    block_size=128, # Maximum sequence length
    attn_pdrop=0.1, # Dropout probability for attention
    resid_pdrop=0.1 # Dropout probability for residuals
)

# Instantiate the CasualSelfAttention model
model = CasualSelfAttention(config)

# Create a sample input tensor (batch size, sequence length, embedding dimension)
x = torch.randn(2, 128, 64)  # Example with batch size 2, sequence length 128, embedding dimension 64

# Run the model
output = model(x)

# Print the output
print("Output:", output)

Input size: B=2, T=128, C=64
QKV size: torch.Size([2, 128, 192])
Q size: torch.Size([2, 128, 64]), K size: torch.Size([2, 128, 64]), V size: torch.Size([2, 128, 64])
Reshaped Q size: torch.Size([2, 8, 128, 8]), K size: torch.Size([2, 8, 128, 8]), V size: torch.Size([2, 8, 128, 8])
Attention scores size: torch.Size([2, 8, 128, 128])
Softmaxed attention scores size: torch.Size([2, 8, 128, 128])
Attention output size: torch.Size([2, 8, 128, 8])
Reshaped output size: torch.Size([2, 128, 64])
Final output size: torch.Size([2, 128, 64])
Output: tensor([[[ 0.2121,  0.0830, -0.3663,  ..., -0.4172,  0.0820,  0.0012],
         [ 0.3427,  0.2058, -0.4499,  ..., -0.2930, -0.0223,  0.0244],
         [ 0.3764,  0.2091, -0.5816,  ..., -0.3658,  0.0075,  0.0493],
         ...,
         [ 0.0141, -0.1585, -0.1081,  ...,  0.0640, -0.1367,  0.2067],
         [-0.0184, -0.1475, -0.1249,  ...,  0.0404, -0.1119,  0.1863],
         [-0.0016, -0.1797, -0.0994,  ...,  0.0016, -0.0404,  0.1661]],

        [[-0.