In [14]:
import numpy as np
import matplotlib.pyplot as plt
import torch
import torch.nn.functional as F
from torchvision import transforms
from PIL import Image
from pycocotools.coco import COCO
from segment_anything import sam_model_registry as sam_model_registry_lora
from segment_anything_new import SamPredictor, sam_model_registry
from segment_anything_new.utils.transforms import ResizeLongestSide
from SurfaceDice import compute_dice_coefficient
from pathlib import Path
from tqdm.auto import tqdm
# set seeds
torch.manual_seed(1234)
np.random.seed(1234)

In [15]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

Create image list

In [16]:
task = 'test'
root_dir = Path(rf'D:\Side_project\SAMed\data\map_1-1_gastrointestinal_coco\{task}')
anno_info = COCO(rf'D:\Side_project\SAMed\data\map_1-1_gastrointestinal_coco\{task}\_annotations.coco.json')

def get_image(image_id):
    image_info = anno_info.loadImgs(image_id)[0]
    image_path = root_dir / image_info['file_name']
    ann_ids = anno_info.getAnnIds(imgIds=image_id)
    anns = anno_info.loadAnns(ann_ids)

    mask = anno_info.annToMask(anns[0])
    
    image_data = np.array(Image.open(root_dir / image_info['file_name']))

    def get_bbox_from_mask(mask):
        '''Returns a bounding box from a mask'''
        y_indices, x_indices = np.where(mask > 0)
        x_min, x_max = np.min(x_indices), np.max(x_indices)
        y_min, y_max = np.min(y_indices), np.max(y_indices)
        # add perturbation to bounding box coordinates
        H, W = mask.shape
        x_min = max(1, x_min - np.random.randint(0, 20))
        x_max = min(W-1, x_max + np.random.randint(0, 20))
        y_min = max(1, y_min - np.random.randint(0, 20))
        y_max = min(H-1, y_max + np.random.randint(0, 20))

        return np.array([x_min, y_min, x_max, y_max])

    # gt_data = io.imread(join(ts_gt_path, test_names[img_idx]))
    bbox_raw = get_bbox_from_mask(mask)
    # print(f'{bbox_raw=}')


    # # preprocess: cut-off and max-min normalization
    lower_bound, upper_bound = np.percentile(image_data, 0.5), np.percentile(image_data, 99.5)
    image_data_pre = np.clip(image_data, lower_bound, upper_bound)
    image_data_pre = (image_data_pre - np.min(image_data_pre))/(np.max(image_data_pre)-np.min(image_data_pre))*255.0
    image_data_pre[image_data==0] = 0
    image_data_pre = np.uint8(image_data_pre)
    H, W, _ = image_data_pre.shape

    return image_data_pre, bbox_raw, mask, H, W
# # predict the segmentation mask using the original SAM model
# ori_sam_predictor.set_image(image_data_pre)
# ori_sam_seg, _, _ = ori_sam_predictor.predict(point_coords=None, box=bbox_raw, multimask_output=False)

loading annotations into memory...
Done (t=0.02s)
creating index...
index created!


Create SAM model

In [17]:
# MedSAM model
ori_sam_model= sam_model_registry['vit_b'](checkpoint=r"D:\Side_project\SAMed\MedSAM\medsam_vit_b.pth")

ori_sam_model.to(device)

ori_sam_predictor = SamPredictor(ori_sam_model)

# Fine tune MedSAM model with mask decoder
checkpoint_gas = r'D:\Side_project\MedSAM\work_dir\demo2D_ICIP\sam_model_best.pth'
sam_model= sam_model_registry['vit_b'](checkpoint=checkpoint_gas)

sam_model.to(device)

# Fine tune MedSAM model with LoRA
from importlib import import_module

checkpoint_sam = r'D:\Side_project\SAMed\MedSAM\medsam_vit_b.pth'
checkpoint_lora = r'D:\Side_project\SAMed\output\MedSAM\results\gastrointestinal_512_pretrain_vit_b_epo200_bs5_lr0.0001\best_epoch=102_valid_loss=0.068.pth'
sam_model, _ = sam_model_registry_lora['vit_b'](image_size=512, 
                                            num_classes=2,
                                            checkpoint=checkpoint_sam, 
                                            pixel_mean=[0, 0, 0],
                                            pixel_std=[1, 1, 1])

sam_model.to(device)

pkg = import_module('sam_lora_image_encoder')
net = pkg.LoRA_Sam(sam_model, 4).cuda()
net.load_lora_parameters(checkpoint_lora)

