In [18]:
import torch
from torch import nn
from torch.functional import F

In [16]:
emb = nn.Embedding(64, 512)
X = torch.randint(0, 20, (64,))
y = emb(X[:1])
y

tensor([[-0.3291, -1.3480,  0.9504, -1.7910, -0.3196, -0.0874,  1.2529,  1.3262,
          1.6493,  1.6755,  0.8933,  0.7003, -0.2285, -0.2416, -1.0265, -0.7365,
         -0.7453, -1.0572,  0.0381,  0.6842, -0.8422,  1.6608,  0.9691,  1.2594,
          0.1552,  1.4307,  2.0422, -0.6484,  0.3191,  0.0791, -0.5932, -0.4018,
          0.0460,  0.7362, -1.9512, -0.0411,  0.5559, -0.7926, -0.7191,  1.5175,
         -1.1717, -0.9494, -0.1912,  1.3918, -0.8391, -0.9527, -0.8429,  0.6504,
         -0.4350,  0.2483,  1.3927, -0.1103, -0.8361,  0.9317,  0.8597, -0.6459,
         -2.1871,  1.6566, -0.1186, -0.8953,  0.6670, -0.3782, -0.2250,  0.3909,
          0.8970,  0.4080, -0.4974, -0.6406, -0.0366, -1.3596, -0.1455,  0.5838,
         -0.5452,  0.1360,  2.5299,  0.5774, -0.7170,  0.0148, -1.4872,  0.5589,
         -1.3776, -1.5706,  0.1089, -0.5625, -0.2405,  0.4433,  0.9503, -0.1150,
         -0.1295,  0.0909, -0.3933,  0.3517, -0.0670,  0.0380,  0.1372,  1.7550,
         -0.4367, -2.1240, -

In [22]:
n_embd = 512
qkv = nn.Linear(n_embd, 3*n_embd)
project_qkv = qkv(y)
project_qkv.shape

torch.Size([1, 1536])

In [29]:
y.size()

torch.Size([1, 512])

In [35]:
B = 1
T, C = y.size()
n_head = 8
d_k = n_embd // n_head

project_qkv = project_qkv.reshape(B, T, 3, n_head, d_k)
project_qkv = project_qkv.permute(2, 0, 3, 1, 4)
project_qkv.shape

q, k, v = project_qkv[0], project_qkv[1], project_qkv[2]

In [43]:
scale = d_k ** -0.5 
dropout = 0.1

attn_dropout = nn.Dropout(dropout)
resid_dropout = nn.Dropout(dropout)

attn_scores = (q @ k.transpose(-2, -1)) * scale

attn_probs = F.softmax(attn_scores, dim=-1)
attn_probs = attn_dropout(attn_probs)

attn_output = (attn_probs @ v)
print(attn_output.shape)

attn_output = attn_output.transpose(1, 2)
print(attn_output.shape)

attn_output = attn_output.reshape(B, T, C)
print(attn_output.shape)


out_proj = nn.Linear(n_embd, n_embd, bias=False)

attn_output = resid_dropout(out_proj(attn_output))
attn_output

torch.Size([1, 8, 1, 64])
torch.Size([1, 1, 8, 64])
torch.Size([1, 1, 512])


tensor([[[ 1.0464e-01,  5.6760e-01, -1.0019e-02,  1.8100e-01, -2.5226e-02,
          -7.5832e-01,  1.2557e-01, -6.3462e-01,  1.8910e-01,  2.3722e-01,
           3.3805e-02, -9.4745e-02, -1.2421e-01,  4.4703e-01, -9.5530e-01,
          -3.5607e-01, -7.6822e-01, -2.2390e-01, -2.6740e-01,  3.5489e-01,
           5.1837e-02,  4.1861e-01, -7.4142e-01,  2.5923e-01,  3.0335e-01,
          -8.3863e-02, -3.6304e-01, -4.8254e-02, -3.5712e-02,  0.0000e+00,
           0.0000e+00, -4.3660e-01,  7.0496e-01,  3.5344e-01, -5.5563e-01,
          -0.0000e+00, -3.2901e-01,  6.1909e-03,  2.1131e-02, -2.5607e-02,
          -4.4943e-02,  0.0000e+00, -3.4360e-02,  5.8925e-01, -7.1074e-01,
           2.6359e-01, -2.8649e-01, -7.7213e-01,  0.0000e+00,  3.1426e-02,
           0.0000e+00,  2.2114e-01,  5.8297e-01, -0.0000e+00,  5.6077e-01,
           1.2574e-01,  3.4656e-01, -6.5826e-01,  3.7308e-01, -1.6215e-01,
           4.0490e-01, -3.0029e-01, -2.2447e-01,  2.2844e-01, -1.9581e-02,
          -6.5172e-01,  3

In [17]:
class CausalSelfAttention(nn.Module):
    def __init__(self, n_embd, n_head, block_size, dropout, bias):
        super().__init__()
        assert n_embd % n_head == 0

        self.n_head = n_head
        self.d_k = n_embd // n_head
        self.scale = self.d_k ** -0.5

        self.qkv_proj = nn.Linear(n_embd, 3 * n_embd, bias=bias)
        self.out_proj = nn.Linear(n_embd, n_embd, bias=bias)
        self.attn_dropout = nn.Dropout(dropout)
        self.resid_dropout = nn.Dropout(dropout)

        self.register_buffer("mask", torch.tril(torch.ones(block_size, block_size)).view(1, 1, block_size, block_size))

    def forward(self, x):
        B, T, C = x.size()
        qkv = self.qkv_proj(x).reshape(B, T, 3, self.n_head, self.d_k).permute(2, 0, 3, 1, 4)
        q, k, v = qkv[0], qkv[1], qkv[2]

        attn_scores = (q @ k.transpose(-2, -1)) * self.scale
        attn_scores = attn_scores.masked_fill(self.mask[:, :, :T, :T] == 0, float('-inf'))
        attn_probs = F.softmax(attn_scores, dim=-1)
        attn_probs = self.attn_dropout(attn_probs)

        attn_output = (attn_probs @ v).transpose(1, 2).reshape(B, T, C)
        attn_output = self.resid_dropout(self.out_proj(attn_output))

        return attn_output

torch.Size([1, 512])

In [63]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

# Configuration parameters
n_embd = 768  # Number of embedding dimensions
n_head = 12   # Number of attention heads
head_dim = n_embd // n_head  # Dimension per head
scale = head_dim ** -0.5  # Scaling factor for dot product attention
dropout_p = 0.1  # Dropout probability
block_size = 64  # Size of blocks for tiling
causal = True  # Assuming causal for GPT-like model

x = torch.randn(2, 512, n_embd)  # Example input tensor

# Linear layer to project input to QKV
qkv_proj = nn.Linear(n_embd, 3 * n_embd)
qkv = qkv_proj(x).view(x.size(0), x.size(1), 3, n_head, head_dim)

# Split into query, key, and value
q, k, v = qkv.unbind(2)

# Rearrange for easier computation
q = rearrange(q, 'b t h d -> (b h) t d')
k = rearrange(k, 'b t h d -> (b h) t d')
v = rearrange(v, 'b t h d -> (b h) t d')


In [64]:
# Initialize output tensor
output = torch.zeros_like(q)
output

tensor([[[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.]],

        ...,

        [[0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0.,  ..., 0., 0., 0.],
         ...,
         [0., 0., 0.,  ..., 0., 0., 0.],
         [0., 0., 0., 

In [82]:
q[:, 64:2]

(torch.Size([24, 0, 64]), torch.Size([24, 512, 64]))

In [88]:
for i in range(0, x.size(1), block_size):
    i_end = min(i + block_size, x.size(1))  # Define the end of the current block
    q_block = q[:, i:i_end]  # Extract the current block of queries
    print('q', q_block.shape)

    m = torch.full((q.shape[0], i_end - i), float('-inf'), device=q.device)  # Initialize m for numerical stability
    l = torch.zeros((q.shape[0], i_end - i), device=q.device)  # Initialize l for accumulation

    for j in range(0, x.size(1), block_size):
        j_end = min(j + block_size, x.size(1))  # Define the end of the current key/value block
        k_block = k[:, j:j_end]  # Extract the current block of keys
        v_block = v[:, j:j_end]  # Extract the current block of values
        print('k', k_block.shape)

        # Compute attention scores
        attn_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * scale

    #     # Apply causal masking if necessary
    #     if causal and j > i:
    #         attn_block.fill_(float('-inf'))
    #     elif causal:
    #         causal_mask = torch.triu(torch.ones(i_end - i, j_end - j, dtype=torch.bool, device=attn_block.device), diagonal=j - i + 1)
    #         attn_block.masked_fill_(causal_mask, float('-inf'))

    #     # Compute exponential attention scores
        m_new = torch.maximum(m, attn_block.max(dim=-1)[0])
        exp_attn = torch.exp(attn_block - m_new.unsqueeze(-1))

    #     # Accumulate the exponentials
        l_new = l * torch.exp(m - m_new) + exp_attn.sum(dim=-1)
        output_block = torch.matmul(exp_attn, v_block)

    #     # Update output tensor
    #     output[:, i:i_end] += (output_block - output[:, i:i_end]) * (l / l_new).unsqueeze(-1)

    #     m, l = m_new, l_new  # Update m and l for the next block

q torch.Size([24, 64, 64])
k torch.Size([24, 64, 64])
tensor([[[0.5319, 0.2556, 0.8579,  ..., 0.5720, 0.8067, 0.2368],
         [0.5617, 0.2484, 0.7162,  ..., 0.7858, 0.4037, 0.3142],
         [0.3666, 0.4530, 0.3844,  ..., 0.2680, 0.1927, 0.8306],
         ...,
         [0.3080, 0.5550, 0.4721,  ..., 0.5330, 0.8827, 0.4427],
         [0.3601, 0.3822, 0.3473,  ..., 0.3981, 0.4210, 0.2101],
         [0.6152, 0.3377, 0.5226,  ..., 0.3232, 0.2234, 0.4977]],

        [[0.4002, 0.7167, 0.4578,  ..., 0.5990, 0.5494, 0.4916],
         [0.1576, 0.2719, 0.2555,  ..., 0.2833, 0.2457, 0.2109],
         [0.5954, 0.7846, 0.8272,  ..., 0.2973, 1.0000, 0.8668],
         ...,
         [0.5079, 0.5349, 0.2823,  ..., 0.4593, 0.3702, 0.6295],
         [0.5131, 0.5369, 0.2979,  ..., 0.3473, 0.5669, 0.5330],
         [0.6837, 0.4860, 0.4508,  ..., 0.3990, 0.8357, 0.7972]],

        [[0.6742, 0.4555, 0.4356,  ..., 0.3969, 0.2809, 0.5616],
         [0.8453, 0.5961, 0.4333,  ..., 0.8676, 0.5440, 0.6574],
    

In [66]:
# Normalize output
output[:, i:i_end] /= l.unsqueeze(-1)

# Rearrange output to original dimensions
output = rearrange(output, '(b h) t d -> b t (h d)', h=n_head)

# Linear layer to project output back to original embedding dimension
proj = nn.Linear(n_embd, n_embd)
output = proj(output)

# Output tensor shape should be [batch_size, seq_len, n_embd]
print(output.shape)  # Should be torch.Size([2, 512, 768])


torch.Size([2, 512, 768])


In [67]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange

class SimpleFlashAttention(nn.Module):
    def __init__(self, config):
        super().__init__()
        self.n_embd = config.n_embd
        self.n_head = config.n_head
        self.head_dim = self.n_embd // self.n_head
        self.scale = self.head_dim ** -0.5
        self.qkv = nn.Linear(self.n_embd, 3 * self.n_embd)
        self.proj = nn.Linear(self.n_embd, self.n_embd)
        self.dropout_p = config.dropout
        self.causal = True  # assuming causal for GPT-like model
        self.block_size = config.block_size  # size of blocks for tiling, now configurable

    def forward(self, x):
        b, t, c = x.size()
        qkv = self.qkv(x).view(b, t, 3, self.n_head, self.head_dim)
        q, k, v = qkv.unbind(2)
        q, k, v = [rearrange(x, 'b t h d -> (b h) t d') for x in (q, k, v)]

        output = torch.zeros_like(q)
        
        for i in range(0, t, self.block_size):
            i_end = min(i + self.block_size, t)
            q_block = q[:, i:i_end]
            
            m = torch.full((q.shape[0], i_end - i), float('-inf'), device=q.device)
            l = torch.zeros((q.shape[0], i_end - i), device=q.device)
            
            for j in range(0, t, self.block_size):
                j_end = min(j + self.block_size, t)
                k_block = k[:, j:j_end]
                v_block = v[:, j:j_end]
                
                attn_block = torch.matmul(q_block, k_block.transpose(-2, -1)) * self.scale
                
                if self.causal and j > i:
                    attn_block.fill_(float('-inf'))
                elif self.causal:
                    causal_mask = torch.triu(torch.ones(i_end - i, j_end - j, dtype=torch.bool, device=attn_block.device), diagonal=j - i + 1)
                    attn_block.masked_fill_(causal_mask, float('-inf'))
                
                m_new = torch.maximum(m, attn_block.max(dim=-1)[0])
                exp_attn = torch.exp(attn_block - m_new.unsqueeze(-1))
                
                l_new = l * torch.exp(m - m_new) + exp_attn.sum(dim=-1)
                output_block = torch.matmul(exp_attn, v_block)
                
                output[:, i:i_end] += (output_block - output[:, i:i_end]) * (l / l_new).unsqueeze(-1)
                
                m, l = m_new, l_new
            
            output[:, i:i_end] /= l.unsqueeze(-1)
        
        output = rearrange(output, '(b h) t d -> b t (h d)', h=self.n_head)
        return self.proj(output)

# Example usage
class Config:
    n_embd = 768
    n_head = 12
    dropout = 0.1
    block_size = 64  # Configurable block size

config = Config()
model = SimpleFlashAttention(config)
x = torch.randn(2, 512, config.n_embd)  # batch_size=2, seq_len=512
output = model(x)
print(output.shape)  # Should be torch.Size([2, 512, 768])

torch.Size([2, 512, 768])
