<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/DeepSeek%2Battention.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
# https://github.com/deepseek-ai/DeepSeek-V3/blob/main/DeepSeek_V3.pdf

# this video is on attentions' implementation
# 【手写self-attention的四重境界-part1 pure self-attention】 https://www.bilibili.com/video/BV19YbFeHETz/?share_source=copy_web&vd_source=cdc9fab15e0ce1d464719ce689a12b14

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadAttention(nn.Module):
    def __init__(self, embed_dim, num_heads):
        super(MultiHeadAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads

        # Ensure embed_dim is divisible by num_heads
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads."

        self.head_dim = embed_dim // num_heads

        # Learnable projection matrices for query, key, and value
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        # Project inputs to queries, keys, and values
        queries = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        keys = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
        attended_values = torch.matmul(attention_weights, values)

        # Concatenate attended values from all heads
        attended_values = attended_values.transpose(1, 2).contiguous()
        attended_values = attended_values.view(batch_size, seq_len, embed_dim)

        # Apply the output projection
        output = self.out_proj(attended_values)

        return output

# Example usage
if __name__ == "__main__":
    batch_size = 2
    seq_len = 5
    embed_dim = 16
    num_heads = 4

    # Initialize model and input
    mla = MultiHeadAttention(embed_dim, num_heads)
    x = torch.randn(batch_size, seq_len, embed_dim)

    # Forward pass
    output = mla(x)
    print("Output shape:", output.shape)


Output shape: torch.Size([2, 5, 16])


In [6]:
import torch
import torch.nn as nn
import torch.nn.functional as F

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, num_latents):
        super(MultiHeadLatentAttention, self).__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.num_latents = num_latents

        # Ensure embed_dim is divisible by num_heads
        assert embed_dim % num_heads == 0, "Embedding dimension must be divisible by the number of heads."

        self.head_dim = embed_dim // num_heads

        # Latent representations (slots)
        self.latents = nn.Parameter(torch.randn(num_latents, embed_dim))

        # Learnable projection matrices for query, key, and value
        self.q_proj = nn.Linear(embed_dim, embed_dim)
        self.k_proj = nn.Linear(embed_dim, embed_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)

        # Output projection
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.size()

        # Expand latents for batch size
        latents = self.latents.unsqueeze(0).expand(batch_size, -1, -1)  # Shape: (batch_size, num_latents, embed_dim)

        # Project inputs and latents to queries, keys, and values
        queries = self.q_proj(latents).view(batch_size, self.num_latents, self.num_heads, self.head_dim).transpose(1, 2)
        keys = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
        values = self.v_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim).transpose(1, 2)

        # Scaled dot-product attention
        scores = torch.matmul(queries, keys.transpose(-2, -1)) / (self.head_dim ** 0.5)
        attention_weights = F.softmax(scores, dim=-1)
        attended_values = torch.matmul(attention_weights, values)

        # Concatenate attended values from all heads
        attended_values = attended_values.transpose(1, 2).contiguous()
        attended_values = attended_values.view(batch_size, self.num_latents, embed_dim)

        # Apply the output projection
        output = self.out_proj(attended_values)

        return output

# Example usage
if __name__ == "__main__":
    batch_size = 2
    seq_len = 5
    embed_dim = 16
    num_heads = 4
    num_latents = 3

    # Initialize model and input
    mla = MultiHeadLatentAttention(embed_dim, num_heads, num_latents)
    x = torch.randn(batch_size, seq_len, embed_dim)

    # Forward pass
    output = mla(x)
    print("Output shape:", output.shape)


Output shape: torch.Size([2, 3, 16])
