# MedSAM test with Prostate MRI
https://github.com/bowang-lab/MedSAM

In [1]:
!pip install git+https://github.com/bowang-lab/MedSAM.git
!pip install gdown

model_id = "1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_&confirm=t"
!gdown $model_id
!mv medsam_vit_b.pth ../

Collecting git+https://github.com/bowang-lab/MedSAM.git
  Cloning https://github.com/bowang-lab/MedSAM.git to c:\users\ultrastmedtech\appdata\local\temp\pip-req-build-xxoaxxoh
  Resolved https://github.com/bowang-lab/MedSAM.git to commit 71237ca7a942e48d2fee1b40483769ed369a2adb
  Preparing metadata (setup.py): started
  Preparing metadata (setup.py): finished with status 'done'


  Running command git clone --filter=blob:none --quiet https://github.com/bowang-lab/MedSAM.git 'C:\Users\ultrastmedtech\AppData\Local\Temp\pip-req-build-xxoaxxoh'




Downloading...
From (uriginal): https://drive.google.com/uc?id=1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_
From (redirected): https://drive.google.com/uc?id=1UAmWL88roYR7wKlnApw5Bcuzf2iQgk6_&confirm=t&uuid=8cdee512-a511-4e19-a55b-4741652e21ee
To: E:\ITKMedicalImageProcessing_demo\medsam_vit_b.pth

  0%|          | 0.00/375M [00:00<?, ?B/s]
  2%|1         | 5.77M/375M [00:00<00:06, 55.4MB/s]
  5%|4         | 17.3M/375M [00:00<00:04, 88.7MB/s]
  7%|6         | 26.2M/375M [00:00<00:04, 84.6MB/s]
  9%|9         | 35.1M/375M [00:00<00:03, 86.1MB/s]
 12%|#2        | 45.6M/375M [00:00<00:03, 91.8MB/s]
 15%|#5        | 57.1M/375M [00:00<00:03, 97.8MB/s]
 18%|#7        | 67.1M/375M [00:00<00:03, 94.5MB/s]
 21%|##        | 77.1M/375M [00:00<00:03, 85.9MB/s]
 23%|##2       | 86.0M/375M [00:00<00:03, 82.4MB/s]
 26%|##6       | 97.5M/375M [00:01<00:03, 90.4MB/s]
 29%|##8       | 109M/375M [00:01<00:02, 94.6MB/s] 
 32%|###1      | 118M/375M [00:01<00:02, 91.4MB/s]
 34%|###4      | 128M/375M [00:01<00:02, 89.4

## Import

In [1]:
# %% environment and functions
import re
import cv2
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
import torch
from segment_anything import sam_model_registry
from skimage import io, transform
import torch.nn.functional as F
import SimpleITK as sitk
from pathlib import Path
from lib.folder import FolderMg
from lib.sam_base import BasicSAM

In [2]:
# visualization functions
# source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
# change color to avoid red and green
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251/255, 252/255, 30/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_points(coords, labels, ax, marker_size=375):
    pos_points = coords[labels==1]
    neg_points = coords[labels==0]
    ax.scatter(pos_points[:, 0], pos_points[:, 1], color='green', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)
    ax.scatter(neg_points[:, 0], neg_points[:, 1], color='red', marker='*', s=marker_size, edgecolor='white', linewidth=1.25)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='green', facecolor=(0,0,0,0), lw=2))

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :] # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed, # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
        multimask_output=False,
        )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg

In [3]:
sourceFolderName = "mri-prostate-slices-resample"
sourceDataPath = Path("data").joinpath(sourceFolderName)
sourceMg = FolderMg(sourceDataPath)
sourceMg.ls()


Current Folder 'mri-prostate-slices-resample' contains 11 folders, which are:
  - 1_029_10slices
  - 2_837_19slices
  - 3_947_10slices
  - 4_920_19slices
  - 6_624_17slices
  - ...

Current Folder 'mri-prostate-slices-resample' contains NO files



In [4]:
destinationPath = Path("result").joinpath("medsam-"+sourceFolderName)
samModel = BasicSAM("medsam_vit_b")

SAM Model set up finished.


In [5]:
slice_range = {
    "1": [2, 7],
    "2": [9, 12],
    "3": [3, 6],
    "4": [6, 12],
    "6": [6, 9],
    "7": [4, 15],
    "8": [5, 7],
    "9": [6, 9],
    "10": [3, 9],
    "11": [7, 10],
    "12": [4, 7],
}

## Main

In [7]:
print("Processing folders:")
for fd in sourceMg.dirs:
    print(f"- {fd.name}")
    fdMg = FolderMg(fd)

    outputFolderPath = destinationPath.joinpath(f"{fd.name}")
    if not outputFolderPath.exists():
        outputFolderPath.mkdir(parents=True)
    fdIdx = re.search("([0-9]+)", fd.stem).group(0)
    # print(f"fdIdx:{fdIdx}")
    # middleFile = fdMg.files[int(fdMg.nFile / 2)]
    # figSavePath = outputFolderPath.joinpath(f"{middleFile.name}_mask_")
    # masks, scores = predictOneImg(middleFile, predictor, figSavePath, False)
    sliceSegResult = []
    for slice in fdMg.files:
        sliceIdx = int(re.search("([0-9]+)", slice.stem).group(0))
        figSavePath = outputFolderPath.joinpath(f"{slice.name}")
        if sliceIdx < slice_range[fdIdx][0] or sliceIdx > slice_range[fdIdx][1]:
            mask = samModel.skipPredictAndSaveEmpty(slice, figSavePath)
            sliceSegResult.append(mask.astype(int).squeeze())
            continue
        # print(
        #     f"stem:{slice.stem}"
        #     f"slice_range[fdIdx][0]: {slice_range[fdIdx][0]}, "
        #     f"slice_range[fdIdx][1]:{slice_range[fdIdx][1]}, "
        #     f"sliceIdx:{sliceIdx}"
        # )
        masks, scores = samModel.predictOneImg(slice, figSavePath, onlyFirstMask=True)
        sliceSegResult.append(masks[0].astype(int).squeeze())
    imgSize = (masks.shape[0], masks.shape[1], fdMg.nFile)  # x,y,z
    # imgSegResult = sitk.Image(masks[0].shape[0],masks[0].shape[1], fdMg.nFile, sitk.sitkUInt8)
    segImgArray = np.array([sliceSegResult]).astype(int).squeeze()
    segImg = sitk.GetImageFromArray(segImgArray)
    imgSavePath = destinationPath.joinpath(f"seg_{fd.name}.nii.gz")
    sitk.WriteImage(segImg, imgSavePath)
    plt.close("all")
print("Finished")

Processing folders:
- 1_029_10slices
- 2_837_19slices
- 3_947_10slices
- 4_920_19slices
- 6_624_17slices
- 7_709_20slices
- 8_258_16slices
- 9_825_15slices
- 10_213_13slices
- 11_543_18slices
- 12_244_14slices
Finished


# Result

not accurate for prompt