In [55]:
import torchio as tio
from torchio.data.io import sitk_to_nib
import torch
import numpy as np
import torch
import SimpleITK as sitk
import torch.nn.functional as F
from itertools import product
import edt
from argparse import Namespace
import pickle
from torch import tensor
from copy import deepcopy

from utils.analysisUtils import compute_dice
from utils.base_classes import Points, SegmenterWrapper, Inferer
from utils.modelUtils import load_sammed3d

In [26]:
def get_pos_clicks3D(gt, n_clicks, seed = None):
    if seed is not None:
        np.random.seed(seed)
        
    volume_fg = np.where(gt==1) # Get foreground indices (formatted as triple of arrays)
    volume_fg = np.array(volume_fg).T # Reformat to numpy array of shape n_fg_voxels x 3

    n_fg_voxels = len(volume_fg)

    # Error testing
    if n_fg_voxels == 0:
        raise RuntimeError(f'No foreground voxels found! Check that the supplied label is a binary segmentation mask with foreground coded as 1')

    if n_fg_voxels < n_clicks:
        raise RuntimeError(f'More foreground points were requested than the number of foreground voxels in the volume')

    point_indices = np.random.choice(n_fg_voxels, size = n_clicks, replace = False)
    pos_coords = volume_fg[point_indices]  
    pos_coords = Points({'coords': pos_coords, 'labels': [1]*len(pos_coords)})
    return(pos_coords)

def get_next_click3D_torch_ritm(prev_seg, gt_semantic_seg):
    mask_threshold = 0.5

    batch_points = []
    batch_labels = []
    # dice_list = []

    pred_masks = (prev_seg > mask_threshold)
    true_masks = (gt_semantic_seg > 0)
    fn_masks = torch.logical_and(true_masks, torch.logical_not(pred_masks))
    fp_masks = torch.logical_and(torch.logical_not(true_masks), pred_masks)

    fn_mask_single = F.pad(fn_masks, (1,1,1,1,1,1), 'constant', value=0).to(torch.uint8)[0,0]
    fp_mask_single = F.pad(fp_masks, (1,1,1,1,1,1), 'constant', value=0).to(torch.uint8)[0,0]
    fn_mask_dt = torch.tensor(edt.edt(fn_mask_single.cpu().numpy(), black_border=True, parallel=4))[1:-1, 1:-1, 1:-1]
    fp_mask_dt = torch.tensor(edt.edt(fp_mask_single.cpu().numpy(), black_border=True, parallel=4))[1:-1, 1:-1, 1:-1]
    fn_max_dist = torch.max(fn_mask_dt)
    fp_max_dist = torch.max(fp_mask_dt)
    is_positive = fn_max_dist > fp_max_dist # the biggest area is selected to be interaction point
    dt = fn_mask_dt if is_positive else fp_mask_dt
    to_point_mask = dt > (max(fn_max_dist, fp_max_dist) / 2.0) # use a erosion area
    to_point_mask = to_point_mask[None, None]
    # import pdb; pdb.set_trace()

    for i in range(gt_semantic_seg.shape[0]):
        points = torch.argwhere(to_point_mask[i])
        point = points[np.random.randint(len(points))]
        if fn_masks[i, 0, point[1], point[2], point[3]]:
            is_positive = True
        else:
            is_positive = False

        bp = point[1:].clone().detach().reshape(1,1,3) 
        bl = torch.tensor([int(is_positive),]).reshape(1,1)
        batch_points.append(bp)
        batch_labels.append(bl)

    return batch_points, batch_labels # , (sum(dice_list)/len(dice_list)).item()    

click_methods = {
    'default': get_next_click3D_torch_ritm
}

