In [3]:
import torch
import numpy as np
import SimpleITK as sitk
import pickle
import os
import cv2
import matplotlib.pyplot as plt
import SimpleITK as sitk
import json
import torchio as tio
import torch.nn.functional as F
import copy
from typing import TypeVar


In [4]:
img_path = '/home/t722s/Desktop/Datasets/BratsMini/imagesTs/BraTS2021_01646.nii.gz'
label_path = '/home/t722s/Desktop/Datasets/BratsMini/labelsTs/BraTS2021_01646.nii.gz'
prompts_path = '/home/t722s/Desktop/Sam-Med3DTest/BratsMini/prompts.pkl'
metadata_path = '/home/t722s/Desktop/Datasets/BratsMini/dataset.json'

with open(prompts_path, 'rb') as f:
    prompts_dict = pickle.load(f)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

img_name = os.path.basename(img_path)
    

In [5]:
class Prompt(): # Abstract class for prompts to be inputted with the SAM adjusted model
    def __init__(self, name, value):
        self.name = name
        self.value = value

    def __repr__(self):
        return f'{self.name} prompt({self.value}) '

    def _get_crop_pad_center(self): 
        """Calculate and return an approximate center of the region of interest (ROI)."""
        raise NotImplementedError('This class does not have an implemented method to get the crop/pad center')

    def _transform_xyz_to_zyx(self): 
        """Transform coordinates from xyz to zyx."""
        raise NotImplementedError('This class does not have an implemented method to change xyz to zyx')

    def _transform_with_crop_pad(self, cropping_params, padding_params):
        """Adjust the prompt according to cropping and padding parameters."""
        raise NotImplementedError('This class does not have an implemented method to appropriately change the prompt due to a cropping/padding')

class Points(Prompt):
    def __init__(self, value):
        super().__init__(name = 'points', value = value)

    def _transform_xyz_to_zyx(self): # In place method
        self.value['points'] = self.value['points'][:,::-1]

    def _get_crop_pad_center(self):
        bbox_min = self.value['points'].T.min(axis = 1) # Get an array of two points: the minimal and maximal vertices of the minimal cube parallel to the axes bounding the points
        bbox_max = self.value['points'].T.max(axis = 1) + 1 # Add 1 since we'll be using this for indexing # TESTING: Remove 'self's here
        point_center = np.mean((bbox_min, bbox_max), axis = 0)  

        return(point_center)

    def _transform_with_crop_pad(self, cropping_params, padding_params):
        axis_add, axis_sub = padding_params[::2], cropping_params[::2] 
        self.value['points'] = self.value['points'] + axis_add - axis_sub # same as value[:,i] = value[:,i] + axis_add[i] - axis_sub[i] iterating over i



In [11]:
def binarize(gt, organ, metadata): # TODO: Raise errors for metadata not having labels, labels in the wrong k,v order, and organ not being present
    organ_label = int(metadata['labels'][organ])
    gt_binary = torch.where(gt == organ_label, 1, torch.zeros_like(gt))
    return(gt_binary)

