In [1]:
import monai.transforms as mt
import logging
import sys
import matplotlib.pyplot as plt
import ignite
import numpy as np
import torch
import monai
import torchvision.transforms as transform
import warnings
warnings.filterwarnings("ignore")  # remove some scikit-image warnings

In [15]:
def prompt_label(Ground_truth_mask,prompt = "box"):
    label_images = []
    prompt_images = []
    local_gt = Ground_truth_mask
    for i in range(1,int(np.max(local_gt))+1):
        i_gt_mask = np.where(local_gt == i, 1, 0)
        if np.sum(i_gt_mask.flatten()) == 0:
            continue
        else:
            label_images.append(i_gt_mask)
            if prompt == "point":
                # 随机选一个点
                point = np.zeros((1,2))
                indices = np.nonzero(i_gt_mask)
                random_index = np.random.randint(0,len(indices[0])-1)
                point[0,0] ,point[0,1] = indices[2][random_index],indices[1][random_index]
                prompt_images.append(point)
                #SAM_mask , scores, logits = predictor.predict(point_coords=point,point_labels=np.array([1]),multimask_output=False,)
            elif prompt == "points":
                # 随机选5个点
                indices = np.nonzero(i_gt_mask)
                random_index = np.random.random_integers(0,len(indices[0])-1,5)
                points = np.zeros((len(random_index),2))
                for j in range(len(random_index)):
                    points[j,0], points[j,1] = indices[2][random_index[j]],indices[1][random_index[j]]
                prompt_images.append(points)
                #SAM_mask , scores, logits = predictor.predict(point_coords=points,point_labels=np.ones(len(points)),multimask_output=False,)
            elif prompt == "box":
                indices = np.nonzero(i_gt_mask)
                x_min = min(indices[2])
                x_max = max(indices[2])
                y_min = min(indices[1])
                y_max = max(indices[1])
                input_box = np.array([x_min,y_min,x_max,y_max])
                prompt_images.append(input_box)
                #SAM_mask , scores, logits = predictor.predict(point_coords=None,point_labels=None,box=input_box,multimask_output=False,)
            else:
                raise NameError("prompt should be in [\"point\",\"points\",\"box\"]")
            
    return label_images, prompt_images

def Gray2RGB(image):
    # 进来的是 batchsize = 1 * 512 * 512
    # 返回一个 B*C*H*W
    input_img = torch.zeros((image.shape[1],image.shape[2],3))
    input_img[:,:,0] = input_img[:,:,1]  = input_img[:,:,2]  = image[0,:,:] #(image-torch.min(image))/(torch.max(image))*255
    input_img = np.uint8(input_img)
    #input_img = input_img.to(device="cuda")
    return input_img

def Extra_Dim(image):
    image_4d = np.zeros((1,3,image.shape[0],image.shape[1]))
    image_4d[0,0,:,:] = image[:,:,0]
    image_4d[0,1,:,:] = image[:,:,1]
    image_4d[0,2,:,:] = image[:,:,2]
    return image_4d

from segment_anything.utils.transforms import ResizeLongestSide

def image_preprocess(image,sam_model,device):
    image = Gray2RGB(image)
    transform = ResizeLongestSide(sam_model.image_encoder.img_size)
    input_image = transform.apply_image(image)
    input_image_torch = torch.as_tensor(input_image, device=device)
    transformed_image = input_image_torch.permute(2, 0, 1).contiguous()[None, :, :, :]
    input_image = sam_model.preprocess(transformed_image)
    original_image_size = image.shape[:2]
    input_size = tuple(transformed_image.shape[-2:])
    return input_image, original_image_size, input_size

import torch.nn as nn


def Dice(SAM_mask,i_gt_mask):
    output_mat = SAM_mask[0,0,:,:] * i_gt_mask[0,0,:,:]
    overlap = torch.sum(output_mat)
    return 2* overlap / (torch.sum(SAM_mask)+torch.sum(i_gt_mask))

def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([30/255, 144/255, 255/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)   
    
def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))  

In [3]:
from monai.apps import datasets
import json
import os
from monai import transforms
import torchvision
from monai.transforms import LoadImaged,EnsureChannelFirstd,Compose,ToTensord
import logging

image_transform = Compose([
    LoadImaged(keys=("image", "label")),
#    torchvision.transforms.Grayscale(num_output_channels=3),
    ToTensord(keys=("image", "label"))
])

with open("./training_data.json") as file1:
    dataset = json.load(file1)

