In [27]:
import sys
import os
import numpy as np

sys.path.append("..\\..")
import torch
import torch.nn.functional as F
from modeling.Med_SAM.image_encoder import ImageEncoderViT_3d_v2 as ImageEncoderViT_3d
from modeling.Med_SAM.prompt_encoder import PromptEncoder, TwoWayTransformer
from modeling.Med_SAM.mask_decoder import VIT_MLAHead_h as VIT_MLAHead
from functools import partial
from torch.utils.data import Dataset, DataLoader
from dataset.my_dataset import MyDataset
import matplotlib.pyplot as plt
from tqdm import tqdm
from monai.losses import DiceCELoss, DiceLoss

from einops import rearrange

sys.path.append("..")
import surface_distance
from surface_distance import metrics


num_prompts = 1
patch_size = 128
snapshot_path = "C:\\Users\\Jacky\\Desktop\\3DSAM-adapter\\3DSAM-adapter\\snapshot\\show_model_result\\best.pth.tar"
device = "cuda" if torch.cuda.is_available() else "cpu"
dice_loss = DiceLoss(
    include_background=False, softmax=False, to_onehot_y=True, reduction="none"
)

## load pretrained model

In [28]:
# load image encoder
img_encoder = ImageEncoderViT_3d(
    depth=12,
    embed_dim=768,
    img_size=1024,
    mlp_ratio=4,
    norm_layer=partial(torch.nn.LayerNorm, eps=1e-6),
    num_heads=12,
    patch_size=16,
    qkv_bias=True,
    use_rel_pos=True,
    global_attn_indexes=[2, 5, 8, 11],
    window_size=14,
    cubic_window_size=8,
    out_chans=256,
    num_slice=16,
)

img_encoder.load_state_dict(
    torch.load(snapshot_path, map_location="cpu")["encoder_dict"],
    strict=True,
)
img_encoder.to(device)

# load prompt encoder
prompt_encoder_list = []
for i in range(4):
    prompt_encoder = PromptEncoder(
        transformer=TwoWayTransformer(
            depth=2, embedding_dim=256, mlp_dim=2048, num_heads=8
        )
    )
    prompt_encoder.load_state_dict(
        torch.load(snapshot_path, map_location="cpu")["feature_dict"][i],
        strict=True,
    )
    prompt_encoder.to(device)
    prompt_encoder_list.append(prompt_encoder)

# load mask decoder
mask_decoder = VIT_MLAHead(img_size=96).to(device)
mask_decoder.load_state_dict(
    torch.load(snapshot_path, map_location="cpu")["decoder_dict"],
    strict=True,
)
mask_decoder.to(device)

