In [None]:
from typing import List, Optional

import torch 
import torch.nn as nn
import torch.nn.functional as F

from einops import rearrange
from stable_diffusion_model import BaSicTransformerBlock, CrossAttentionBlock, LinearCrossAttentionBlock, FeedFoward, Normal_Attention

In [None]:
class SpatialTransformer(nn.Module):
    def __init__(self, 
                 channels: int, 
                 heads: int, 
                 block_num: int, 
                 model_dim: int, 
                 clip_dim:int,
                 linear:List[None]
                 ): 
        super().__init__()
        self.channels = channels
        self.heads = heads
        self.block_num = block_num
        self.model_dim = model_dim
        self.clip_dim = clip_dim
        self.linear = linear

        self.norm = torch.nn.GroupNorm(32, channels, eps=1e-6, affine=True)
        self.proj_in = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)

        self.transformer_blocks = nn.ModuleList(
            [BaSicTransformerBlock(channels, heads, channels//heads, cond_dim = clip_dim) for _ in range(block_num)]
        )

        self.proj_out = nn.Conv2d(channels, channels, kernel_size=1, stride=1, padding=0)
    

    def forward(self, x:torch.Tensor, clip_dim:torch.Tensor, linear:List[None], fast_connect: bool = False) -> torch.Tensor:
        b, c, h, w = x.shape
        if fast_connect:
            x_in = nn.Identity(x)
        x = self.norm(x)
        x = self.proj_in(x)
        x = rearrange(x, 'b c h w -> b h*w c')
        # x = x.permute(0, 2, 3, 1).view(b, h*w, c)
        for blocks in self.transformer_blocks:
            x = self.transformer_blocks(x, )
        x = rearrange(x, 'b h*w c -> b c h w')
        x = self.proj_out(x)
        x_out = x_in + x
        return x_out

We will use two options here, one would be the linear focus attention, another one would be the normal cross attention


In [None]:
class BaSicTransformerBlock(nn.Module):
    def __init__(self, 
                 emd_dim: torch.Tensor,
                 num_heads: torch.Tensor,
                 head_dim: torch.Tensor,
                 cond_dim: torch.Tensor,
                 linear = List[None],
                 ):
        """
        The model_dim would be the input channels of the ViT model, thus no explicit model_dim param is needed
        """
        super().__init__()
        self.emd_dim = emd_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.cond_dim = cond_dim
        self.linear = linear

        if self.linear[0] is not None:
            self.attention1 = LinearCrossAttentionBlock(emd_dim, emd_dim, num_heads, head_dim)
        else:
            self.attention1 = CrossAttentionBlock(emd_dim, emd_dim, num_heads, head_dim)
        self.norm1 = nn.LayerNorm(emd_dim)

        if self.linear[1] is not None:
            self.attention2 = LinearCrossAttentionBlock(emd_dim, cond_dim, num_heads, head_dim)
        else:
            self.attention2 = CrossAttentionBlock(emd_dim, cond_dim, num_heads, head_dim)
        self.norm2 = nn.LayerNorm(cond_dim)

        self.ffn = FeedFoward(emd_dim)
        self.norm3 = nn.LayerNorm(emd_dim)

    def forward(self, x:torch.Tensor, cond:torch.Tensor):
        x = self.norm1(self.attention1(x)) + x
        x = self.norm2(self.attention2(x, cond = cond)) + x
        x = self.norm3(self.ffn(x))
        return x

In [None]:
normal_attention = Normal_Attention()

In [None]:
class CrossAttentionBlock(nn.Module):
    def __init__(self, 
                 emd_dim: int, 
                 cond_dim: int, 
                 num_heads: int, 
                 head_dim: int, 
                 use_spe_attn: bool = False,
                 inplace:bool = True
                 ):
        super().__init__()
        self.emd_dim = emd_dim
        self.cond_dim = cond_dim
        self.num_heads = num_heads
        self.head_dim = head_dim
        self.inplace = inplace

        self.scale_factor = head_dim ** 0.5
        model_dim = self.head_dim ** 0.5
        self.to_q = nn.Linear(self.emb_dim, model_dim, bias=False)
        self.to_k = nn.Linear(self.emb_dim, model_dim, bias=False)
        self.to_v = nn.Linear(self.emb_dim, model_dim, bias=False)

        self.to_out = nn.Linear(model_dim,self.emb_dim)

        try:
            from stable_diffusion_model import EfficientAttention
            self.efficient_attention = EfficientAttention()
        except ImportError:
            self.efficient_attention = None

    def forward(self, x:torch.Tensor, cond:Optional[torch.Tensor]=None):
        has_cond = cond is not None
        if has_cond is None:
            cond = x
        q = self.to_q(x)
        k = self.to_k(x)
        v = self.to_v(v)
        
        if self.efficient_attention is not None and self.use_spe_attn == True and has_cond is not None and self.head_dim >= 128:
            return self.efficient_attention(q, k, v)
        else:
            return self.normal_attention(q, k, v)
        
    def normal_attention(self, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor):
        """
        split the Q, K, V into multiple heads and calculate the attention
        """
        q = rearrange(q, 'b l (d h) -> b l h d')
        k = rearrange(k, 'b l (d h) -> b l h d')
        v = rearrange(v, 'b l (d h) -> b l h d')
    
        attn = torch.einsum('b i h d, b j h d -> b h i j', q, k) * self.scale_factor

        if self.inplace:
            