#train_dataset =  monai.data.CacheDataset(dataset["training"], transform=image_transform)
train_dataset =  monai.data.Dataset(dataset, transform=image_transform)
train_loader  = monai.data.DataLoader(train_dataset, batch_size=1, shuffle=True, num_workers=1, drop_last=True)

In [18]:
from segment_anything import SamPredictor, sam_model_registry 
import torch 
import os 
device = "cuda" if torch.cuda.is_available() else "cpu"
# 根据检查点加载模型

sam_model = sam_model_registry["vit_b"](checkpoint="./checkpoint/sam_vit_b_01ec64.pth")
sam_model.train()
predictor = SamPredictor(sam_model)
for name, param in sam_model.named_parameters():
  if name.startswith("vision_encoder") or name.startswith("prompt_encoder"):
    param.requires_grad_(False)
sam_model.to(device=device)
# 定义损失函数和优化器
# hyperparameters
max_epochs = 1
lr = 5e-6
wd = 0
optimizer = torch.optim.Adam(sam_model.mask_decoder.parameters(), lr=lr, weight_decay=wd)
seg_loss = monai.losses.DiceCELoss(sigmoid=True, squared_pred=True, reduction='mean')

In [22]:
from tqdm import tqdm
from statistics import mean
import torch
import torch.nn.functional as F
from torchvision.transforms import Resize
from segment_anything import SamPredictor, sam_model_registry, utils
losses = []
for epoch in range(max_epochs):
    epoch_losses = []
    batch_number = 0
    for batch in tqdm(train_loader):
        optimizer.zero_grad()
        batch_number += 1
    # forward pass
    # batch image : [batch_size * 512 * 512], label : [batch_size * 512 * 512]
        loss = 0
        input_image, input_label = batch["image"],batch["label"]
        input_image, original_image_size, input_size = image_preprocess(input_image,sam_model,device)
        label_images, prompt_images = prompt_label(input_label)
        with torch.no_grad():
            image_embedding = sam_model.image_encoder(input_image)
        for k in range(len(label_images)):
            box = ResizeLongestSide(sam_model.image_encoder.img_size).apply_boxes(prompt_images[k], original_image_size)
            box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
            box_torch = box_torch[None, :]
            with torch.no_grad():
                sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(points=None, boxes=box_torch, masks=None)
        # compute loss
            low_res_masks, iou_predictions = sam_model.mask_decoder(
            image_embeddings=image_embedding,
            image_pe=sam_model.prompt_encoder.get_dense_pe(),
            sparse_prompt_embeddings=sparse_embeddings,
            dense_prompt_embeddings=dense_embeddings,
            multimask_output=False,
            )
            upscaled_masks = sam_model.postprocess_masks(low_res_masks, input_size, original_image_size).to(device)
            binary_mask = F.normalize(F.threshold(upscaled_masks, 0.0, 0))

            gt_mask_resized = torch.from_numpy(np.resize(label_images[k], (1, 1, label_images[k].shape[1], label_images[k].shape[2]))).to(device)
            gt_binary_mask = torch.as_tensor(gt_mask_resized > 0, dtype=torch.float32)
    
            loss += 1-Dice(binary_mask, gt_binary_mask)
            loss /= len(label_images)
        
        loss.backward()
        optimizer.step()
        epoch_losses.append(loss.item())
        losses.append(epoch_losses)
        torch.cuda.empty_cache()
        if batch_number % 10 == 0:
            print(f'EPOCH: {epoch},processf{batch_number}')
            print(f'Mean loss: {mean(epoch_losses)}')
    PATH = f"finetune/fine_tuned_sam_{1+epoch}.pth"
    torch.save(sam_model.state_dict(), PATH)

  1%|          | 10/1770 [01:22<3:35:13,  7.34s/it]

EPOCH: 0,processf10
Mean loss: 0.25659630112349985


  1%|          | 20/1770 [02:34<3:23:48,  6.99s/it]

EPOCH: 0,processf20
Mean loss: 0.20913966819643975


  2%|▏         | 30/1770 [03:41<2:58:53,  6.17s/it]

EPOCH: 0,processf30
Mean loss: 0.2256810624152422


  2%|▏         | 40/1770 [04:50<3:21:15,  6.98s/it]

EPOCH: 0,processf40
Mean loss: 0.19625952476635575


  3%|▎         | 50/1770 [06:02<3:36:58,  7.57s/it]

EPOCH: 0,processf50
Mean loss: 0.18174210786819459


  3%|▎         | 57/1770 [06:51<3:16:37,  6.89s/it]

In [6]:
PATH = "finetune/fine_tuned_sam.pth"
torch.save(sam_model.state_dict(), PATH)