In [39]:
import numpy as np
import torch
from torch import nn

from typing import Any, Optional, Tuple, Type

In [41]:
# 1 定义LayerNorm层

class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x

In [45]:
# 2 定义PositionEmbeddingRandom

# 用高斯随机生成随机位置编码
# 关于_position_embedding_encoding，它的输入都是实际坐标，因此在forward和forward函数调用之前要对coords进行处理，形成[d1, d2, ..., 2]的shape
# 简单来说就是shape=[batch_size, N, 2]形状的坐标coords

class PositionEmbeddingRandom(nn.Module):

    def __init__(self,
                 num_pos_feats: int = 64, # 位置编码的特征维度
                 scale: Optional[float] = None
                 ) -> None:
        super().__init__()
        if scale is None or scale <= 0.0:
            scale = 1.0
        self.register_buffer( # 用到了register_buffer，即它不是一个参数，而是一个在前向传播过程中需要的固定值
            "positional_encoding_gaussian_matrix", # 注册了一个名为这个的缓冲区张量
            scale * torch.randn((2, num_pos_feats)), # 大小为(2, num_pos_feats=64)的高斯矩阵，并用scale缩放
        )

    # 对输入的坐标做随机位置编码，要求输入坐标已经标准化在[0,1]之间。该方法要求传入的是实际坐标coords，而不是坐标的size，即coords_size
    def _position_embedding_encoding(self,
                     coords: torch.Tensor # shape = [d1, d2, ..., 2]，前面的dx是网格，最后一个2存储的是点坐标，例如[0.1, 0.3]，前面的所有维度是用来储存点坐标的列表，例如[batch_size, N, 2] = [3, 3, 2]
                     ) -> torch.Tensor:

        coords = 2 * coords - 1 # 将coords的范围从[0,1]放缩到[-1,1]
        coords = coords @ self.positional_encoding_gaussian_matrix # 例如[3, 3, 2] * [2, 64] = [3, 3, 64]
        coords = 2 * np.pi * coords # shape = [3, 3, 64]不变
        
        return torch.cat([torch.sin(coords), torch.cos(coords)], dim=-1) # 放进sin和cos之后沿着最后一个维度拼成一个新的张量，shape = [3, 3, 128]

    # 生成指定大小的网格位置编码，以_position_embedding_encoding()的例子为例，网格就应该是[3, 3]，输入参数是一个包含batch_size和N的元组Tuple(batch_size, N),以下简化为h和w的元组Tuple(h, w)
    # 传入的是坐标size，即coords_size
    def forward(self, 
                coords_size: Tuple[int, int] # 这个coords_size其实是(batch_size, N)，N为每个样本中的点数
                ) -> torch.Tensor:

        h, w = coords_size # 根据输入读取h和w
        device: Any = self.positional_encoding_gaussian_matrix.device
        
        #创建一个shape=[h, w]的全1张量，并储存在device上
        grid = torch.ones((h, w), device=device, dtype=torch.float32) # 假设生成一个[3, 3]的全为1的矩阵，命名为grid,并存储在device设备上，以便后续运算更新
        x_embed = grid.cumsum(dim=1) - 0.5 # x轴坐标embedding为grid按行累加并减0.5，即每一行为[0.5, 1.5, 2.5]，shape = [3, 3]
        y_embed = grid.cumsum(dim=0) - 0.5 # y轴坐标embedding为grid按列累加并减0.5，即每一列为[0.5, 1.5, 2.5]，shape = [3, 3]
        x_embed = x_embed / w # 标准化，使其规范为[0, 1]之间，shape = (3, 3)
        y_embed = y_embed / h # 标准化，使其规范为[0, 1]之间，shape = (3, 3)
        
        coords = torch.stack([x_embed, y_embed], dim=-1) # 在最后一列创建一个新的维度，stack在一起，得到coords，shape为[3, 3, 2]
        position_embedding = self._position_embedding_encoding(coords) # 再将coords传入_position_embedding_encoding，得到position_embedding，shape = [3, 3, 128]
        return position_embedding.permute(2, 0, 1)  # 调整为[128, 3, 3]，也就是 [C, H, W]

    # 根据实际输入的坐标的shape和原始图像shape，标准化之后进行随机位置编码
    # 传入的是坐标，即coords，和原始图像大小，即image_size
    def forward_with_coords(self, 
                            coords_input: torch.Tensor, # shape = (batch_size, N, 2)
                            image_size: Tuple[int, int]
                           ) -> torch.Tensor:

        coords = coords_input.clone()
        coords[:, :, 0] = coords[:, :, 0] / image_size[1] # 标准化x轴的坐标
        coords[:, :, 1] = coords[:, :, 1] / image_size[0] # 标准化y轴的坐标

        position_embedding = self._position_embedding_encoding(coords.to(torch.float)) # B x N x C
        
        return position_embedding.permute(2, 0, 1) # 原版代码忘了permute了