def get_img_gt_sammed3d(img_path, gt_path):    
    infer_transform = [
        tio.ToCanonical(),
    ]
    transform = tio.Compose(infer_transform)

    sitk_image = sitk.ReadImage(img_path)
    sitk_label = sitk.ReadImage(gt_path)

    if sitk_image.GetOrigin() != sitk_label.GetOrigin():
        sitk_image.SetOrigin(sitk_label.GetOrigin())
    if sitk_image.GetDirection() != sitk_label.GetDirection():
        sitk_image.SetDirection(sitk_label.GetDirection())

    sitk_image_arr, _ = sitk_to_nib(sitk_image)
    sitk_label_arr, _ = sitk_to_nib(sitk_label)

    subject = tio.Subject(
        image = tio.ScalarImage(tensor=sitk_image_arr),
        label = tio.LabelMap(tensor=sitk_label_arr),
    )

    if transform:
        subject = transform(subject)
                        
    #subject = crop_transform(subject)

    return subject.image.data.clone().detach(), subject.label.data.clone().detach()

def preprocess_into_patches(img3D, pts_prompt = None, cheat = False, gt = None, offset_mode="center"):
    subject = tio.Subject(
        image = tio.ScalarImage(tensor=img3D)
    )
    
    if cheat:
        subject.add_image(tio.LabelMap(tensor = gt,
                                       affine = subject.image.affine,),
                          image_name = 'label')
        crop_transform = tio.CropOrPad(mask_name='label', 
                            target_shape=(128,128,128))
    else:
        coords = pts_prompt.value['coords']
        crop_mask = torch.zeros_like(subject.image.data)
        crop_mask[*coords.T] = 1
        subject.add_image(tio.LabelMap(tensor = crop_mask,
                                        affine = subject.image.affine),
                            image_name="crop_mask")
        crop_transform = tio.CropOrPad(mask_name='crop_mask', 
                                target_shape=(128,128,128))
        

    padding_params, cropping_params = crop_transform.compute_crop_or_pad(subject)
    # cropping_params: (x_start, x_max-(x_start+roi_size), y_start, ...)
    # padding_params: (x_left_pad, x_right_pad, y_left_pad, ...)
    if(cropping_params is None): cropping_params = (0,0,0,0,0,0)
    if(padding_params is None): padding_params = (0,0,0,0,0,0)
    roi_shape = crop_transform.target_shape
    vol_bound = (0, img3D.shape[1], 0, img3D.shape[2], 0, img3D.shape[3])
    center_oob_ori_roi=(
        cropping_params[0]-padding_params[0], cropping_params[0]+roi_shape[0]-padding_params[0],
        cropping_params[2]-padding_params[2], cropping_params[2]+roi_shape[1]-padding_params[2],
        cropping_params[4]-padding_params[4], cropping_params[4]+roi_shape[2]-padding_params[4],
    )
    window_list = []
    offset_dict = {
        "rounded": list(product((-32,+32,0), repeat=3)),
        "center": [(0,0,0)],
    }
    for offset in offset_dict[offset_mode]:
        # get the position in original volume~(allow out-of-bound) for current offset
        oob_ori_roi = (
            center_oob_ori_roi[0]+offset[0], center_oob_ori_roi[1]+offset[0],
            center_oob_ori_roi[2]+offset[1], center_oob_ori_roi[3]+offset[1],
            center_oob_ori_roi[4]+offset[2], center_oob_ori_roi[5]+offset[2],
        )
        # get corresponing padding params based on `vol_bound`
        padding_params = [0 for i in range(6)]
        for idx, (ori_pos, bound) in enumerate(zip(oob_ori_roi, vol_bound)):
            pad_val = 0
            if(idx%2==0 and ori_pos<bound): # left bound
                pad_val = bound-ori_pos
            if(idx%2==1 and ori_pos>bound):
                pad_val = ori_pos-bound
            padding_params[idx] = pad_val
        # get corresponding crop params after padding
        cropping_params = (
            oob_ori_roi[0]+padding_params[0], vol_bound[1]-oob_ori_roi[1]+padding_params[1],
            oob_ori_roi[2]+padding_params[2], vol_bound[3]-oob_ori_roi[3]+padding_params[3],
            oob_ori_roi[4]+padding_params[4], vol_bound[5]-oob_ori_roi[5]+padding_params[5],
        )
        # pad and crop for the original subject
        pad_and_crop = tio.Compose([
            tio.Pad(padding_params, padding_mode=crop_transform.padding_mode),
            tio.Crop(cropping_params),
        ])
        subject_roi = pad_and_crop(subject)  
        img3D_roi, = subject_roi.image.data.clone().detach().unsqueeze(0)
        norm_transform = tio.ZNormalization(masking_method=lambda x: x > 0)
        img3D_roi = norm_transform(img3D_roi) # (N, C, W, H, D)
        img3D_roi = img3D_roi.unsqueeze(dim=0)
        

        # collect all position information, and set correct roi for sliding-windows in 
        # todo: get correct roi window of half because of the sliding 
        windows_clip = [0 for i in range(6)]
        for i in range(3):
            if(offset[i]<0):
                windows_clip[2*i] = 0
                windows_clip[2*i+1] = -(roi_shape[i]+offset[i])
            elif(offset[i]>0):
                windows_clip[2*i] = roi_shape[i]-offset[i]
                windows_clip[2*i+1] = 0
        pos3D_roi = dict(
            padding_params=padding_params, cropping_params=cropping_params, 
            ori_roi=(
                cropping_params[0]+windows_clip[0], cropping_params[0]+roi_shape[0]-padding_params[0]-padding_params[1]+windows_clip[1],
                cropping_params[2]+windows_clip[2], cropping_params[2]+roi_shape[1]-padding_params[2]-padding_params[3]+windows_clip[3],
                cropping_params[4]+windows_clip[4], cropping_params[4]+roi_shape[2]-padding_params[4]-padding_params[5]+windows_clip[5],
            ),
            pred_roi=(
                padding_params[0]+windows_clip[0], roi_shape[0]-padding_params[1]+windows_clip[1],
                padding_params[2]+windows_clip[2], roi_shape[1]-padding_params[3]+windows_clip[3],
                padding_params[4]+windows_clip[4], roi_shape[2]-padding_params[5]+windows_clip[5],
            ))

        window_list.append((img3D_roi, pos3D_roi))
    return cropping_params, padding_params, window_list

