In [85]:
import torch
import torch.nn as nn
from torch.nn import functional as F
from torchvision import transforms
from PIL import Image
from typing import Optional, Tuple, Type
import math

In [87]:
# 0.1 前置函数1: 获取相对位置嵌入

def get_rel_pos(q_size: int, k_size: int, rel_pos: torch.Tensor) -> torch.Tensor:
    """
    get_rel_pos函数用于根据q和k的size，获取相对位置嵌入
    它的作用是捕捉输入序列中不同位置之间的相对关系
    在注意力机制中，相对位置嵌入被用来增强模型对不同位置之间的依赖关系的建模能力
    通过计算查询和键之间的相对坐标，然后根据相对坐标从相对位置嵌入中提取相应的位置嵌入
    可以将这些位置嵌入添加到attention map中，从而影响注意力权重的计算
    这有助于模型更好地理解输入序列中不同位置之间的关系，并提高模型在处理序列数据时的性能。
    
    参数解释:
        q_size (int): size of query q.
        k_size (int): size of key k.
        rel_pos (Tensor): 相对位置嵌入 (L, C).

    输出:
        是根据查询和键的大小提取的相对位置嵌入
    """
    max_rel_dist = int(2 * max(q_size, k_size) - 1)
    # Interpolate rel pos if needed.
    if rel_pos.shape[0] != max_rel_dist:
        # Interpolate rel pos.
        rel_pos_resized = F.interpolate(
            rel_pos.reshape(1, rel_pos.shape[0], -1).permute(0, 2, 1),
            size=max_rel_dist,
            mode="linear",
        )
        rel_pos_resized = rel_pos_resized.reshape(-1, max_rel_dist).permute(1, 0)
    else:
        rel_pos_resized = rel_pos

    # Scale the coords with short length if shapes for q and k are different.
    q_coords = torch.arange(q_size)[:, None] * max(k_size / q_size, 1.0)
    k_coords = torch.arange(k_size)[None, :] * max(q_size / k_size, 1.0)
    relative_coords = (q_coords - k_coords) + (k_size - 1) * max(q_size / k_size, 1.0)

    return rel_pos_resized[relative_coords.long()]

In [89]:
#0.2 前置函数2: 根据分解的相对位置嵌入调整attention map

def add_decomposed_rel_pos(
    scores: torch.Tensor,
    q: torch.Tensor,
    rel_pos_h: torch.Tensor,
    rel_pos_w: torch.Tensor,
    q_size: Tuple[int, int],
    k_size: Tuple[int, int],
) -> torch.Tensor:
    """
    计算分解后的相对位置嵌入
        scores (Tensor): attention map，也就是torch.matmul(q, k_trans) / self.scale
        q (Tensor): query q in the attention layer with shape (B, q_h * q_w, C).
        rel_pos_h (Tensor): relative position embeddings (Lh, C) for height axis.
        rel_pos_w (Tensor): relative position embeddings (Lw, C) for width axis.
        q_size (Tuple): spatial sequence size of query q with (q_h, q_w).
        k_size (Tuple): spatial sequence size of key k with (k_h, k_w).

    Returns:
        scores (Tensor): 加上了相对位置嵌入补偿的attention map
    """
    q_h, q_w = q_size
    k_h, k_w = k_size
    Rh = get_rel_pos(q_h, k_h, rel_pos_h)
    Rw = get_rel_pos(q_w, k_w, rel_pos_w)

    B, _, dim = q.shape
    r_q = q.reshape(B, q_h, q_w, dim)
    rel_h = torch.einsum("bhwc,hkc->bhwk", r_q, Rh)
    rel_w = torch.einsum("bhwc,wkc->bhwk", r_q, Rw)

    scores = (
        scores.view(B, q_h, q_w, k_h, k_w) + rel_h[:, :, :, :, None] + rel_w[:, :, :, None, :]
    ).view(B, q_h * q_w, k_h * k_w)

    return scores           

In [91]:
# 1 定义MLP块
class MLPBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        mlp_dim: int,
        act: Type[nn.Module] = nn.GELU,
    ) -> None:
        super().__init__()
        self.lin1 = nn.Linear(embedding_dim, mlp_dim)
        self.lin2 = nn.Linear(mlp_dim, embedding_dim)
        self.act = act()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lin2(self.act(self.lin1(x)))

In [93]:
# 2 定义LayerNorm块
class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x

In [95]:
# 3 定义Patch Embedding类，用卷积做
class PatchEmbed(nn.Module):
    def __init__(
            self, 
            kernel_size: Tuple[int, int] = (16, 16),
            stride: Tuple[int, int] = (16, 16),
            padding: Tuple[int, int] = (0, 0),
            in_chans: int = 3,
            embed_dim: int = 768,
    ) -> None:
        super().__init__()
        
        self.projection = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor: # x:torch.Tensor表示输入x是Tensor，括号外面的-> torch.Tensor指函数返回值也是tensor
        x = self.projection(x)
        x = x.permute(0, 2, 3, 1) # 交换维度，即：(B C H W) -> (B H W C)
        return x

