In [3]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange, repeat
from torch import einsum

def default(val, d):
    return val if val is not None else d

class CrossAttention(nn.Module):
    def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.):
        super().__init__()
        inner_dim = dim_head * heads
        context_dim = default(context_dim, query_dim)

        self.scale = dim_head ** -0.5
        self.heads = heads

        self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
        self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_k_control = nn.Linear(context_dim, inner_dim, bias=False)
        self.to_v_control = nn.Linear(context_dim, inner_dim, bias=False)

        self.to_out = nn.Sequential(
            nn.Linear(inner_dim, query_dim),
            nn.Dropout(dropout)
        )

    def forward(self, x, context=None, control=None, mask=None, lambda_=1.0):
        h = self.heads

        q = self.to_q(x)
        context = default(context, x)
        k = self.to_k(context)
        v = self.to_v(context)

        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

        sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

        if exists(mask):
            k_control = self.to_k_control(control)
            v_control = self.to_v_control(control)
            k_control, v_control = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (k_control, v_control))

            sim_control = einsum('b i d, b j d -> b i j', q, k_control) * self.scale

            mask = rearrange(mask, 'b ... -> b (...)')
            max_neg_value = -torch.finfo(sim_control.dtype).max
            mask = repeat(mask, 'b j -> (b h) () j', h=h)
            sim_control.masked_fill_(~mask, max_neg_value)

            attn_control = sim_control.softmax(dim=-1)
            out_control = einsum('b i j, b j d -> b i d', attn_control, v_control)
            out_control = rearrange(out_control, '(b h) n d -> b n (h d)', h=h)

        attn = sim.softmax(dim=-1)
        out = einsum('b i j, b j d -> b i d', attn, v)
        out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
        
        if exists(mask):
            out = out + lambda_ * out_control

        return self.to_out(out)

def exists(val):
    return val is not None

# 示例输入
query_dim = 320
context_dim = 768
heads = 8
dim_head = 40
dropout = 0.

cross_attention = CrossAttention(query_dim, context_dim, heads, dim_head, dropout)

x = torch.randn(2, 4096, query_dim)  # batch_size x seq_length x dim
context = torch.randn(2, 10, context_dim)
control = torch.randn(2, 10, context_dim)
mask = torch.randint(0, 2, (2,64,64)).bool()

output = cross_attention(x, context=context, control=control, mask=mask, lambda_=1.0)
print(output.shape)  # 应输出: torch.Size([2, 10, query_dim])


RuntimeError: The size of tensor a (10) must match the size of tensor b (4096) at non-singleton dimension 2