In [None]:
import torch 
import math
import torch.nn as nn
import torch.nn.functional as F

def casual_attention_mask(seq_len):
    return torch.triu(torch.ones(seq_len, seq_len, dtype=torch.bool), diagonal=1)

class MultiHeadAttention(nn.Module):
    def __init__(self,
                 embed_dim: int,
                 num_heads: int,
                 rotary,
        ):
        super().__init__()
        self.embed_dim = embed_dim
        self.num_heads = num_heads
        self.rotary = rotary
        self.head_dim = embed_dim // num_heads

        self.q_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.k_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.v_proj = nn.Linear(embed_dim, embed_dim, bias=True)
        self.out_proj = nn.Linear(embed_dim, embed_dim, bias=True)

        for lin in (self.q_proj, self.k_proj, self.v_proj, self.out_proj):
            nn.init.xavier_uniform_(lin.weight)
            nn.init.zeros_(lin.bias)
        
    def forward(self, inputs, past_kv=None, use_cache: bool = False):
        B, L, D = inputs.shape

        # 1) project
        q = self.q_proj(inputs).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        k = self.k_proj(inputs).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)
        v = self.v_proj(inputs).view(B, L, self.num_heads, self.head_dim).transpose(1, 2)

        # 2) past length
        past_len = 0
        if past_kv is not None:
            # TODO: set past_len from past_kv[0]
            pass

        # 3) RoPE with offset
        # TODO: call self.rotary.rotation(q, k, start_pos=past_len)
        # q, k = ...

        # 4) append cache
        if past_kv is not None:
            past_k, past_v = past_kv
            # TODO: concat along sequence dim (dim=2)
            # k = ...
            # v = ...

        # 5) mask logic
        # Keep your exact old behavior for the no-cache full-seq case.
        # For cached incremental (typical L==1), you can use mask=None.
        mask = None
        if past_kv is None:
            # TODO: old mask path (only when L>1)
            pass
        else:
            # mask=None is fine for L==1 cached decode
            mask = None

        y = F.scaled_dot_product_attention(q, k, v, dropout_p=0.0, attn_mask=mask)
        y = y.transpose(1, 2).contiguous().view(B, L, D)
        y = self.out_proj(y)

        if use_cache:
            return y, (k, v)
        return y

