In [1]:
import warnings
import contextlib
import torch
import torch.nn as nn
from torch import Tensor
from typing import Tuple, Type
import torch.nn.functional as F
from misc import get_sdpa_settings
warnings.filterwarnings(action='ignore', category=FutureWarning)
OLD_GPU, USE_FLASH_ATTN, MATH_KERNEL_ON = get_sdpa_settings()
ALLOW_ALL_KERNELS = False

ModuleNotFoundError: No module named 'tqdm'

In [None]:
from sys import intern


def sdp_kernel_context(dropout_p):
    """
    Get the context for the attention scaled dot-product kernel. We use Flash Attention
    by default, but fall back to all available kernels if Flash Attention fails.
    """
    if ALLOW_ALL_KERNELS:
        return contextlib.nullcontext()

    return torch.backends.cuda.sdp_kernel(
        enable_flash=USE_FLASH_ATTN,
        # if Flash attention kernel is off, then math kernel needs to be enabled
        enable_math=(OLD_GPU and dropout_p > 0.0) or MATH_KERNEL_ON,
        enable_mem_efficient=OLD_GPU,
    )


class Attention(nn.Module):
    """attention layer

    Args:
        nn (_type_): _description_
    """

    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        downsample_rate: int = 1,
        dropout: float = 0.0,
        kv_in_dim: int = None,
    ) -> None:
        super().__init__()
        self.embedding_dim = embedding_dim
        self.internal_dim = embedding_dim // downsample_rate
        self.kv_in_dim = kv_in_dim if kv_in_dim is not None else embedding_dim
        self.num_heads = num_heads
        assert (
            self.internal_dim % self.num_heads == 0
        ), "number of heads must divide internal dimension"

        self.q_proj = nn.Linear(embedding_dim, self.internal_dim)
        self.k_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
        self.v_proj = nn.Linear(self.kv_in_dim, self.internal_dim)
        self.out_proj = nn.Linear(self.internal_dim, embedding_dim)
        self.dropout_p = dropout

    def _separate_heads(self, x: Tensor, num_heads: int) -> Tensor:
        batch_size, seq_len, internal_dim = x.shape
        x = x.view(batch_size, seq_len, num_heads, internal_dim // num_heads).permute(0, 2, 1, 3)
        return x # B * N_heads * Sequence_len * dim_per_head

    def _recombine_heads(self, x: Tensor) -> Tensor:
        b, n_heads, n_tokens, dim_per_head = x.shape
        x = x.permute(0, 2, 1, 3).reshape(b, n_tokens, n_heads * dim_per_head)
        return x
    
    def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
        # Input Projection
        q = self.q_proj(q)
        k = self.k_proj(k)
        v = self.v_proj(v)

        # multi-head
        q = self._separate_heads(q, self.num_heads)
        k = self._separate_heads(k, self.num_heads)
        v = self._separate_heads(v, self.num_heads)

        dropout_p = self.dropout_p if self.training else 0.0
        try:
            with sdp_kernel_context(dropout_p):
                out = F.scaled_dot_product_attention(q, k, v, dropout_p)
        except Exception as e:
            warnings.warn(
                f"Flash Attention kernel failed due to: {e}\nFalling back to all available "
                f"kernels for scaled_dot_product_attention (which may have a slower speed).",
                category=UserWarning,
                stacklevel=2,
            )
            global ALLOW_ALL_KERNELS
            ALLOW_ALL_KERNELS = True
            out = F.scaled_dot_product_attention(q, k, v, dropout_p)
        
        out = self._recombine_heads(out)
        out = self.out_proj(out)
        return out

In [None]:
class TwoWayAttentionBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        num_heads: int,
        mlp_dim: int = 2048,
        activation: Type[nn.Module] = nn.ReLU,
        attention_dowansample_rate: int = 2,
        skip_first_layer_pe: bool = False,
    ) -> None:
        """Tansformer block有四个层：
        1. 稀疏查询自注意力
        2. 稀疏到稠密查询交叉注意力
        3. mlp稀疏查询
        4. 密集查询到稀疏查询的交叉注意力
        

        Args:
            embedding_dim (int): _description_
            num_heads (int): _description_
            mlp_dim (int, optional): _description_. Defaults to 2048.
            activation (Type[nn.Module], optional): _description_. Defaults to nn.ReLU.
            attention_dowansample_rate (int, optional): _description_. Defaults to 2.
            skip_first_layer_pe (bool, optional): _description_. Defaults to False.
        """
        super().__init__()
        self.
        
        
    def forward(
        self,

    )

In [None]:
class TwoWayTransformer(nn.Module):
    def __init__(
        self,
        depth: int,
        embedding_dim: int,
        num_heads: int,
        mlp_dim:int,
        activation: Type[nn.Module] = nn.ReLU,
        attention_downsample: int = 2,
    ) -> None:
        """一个双向的transformer，用于处理两个不同的输入，然后输出一个结果

        Args:
            depth (int): layer的数量
            embedding_dim (int): 编码的维度
            num_heads (int): 多头注意力
            mlp_dim (int): mlp的维度
            activation (Type[nn.Module], optional): _description_. Defaults to nn.ReLU.
            attention_downsample (int, optional): _description_. Defaults to 2.
        """
        super().__init__()
        self.depth = depth
        self.embedding_dim = embedding_dim 
        self.num_heads = num_heads
        self.mlp_dim = mlp_dim
        self.layers = nn.ModuleList()

        for i in range(depth):
            self.layers.append(
                TwoWayAttnetionBlock(
                    
            )
    def forward(
        self,
        
    )