In [97]:
# 3.1 Patch Embedding的读入图片测试
image = Image.open('/Users/kalen/Desktop/Python_env/segment-anything/cat2.jpg')

# 创建PatchEmbed实例
patch_embed = PatchEmbed() # 调用默认值
#或自定义各个参数，即patch_embed = PatchEmbed(in_chans=3, embed_dim=768, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0))
transform = transforms.Compose([transforms.Resize((256, 256)),
                                transforms.ToTensor()
                                ])

image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze(0)
output = patch_embed(image_tensor)
print(output.shape) #  shape = (B H W C)

torch.Size([1, 16, 16, 768])


In [99]:
# 4 更新的Attention类，用nn.Linear集成了可学习母参数以提高效率

# 图片进入Attention之前的流程
# image-tensorlizer-PatchEmbed-Attention
# 最后的图片shape=(B=1, H. W. C) = (1, 16, 16, 768)

class Attention(nn.Module):
    def __init__(self,
                 dmodel: int, # 也就是上面传进来的768
                 num_heads: int = 8,
                 qkv_bias: bool = True,
                 use_rel_pos: bool = False,
                 rel_pos_zero_init: bool = True,
                 input_size: Optional[Tuple[int, int]] = None) -> None:
        super().__init__()
        self.num_heads = num_heads
        dmodel_per_head = dmodel // num_heads
        self.scale = math.sqrt(dmodel_per_head)
        
        self.qkv = nn.Linear(dmodel, dmodel * 3, bias = qkv_bias) # 定义一个线性层
        self.output_linear = nn.Linear(dmodel, dmodel)
        
        self.use_rel_pos = use_rel_pos
        if self.use_rel_pos:
            assert (
                input_size is not None
            ), "如果用了相对位置编码，则必须提供输入的size"
            # initialize relative positional embeddings
            self.rel_pos_h = nn.Parameter(torch.zeros(2 * input_size[0] - 1, dmodel_per_head), requires_grad = True)
            self.rel_pos_w = nn.Parameter(torch.zeros(2 * input_size[1] - 1, dmodel_per_head), requires_grad = True)
            
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, H, W, _ = x.shape # X的shape = (B, H, W, dmodel) = (1, 16, 16, 768)
        qkv_combine = self.qkv(x) # shape = (B, H, W, 3*dmodel)，即用一个线性层生成qkv_combine
        qkv_combine = qkv_combine.reshape(B, H*W, 3, self.num_heads, -1) # shape = (B, H*W, 3, num_heads, dmodel_per_head)
        qkv_combine = qkv_combine.permute(2, 0, 3, 1, 4) # shape = (3, B, num_heads, H*W, dmodel_per_head)
        qkv_combine = qkv_combine.reshape(3, B*self.num_heads, H*W, -1) # shape = (3, B*num_heads, H*W, dmodel_per_head)
        q, k, v = qkv_combine.unbind(0) # q, k, v shape = (B*num_heads, H*W, dmodel_per_head)
        
        k_trans = k.transpose(-2, -1) # k_trans shape = (B*num_heads, dmodel_per_head, H*W)
        scores = torch.matmul(q, k_trans) / self.scale # scores shape = (B*num_heads, H*W, H*W)
        
        # 如果使用了相对位置编码则启用
        if self.use_rel_pos:
            scores = add_decomposed_rel_pos(attn=scores, q=q, rel_pos_h=self.rel_pos_h, rel_pos_w=self.rel_pos_w, q_size=(H, W), k_size=(H, W))
        
        weights = torch.softmax(scores, dim = -1) # 在(B*num_heads, H*W, H*W)的最后一个维度上计算softmax。至于为什么已经推过了
        output = torch.matmul(weights, v) # shape = (B*num_heads, H*W, dmodel_per_head)
        output = output.view(B, self.num_heads, H, W, -1) # shape = (B, num_heads, H, W, dmodel_per_head)
        output = output.permute(0, 2, 3, 1, 4) # (B, H, W, num_heads, dmodel_per_head)
        output = output.reshape(B, H, W, -1) # shape = (B, H, W, dmodel)
        output = self.output_linear(output)
        
        return output


In [101]:
# 4.1 测试Attention类

input = torch.randn((1, 16, 16, 768))
print(input.shape)

dmodel = input.shape[-1]
num_heads = 8
qkv_bias = True
use_rel_pos = False
rel_pos_zero_init = False

attn = Attention(dmodel=dmodel, num_heads=num_heads, qkv_bias=qkv_bias, use_rel_pos=use_rel_pos, rel_pos_zero_init=rel_pos_zero_init)

output = attn(input)
print(output.shape)

torch.Size([1, 16, 16, 768])
torch.Size([1, 16, 16, 768])
