In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import List, Dict, Optional, Tuple, Union
backbone = torch.hub.load("models/backbone/dinov3", 'dinov3_convnext_small', source='local', weights='./checkpoints/dinov3_convnext_small_pretrain_lvd1689m-296db49d.pth')
backbone.forward_features(torch.randn(1, 3, 224, 224))[3]

In [None]:
class GaussianPooling(nn.Module):
    def __init__(self, kernel_size: int = 7, sigma: float = 2.0):
        super().__init__()
        self.kernel_size = kernel_size
        self.sigma = sigma
        self.padding = kernel_size // 2

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        batch_size, channels, height, width = x.size()
        device = x.device

        # Create Gaussian kernel
        ax = torch.arange(-self.padding, self.padding + 1, device=device).float()
        xx, yy = torch.meshgrid(ax, ax)
        kernel = torch.exp(-(xx**2 + yy**2) / (2 * self.sigma**2))
        kernel = kernel / kernel.sum()
        kernel = kernel.view(1, 1, self.kernel_size, self.kernel_size).repeat(channels, 1, 1, 1)

        # Apply Gaussian pooling
        x = F.conv2d(x, kernel, padding=self.padding, groups=channels)
        return x

In [None]:
import math


class GatedMultiHeadAttention(nn.Module):
    """
    Gated Multi-Head Attention from "Gated Attention for Large Language Models: 
    Non-linearity, Sparsity, and Attention-Sink-Free" (arXiv:2505.06708)
    
    핵심 아이디어:
    - SDPA (Scaled Dot-Product Attention) 출력 후에 sigmoid gate를 적용
    - 수식: Y' = Y ⊙ σ(XW_θ)
    - head-specific한 gating으로 attention sink 현상 완화
    """
    
    def __init__(
        self,
        d_model: int,
        num_heads: int,
        dropout: float = 0.1,
        gate_type: str = "elementwise",  # "elementwise" or "headwise"
        bias: bool = True,
    ):
        """
        Args:
            d_model: 모델의 hidden dimension
            num_heads: attention head 수
            dropout: dropout 비율
            gate_type: "elementwise" (각 원소별) 또는 "headwise" (head별 단일 게이트)
            bias: linear layer에 bias 사용 여부
        """
        super().__init__()
        
        assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
        
        self.d_model = d_model
        self.num_heads = num_heads
        self.d_k = d_model // num_heads
        self.gate_type = gate_type
        
        # Q, K, V projection layers
        self.W_q = nn.Linear(d_model, d_model, bias=bias)
        self.W_k = nn.Linear(d_model, d_model, bias=bias)
        self.W_v = nn.Linear(d_model, d_model, bias=bias)
        
        # Output projection
        self.W_o = nn.Linear(d_model, d_model, bias=bias)
        
        # Gate projection (핵심 추가 부분)
        # elementwise: 각 차원마다 독립적인 gate
        # headwise: 각 head마다 하나의 gate 값
        if gate_type == "elementwise":
            self.W_gate = nn.Linear(d_model, d_model, bias=bias)
        elif gate_type == "headwise":
            self.W_gate = nn.Linear(d_model, num_heads, bias=bias)
        else:
            raise ValueError(f"Unknown gate_type: {gate_type}")
        
        self.dropout = nn.Dropout(dropout)
        self.scale = math.sqrt(self.d_k)
        
    def forward(
        self,
        x: torch.Tensor,
        mask: torch.Tensor = None,
        return_gate_scores: bool = False,
    ):
        """
        Args:
            x: Input tensor (batch_size, seq_len, d_model)
            mask: Attention mask (batch_size, seq_len, seq_len) or (batch_size, 1, seq_len, seq_len)
            return_gate_scores: gate 값도 반환할지 여부
            
        Returns:
            output: (batch_size, seq_len, d_model)
            gate_scores (optional): gate 값들
        """
        batch_size, seq_len, d_model = x.shape
        
        # 1. Linear projections for Q, K, V
        Q = self.W_q(x)  # (batch_size, seq_len, d_model)
        K = self.W_k(x)
        V = self.W_v(x)
        
        # 2. Split into multiple heads
        Q = Q.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        K = K.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        V = V.view(batch_size, seq_len, self.num_heads, self.d_k).transpose(1, 2)
        # Shape: (batch_size, num_heads, seq_len, d_k)
        
        # 3. Scaled Dot-Product Attention (SDPA)
        scores = torch.matmul(Q, K.transpose(-2, -1)) / self.scale
        # Shape: (batch_size, num_heads, seq_len, seq_len)
        
        # Apply mask if provided
        if mask is not None:
            if mask.dim() == 3:
                mask = mask.unsqueeze(1)  # Add head dimension
            scores = scores.masked_fill(mask == 0, float('-inf'))
        
        # Attention weights
        attn_weights = F.softmax(scores, dim=-1)
        attn_weights = self.dropout(attn_weights)
        
        # Attention output
        Y = torch.matmul(attn_weights, V)
        # Shape: (batch_size, num_heads, seq_len, d_k)
        
        # 4. 핵심: Gating mechanism (논문의 G1 position)
        # Compute gate scores: σ(XW_θ)
        gate_logits = self.W_gate(x)  # (batch_size, seq_len, d_model or num_heads)
        gate_scores = torch.sigmoid(gate_logits)
        
        if self.gate_type == "elementwise":
            # Element-wise gating: 각 차원마다 독립적인 gate
            gate_scores = gate_scores.view(batch_size, seq_len, self.num_heads, self.d_k)
            gate_scores = gate_scores.transpose(1, 2)
            # Shape: (batch_size, num_heads, seq_len, d_k)
            
            # Apply gate: Y' = Y ⊙ σ(XW_θ)
            Y_gated = Y * gate_scores
            
        elif self.gate_type == "headwise":
            # Head-wise gating: 각 head마다 하나의 gate 값
            gate_scores = gate_scores.view(batch_size, seq_len, self.num_heads, 1)
            gate_scores = gate_scores.transpose(1, 2)
            # Shape: (batch_size, num_heads, seq_len, 1)
            
            # Apply gate
            Y_gated = Y * gate_scores
        
        # 5. Concatenate heads
        Y_gated = Y_gated.transpose(1, 2).contiguous()
        # Shape: (batch_size, seq_len, num_heads, d_k)
        Y_gated = Y_gated.view(batch_size, seq_len, d_model)
        # Shape: (batch_size, seq_len, d_model)
        
        # 6. Final output projection
        output = self.W_o(Y_gated)
        output = self.dropout(output)
        
        if return_gate_scores:
            return output, gate_scores
        return output