VIT_MLAHead_h(
  (mlahead): MLAHead(
    (head2): Sequential(
      (0): Conv3d(256, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU()
      (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (4): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (5): ReLU()
    )
    (head3): Sequential(
      (0): Conv3d(256, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (1): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (2): ReLU()
      (3): Conv3d(128, 128, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1), bias=False)
      (4): InstanceNorm3d(128, eps=1e-05, momentum=0.1, affine=False, track_running_stats=False)
      (5): ReLU()
    )
    (head4): Sequential(
      (0): Conv3

In [29]:
train_data = MyDataset("D:\\ds", "train")
val_data = MyDataset("D:\\ds", "val")
test_data = MyDataset("D:\\ds", "test")
print("train:", len(train_data))
print("val:", len(val_data))
print("test:", len(test_data))
train_data = DataLoader(train_data, batch_size=1, shuffle=False)
val_data = DataLoader(val_data, batch_size=1, shuffle=False)
test = DataLoader(test_data, batch_size=1, shuffle=False)

train: 577
val: 80
test: 168


In [30]:
save_path = "C:\\Users\\Jacky\\Desktop\\3DSAM-adapter\\3DSAM-adapter\\snapshot\\show_model_result\\images\\{}\\slice_{}.png"

def plot_slices(img, predict, ground_truth, fname):
    # 把img(有spacing和repeat)內插回來
    img = F.interpolate(img, size=ground_truth.shape[2:], mode="trilinear")
    img = img[0, 0]
    predict = predict.squeeze()
    ground_truth = ground_truth.squeeze()

    assert (
        predict.shape == ground_truth.shape == img.shape
    ), "Shapes of predict, ground_truth and img must be the same."

    for i in range(img.shape[0]):
        dice = 1 - dice_loss(predict[i][None, None], ground_truth[i][None, None])
        dice = dice.squeeze().cpu().numpy()
        fig, axes = plt.subplots(1, 3)
        plt.suptitle(f"DICE: {dice}")
        
        # 將張量轉換為NumPy數組
        img_np = img[i].cpu().numpy()
        predict_np = predict[i].cpu().numpy()
        ground_truth_np = ground_truth[i].cpu().numpy()

        for j, (data, title) in enumerate(zip([img_np, predict_np, ground_truth_np], ["Image", "Predict", "Ground Truth"])):
            axes[j].imshow(data, cmap="gray")
            axes[j].set_title(title)

        # 儲存圖片
        plt.savefig(save_path.format(fname, i))
        
        # 清除當前圖片以節省記憶體
        plt.close(fig)


In [31]:
def model_predict(img, prompt, img_encoder, prompt_encoder, mask_decoder):
    out = F.interpolate(img.float(), scale_factor=256 / patch_size, mode="trilinear")
    input_batch = out[0].transpose(0, 1)
    batch_features, feature_list = img_encoder(input_batch)
    feature_list.append(batch_features)
    # feature_list = feature_list[::-1]
    points_torch = prompt.transpose(0, 1)
    new_feature = []
    for i, (feature, feature_decoder) in enumerate(zip(feature_list, prompt_encoder)):
        if i == 3:
            new_feature.append(
                feature_decoder(
                    feature.to(device),
                    points_torch.clone(),
                    [patch_size, patch_size, patch_size],
                )
            )
        else:
            new_feature.append(feature.to(device))
    # new_feature => 4個[1,256,32,32,32]tensor的list
    img_resize = F.interpolate( # torch.Size([1, 3, 128, 128, 128])=>torch.Size([1, 1, 64, 64, 64])
        img[0, 0].permute(1, 2, 0).unsqueeze(0).unsqueeze(0).to(device),
        scale_factor = 32 / patch_size,
        mode="trilinear",
    )
    new_feature.append(img_resize) # 除了4層feature，也加入內插過的原圖
    masks = mask_decoder(new_feature, 2, patch_size // 32)
    masks = masks.permute(0, 1, 4, 2, 3)
    return masks #1,2,128,128,128

In [42]:
with torch.no_grad():
    loss_summary = []
    class_name = None
    last_name = None
    index = {}
    best = 0
    worst = 1
    for idx, (img, seg, name) in tqdm(enumerate(test_data)):
        seg = seg.float().unsqueeze(0)
        img = img.unsqueeze(0)
        # 把seg (無spacing)內插成img(有spacing)的大小
        prompt = F.interpolate(seg[None, :, :, :, :], img.shape[2:], mode="nearest")[0]
        seg = seg.to(device).unsqueeze(0)
        img = img.to(device)
        seg_pred = torch.zeros_like(prompt).to(device)

        # 隨機選num_prompts個正樣本的x, y, z座標為中心，選一個patch_size大小的立方體
        l = len(torch.where(prompt == 1)[0])
        sample = np.random.choice(np.arange(l), num_prompts, replace=True)

        x = torch.where(prompt == 1)[1][sample].unsqueeze(1)
        y = torch.where(prompt == 1)[3][sample].unsqueeze(1)
        z = torch.where(prompt == 1)[2][sample].unsqueeze(1)

        # x_m = (torch.max(x) + torch.min(x)) // 2
        # y_m = (torch.max(y) + torch.min(y)) // 2
        # z_m = (torch.max(z) + torch.min(z)) // 2
        x_m = torch.div(torch.max(x) + torch.min(x), 2, rounding_mode="trunc")
        y_m = torch.div(torch.max(y) + torch.min(y), 2, rounding_mode="trunc")
        z_m = torch.div(torch.max(z) + torch.min(z), 2, rounding_mode="trunc")

        d_min = x_m - patch_size // 2
        d_max = x_m + patch_size // 2
        h_min = z_m - patch_size // 2
        h_max = z_m + patch_size // 2
        w_min = y_m - patch_size // 2
        w_max = y_m + patch_size // 2
        d_l = max(0, -d_min)
        d_r = max(0, d_max - prompt.shape[1])
        h_l = max(0, -h_min)
        h_r = max(0, h_max - prompt.shape[2])
        w_l = max(0, -w_min)
        w_r = max(0, w_max - prompt.shape[3])

        # 轉成相對位置座標(大概)
        points = (
            torch.cat([x - d_min, y - w_min, z - h_min], dim=1).unsqueeze(1).float()
        )
        points_torch = points.to(device)
        d_min = max(0, d_min)
        h_min = max(0, h_min)
        w_min = max(0, w_min)
        img_patch = img[:, :, d_min:d_max, h_min:h_max, w_min:w_max].clone()
        img_patch = F.pad(img_patch, (w_l, w_r, h_l, h_r, d_l, d_r))
        pred = model_predict(
            img_patch, points_torch, img_encoder, prompt_encoder_list, mask_decoder
        )
        pred = pred[
            :, :, d_l : patch_size - d_r, h_l : patch_size - h_r, w_l : patch_size - w_r
        ]
        pred = F.softmax(pred, dim=1)[:, 1]
        seg_pred[:, d_min:d_max, h_min:h_max, w_min:w_max] += pred

        # 把有spacing的結果內插回原始大小
        final_pred = F.interpolate(
            seg_pred.unsqueeze(1), size=seg.shape[2:], mode="trilinear"
        )
        masks = final_pred > 0.5
        loss = 1 - dice_loss(masks, seg)
        loss = loss.squeeze().item()
        loss_summary.append(loss)
        # dictionary記住各部位最好和最差的volumn
        if name != last_name:
            if last_name is not None:
                losses = np.array(loss_summary)
                with open("loss_summary.txt", "a+") as f:
                    f.write(f"{last_name}, max: {np.max(losses)}, min: {np.min(losses)}, median: {np.median(losses)}\n")
                loss_summary = []
                best = 0
                worst = 1
            last_name = name
        if loss > best:
            best = loss
            if not os.path.isdir(f"images\\{name}_best_loss"):
                os.makedirs(f"images\\{name}_best_loss")
            index[f"{name}_best_loss"]=(img_patch, masks, seg)
        if loss < worst:
            worst = loss
            if not os.path.isdir(f"images\\{name}_worst_loss"):
                os.makedirs(f"images\\{name}_worst_loss")
            index[f"{name}_worst_loss"]=(img_patch, masks, seg)
losses = np.array(loss_summary)
with open("loss_summary.txt", "a+") as f:
    f.write(f"{last_name}, max: {np.max(losses)}, min: {np.min(losses)}, median: {np.median(losses)}\n")

26it [01:25,  3.28s/it]


In [34]:
#存slice
for k, (img_patch, masks, seg) in index.items():
    plot_slices(img_patch, masks, seg, k)
    print(k, "Done")

In [35]:
# with torch.no_grad():
#     for mode, data in enumerate([train_data, val_data]):
#         for idx, (img, seg, name) in enumerate(data):
#             print('seg: ', seg.sum())
#             out = F.interpolate(img.float(), scale_factor=256 / patch_size, mode='trilinear')
#             input_batch = out.to(device)
#             input_batch = input_batch[0].transpose(0, 1)
#             batch_features, feature_list = img_encoder(input_batch)
#             feature_list.append(batch_features)
#             #feature_list = feature_list[::-1]
#             l = len(torch.where(seg == 1)[0])
#             points_torch = None
#             if l > 0:
#                 sample = np.random.choice(np.arange(l), 10, replace=True)
#                 x = torch.where(seg == 1)[1][sample].unsqueeze(1)
#                 y = torch.where(seg == 1)[3][sample].unsqueeze(1)
#                 z = torch.where(seg == 1)[2][sample].unsqueeze(1)
#                 points = torch.cat([x, y, z], dim=1).unsqueeze(1).float()
#                 points_torch = points.to(device)
#                 points_torch = points_torch.transpose(0, 1)
#             l = len(torch.where(seg < 10)[0])
#             sample = np.random.choice(np.arange(l), 10, replace=True)
#             x = torch.where(seg < 10)[1][sample].unsqueeze(1)
#             y = torch.where(seg < 10)[3][sample].unsqueeze(1)
#             z = torch.where(seg < 10)[2][sample].unsqueeze(1)
#             points = torch.cat([x, y, z], dim=1).unsqueeze(1).float()
#             points_torch_negative = points.to(device)
#             points_torch_negative = points_torch_negative.transpose(0, 1)
#             if points_torch is not None:
#                 points_torch = points_torch
#             else:
#                 points_torch = points_torch_negative
#             new_feature = []
#             for i, (feature, prompt_encoder) in enumerate(zip(feature_list, prompt_encoder_list)):
#                 if i == 3:
#                     new_feature.append(
#                         prompt_encoder(feature, points_torch.clone(), [patch_size, patch_size, patch_size])
#                     )
#                 else:
#                     new_feature.append(feature)
#             img_resize = F.interpolate(img[:, 0].permute(0, 2, 3, 1).unsqueeze(1).to(device), scale_factor=32/patch_size,
#                                         mode='trilinear')
#             new_feature.append(img_resize)
#             masks = mask_decoder(new_feature, 2, patch_size//32)
#             masks = masks.permute(0, 1, 4, 2, 3)
#             seg = seg.to(device)
#             seg = seg.unsqueeze(1)
#             loss = dice_loss(masks, seg)
            
#             if mode == 0:
#                 if not os.path.isdir(f"images\\{name[0]}_{idx}"):
#                     os.makedirs(f"images\\train\\{name[0]}_{idx}")
#                 plot_slices(img_patch, masks, seg, f"{name[0]}_{idx}", "train")
#             else:
#                 if not os.path.isdir(f"images\\{name[0]}_{idx}"):
#                     os.makedirs(f"images\\val\\{name[0]}_{idx}")
#                 plot_slices(img_patch, masks, seg, f"{name[0]}_{idx}", "val")