In [3]:
import torch
import torch.nn as nn

class HeadLatentAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, latent_dim):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        
        # Query, Key, Value 변환을 위한 선형 레이어
        self.W_q = nn.Linear(embed_dim, latent_dim)  # embed_dim -> latent_dim으로 변경
        self.W_k = nn.Linear(embed_dim, embed_dim)
        self.W_v = nn.Linear(embed_dim, embed_dim)
        
        # Latent Variable 생성
        self.latent_proj = nn.Linear(embed_dim, latent_dim)
        
        # 최종 출력 변환
        self.out_proj = nn.Linear(embed_dim, embed_dim)

    def forward(self, x):
        batch_size, seq_len, embed_dim = x.shape

        # Query를 latent_dim으로 투영하도록 변경
        Q = self.W_q(x)  # (batch_size, seq_len, latent_dim)
        K = self.W_k(x)  # (batch_size, seq_len, embed_dim)
        V = self.W_v(x)  # (batch_size, seq_len, embed_dim)
        
        # Latent Variable 생성
        L = self.latent_proj(x).mean(dim=1, keepdim=True)  # (batch_size, 1, latent_dim)

        # Latent Variable과 Query 간의 Attention 계산 (이제 차원이 맞습니다)
        attn_scores = torch.matmul(Q, L.transpose(-2, -1)) / (self.latent_dim ** 0.5)  # (batch_size, seq_len, 1)
        attn_weights = torch.softmax(attn_scores, dim=1)  # (batch_size, seq_len, 1)
        
        # Value와 attention weights를 곱합
        output = attn_weights.transpose(-2, -1) @ V  # (batch_size, 1, embed_dim)
        
        # 최종 출력 변환
        output = self.out_proj(output)  # (batch_size, 1, embed_dim)

        return output


# 테스트 실행
batch_size = 2
seq_len = 5
embed_dim = 32
latent_dim = 16
num_heads = 4

x = torch.randn(batch_size, seq_len, embed_dim)  # 입력
print(x.shape)
mla = HeadLatentAttention(embed_dim, num_heads, latent_dim)
output = mla(x)

print(output.shape)  # (batch_size, 1, embed_dim)

torch.Size([2, 5, 32])
torch.Size([2, 1, 32])


In [None]:
import torch
import torch.nn as nn

class MultiHeadLatentAttention(nn.Module):
    def __init__(self, embed_dim, num_heads, latent_dim):
        super().__init__()
        self.num_heads = num_heads
        self.embed_dim = embed_dim
        self.latent_dim = latent_dim
        self.head_dim = latent_dim // num_heads
        
        # Query, Key 투영
        self.q_proj = nn.Linear(embed_dim, latent_dim)
        self.k_proj = nn.Linear(embed_dim, latent_dim)
        self.v_proj = nn.Linear(embed_dim, embed_dim)
        
        # Latent Variable 생성
        self.latent_proj = nn.Linear(embed_dim, latent_dim)
        
        # 출력 투영
        self.out_proj = nn.Linear(embed_dim, embed_dim)
        
    def forward(self, x):
        batch_size, seq_len, _ = x.shape
        
        # Query, Key, Value 투영
        q = self.q_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        k = self.k_proj(x).view(batch_size, seq_len, self.num_heads, self.head_dim)
        v = self.v_proj(x)
        
        # Latent Variable 생성 (시퀀스 차원에서 평균)
        latent = self.latent_proj(x).mean(dim=1, keepdim=True)  # [batch_size, 1, latent_dim]
        latent = latent.view(batch_size, 1, self.num_heads, self.head_dim)
        
        # 차원 재배열 (multi-head attention 계산을 위해)
        q = q.transpose(1, 2)  # [batch_size, num_heads, seq_len, head_dim]
        k = latent.transpose(1, 2)  # [batch_size, num_heads, 1, head_dim]
        
        # Scaled Dot-Product Attention
        attn_scores = torch.matmul(q, k.transpose(2, 3)) / (self.head_dim ** 0.5)  # [batch_size, num_heads, seq_len, 1]
        attn_weights = torch.softmax(attn_scores, dim=2)  # [batch_size, num_heads, seq_len, 1]
        
        # 원래 시퀀스 차원으로 재배열
        attn_weights = attn_weights.transpose(1, 2).mean(dim=2, keepdim=True)  # [batch_size, seq_len, 1, 1]
        attn_weights = attn_weights.squeeze(-1)  # [batch_size, seq_len, 1]
        
        # Weighted sum
        output = torch.bmm(attn_weights.transpose(1, 2), v)  # [batch_size, 1, embed_dim]
        
        # 출력 투영
        output = self.out_proj(output)  # [batch_size, 1, embed_dim]
        
        return output


# 테스트 실행
batch_size = 2
seq_len = 5
embed_dim = 32
latent_dim = 16
num_heads = 4

x = torch.randn(batch_size, seq_len, embed_dim)
mla = MultiHeadLatentAttention(embed_dim, num_heads, latent_dim)
output = mla(x)

print(output.shape)  # 예상 출력: torch.Size([2, 1, 32])