In [1]:
import numpy as np
import matplotlib.pyplot as plt
import os
from tqdm import tqdm
import torch
from torch.utils.data import Dataset, DataLoader
import monai
from local_segment_anything import SamPredictor, sam_model_registry
from local_segment_anything.utils.transforms import ResizeLongestSide
import argparse
# set seeds
torch.manual_seed(2023)
np.random.seed(2023)
#self dataset
from datasets.self_dataset import PolypDataset
from utils.tools import compute_num_params
from utils.tools import set_save_path
from utils.tools import calc_seg

In [2]:
def get_bbox_from_mask(mask):
    '''Returns a bounding box from a mask'''
    y_indices, x_indices = np.where(mask > 0)
    x_min, x_max = np.min(x_indices), np.max(x_indices)
    y_min, y_max = np.min(y_indices), np.max(y_indices)
    # add perturbation to bounding box coordinates
    H, W = mask.shape
    x_min = max(0, x_min - np.random.randint(0, 20))
    x_max = min(W, x_max + np.random.randint(0, 20))
    y_min = max(0, y_min - np.random.randint(0, 20))
    y_max = min(H, y_max + np.random.randint(0, 20))

    return np.array([x_min, y_min, x_max, y_max])

In [3]:
model_name = 'vit_b'
image_root = '/home/hjh/data_disk/hjh/dataset/Kvasir-SEG/images'
mask_root = '/home/hjh/data_disk/hjh/dataset/Kvasir-SEG/masks_new'
separate_info = '/home/hjh/data_disk/hjh/dataset/Kvasir-SEG/fold_info/kvasir_fold_0.csv'
ori_model_path = '/home/hjh/code_lib/SAM/sam_vit_b_01ec64.pth'
# finetune_model_path = ''
device = 'cuda:0'
sam_model_ori = sam_model_registry[model_name](checkpoint=ori_model_path).to(device)
sam_model =sam_model_ori
# sam_model_fine = sam_model_registry[model_type](checkpoint=finetune_model_path).to(device)

In [4]:
valid_dataset = PolypDataset(image_root=image_root, gt_root=mask_root, separate_info=separate_info, train_flag=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=1, shuffle=False, num_workers=4)

number of dataset is 200
[Resize(always_apply=False, p=1.0, height=1024, width=1024, interpolation=1), Normalize(always_apply=False, p=1.0, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225], max_pixel_value=255)]


In [6]:
import medpy.metric.binary as eval_tool
sam_model.eval()
pred_list = []
gt_list = []
epoch = 0
for step, (image, gt2D, boxes) in enumerate(tqdm(valid_dataloader, desc='valid')):
     # do not compute gradients for image encoder and prompt encoder
    with torch.no_grad():
        image_embedding = sam_model.image_encoder(image.to(device))

        # convert box to 1024x1024 grid
#         box_np = boxes.numpy()
#         sam_trans = ResizeLongestSide(sam_model.image_encoder.img_size)
#         box = sam_trans.apply_boxes(box_np, (gt2D.shape[-2], gt2D.shape[-1]))
#         box_torch = torch.as_tensor(box, dtype=torch.float, device=device)
        box_torch = boxes.to(device)
        if len(box_torch.shape) == 2:
            box_torch = box_torch[:, None, :] # (B, 1, 4)

        sparse_embeddings, dense_embeddings = sam_model.prompt_encoder(
            points=None,
            boxes=box_torch,
            masks=None,
        )
        low_res_masks, iou_predictions = sam_model.mask_decoder(
            image_embeddings=image_embedding, # (B, 256, 64, 64)
            image_pe=sam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
            dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
            multimask_output=False,
            )
        pred_list.append(torch.sigmoid(low_res_masks))
        gt_list.append(gt2D)

pred_list = torch.cat(pred_list, 0)
gt_list = torch.cat(gt_list, 0)

sm, em, mDice, mIoU = calc_seg(pred_list, gt_list)
print('epoch %d validation results: SM(%.4f) EM(%.4f) mDice(%.2f) mIoU(%.2f)'%(epoch, sm, em, mDice*100, mIoU*100))

valid: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████| 200/200 [00:46<00:00,  4.32it/s]


epoch 0 validation results: SM(0.9133) EM(0.9017) mDice(89.82) mIoU(83.68)


In [None]:
for name, para in sam_model.named_parameters():
    if 'mask_decoder' in name:
        para.requires_grad_(True)
    elif 'prompt_encoder' in name:
        para.requires_grad_(False)
    elif "image_encoder" in name and ("interaction" not in name and 'SPM' not in name):
        para.requires_grad_(False)
    else:
        para.requires_grad_(True)

In [None]:
for name, para in sam_model.named_parameters():
    if para.requires_grad:
        print(name)

In [None]:
import numpy as np
np.random.randint(10, size=1)

In [None]:
[1,2,3,4][np.random.randint(3, size=1).item()]

In [None]:
np.zeros((3,4))

In [None]:
from matplotlib import pyplot as plt

In [None]:
test_image = np.zeros((100, 200))
test_image[20:50, 70:100] = 1
plt.imshow(test_image)