In [6]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from typing import Optional, Tuple, Type

In [10]:
# 1 定义MLP块
class MLPBlock(nn.Module):
    def __init__(
        self,
        embedding_dim: int,
        mlp_dim: int,
        act: Type[nn.Module] = nn.GELU,
    ) -> None:
        super().__init__()
        self.lin1 = nn.Linear(embedding_dim, mlp_dim)
        self.lin2 = nn.Linear(mlp_dim, embedding_dim)
        self.act = act()

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.lin2(self.act(self.lin1(x)))

In [16]:
# 2 定义LayerNorm块
class LayerNorm2d(nn.Module):
    def __init__(self, num_channels: int, eps: float = 1e-6) -> None:
        super().__init__()
        self.weight = nn.Parameter(torch.ones(num_channels))
        self.bias = nn.Parameter(torch.zeros(num_channels))
        self.eps = eps

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        u = x.mean(1, keepdim=True)
        s = (x - u).pow(2).mean(1, keepdim=True)
        x = (x - u) / torch.sqrt(s + self.eps)
        x = self.weight[:, None, None] * x + self.bias[:, None, None]
        return x

In [20]:
# 3 定义Patch Embedding类，用卷积做
class PatchEmbed(nn.Module):
    def __init__(
            self, 
            kernel_size: Tuple[int, int] = (16, 16),
            stride: Tuple[int, int] = (16, 16),
            padding: Tuple[int, int] = (0, 0),
            in_chans: int = 3,
            embed_dim: int = 768,
    ) -> None:
        super().__init__()
        
        self.projection = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor: # x:torch.Tensor表示输入x是Tensor，括号外面的-> torch.Tensor指函数返回值也是tensor
        x = self.projection(x)
        x = x.permute(0, 2, 3, 1) # 交换维度，即：(B C H W) -> (B H W C)
        return x

In [24]:
# 3.1 Patch Embedding的读入图片测试
image = Image.open('/Users/kalen/Desktop/Python_env/segment-anything/cat2.jpg')

# 创建PatchEmbed实例
patch_embed = PatchEmbed() # 调用默认值
#或自定义各个参数，即patch_embed = PatchEmbed(in_chans=3, embed_dim=768, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0))
transform = transforms.Compose([transforms.Resize((256, 256)),
                                transforms.ToTensor()
                                ])

image_tensor = transform(image)
image_tensor = image_tensor.unsqueeze(0)
output = patch_embed(image_tensor)
print(output.shape) #  shape = (B H W C)

torch.Size([1, 16, 16, 768])


In [None]:
#unfin