In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import math

class ModelArgs:
    def __init__(self, dim=512, n_heads=8, n_kv_heads=None, max_batch_size=2, max_seq_len=10):
        self.dim = dim
        self.n_heads = n_heads
        self.n_kv_heads = n_kv_heads if n_kv_heads is not None else n_heads
        self.max_batch_size = max_batch_size
        self.max_seq_len = max_seq_len
class SelfAttention(nn.Module):
    def __init__(self, args: ModelArgs):
        super().__init__()
        
        # Key, Value들의 헤드 개수
        self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
        
        # Query의 헤드개수수
        self.n_heads_q = args.n_heads
        self.n_rep = self.n_heads_q // self.n_kv_heads # ratio between q랑 kv
        
        # 각 헤드의 차원 개수
        self.head_dim = args.dim // args.n_heads
        
        self.wq = nn.Linear(args.dim, args.n_heads*self.head_dim, bias=False)
        self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
        self.wo = nn.Linear(args.n_heads*self.head_dim, args.dim, bias=False)
        
        self.cache_k = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
        self.cache_v = torch.zeros((args.max_batch_size, args.max_seq_len, self.n_kv_heads, self.head_dim))
        
    def forward(self, x:torch.Tensor, start_pos: int, freqs_complex: torch.Tensor):
        batch_size, seq_len, _ = x.size() # B, 길이, dim -> 근데 길이는 1

        xq = self.wq(x)
        xk = self.wk(x)
        xv = self.wv(x)
        
        xq = xq.view(batch_size, seq_len, self.n_heads_q, self.head_dim)
        xk = xk.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        xv = xv.view(batch_size, seq_len, self.n_kv_heads, self.head_dim)
        
        # RoPE 적용
#        xq = apply_rotary_embeddings(xq, freqs_complex, device=x.device)
#        xk = apply_rotary_embeddings(xk, freqs_complex, device=x.device)
        
        # replace the entry in the cache for this token
        self.cache_k[:batch_size, start_pos:start_pos+seq_len] = xk
        self.cache_v[:batch_size, start_pos:start_pos+seq_len] = xv
        
        # Retrieve all the cached keys and values so far
        keys = self.cache_k[:batch_size, :start_pos+seq_len]
        values = self.cache_v[:batch_size, :start_pos+seq_len]
        
        keys = repeat_kv(keys, self.n_rep)
        values = repeat_kv(values, self.n_rep)
        
        xq = xq.transpose(1,2)
        keys = keys.transpose(1,2)
        values = values.transpose(1,2)
        
        # (B, H_Q, 1, Head_Dim) @ (B, H_Q, Head_dim, seq_len_kv) -> (B,H_Q,1,seq_len_kv)
        scores = torch.matmul(xq, keys.transpose(2,3)) / math.sqrt(self.head_dim)
        scores = F.softmax(scores.float(), dim=-1).type_as(xq)
        
        output = torch.matmul(scores, values)
        
        output = (output.transpose(1,2).contiguous().view(batch_size, seq_len, -1))
        return self.wo(output)

def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
    batch_size, seq_len, n_kv_heads, head_dim = x.shape
    if n_rep ==1:
        return x
    else:
        return (            # B, seq_len, N_KV_heads, 1, head_dim
            x[:,:,:,None,:]
            .expand(batch_size,seq_len, n_kv_heads, n_rep, head_dim)
            .reshape(batch_size, seq_len, n_kv_heads*n_rep, head_dim)
        )

# 예제 파라미터 설정
args = ModelArgs(dim=512, n_heads=8, max_batch_size=2, max_seq_len=10)
self_attention = SelfAttention(args)

# 예제 입력 데이터 생성
batch_size = 2
seq_len = 1  # SelfAttention이 1개의 토큰을 처리한다고 가정
x = torch.randn(batch_size, seq_len, args.dim)  # 임의의 입력 텐서

# Rotary Embedding에 사용할 freqs_complex (더미 데이터)
freqs_complex = torch.randn(batch_size, seq_len, args.n_heads, args.dim // args.n_heads)

# 실행
start_pos = 0
output = self_attention(x, start_pos, freqs_complex)

print("Output Shape:", output.shape)
print("Output:", output)


Output Shape: torch.Size([2, 1, 512])
Output: tensor([[[-0.1775,  0.2161,  0.3237,  ..., -0.0079,  0.1935,  0.2962]],

        [[ 0.1869,  0.2499, -0.5362,  ..., -0.0165,  0.2009, -0.1461]]],
       grad_fn=<UnsafeViewBackward0>)