In [None]:
class MetaSpace(nn.Moduule):
    def __init__(self, size, feature_dim: int, num_meta_spaces: int):
        super().__init__()
        self.original_size = size
        self.pool = GaussianPooling(kernel_size=5, sigma=2.0)
        self.feature_dim = feature_dim
        self.num_meta_spaces = num_meta_spaces
        self.meta_spaces = nn.Parameter(torch.randn(num_meta_spaces, feature_dim))
        self.gmha = nn.ModuleList(GatedMultiHeadAttention(
            d_model=feature_dim,
            num_heads=8,
            gate_type="headwise"
        ) for _ in range(3))
        
    def forward_features(self, x: torch.Tensor, keypoints: torch.Tensor) -> torch.Tensor:
        pooled_features = []
        for feature_space in x[:-1]:
            for kpts in keypoints:
                resized_kpts = self.cal_resized_keypoints(kpts, feature_space.shape[-2:])
                px, py = resized_kpts[:, 0], resized_kpts[:, 1]
                patch = feature_space[:, :, py.long(), px.long()]
                gaussian_feature = self.pool(patch)

                # feature_space와 resized_kpts를 이용한 추가적인 처리 로직 필요
            pooled_features.append(self.pool(feature_space))
            # 추가적인 처리 로직 필요

    def cal_resized_keypoints(self, keypoints: torch.Tensor, target_size: Tuple[int, int]) -> torch.Tensor:
        orig_h, orig_w = self.original_size
        target_h, target_w = target_size
        scale_x = target_w / orig_w
        scale_y = target_h / orig_h
        resized_keypoints = keypoints.clone()
        resized_keypoints[..., 0] *= scale_x
        resized_keypoints[..., 1] *= scale_y
        return resized_keypoints

    def forward(self, x: torch.Tensor, keypoints: torch.Tensor) -> torch.Tensor:
        
        pass 
        

In [None]:
class FSKD(nn.Module):
    def __init__(
            self, 
            in_channels: int, 
            out_channels: int
        ):
        super(FSKD, self).__init__()
        self.backbone = torch.hub.load("models/backbone/dinov3", 
                                       'dinov3_convnext_small', 
                                       source='local', 
                                       weights='./checkpoints/dinov3_convnext_small_pretrain_lvd1689m-296db49d.pth')
        
        self.neck = nn.Sequential()
        self.head = nn.Sequential()

    def forward_features(
            self, 
            x: torch.Tensor, 
            masks: Optional[torch.Tensor] = None
        ) -> List[Dict[str, torch.Tensor]]:
        scaled_features = self.backbone.forward_features_list([x], [masks])[1:]
        pose_feature = self.neck(scaled_features)
        result = self.head(pose_feature)
        return result
            
    def forward(self, x: torch.Tensor) -> List[Dict[str, torch.Tensor]]:
        result = self.forward_features(x)
        return result