# SAM Inference with no inference

https://github.com/bowang-lab/MedSAM

1. Download the checkpoint and place it at `work_dir/MedSAM`
2. (Download Dataset)
3. pre-process dataset with `pre_CT_MR.py`

## Setup

In [9]:
import os, sys
dir1 = os.path.abspath(os.path.join(os.path.abspath(''), '..', '..'))
if not dir1 in sys.path: sys.path.append(dir1)
dir2 = os.path.abspath(os.path.join(os.path.abspath(''), '..'))
if not dir2 in sys.path: sys.path.append(dir2)

In [2]:
from utils.environment import setup_data_vars
setup_data_vars()

In [3]:
# !pip install connected-components-3d
# !wget -P $PROJECT_DIR/models/MedSAM/work_dir/MedSAM/ https://dl.fbaipublicfiles.com/segment_anything/sam_vit_b_01ec64.pth

## Plotting Functions

In [4]:
#  %% environment and functions
import numpy as np
import matplotlib.pyplot as plt
import os
join = os.path.join
import torch
from segment_anything import sam_model_registry
from skimage import io, transform
import torch.nn.functional as F

# visualization functions
# source: https://github.com/facebookresearch/segment-anything/blob/main/notebooks/predictor_example.ipynb
# change color to avoid red and green
def show_mask(mask, ax, random_color=False):
    if random_color:
        color = np.concatenate([np.random.random(3), np.array([0.6])], axis=0)
    else:
        color = np.array([251/255, 252/255, 30/255, 0.6])
    h, w = mask.shape[-2:]
    mask_image = mask.reshape(h, w, 1) * color.reshape(1, 1, -1)
    ax.imshow(mask_image)

def show_box(box, ax):
    x0, y0 = box[0], box[1]
    w, h = box[2] - box[0], box[3] - box[1]
    ax.add_patch(plt.Rectangle((x0, y0), w, h, edgecolor='blue', facecolor=(0,0,0,0), lw=2))

@torch.no_grad()
def medsam_inference(medsam_model, img_embed, box_1024, H, W):
    box_torch = torch.as_tensor(box_1024, dtype=torch.float, device=img_embed.device)
    if len(box_torch.shape) == 2:
        box_torch = box_torch[:, None, :] # (B, 1, 4)

    sparse_embeddings, dense_embeddings = medsam_model.prompt_encoder(
        points=None,
        boxes=box_torch,
        masks=None,
    )
    low_res_logits, _ = medsam_model.mask_decoder(
        image_embeddings=img_embed, # (B, 256, 64, 64)
        image_pe=medsam_model.prompt_encoder.get_dense_pe(), # (1, 256, 64, 64)
        sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 256)
        dense_prompt_embeddings=dense_embeddings, # (B, 256, 64, 64)
        multimask_output=False,
        )

    low_res_pred = torch.sigmoid(low_res_logits)  # (1, 1, 256, 256)

    low_res_pred = F.interpolate(
        low_res_pred,
        size=(H, W),
        mode="bilinear",
        align_corners=False,
    )  # (1, 1, gt.shape)
    low_res_pred = low_res_pred.squeeze().cpu().numpy()  # (256, 256)
    medsam_seg = (low_res_pred > 0.5).astype(np.uint8)
    return medsam_seg


## Pre-process Data (Assume for now this is for Anorectum)

In [5]:
nii_path = os.path.join(os.environ.get('nnUNet_raw'), os.environ.get('Anorectum'), os.environ.get('data_trainingImages')) # path to the nii images
gt_path = os.path.join(os.environ.get('nnUNet_raw'), os.environ.get('Anorectum'), os.environ.get('data_trainingLabels')) # path to the ground truth

In [11]:
# %run /vol/bitbucket/az620/radiotherapy/models/MedSAM/pre_CT_MR.py
from pre_CT_MR import pre_CT_MR
pre_CT_MR(
    nii_path = nii_path
  , gt_path = gt_path
  , npy_path = 
  , anatomy = nii_path
)

## Inference Pipeline