class SegmenterData():
    def __init__(self, img_path, prompt, cropped_shape = None, crop_pad_center = None, xyz = True, standardize = True,
                 prev_mask = None, img_embedding = None, ground_truth = None): # xyz: if prompts and images are in xyz: need to be converted to zyx
        # Store supplied variables
        self.img_path = img_path
        self.img_name = os.path.basename(self.img_path)
        self.img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
        self.prompt = prompt
        
        self.crop_pad_center = crop_pad_center
        self.cropped_shape = (cropped_shape, cropped_shape, cropped_shape) if isinstance(cropped_shape, int) else cropped_shape
        
        self.prev_mask = prev_mask
        self.img_embedding = img_embedding

        # Store originals without any procerssing
        self.prompt_original, self.img_original = copy.deepcopy(prompt), copy.deepcopy(self.img)

        # Points generated from generate_points.py are in xyz since the image is read in through sitk.GetArrayFromImage on sitk.ReadImage output, and here self.img is also in xyz. Transform both to zyx
        if xyz:
            self.img = self.img.transpose(2,1,0)
            prompt._transform_xyz_to_zyx()

        if crop_pad_center == 'from_prompt': 
            self.crop_pad_center = prompt._get_crop_pad_center()

        if self.crop_pad_center is not None:
            self.crop_params, self.pad_params = self._get_crop_pad_params(self.crop_pad_center, self.cropped_shape)
            # Crop/pad image
            self.img = self._crop_im(self.img, self.crop_params)
            self.img = self._pad_im(self.img, self.pad_params)

            # Crop/pad prompts
            self.prompt._transform_with_crop_pad(self.crop_params, self.pad_params)

        if standardize == True:
            mask = self.img > 0
            mean, std = self.img[mask].mean(), self.img[mask].std()
            # standardize_func = tio.ZNormalization(masking_method=lambda x: x > 0)
            # self.img2 = np.array(standardize_func(torch.from_numpy(self.img).unsqueeze(0))).squeeze(0) # Gives a different result; investigate
            self.img = (self.img-mean)/std 

    def _get_crop_pad_params(self, crop_pad_center, target_shape): # Modified from TorchIO cropOrPad
        subject_shape = self.img.shape
        padding = []
        cropping = []

        for dim in range(3):
            target_dim = target_shape[dim]
            center_dim = crop_pad_center[dim]
            subject_dim = subject_shape[dim]

            center_on_index = not (center_dim % 1)
            target_even = not (target_dim % 2)

            # Approximation when the center cannot be computed exactly
            # The output will be off by half a voxel, but this is just an
            # implementation detail
            if target_even ^ center_on_index:
                center_dim -= 0.5

            begin = center_dim - target_dim / 2
            if begin >= 0:
                crop_ini = begin
                pad_ini = 0
            else:
                crop_ini = 0
                pad_ini = -begin

            end = center_dim + target_dim / 2
            if end <= subject_dim:
                crop_fin = subject_dim - end
                pad_fin = 0
            else:
                crop_fin = 0
                pad_fin = end - subject_dim

            padding.extend([pad_ini, pad_fin])
            cropping.extend([crop_ini, crop_fin])

        padding_params = np.asarray(padding, dtype=int)
        cropping_params = np.asarray(cropping, dtype=int)

        return cropping_params, padding_params  # type: ignore[return-value]

    def _crop_im(self, img, cropping_params): # Modified from TorchIO cropOrPad
        low = cropping_params[::2]
        high = cropping_params[1::2]
        index_ini = low
        index_fin = np.array(img.shape) - high 
        i0, j0, k0 = index_ini
        i1, j1, k1 = index_fin
        image_cropped = img[i0:i1, j0:j1, k0:k1]

        return(image_cropped)

    def _pad_im(self, img, padding_params): # Modified from TorchIO cropOrPad
        paddings = padding_params[:2], padding_params[2:4], padding_params[4:]
        image_padded = np.pad(img, paddings, mode = 'constant', constant_values = 0)  

        return(image_padded)
    
    def invert_crop_or_pad(self, mask, cropping_params, padding_params):
        if padding_params is not None:
            mask = self._crop_im(mask, padding_params)
        if cropping_params is not None:
            mask = self._pad_im(mask, cropping_params)
        return(mask)

    def clear_storage(self):
        self.prev_mask = None
        self.image_embedding = None

    def generate_next_prompt(self):
        '''Generate next prompt as a function of prev mask and ground truth'''

In [12]:
class Backbone():
    def __init__(self, model):
        self.model = model

    def preprocess_img(self, img):
        '''Any further necessary preprocessing steps'''
        pass

    def postprocess_mask(self, mask):
        '''Any further necessary postprocessing steps'''
        pass
 
    def predict(self, model, inputs, device = 'cuda', keep_img_embedding = True):
        '''Obtain logits '''
        pass

SAM3D = TypeVar('SAM3D')

