In [5]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

from segment_anything import sam_model_registry, SamPredictor

sam_checkpoint = "module_vit_h.pth"
model_type = "vit_h"
device = "cuda"

sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)

predictor = SamPredictor(sam)

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

def dice_coefficient(mask1, mask2):
    intersection = np.sum(np.logical_and(mask1, mask2))
    union = np.sum(np.logical_or(mask1, mask2))
    dice = (2.0 * intersection) / (union + intersection)
    return dice


In [6]:
# 预处理文件路径
import nibabel as nib
import glob
import os
training_train = [1,2,3,4,5,6,7,8,9,10,21,22,23,24,25,26,27,28,29,30,31,32,33,34]
training_test =[35,36,37,38,39,40]
training=[1,2,3,4,5,6,7,8,9,10,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40]
image_dir = 'RawData/Training/imagesTr'
label_dir = 'RawData/Training/labelsTr'


In [9]:
# 选择单点/多点/边框
# k表示选择点的数量
# k=0是表示边框
k = 0

In [10]:
# 大致思路：
# 先利用label找到所有可能的灰度值，也就是所有不同的器官
# 接下来还是利用label对每个器官进行选取prompt(单点、多点、边界框)
# 使用SAM的官方训练模型进行训练，接着以mean dice输出训练的结果
dice_sum=0.0
for T in training:
    # 找到图片和标签的路径
    image_filename = f"img{str(T).zfill(4)}.nii.gz"
    image_path = os.path.join(image_dir, image_filename)
    file_img = nib.load(image_path)
    data_img = file_img.get_fdata()

    label_filename = f"label{str(T).zfill(4)}.nii.gz"
    label_path = os.path.join(label_dir, label_filename)
    file_label = nib.load(label_path)
    data_label = file_label.get_fdata()

    total_img = data_img.shape[2]

    # 找到所有可能的灰度值，也就是所有不同的器官
    unique_gray_val = set()
    for picnum in range(total_img):
        slice_normalized_label = cv2.normalize(data_label[:,:,picnum], None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
        slice_8bit_label = np.uint8(slice_normalized_label)
        non_zero_values = np.unique(slice_8bit_label[slice_8bit_label != 0])
        unique_gray_val.update(non_zero_values)
                    
    # 初始化两个列表
    num_grays = len(unique_gray_val)
    value_vector = [0] * num_grays
    num_vector = [0] * num_grays

    # 处理组中所有图片
    for picnum in range(total_img):
        slice_normalized_label = cv2.normalize(data_label[:,:,picnum], None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
        slice_8bit_label = np.uint8(slice_normalized_label)
        slice_bgr_label = cv2.cvtColor(slice_8bit_label, cv2.COLOR_GRAY2BGR)
        # 分解出所有像素点
        h, w = slice_8bit_label.shape
        nonzero_mask = slice_8bit_label != 0
        y_coords, x_coords = np.nonzero(nonzero_mask)
        gray_values = slice_8bit_label[nonzero_mask]
        pixel_data_label = list(zip(x_coords, y_coords, gray_values))
        pixel_data_array_label = np.array(pixel_data_label, dtype=np.int32)

        if pixel_data_array_label.shape[0] == 0 :
            continue
        unique_gray_values = np.unique(pixel_data_array_label[:, 2])

        # 将灰度图转换为三通道BGR图像
        slice_normalized_img = cv2.normalize(data_img[:,:,picnum], None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
        slice_8bit_img = np.uint8(slice_normalized_img)
        image = cv2.cvtColor(slice_8bit_img, cv2.COLOR_GRAY2BGR)

        predictor.set_image(image)
        # 遍历每个灰度值，对每个器官单独进行取点
        for color in unique_gray_values:
            indices = np.where(pixel_data_array_label[:, 2] == color)
            if k != 0 : 
                # 随机选k个点，并获得他们的坐标
                tmp_k=min(k, len(indices[0]))
                selected_indices = np.random.choice(len(indices[0]), tmp_k, replace=False)
                selected_points_coords = list(zip(pixel_data_array_label[indices, 0][0][selected_indices],
                                                    pixel_data_array_label[indices, 1][0][selected_indices]))
                input_label = np.ones(tmp_k, dtype=int)
                input_point = np.array(selected_points_coords)
                input_label = np.array(input_label)
                masks, scores, logits = predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    multimask_output=True,
                )
            else :
                # 边框
                input_box = np.array([max(0, min(pixel_data_array_label[indices, 0][0][:])-9), 
                                      max(0, min(pixel_data_array_label[indices, 1][0][:])-9),
                                      min(w, max(pixel_data_array_label[indices, 0][0][:])+9),
                                      min(h, max(pixel_data_array_label[indices, 1][0][:])+9)])
                masks, _, _ = predictor.predict(
                    point_coords = None,
                    point_labels = None,
                    box = input_box[None, :],
                    multimask_output = False,
                )

            mask = masks[0]
            score = scores[0]
            new_space = slice_normalized_label
            # 将生成的掩码（mask）和地面实况掩码（slice_8bit_label）二值化
            slice_8bit_label_now = np.uint8(new_space)
            slice_8bit_label_now[slice_8bit_label_now != color] = 0
            binary_mask = (mask > 0).astype(np.uint8)
            binary_ground_truth = (slice_8bit_label_now > 0).astype(np.uint8)
            # 计算Dice系数
            dice = dice_coefficient(binary_mask, binary_ground_truth)
            for i, tmp in enumerate(unique_gray_val):
                if color == tmp:
                    value_vector[i] += dice
                    num_vector[i] += 1
                    break
    average_dice = 0
    for i in range(0, num_grays):
        average_dice += value_vector[i]/num_vector[i]
    average_dice=average_dice/num_grays
    print(average_dice)
    dice_sum+=average_dice

mean_dice=dice_sum/30
print(mean_dice)

0.7967357844051169
0.722585551547401
0.7396359611305311
0.7915605422904202
0.732266407943971
0.7403402572704642
0.7237502024989232
0.7371675096732362
0.6897816718341289
0.738045914445731
0.7092389497531043
0.7824487315817205
0.737526207231367
0.7468555537339778
0.7984034060551718
0.7295809442902438
0.7369996214322366
0.7437067089730862
0.7620732985119555
0.7670838182727515
0.7692502058114838
0.777934382941387
0.777669340766949
0.8246253714926941
0.7874660671871921
0.69025184656826
0.7261600790189808
0.7083221328215524
0.7850463326497233
0.7581997462660073
0.7510237516133257
