In [8]:
import torch
from segment_anything import SamAutomaticMaskGenerator, sam_model_registry
from image_encoder import ImageEncoderViT_3d_v2 as ImageEncoderViT_3d
from functools import partial

sam = sam_model_registry["vit_b"](checkpoint="../../ckpt/sam_vit_b_01ec64.pth")

mask_generator = SamAutomaticMaskGenerator(sam)
img_encoder = ImageEncoderViT_3d(
    depth=12,
    embed_dim=768,
    img_size=1024,
    mlp_ratio=4,
    norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    num_heads=12,
    patch_size=16,
    qkv_bias=True,
    use_rel_pos=True,
    global_attn_indexes=[2, 5, 8, 11],
    window_size=14,
    cubic_window_size=8,
    out_chans=256,
    num_slice = 16)
img_encoder.load_state_dict(mask_generator.predictor.model.image_encoder.state_dict(), strict=False)
del sam
[i for i, _ in img_encoder.named_parameters()]

['pos_embed',
 'depth_embed',
 'patch_embed.proj.weight',
 'patch_embed.proj.bias',
 'slice_embed.weight',
 'slice_embed.bias',
 'blocks.0.norm1.weight',
 'blocks.0.norm1.bias',
 'blocks.0.attn.rel_pos_h',
 'blocks.0.attn.rel_pos_w',
 'blocks.0.attn.rel_pos_d',
 'blocks.0.attn.lr',
 'blocks.0.attn.qkv.weight',
 'blocks.0.attn.qkv.bias',
 'blocks.0.attn.proj.weight',
 'blocks.0.attn.proj.bias',
 'blocks.0.norm2.weight',
 'blocks.0.norm2.bias',
 'blocks.0.mlp.lin1.weight',
 'blocks.0.mlp.lin1.bias',
 'blocks.0.mlp.lin2.weight',
 'blocks.0.mlp.lin2.bias',
 'blocks.0.adapter.linear1.weight',
 'blocks.0.adapter.linear1.bias',
 'blocks.0.adapter.conv.weight',
 'blocks.0.adapter.conv.bias',
 'blocks.0.adapter.linear2.weight',
 'blocks.0.adapter.linear2.bias',
 'blocks.1.norm1.weight',
 'blocks.1.norm1.bias',
 'blocks.1.attn.rel_pos_h',
 'blocks.1.attn.rel_pos_w',
 'blocks.1.attn.rel_pos_d',
 'blocks.1.attn.lr',
 'blocks.1.attn.qkv.weight',
 'blocks.1.attn.qkv.bias',
 'blocks.1.attn.proj.weigh

In [4]:
[i for i in mask_generator.predictor.model.image_encoder.state_dict().keys()]

['pos_embed',
 'patch_embed.proj.weight',
 'patch_embed.proj.bias',
 'blocks.0.norm1.weight',
 'blocks.0.norm1.bias',
 'blocks.0.attn.rel_pos_h',
 'blocks.0.attn.rel_pos_w',
 'blocks.0.attn.qkv.weight',
 'blocks.0.attn.qkv.bias',
 'blocks.0.attn.proj.weight',
 'blocks.0.attn.proj.bias',
 'blocks.0.norm2.weight',
 'blocks.0.norm2.bias',
 'blocks.0.mlp.lin1.weight',
 'blocks.0.mlp.lin1.bias',
 'blocks.0.mlp.lin2.weight',
 'blocks.0.mlp.lin2.bias',
 'blocks.1.norm1.weight',
 'blocks.1.norm1.bias',
 'blocks.1.attn.rel_pos_h',
 'blocks.1.attn.rel_pos_w',
 'blocks.1.attn.qkv.weight',
 'blocks.1.attn.qkv.bias',
 'blocks.1.attn.proj.weight',
 'blocks.1.attn.proj.bias',
 'blocks.1.norm2.weight',
 'blocks.1.norm2.bias',
 'blocks.1.mlp.lin1.weight',
 'blocks.1.mlp.lin1.bias',
 'blocks.1.mlp.lin2.weight',
 'blocks.1.mlp.lin2.bias',
 'blocks.2.norm1.weight',
 'blocks.2.norm1.bias',
 'blocks.2.attn.rel_pos_h',
 'blocks.2.attn.rel_pos_w',
 'blocks.2.attn.qkv.weight',
 'blocks.2.attn.qkv.bias',
 'bloc

In [17]:
H, W, D = 32, 32, 32  # 輸入張量的高度、寬度和深度
img_mask = torch.zeros((1, H, W, D, 1))  # 創建一個形狀為[1, H, W, D, 1]的零張量

# 定義滑動窗口的移位範圍
h_slices = (slice(0, -8),
            slice(-8, -4),
            slice(-4, None))
w_slices = (slice(0, -8),
            slice(-8, -4),
            slice(-4, None))
d_slices = (slice(0, -8),
            slice(-8, -4),
            slice(-4, None))

cnt = 0
# 對每個窗口進行標籤
for h in h_slices:
    for w in w_slices:
        for d in d_slices:
            img_mask[:, h, w, d, :] = cnt
            cnt += 1

print(img_mask[0,0:24,0:24,24:28,0])

tensor([[[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         ...,
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         ...,
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         ...,
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        ...,

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         ...,
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         ...,
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.]],

        [[1., 1., 1., 1.],
         [1., 1., 1., 1.],
         [1., 1., 1., 1.],
         ...,
 

In [10]:
import numpy as np
seg = torch.ones(1,128,128,128)
l = len(torch.where(seg == 1)[0])
print(l)
sample = np.random.choice(np.arange(l), 10, replace=True) # 從範圍為 [0, l) 的整數中隨機選取 10 個數字（可能有重複）
print(sample)
x = torch.where(seg == 1)[1][sample].unsqueeze(1)
y = torch.where(seg == 1)[3][sample].unsqueeze(1)
z = torch.where(seg == 1)[2][sample].unsqueeze(1)
print(z)
point_coord = torch.cat([x, y, z], dim=1).unsqueeze(1).float() 

foo = torch.randn(1,20,3)
point_coord = point_coord.transpose(0,1)
point_coord = torch.cat([point_coord,foo],dim=1)
print(point_coord.size())



2097152
[ 970475 1634943  237443  278763  792592  183698 1383205 1848259 1226765
  703493]
tensor([[ 29],
        [100],
        [ 63],
        [  1],
        [ 48],
        [ 27],
        [ 54],
        [103],
        [112],
        [120]])
torch.Size([1, 30, 3])


In [11]:
foo_feature = torch.randn(1,256,32,32,32) # 1, 256, ?, ?, ?
prompt_encoder = PromptEncoder(transformer=TwoWayTransformer(depth=2,
                                                                 embedding_dim=256,
                                                                 mlp_dim=2048,
                                                                 num_heads=8))
prompt_encoder.to("cpu")
patch_size=128
ans = prompt_encoder(foo_feature, point_coord, [patch_size, patch_size, patch_size]) # ?, [1,30,3], [128,128,128]

ans.size() # 1, 256, 32, 32, 32

送進transformer的三個參數image_embeddings, image_pe, point_coord torch.Size([1, 256, 32, 32, 32]) torch.Size([1, 256, 32, 32, 32]) torch.Size([1, 1, 1, 30, 3])
===init===
image_embedding init torch.Size([1, 256, 32, 32, 32])
point_coord init torch.Size([1, 1, 1, 30, 3])

point_embedding after grid sample torch.Size([1, 256, 1, 1, 30])
point_pe after grid sample torch.Size([1, 256, 1, 1, 30])

        之所以維度由[1,256,32,32,32]變成[1,256,1,1,30], 是因為point_coord [1,1,1,30,3]中包含了30個xyz的座標(已正規化到-1~1之間)
        定位了在image_embedding中的30個位置(維度中為32的D*H*W), 並對原始在對應image_embedding空間上的特徵進行插值(僅限這30個點)
        因此結果會是[1,256,1,1,30], 最後一個維度代表其中某一個通道在這30個點中的特徵值
        

        接下來squeeze去除1維度
        
point_embedding after squeeze torch.Size([1, 256, 30])
point_pe after squeeze torch.Size([1, 256, 30])

        permute後, 現在我們有包含了點座標資訊的point_embedding特徵以及包含了點座標資訊的point_pe(一個固定的位置編碼矩陣)
        
point_embedding after permute torch.Size([1, 30, 256])
point_pe after permute torch.Size([1, 30, 256])

        把沒有經過給定點插植

torch.Size([1, 256, 32, 32, 32])