class SAMMed3DSegmenter(Backbone):
    def __init__(self, model: SAM3D, required_shape, device = 'cuda'):
        self.model = model
        self.prev_mask = None
        self.required_shape = required_shape
        self.inputs = None
        self.device = 'cuda'

    def preprocess_img(self, img):
        '''Any further necessary preprocessing steps'''
        img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0) # add channel and batch dimensions
        return(img)
    
    def preprocess_points(self, prompt): # We only make use of point prompts
        pts = torch.from_numpy(prompt.value['points']).unsqueeze(0).to(self.device) # Must have shape B, N, 3
        labs = torch.tensor(prompt.value['labels']).unsqueeze(0).to(self.device)
        return(pts, labs)
    
    def postprocess_logits(self, mask, cropping_params, padding_params, xyz):
        mask = torch.sigmoid(mask)
        mask = (mask>0.5).numpy().astype(np.uint8)
        mask = self.inputs.invert_crop_or_pad(mask, cropping_params, padding_params)
        if xyz:
            mask = mask.transpose(2,1,0)
        return(mask)
 
    def predict(self, inputs, store_inference_data, device = 'cuda',):
        # Last preprocessing steps for prompts and image
        self.inputs = inputs
        img = self.preprocess_img(inputs.img)
        pts, labs = self.preprocess_points(inputs.prompt)

        # Obtain necessaries for inference and then get output mask
        ## Handle mask and embeddings
        low_res_spatial_shape = [dim//4 for dim in img.shape[-3:]] #batch and channel dimensions remain the same, spatial dimensions are quartered 
        
        if inputs.prev_mask is None:
            low_res_mask = torch.zeros([1,1] + low_res_spatial_shape).to(self.device) # Include batch and channel dimensions
        else:
            print('Using previous mask as an input!')
            low_res_mask = F.interpolate(inputs.prev_mask, size = low_res_spatial_shape, mode='trilinear', align_corners=False)

        with torch.no_grad():
            image_embedding = self.model.image_encoder(img.to(self.device)) # (1, 384, 16, 16, 16)

        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                points = [pts, labs],
                boxes = None,
                masks = low_res_mask.to(self.device),
            )
        
        ## Decode
        mask_out, _ = self.model.mask_decoder(
            image_embeddings = image_embedding.to(self.device), # (B, 384, 64, 64, 64)
            image_pe = self.model.prompt_encoder.get_dense_pe(), # (1, 384, 64, 64, 64)
            sparse_prompt_embeddings = sparse_embeddings, # (B, 2, 384)
            dense_prompt_embeddings = dense_embeddings, # (B, 384, 64, 64, 64)
            multimask_output = False,
            )
        
        # If repeating inference on same image, store the mask and the embedding
        
        if store_inference_data == True:
            inputs.prev_mask = mask_out
            inputs.image_embedding = image_embedding
        
        logits = F.interpolate(mask_out, size=img.shape[-3:], mode = 'trilinear', align_corners = False).detach().cpu().squeeze()
        segmentation = self.postprocess_logits(logits, inputs.crop_params, inputs.pad_params, xyz = True)

        return(segmentation)
    
    def clear_storage(self):
        self.prev_mask = None
        self.image_embedding = None


In [13]:
from segment_anything.build_sam3D import build_sam3D_vit_b_ori

device = 'cuda'
checkpoint_path = '/home/t722s/Desktop/UniversalModels/TrainedModels/sam_med3d.pth'
model_type = 'vit_b_ori'
# Load in model

sam_model_tune = build_sam3D_vit_b_ori(checkpoint=None).to(device)
if checkpoint_path is not None:
    model_dict = torch.load(checkpoint_path, map_location=device)
    state_dict = model_dict['model_state_dict']
    sam_model_tune.load_state_dict(state_dict)


In [14]:
def compute_dice(mask_gt, mask_pred):
    """Compute soerensen-dice coefficient.
    Returns:
    the dice coeffcient as float. If both masks are empty, the result is NaN
    """
    volume_sum = mask_gt.sum() + mask_pred.sum()
    if volume_sum == 0:
        return np.NaN
    volume_intersect = (mask_gt & mask_pred).sum()
    return 2*volume_intersect / volume_sum

gt = sitk.GetArrayFromImage(sitk.ReadImage(label_path))



In [15]:
pts = prompts_dict[img_name][1]['3D']
pts_prompt = Points({'points': pts[0], 'labels': pts[1]})
inputs = SegmenterData(img_path, pts_prompt, 128, 'from_prompt')
segmenter = SAMMed3DSegmenter(sam_model_tune, (128,128,128))
segmentation = segmenter.predict(inputs, store_inference_data=True)
compute_dice(segmentation, gt)

[ 1 31  0 34  2  0] [ 0  0  2  0  0 26] (160, 160, 104)


0.7024892502998216

In [10]:
segmentation2 = segmenter.predict(inputs, store_inference_data=True)
compute_dice(segmentation2, gt)

Using previous mask as an input!


0.6565609307258039