In [19]:
ori_sam_dsc_list = []
tuned_sam_mask_decoder_list = []
lora_list = []
for image_id in tqdm(range(len(list(root_dir.glob('*.png'))))):
    image_data_pre, bbox_raw, mask, H, W = get_image(image_id)

    ## Original MedSAM predict
    ori_sam_predictor.set_image(image_data_pre)
    ori_sam_seg, _, _ = ori_sam_predictor.predict(point_coords=None, box=bbox_raw, multimask_output=False)
    ori_sam_dsc = compute_dice_coefficient(mask>0, ori_sam_seg>0)
    ori_sam_dsc_list.append(ori_sam_dsc)


    ## Fine tune MedSAM with mask decoder predict
    sam_transform = ResizeLongestSide(sam_model.image_encoder.img_size)
    resize_img = sam_transform.apply_image(image_data_pre)
    # print(resize_img.shape)
    # plt.imshow(resize_img)

    resize_img_tensor = transforms.ToTensor()(resize_img).to(device)
    # print(resize_img_tensor.shape)
    # resize_img_tensor = torch.as_tensor(resize_img.transpose(2, 0, 1)).to(device)

    input_image = sam_model.preprocess(resize_img_tensor.unsqueeze(dim=0)) # (1, 3, 1024, 1024)
    # assert input_image.shape == (1, 3, sam_model.image_encoder.img_size, sam_model.image_encoder.img_size), 'input image should be resized to 1024*1024'

    with torch.no_grad():
        # pre-compute the image embedding
        ts_img_embedding = sam_model.image_encoder(input_image)
        # convert box to 1024x1024 grid
        bbox = sam_transform.apply_boxes(bbox_raw, (H, W))
        # print(f'{bbox_raw=} -> {bbox=}')
        box_torch = torch.as_tensor(bbox, dtype=torch.float, device=device)
        if len(box_torch.shape) == 2:
            box_torch = box_torch[:, None, :] # (B, 4) -> (B, 1, 4)
        
        sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
            points=None,
            boxes=box_torch,
            masks=None,
        )
        medsam_seg_prob, _ = sam_model.mask_decoder(
            image_embeddings=ts_img_embedding.to(device), # (B, 256, 64, 64)
            image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
            )
        medsam_seg_prob = torch.sigmoid(medsam_seg_prob)

        pred_shapped_masks = F.interpolate(
            medsam_seg_prob, # shape: (1, 1, 256, 256)
            (224, 224),
            mode="bilinear",
            align_corners=False,
        )

        # convert soft mask to hard mask
        pred_shapped_masks = pred_shapped_masks.cpu().numpy().squeeze()
        medsam_seg = (pred_shapped_masks > 0.5).astype(np.uint8)
    medsam_dsc = compute_dice_coefficient(mask>0, medsam_seg>0)
    tuned_sam_mask_decoder_list.append(medsam_dsc)

    ### Fine tune MedSAM with LoRA predict
    transform = transforms.Compose([transforms.ToPILImage(),
                                transforms.Resize(size=(512,512)),
                                transforms.ToTensor()])
    img_tensor = transform(image_data_pre)
    img_tensor = img_tensor.unsqueeze(dim=0).to(device)
    with torch.no_grad():
        output = net(img_tensor, False, 512)
    transform = transforms.Compose([transforms.Resize(size=(224,224))])

    output_masks = torch.argmax(torch.softmax(output['masks'], dim=1), dim=1, keepdim=True)
    output_masks = output_masks.squeeze(dim=0) * 50
    output_masks = transform(output_masks)
    output_masks_np_lora = output_masks.cpu().numpy()
    output_masks_np_lora = (output_masks_np_lora > 0).astype(np.uint8)
    sam_lora_dsc = compute_dice_coefficient(mask>0, output_masks_np_lora>0)
    lora_list.append(sam_lora_dsc)

100%|██████████| 396/396 [03:00<00:00,  2.19it/s]


Test

In [20]:
print(f'Mean dice of Original MedSAM: {sum(ori_sam_dsc_list) / len(ori_sam_dsc_list):.4f}')
print(f'Mean dice of Tuned MedSAM with mask decoder: {sum(tuned_sam_mask_decoder_list) / len(tuned_sam_mask_decoder_list):.4f}')
print(f'Mean dice of Tuned MedSAM with LoRA: {sum(lora_list) / len(lora_list):.4f}')

Mean dice of Original MedSAM: 0.7393
Mean dice of Tuned MedSAM with mask decoder: 0.4291
Mean dice of Tuned MedSAM with LoRA: 0.3022


Valid

In [11]:
print(f'Mean dice of Original MedSAM: {sum(ori_sam_dsc_list) / len(ori_sam_dsc_list):.4f}')
print(f'Mean dice of Tuned MedSAM with mask decoder: {sum(tuned_sam_mask_decoder_list) / len(tuned_sam_mask_decoder_list):.4f}')
print(f'Mean dice of Tuned MedSAM with LoRA: {sum(lora_list) / len(lora_list):.4f}')

Mean dice of Original MedSAM: 0.6442
Mean dice of Tuned MedSAM with mask decoder: 0.2445
Mean dice of Tuned MedSAM with LoRA: 0.5942