def finetune_model_predict3D(img3D, sam_model_tune, prompt, device='cuda', prev_masks=None):
    batch_points, batch_labels = prompt.value['coords'], prompt.value['labels']

    if prev_masks is None:
        prev_masks = torch.zeros_like(img3D)

    low_res_masks = F.interpolate(prev_masks.float(), size=(128//4,128//4,128//4))

    inputs = batch_points, batch_labels, low_res_masks, img3D
    with open('/home/t722s/Desktop/test/inputs.pkl', 'wb') as f:
        pickle.dump(inputs, f)

    with torch.no_grad():
        image_embedding = sam_model_tune.image_encoder(img3D.to(device)) # (1, 384, 16, 16, 16)
    
    

    with torch.no_grad():
        sparse_embeddings, dense_embeddings = sam_model_tune.prompt_encoder(
            points=[batch_points, batch_labels],
            boxes=None,
            masks=low_res_masks.to(device),
        )

        low_res_masks, _ = sam_model_tune.mask_decoder(
            image_embeddings=image_embedding.to(device), # (B, 384, 64, 64, 64)
            image_pe=sam_model_tune.prompt_encoder.get_dense_pe(), # (1, 384, 64, 64, 64)
            sparse_prompt_embeddings=sparse_embeddings, # (B, 2, 384)
            dense_prompt_embeddings=dense_embeddings, # (B, 384, 64, 64, 64)
            multimask_output=False,
            )
        prev_masks = F.interpolate(low_res_masks, size=img3D.shape[-3:], mode='trilinear', align_corners=False)

        medsam_seg_prob = torch.sigmoid(prev_masks)  # (B, 1, 64, 64, 64)
        # convert prob to mask
        medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
        medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
        
    return medsam_seg

def preprocess_prompt(pts_prompt, cropping_params, padding_params, use_only_first_point=False, device = 'cuda'):
    coords = pts_prompt.value['coords']
    labels = pts_prompt.value['labels']

    point_offset = np.array([padding_params[0]-cropping_params[0], padding_params[2]-cropping_params[2], padding_params[4]-cropping_params[4]])
    coords = coords[...,1:] + point_offset
    
    batch_points = torch.from_numpy(coords).unsqueeze(0).to(device)
    batch_labels = torch.tensor(labels).unsqueeze(0).to(device)
    if use_only_first_point: # use only the first point since the model wasn't trained to receive multiple points in one go 
        batch_points = batch_points[:, :1]
        batch_labels = batch_labels[:, :1]
    
    pts_prompt = Points(value = {'coords': batch_points, 'labels': batch_labels})
    return pts_prompt

In [119]:
class SAMMed3DWrapper(SegmenterWrapper):
    def __init__(self, model, device):
        self.model = model.to(device)
        self.device = device

    def __call__(self, img, prompt):
        # Get prompt embeddings
        ## Initialise empty prompts 
        coords, labs = None, None
        boxes = None

        ## Fill with relevant prompts
        if isinstance(prompt, Points):
            coords, labs = prompt.value['coords'], prompt.value['labels']

        low_res_spatial_shape = [dim//4 for dim in img.shape[-3:]] #batch and channel dimensions remain the same, spatial dimensions are quartered 
        low_res_mask = torch.zeros([1,1] + low_res_spatial_shape).to(self.device) # [1,1] is batch and channel dimensions
        inputs = coords, labs, low_res_mask, img
        with open('/home/t722s/Desktop/test/inputs_new.pkl', 'wb') as f:
            pickle.dump(inputs, f)

        with torch.no_grad():
            sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                points = [coords, labs],
                boxes = boxes,
                masks = low_res_mask.to(self.device),
            )

            image_embedding = self.model.image_encoder(img) # (1, 384, 16, 16, 16)        
        
            ## Decode
            mask_out, _ = self.model.mask_decoder(
                image_embeddings = image_embedding.to(self.device), # (B, 384, 64, 64, 64)
                image_pe = self.model.prompt_encoder.get_dense_pe(), # (1, 384, 64, 64, 64)
                sparse_prompt_embeddings = sparse_embeddings, # (B, 2, 384)
                dense_prompt_embeddings = dense_embeddings, # (B, 384, 64, 64, 64)
                multimask_output = False,
                )
        

        logits = F.interpolate(mask_out, size=img.shape[-3:], mode = 'trilinear', align_corners = False).detach().cpu().squeeze()
        
        return(logits)
    
class SAMMed3DInferer(Inferer):
    supported_prompts = supported_prompts = (Points,)
    required_shape = (128, 128, 128) # Hard code to match training

    def __init__(self, segmenter_wrapper: SAMMed3DWrapper, device = 'cuda', use_only_first_point = False):
        self.segmenter = segmenter_wrapper
        self.device = device
        self.use_only_first_point = use_only_first_point
        self.offset_mode = 'center'

    def preprocess_into_patches(self, img3D, prompt = None, cheat = False, gt = None):
        subject = tio.Subject(
            image = tio.ScalarImage(tensor=img3D)
        )
        
        if cheat:
            subject.add_image(tio.LabelMap(tensor = gt,
                                        affine = subject.image.affine,),
                            image_name = 'label')
            crop_transform = tio.CropOrPad(mask_name='label', 
                                target_shape=(128,128,128))
        else:
            coords = prompt.value['coords']
            crop_mask = torch.zeros_like(subject.image.data)
            crop_mask[*coords.T] = 1
            subject.add_image(tio.LabelMap(tensor = crop_mask,
                                            affine = subject.image.affine),
                                image_name="crop_mask")
            crop_transform = tio.CropOrPad(mask_name='crop_mask', 
                                    target_shape=(128,128,128))
            

        padding_params, cropping_params = crop_transform.compute_crop_or_pad(subject)
        # cropping_params: (x_start, x_max-(x_start+roi_size), y_start, ...)
        # padding_params: (x_left_pad, x_right_pad, y_left_pad, ...)
        if(cropping_params is None): cropping_params = (0,0,0,0,0,0)
        if(padding_params is None): padding_params = (0,0,0,0,0,0)
        roi_shape = crop_transform.target_shape
        vol_bound = (0, img3D.shape[1], 0, img3D.shape[2], 0, img3D.shape[3])
        center_oob_ori_roi=(
            cropping_params[0]-padding_params[0], cropping_params[0]+roi_shape[0]-padding_params[0],
            cropping_params[2]-padding_params[2], cropping_params[2]+roi_shape[1]-padding_params[2],
            cropping_params[4]-padding_params[4], cropping_params[4]+roi_shape[2]-padding_params[4],
        )
        window_list = []
        offset_dict = {
            "rounded": list(product((-32,+32,0), repeat=3)),
            "center": [(0,0,0)],
        }
        for offset in offset_dict[self.offset_mode]:
            # get the position in original volume~(allow out-of-bound) for current offset
            oob_ori_roi = (
                center_oob_ori_roi[0]+offset[0], center_oob_ori_roi[1]+offset[0],
                center_oob_ori_roi[2]+offset[1], center_oob_ori_roi[3]+offset[1],
                center_oob_ori_roi[4]+offset[2], center_oob_ori_roi[5]+offset[2],
            )
            # get corresponing padding params based on `vol_bound`
            padding_params = [0 for i in range(6)]
            for idx, (ori_pos, bound) in enumerate(zip(oob_ori_roi, vol_bound)):
                pad_val = 0
                if(idx%2==0 and ori_pos<bound): # left bound
                    pad_val = bound-ori_pos
                if(idx%2==1 and ori_pos>bound):
                    pad_val = ori_pos-bound
                padding_params[idx] = pad_val
            # get corresponding crop params after padding
            cropping_params = (
                oob_ori_roi[0]+padding_params[0], vol_bound[1]-oob_ori_roi[1]+padding_params[1],
                oob_ori_roi[2]+padding_params[2], vol_bound[3]-oob_ori_roi[3]+padding_params[3],
                oob_ori_roi[4]+padding_params[4], vol_bound[5]-oob_ori_roi[5]+padding_params[5],
            )
            # pad and crop for the original subject
            pad_and_crop = tio.Compose([
                tio.Pad(padding_params, padding_mode=crop_transform.padding_mode),
                tio.Crop(cropping_params),
            ])
            subject_roi = pad_and_crop(subject)  
            img3D_roi, = subject_roi.image.data.clone().detach().unsqueeze(0)
            norm_transform = tio.ZNormalization(masking_method=lambda x: x > 0)
            img3D_roi = norm_transform(img3D_roi) # (N, C, W, H, D)
            img3D_roi = img3D_roi.unsqueeze(dim=0)
            

            # collect all position information, and set correct roi for sliding-windows in 
            # todo: get correct roi window of half because of the sliding 
            windows_clip = [0 for i in range(6)]
            for i in range(3):
                if(offset[i]<0):
                    windows_clip[2*i] = 0
                    windows_clip[2*i+1] = -(roi_shape[i]+offset[i])
                elif(offset[i]>0):
                    windows_clip[2*i] = roi_shape[i]-offset[i]
                    windows_clip[2*i+1] = 0
            pos3D_roi = dict(
                padding_params=padding_params, cropping_params=cropping_params, 
                ori_roi=(
                    cropping_params[0]+windows_clip[0], cropping_params[0]+roi_shape[0]-padding_params[0]-padding_params[1]+windows_clip[1],
                    cropping_params[2]+windows_clip[2], cropping_params[2]+roi_shape[1]-padding_params[2]-padding_params[3]+windows_clip[3],
                    cropping_params[4]+windows_clip[4], cropping_params[4]+roi_shape[2]-padding_params[4]-padding_params[5]+windows_clip[5],
                ),
                pred_roi=(
                    padding_params[0]+windows_clip[0], roi_shape[0]-padding_params[1]+windows_clip[1],
                    padding_params[2]+windows_clip[2], roi_shape[1]-padding_params[3]+windows_clip[3],
                    padding_params[4]+windows_clip[4], roi_shape[2]-padding_params[5]+windows_clip[5],
                ))

            window_list.append((img3D_roi, pos3D_roi))
        return cropping_params, padding_params, window_list

    def preprocess_prompt(self, pts_prompt):
        coords = pts_prompt.value['coords']
        labels = pts_prompt.value['labels']

        point_offset = np.array([self.padding_params[0]-self.cropping_params[0], self.padding_params[2]-self.cropping_params[2], self.padding_params[4]-self.cropping_params[4]])
        coords = coords[...,1:] + point_offset
        
        batch_points = torch.from_numpy(coords).unsqueeze(0).to(self.device)
        batch_labels = torch.tensor(labels).unsqueeze(0).to(self.device)
        if self.use_only_first_point: # use only the first point since the model wasn't trained to receive multiple points in one go 
            batch_points = batch_points[:, :1]
            batch_labels = batch_labels[:, :1]
        
        pts_prompt = Points(value = {'coords': batch_points, 'labels': batch_labels})
        return pts_prompt
    
    def predict(self, img, prompt, cheat = False, gt = None):
        if not isinstance(prompt, SAMMed3DInferer.supported_prompts):
            raise ValueError(f'Unsupported prompt type: got {type(prompt)}')
    
        img, prompt = deepcopy(img), deepcopy(prompt)

        self.cropping_params, self.padding_params, patch_list = self.preprocess_into_patches(img, prompt, cheat, gt)

        prompt = self.preprocess_prompt(prompt)

        pred  = torch.zeros_like(img).numpy()
        for (image3D_patch, pos3D) in patch_list:
            image3D_patch = image3D_patch.to(self.device)
            logits = self.segmenter(image3D_patch, prompt)
            seg_mask = (logits>0.5).numpy().astype(np.uint8)
            ori_roi, pred_roi = pos3D["ori_roi"], pos3D["pred_roi"]
            
            seg_mask_roi = seg_mask[..., pred_roi[0]:pred_roi[1], pred_roi[2]:pred_roi[3], pred_roi[4]:pred_roi[5]]
            pred[..., ori_roi[0]:ori_roi[1], ori_roi[2]:ori_roi[3], ori_roi[4]:ori_roi[5]] = seg_mask_roi
        
        return(pred.astype(np.uint8))

In [120]:
checkpoint_path = '/home/t722s/Desktop/UniversalModels/TrainedModels/sam_med3d_turbo.pth'
device = 'cuda'
sam_model_tune = load_sammed3d(checkpoint_path, device = device)
wrapper = SAMMed3DWrapper(sam_model_tune, device)
inferer = SAMMed3DInferer(wrapper)

In [121]:
img_path = '/home/t722s/Desktop/Datasets/preprocessed/spleen/AbdomenAtlasJHU_2img/imagesTr/BDMAP_00000001.nii.gz'
gt_path = '/home/t722s/Desktop/Datasets/preprocessed/spleen/AbdomenAtlasJHU_2img/labelsTr/BDMAP_00000001.nii.gz'
img, gt = get_img_gt_sammed3d(img_path, gt_path)

seed = 2024
n = 5
prompt = get_pos_clicks3D(gt, n, seed = seed)

pred2 = inferer.predict(img, prompt, cheat = True, gt = gt)

compute_dice(pred2, gt.numpy())

0.9317160712240197

In [62]:
img_path = '/home/t722s/Desktop/Datasets/preprocessed/aorta/AbdomenAtlasJHU_2img/imagesTr/BDMAP_00000002.nii.gz'
gt_path = '/home/t722s/Desktop/Datasets/preprocessed/aorta/AbdomenAtlasJHU_2img/labelsTr/BDMAP_00000002.nii.gz'
img, gt = get_img_gt_sammed3d(img_path, gt_path)

seed = 2024
n = 5
prompt = get_pos_clicks3D(gt, n, seed = seed)

offset_mode = 'center'
cropping_params, padding_params, patch_list = preprocess_into_patches(img, pts_prompt = prompt, cheat = True, gt = gt, offset_mode=offset_mode)
point_offset = np.array([padding_params[0]-cropping_params[0], padding_params[2]-cropping_params[2], padding_params[4]-cropping_params[4]])

prompt = preprocess_prompt(prompt, cropping_params, padding_params)

pred  = torch.zeros_like(gt).numpy()
for (image3D_patch, pos3D) in patch_list:
    seg_mask = finetune_model_predict3D(
        image3D_patch, sam_model_tune, device=device, 
        prompt = prompt,
        prev_masks=None)
    
    
    ori_roi, pred_roi = pos3D["ori_roi"], pos3D["pred_roi"]
    
    seg_mask_roi = seg_mask[..., pred_roi[0]:pred_roi[1], pred_roi[2]:pred_roi[3], pred_roi[4]:pred_roi[5]]
    pred[..., ori_roi[0]:ori_roi[1], ori_roi[2]:ori_roi[3], ori_roi[4]:ori_roi[5]] = seg_mask_roi

compute_dice(pred, gt.numpy())

0.7685747088732163

In [99]:
prompt

points prompt({'coords': array([[  0, 116,  89,  45],
       [  0, 123,  77,  78],
       [  0, 121,  78,  74],
       [  0, 114,  72,  84],
       [  0, 115,  77,  63]]), 'labels': [1, 1, 1, 1, 1]})

In [102]:
img_path = '/home/t722s/Desktop/Datasets/preprocessed/aorta/AbdomenAtlasJHU_2img/imagesTr/BDMAP_00000002.nii.gz'
gt_path = '/home/t722s/Desktop/Datasets/preprocessed/aorta/AbdomenAtlasJHU_2img/labelsTr/BDMAP_00000002.nii.gz'
img, gt = get_img_gt_sammed3d(img_path, gt_path)

seed = 2024
n = 5
prompt = get_pos_clicks3D(gt, n, seed = seed)


offset_mode = 'center'
cropping_params, padding_params, patch_list = preprocess_into_patches(img, pts_prompt = prompt, cheat = True, gt = gt, offset_mode=offset_mode)
point_offset = np.array([padding_params[0]-cropping_params[0], padding_params[2]-cropping_params[2], padding_params[4]-cropping_params[4]])

prompt = preprocess_prompt(prompt, cropping_params, padding_params)

pred  = torch.zeros_like(gt).numpy()
for (image3D_patch, pos3D) in patch_list:
    image3D_patch = image3D_patch.to(device)

    logits = wrapper(image3D_patch, prompt)
    seg_mask = (logits>0.5).numpy().astype(np.uint8)
    ori_roi, pred_roi = pos3D["ori_roi"], pos3D["pred_roi"]
    
    seg_mask_roi = seg_mask[..., pred_roi[0]:pred_roi[1], pred_roi[2]:pred_roi[3], pred_roi[4]:pred_roi[5]]
    pred[..., ori_roi[0]:ori_roi[1], ori_roi[2]:ori_roi[3], ori_roi[4]:ori_roi[5]] = seg_mask_roi

compute_dice(pred, gt.numpy())

In [95]:

with open('/home/t722s/Desktop/test/inputs_inferer.pkl', 'rb') as f:
    inputs_inferer = pickle.load(f)
with open('/home/t722s/Desktop/test/inputs_no_inferer.pkl', 'rb') as f:
    inputs_no_inferer = pickle.load(f)

In [96]:
inputs_no_inferer

(tensor([[[ 62,  68,  64],
          [ 69,  56,  97],
          [ 67,  57,  93],
          [ 60,  51, 103],
          [ 61,  56,  82]]], device='cuda:0'),
 tensor([[1, 1, 1, 1, 1]], device='cuda:0'),
 tensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
 
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
 
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 

In [97]:
inputs_inferer

(tensor([[[61, 72, 44],
          [68, 60, 77],
          [66, 61, 73],
          [59, 55, 83],
          [60, 60, 62]]], device='cuda:0'),
 tensor([[1, 1, 1, 1, 1]], device='cuda:0'),
 tensor([[[[[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
 
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.]],
 
           [[0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
            ...,
            [0., 0., 0.,  ..., 0., 0., 0.],
            [0., 0., 0.,  ..., 0., 0., 0.],
          

In [None]:
args = Namespace(sliding_window = False,
                 num_clicks = 1,
                 point_method = 'default',
                 crop_size = 128,
                 seed = seed,
                 test_data_path = '/home/t722s/Desktop/Datasets/preprocessed/')
def run_inference(img_path, gt_path, sam_model_tune, n = 5, seed = 2024):
    img, gt = get_img_gt_sammed3d(img_path, gt_path)

    pts_prompt = get_pos_clicks3D(gt, n, seed = seed)

    offset_mode = 'center'
    cropping_params, padding_params, patch_list = preprocess_into_patches(img, pts_prompt = pts_prompt, cheat = True, gt = gt, offset_mode=offset_mode)

    pts_prompt = preprocess_prompt(pts_prompt, cropping_params, padding_params)

    pred  = torch.zeros_like(gt).numpy()
    for (image3D_patch, pos3D) in patch_list:
        image3D_patch = image3D_patch.to(device)
        logits = wrapper(image3D_patch, pts_prompt)
        seg = (logits > 0.5).numpy().astype(np.uint8)

        ori_roi, pred_roi = pos3D["ori_roi"], pos3D["pred_roi"]
        
        seg_mask_roi = seg[..., pred_roi[0]:pred_roi[1], pred_roi[2]:pred_roi[3], pred_roi[4]:pred_roi[5]]
        pred[..., ori_roi[0]:ori_roi[1], ori_roi[2]:ori_roi[3], ori_roi[4]:ori_roi[5]] = seg_mask_roi

    res = compute_dice(pred, gt.numpy())
    return(pred, res)
from glob import glob
import os.path as osp
import os
from tqdm.notebook import tqdm
join = osp.join
args = Namespace(sliding_window = False,
                    num_clicks = 1,
                    point_method = 'default',
                    crop_size = 128,
                    seed = seed,
                    test_data_path = '/home/t722s/Desktop/Datasets/preprocessed/')

all_dataset_paths = glob(join(args.test_data_path, "*", "*"))
all_dataset_paths = list(filter(osp.isdir, all_dataset_paths))
all_dataset_paths =[p for p in all_dataset_paths if not 'background' in p]

organs = [f.split('/')[-2] for f in all_dataset_paths]

res_dict = {organ: {} for organ in organs}
for ds in all_dataset_paths:
    imgs_dir = join(ds,'imagesTr')
    imgs = [join(imgs_dir, img) for img in sorted(os.listdir(imgs_dir))]

    for img_path in imgs:
        gt_path = img_path.replace('imagesTr', 'labelsTr')
        print(f'Inferring on {img_path}')
        _, dice = run_inference(img_path, gt_path, sam_model_tune, seed = 2024)

        organ = img_path.split('/')[-4]
        res_dict[organ][os.path.basename(img_path)] = dice

res_dict