In [1]:
import torch
import torch.nn as nn
from torchvision import transforms
from PIL import Image
from typing import Optional, Tuple, Type

In [5]:
# 定义Patch Embedding类，用卷积做

class PatchEmbed(nn.Module):
    def __init__(
            self, 
            kernel_size: Tuple[int, int] = (16, 16),
            stride: Tuple[int, int] = (16, 16),
            padding: Tuple[int, int] = (0, 0),
            in_chans: int = 3,
            embed_dim: int = 768,
    ) -> None:
        super().__init__()
        
        self.projection = nn.Conv2d(in_channels=in_chans, out_channels=embed_dim, kernel_size=kernel_size, stride=stride, padding=padding)
        
    def forward(self, x: torch.Tensor) -> torch.Tensor: # x:torch.Tensor表示输入x是Tensor，括号外面的-> torch.Tensor指函数返回值也是tensor
        x = self.projection(x)
        x = x.permute(0, 2, 3, 1) # 交换维度，即：(B C H W) -> (B H W C)
        return x

In [7]:
# 读入照片来测试
image = Image.open('/Users/kalen/Desktop/Python_env/segment-anything/cat2.jpg')

In [13]:
# 先将照片下采样、张量化，然后再用定义的PatchEmbed完成图像嵌入化

tensorlizer = transforms.Compose([transforms.Resize((256, 256)), transforms.ToTensor()]) # 实例化一个tensorlizer，先下采样，再转成Tensor
patchembed = PatchEmbed()
# #或自定义各个参数，即patch_embed = PatchEmbed(in_chans=3, embed_dim=768, kernel_size=(16, 16), stride=(16, 16), padding=(0, 0))

image_tensor = tensorlizer(image) # shape = (C H W)
image_tensor = image_tensor.unsqueeze(0) # shape = (1 C H W),其中1 = B = batch_size
output = patchembed(image_tensor) # shape = (B H W C)
print(output.shape)

torch.Size([1, 16, 16, 768])