In [49]:
# 3 定义Prompt Encoder

class PromptEncoder(nn.Module):
    def __init__(
            self,
            embed_dim: int, # Prompt的Embedding Demension
            image_embedding_size: Tuple[int, int], # Image的Embedding的尺度，例如(H, W)
            input_image_size: Tuple[int, int], # 作为Image Encoder输入的pad之后的图像尺寸，例如(H, W)
            mask_in_chans: int, # 隐藏channel数量，用于编码输入图像的mask
            activation: Type[nn.Module] = nn.GELU, # 激活函数
            ) -> None:
        super().__init__()
        self.embed_dim = embed_dim
        self.input_image_size = input_image_size
        self.image_embedding_size = image_embedding_size
        
        self.positional_embedding_layer = PositionEmbeddingRandom(embed_dim // 2) # 也就是位置编码的特征维度num_pos_feats=64
        
        # 对输入的点和box编码
        self.num_point_embeddings: int = 4  # 正例点、负例点、box的两个对角点
        point_embeddings = [nn.Embedding(1, embed_dim) for i in range(self.num_point_embeddings)] # 对每个点做编码，词汇表大小为1，embeding大小为embde_dim
        self.point_embeddings = nn.ModuleList(point_embeddings) # 编码器封装在self.point_embeddings里
        self.not_a_point_embed = nn.Embedding(1, embed_dim) # ？
        
        # 对mask编码
        self.mask_input_size = (4 * input_image_size[0], 4 * input_image_size[0]) # [4H, 4W]
        self.mask_downscaling = nn.Sequential(
            nn.Conv2d(in_channels=1, out_channels=mask_in_chans // 4, kernel_size=2, stride=2), # out_size = (input_size + 2*padding - kernel_size)/stride + 1，因此为[2H, 2W]
            LayerNorm2d(mask_in_chans//4),
            activation(),
            nn.Conv2d(in_channels=mask_in_chans // 4, out_channels=mask_in_chans, kernel_size=2, stride=2), # 同上，变为[H, W]
            LayerNorm2d(mask_in_chans),
            activation(),
            nn.Conv2d(in_channels=mask_in_chans, out_channels=embed_dim, kernel_size=1) # 同为[H, W]
            )
        self.no_mask_embed = nn.Embedding(1, embed_dim)
        
    def get_dense_positional_embedding(self) -> torch.Tensor:
        '''
        输出位置编码，用来对print prompt进行编码
        这个位置编码应用于图像编码形状的稠密点集
        
        返回值：
        Positional Encoding，shape = (1, embed_dim, embed_h, embed_w)
        '''
        return self.positional_embedding_layer(self.image_embedding_size).unsqueeze(0)
    
    # input(point prompt)做Embedding
    def _embed_points(
            self,
            points: torch.Tensor, # shape = (batch_size, N, 2)
            labels: torch.Tensor, # shape = (batch_size, N)
            pad: bool, # 是否填充
            ) -> torch.Tensor:
        
        points = points + 0.5  # 移动到像素的中心位置
        if pad:
            # 如果需要填充，创建填充值
            padding_point = torch.zeros((points.shape[0], 1, 2), device=points.device)
            padding_label = -torch.ones((labels.shape[0], 1), device=labels.device)
            # 在点和标签的末尾添加填充值
            points = torch.cat([points, padding_point], dim=1)
            labels = torch.cat([labels, padding_label], dim=1)
        # 使用坐标对点进行嵌入
        point_embedding = self.positional_embedding_layer.forward_with_coords(points, self.input_image_size)
        # 将标签为-1的点嵌入设置为0.0
        point_embedding[labels == -1] = 0.0
        # 对标签为-1的点添加`not_a_point_embed`的权重
        point_embedding[labels == -1] += self.not_a_point_embed.weight
        # 对标签为0的点添加`point_embeddings[0]`的权重
        point_embedding[labels == 0] += self.point_embeddings[0].weight
        # 对标签为1的点添加`point_embeddings[1]`的权重
        point_embedding[labels == 1] += self.point_embeddings[1].weight
        
        return point_embedding

    # 对input(box prompt)做embedding
    def _embed_boxes(self, boxes: torch.Tensor) -> torch.Tensor:

        boxes = boxes + 0.5  # Shift to center of pixel
        coords = boxes.reshape(-1, 2, 2)
        corner_embedding = self.positional_embedding_layer.forward_with_coords(coords, self.input_image_size)
        corner_embedding[:, 0, :] += self.point_embeddings[2].weight
        corner_embedding[:, 1, :] += self.point_embeddings[3].weight
        return corner_embedding

    # 对input(mask prompt)做embedding
    def _embed_masks(self, masks: torch.Tensor) -> torch.Tensor:

        mask_embedding = self.mask_downscaling(masks)
        return mask_embedding

    def _get_batch_size(
        self,
        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
        boxes: Optional[torch.Tensor],
        masks: Optional[torch.Tensor],
    ) -> int:
        """
        Gets the batch size of the output given the batch size of the input prompts.
        """
        if points is not None:
            return points[0].shape[0]
        elif boxes is not None:
            return boxes.shape[0]
        elif masks is not None:
            return masks.shape[0]
        else:
            return 1

    def _get_device(self) -> torch.device:
        return self.point_embeddings[0].weight.device

    def forward(
        self,
        points: Optional[Tuple[torch.Tensor, torch.Tensor]],
        boxes: Optional[torch.Tensor],
        masks: Optional[torch.Tensor],
    ) -> Tuple[torch.Tensor, torch.Tensor]:
        """
        Embeds different types of prompts, returning both sparse and dense
        embeddings.

        Arguments:
          points (tuple(torch.Tensor, torch.Tensor) or none): point coordinates
            and labels to embed.
          boxes (torch.Tensor or none): boxes to embed
          masks (torch.Tensor or none): masks to embed

        Returns:
          torch.Tensor: sparse embeddings for the points and boxes, with shape
            BxNx(embed_dim), where N is determined by the number of input points
            and boxes.
          torch.Tensor: dense embeddings for the masks, in the shape
            Bx(embed_dim)x(embed_H)x(embed_W)
        """
        bs = self._get_batch_size(points, boxes, masks)
        sparse_embeddings = torch.empty((bs, 0, self.embed_dim), device=self._get_device())
        if points is not None:
            coords, labels = points
            point_embeddings = self._embed_points(coords, labels, pad=(boxes is None))
            sparse_embeddings = torch.cat([sparse_embeddings, point_embeddings], dim=1)
        if boxes is not None:
            box_embeddings = self._embed_boxes(boxes)
            sparse_embeddings = torch.cat([sparse_embeddings, box_embeddings], dim=1)

        if masks is not None:
            dense_embeddings = self._embed_masks(masks)
        else:
            dense_embeddings = self.no_mask_embed.weight.reshape(1, -1, 1, 1).expand(
                bs, -1, self.image_embedding_size[0], self.image_embedding_size[1]
            )

        return sparse_embeddings, dense_embeddings