In [1]:
import numpy as np
import torch
import matplotlib.pyplot as plt
import cv2

In [3]:
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)
    
def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))    

def dice_coefficient(mask1, mask2):
    intersection = np.sum(np.logical_and(mask1, mask2))
    union = np.sum(np.logical_or(mask1, mask2))
    
    dice = (2.0 * intersection) / (union + intersection)
    return dice


In [5]:
# 预处理文件路径
import nibabel as nib
import glob
import os
training_train = [1,2,3,4,5,6,7,8,9,10,21,22,23,24,25,26,27,28,29,30,31,32,33,34]
training_test =[35,36,37,38,39,40]
training=[1,2,3,4,5,6,7,8,9,10,21,22,23,24,25,26,27,28,29,30,31,32,33,34,35,36,37,38,39,40]
image_dir = 'RawData/Training/imagesTr'
label_dir = 'RawData/Training/labelsTr'

In [6]:
# 初始化和配置一个用于图像分割任务的深度学习模型
import sys
sys.path.append("..")
from segment_anything import sam_model_registry, SamPredictor
sam_checkpoint = "sam_vit_h_4b8939.pth"
model_type = "vit_h"
device ="cuda"
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
sam.to(device=device)
predictor = SamPredictor(sam)

In [12]:
# 设置adam优化器，学习率为1e-5
# 每次执行10次迭代
# 损失函数采用(1-dice)^2
# 使用反向传播来计算梯度
for T in training_train:
    optimizer=torch.optim.Adam(sam.parameters(),lr=1e-5)
    for epoch in range(10):
        optimizer.zero_grad()
        # 找到图片和标签的路径
        image_filename = f"img{str(T).zfill(4)}.nii.gz"
        image_path = os.path.join(image_dir, image_filename)
        file_img = nib.load(image_path)
        data_img = file_img.get_fdata()

        label_filename = f"label{str(T).zfill(4)}.nii.gz"
        label_path = os.path.join(label_dir, label_filename)
        file_label = nib.load(label_path)
        data_label = file_label.get_fdata()

        total_img = data_img.shape[2]

        unique_gray_val = set()
        for picnum in range(total_img):
            slice_normalized_label = cv2.normalize(data_label[:,:,picnum], None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
            slice_8bit_label = np.uint8(slice_normalized_label)
            non_zero_values = np.unique(slice_8bit_label[slice_8bit_label != 0])
            unique_gray_val.update(non_zero_values)
                        
        # 初始化两个列表
        num_grays = len(unique_gray_val)
        value_vector = [0] * num_grays
        num_vector = [0] * num_grays

        for picnum in range(total_img):
            slice_normalized_label = cv2.normalize(data_label[:,:,picnum], None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
            slice_8bit_label = np.uint8(slice_normalized_label)
            slice_bgr_label = cv2.cvtColor(slice_8bit_label, cv2.COLOR_GRAY2BGR)
            # 分解出所有像素点
            h, w = slice_8bit_label.shape
            nonzero_mask = slice_8bit_label != 0
            y_coords, x_coords = np.nonzero(nonzero_mask)
            gray_values = slice_8bit_label[nonzero_mask]
            pixel_data_label = list(zip(x_coords, y_coords, gray_values))
            pixel_data_array_label = np.array(pixel_data_label, dtype=np.int32)

            if pixel_data_array_label.shape[0] == 0 :
                continue
            unique_gray_values = np.unique(pixel_data_array_label[:, 2])

            # 将灰度图转换为三通道BGR图像
            slice_normalized_img = cv2.normalize(data_img[:,:,picnum], None, alpha=0, beta=255, norm_type=cv2.NORM_MINMAX)
            slice_8bit_img = np.uint8(slice_normalized_img)
            image = cv2.cvtColor(slice_8bit_img, cv2.COLOR_GRAY2BGR)

            predictor.set_image(image)
            # 遍历每个灰度值，对每个器官单独进行取点
            for color in unique_gray_values:
                indices = np.where(pixel_data_array_label[:, 2] == color)
                # 随机选k个点，并获得他们的坐标
                k=1
                tmp_k=min(k, len(indices[0]))
                selected_indices = np.random.choice(len(indices[0]), tmp_k, replace=False)
                selected_points_coords = list(zip(pixel_data_array_label[indices, 0][0][selected_indices],
                                                    pixel_data_array_label[indices, 1][0][selected_indices]))
                input_label = np.ones(tmp_k, dtype=int)
                input_point = np.array(selected_points_coords)
                input_label = np.array(input_label)
                masks, scores, logits = predictor.predict(
                    point_coords=input_point,
                    point_labels=input_label,
                    multimask_output=True,
                )

                mask = masks[0]
                score = scores[0]
                new_space = slice_normalized_label
                # 将生成的掩码（mask）和地面实况掩码（slice_8bit_label）二值化
                slice_8bit_label_now = np.uint8(new_space)
                slice_8bit_label_now[slice_8bit_label_now != color] = 0
                binary_mask = (mask > 0).astype(np.uint8)
                binary_ground_truth = (slice_8bit_label_now > 0).astype(np.uint8)
                # 计算Dice系数
                dice = dice_coefficient(binary_mask, binary_ground_truth)
                for i, tmp in enumerate(unique_gray_val):
                    if color == tmp:
                        value_vector[i] += dice
                        num_vector[i] += 1
                        break
        
        dice = 0
        for i in range(0, num_grays):
            dice += value_vector[i]/num_vector[i]
        dice=dice/num_grays
        print(dice)
        dice_tensor=torch.tensor(dice)
        loss=torch.tensor((1-dice_tensor)*(1-dice_tensor),requires_grad=True)
        loss.backward()
        optimizer.step()

sam.eval()

# 保存模型状态字典
torch.save(sam.state_dict(), 'module_vit_h.pth')

0.6532324359954136


  loss=torch.tensor((1-dice_tensor)*(1-dice_tensor),requires_grad=True)


0.6605175390205862
0.648856640173609
0.6711742786764132
0.6652599318006805
0.6592977586462829
0.6706734907585711
0.6687158290685811
0.6673886078245513
0.6557984743762273
0.5193166556237327
0.5256575644967201
0.5443172748669228
0.5412868123750355
0.5217958726964907
0.5307700860317741
0.5213421686231395
0.5312663021224208
0.510967533901885
0.5286986194897922
0.6344788458758233
0.6256560608253608
0.6298765060453938
0.630665755127036
0.6314204463564164
0.6215037040702137
0.6269229856449327
0.6267997781619916
0.6353832207573376
0.6316990769083605
0.6886643906140657
0.6939859973618082
0.687750718241865
0.6872339104058725
0.693233868619386
0.6741680662844434
0.693167086708595
0.6864280964707353
0.6909404378338018
0.6791668268942045
0.5589229691596402
0.5564485888726638
0.5584315848998022
0.5474607731143751
0.5562945168720554
0.5550176198562747
0.546262928416384
0.5538672244487896
0.5536101783028354
0.5499158343126873
0.559004300569324
0.567112972288801
0.5561154216491715
0.5716140370210687
0.