In [1]:
import torch
import numpy as np
import SimpleITK as sitk
import pickle
import os
import cv2
import matplotlib.pyplot as plt
import SimpleITK as sitk
import json
import torchio as tio
import torch.nn.functional as F
import copy
from typing import TypeVar


In [2]:
def binarize(gt, organ, metadata): # TODO: Raise errors for metadata not having labels, labels in the wrong k,v order, and organ not being present
    organ_label = int(metadata['labels'][organ])
    gt_binary = torch.where(gt == organ_label, 1, torch.zeros_like(gt))
    return(gt_binary)

def get_crop_pad_center_from_points(points):
    bbox_min = points.value['coords'].T.min(axis = 1) # Get an array of two points: the minimal and maximal vertices of the minimal cube parallel to the axes bounding the points
    bbox_max = points.value['coords'].T.max(axis = 1) + 1 # Add 1 since we'll be using this for indexing # TESTING: Remove 'self's here
    point_center = np.mean((bbox_min, bbox_max), axis = 0)  

    return(point_center)

def get_crop_pad_params(img, crop_pad_center, target_shape): # Modified from TorchIO cropOrPad
    subject_shape = img.shape
    padding = []
    cropping = []

    for dim in range(3):
        target_dim = target_shape[dim]
        center_dim = crop_pad_center[dim]
        subject_dim = subject_shape[dim]

        center_on_index = not (center_dim % 1)
        target_even = not (target_dim % 2)

        # Approximation when the center cannot be computed exactly
        # The output will be off by half a voxel, but this is just an
        # implementation detail
        if target_even ^ center_on_index:
            center_dim -= 0.5

        begin = center_dim - target_dim / 2
        if begin >= 0:
            crop_ini = begin
            pad_ini = 0
        else:
            crop_ini = 0
            pad_ini = -begin

        end = center_dim + target_dim / 2
        if end <= subject_dim:
            crop_fin = subject_dim - end
            pad_fin = 0
        else:
            crop_fin = 0
            pad_fin = end - subject_dim

        padding.extend([pad_ini, pad_fin])
        cropping.extend([crop_ini, crop_fin])

    padding_params = np.asarray(padding, dtype=int)
    cropping_params = np.asarray(cropping, dtype=int)

    return cropping_params, padding_params  # type: ignore[return-value]

def crop_im(img, cropping_params): # Modified from TorchIO cropOrPad
        low = cropping_params[::2]
        high = cropping_params[1::2]
        index_ini = low
        index_fin = np.array(img.shape) - high 
        i0, j0, k0 = index_ini
        i1, j1, k1 = index_fin
        image_cropped = img[i0:i1, j0:j1, k0:k1]

        return(image_cropped)

def pad_im(img, padding_params): # Modified from TorchIO cropOrPad
    paddings = padding_params[:2], padding_params[2:4], padding_params[4:]
    image_padded = np.pad(img, paddings, mode = 'constant', constant_values = 0)  

    return(image_padded)

def crop_pad_coords(coords, cropping_params, padding_params):
    axis_add, axis_sub = padding_params[::2], cropping_params[::2] 
    coords = coords + axis_add - axis_sub # same as value[:,i] = value[:,i] + axis_add[i] - axis_sub[i] iterating over i
    return(coords)

def invert_crop_or_pad(mask, cropping_params, padding_params):
    if padding_params is not None:
        mask = crop_im(mask, padding_params)
    if cropping_params is not None:
        mask = pad_im(mask, cropping_params)
    return(mask)

def compute_dice(mask_gt, mask_pred):
    """Compute soerensen-dice coefficient.
    Returns:
    the dice coeffcient as float. If both masks are empty, the result is NaN
    """
    volume_sum = mask_gt.sum() + mask_pred.sum()
    if volume_sum == 0:
        return np.NaN
    volume_intersect = (mask_gt & mask_pred).sum()
    return 2*volume_intersect / volume_sum

In [3]:
class Prompt(): # Abstract class for prompts to be inputted with the SAM adjusted model
    def __init__(self, name, value):
        self.name = name
        self.value = value

    def __repr__(self):
        return f'{self.name} prompt({self.value})'

class Points(Prompt):
    def __init__(self, value):
        super().__init__(name = 'points', value = value)
        
class Backbone():
    def __init__(self, model):
        self.model = model

    def preprocess_img(self, img):
        '''Any necessary preprocessing steps'''
        pass

    def postprocess_mask(self, mask):
        '''Any necessary postprocessing steps'''
        pass
 
    def predict(self, model, inputs, device = 'cuda', keep_img_embedding = True):
        '''Obtain logits '''
        pass

SAM3D = TypeVar('SAM3D')

In [4]:
from segment_anything.build_sam3D import build_sam3D_vit_b_ori

device = 'cuda'
checkpoint_path = '/home/t722s/Desktop/UniversalModels/TrainedModels/sam_med3d.pth'
model_type = 'vit_b_ori'
# Load in model

sam_model_tune = build_sam3D_vit_b_ori(checkpoint=None).to(device)
if checkpoint_path is not None:
    model_dict = torch.load(checkpoint_path, map_location=device)
    state_dict = model_dict['model_state_dict']
    sam_model_tune.load_state_dict(state_dict)


In [56]:
class SAMMed3DSegmenter(Backbone):
    def __init__(self, model: SAM3D, required_shape, device = 'cuda'):
        self.model = model
        self.prev_mask = None
        self.required_shape = required_shape
        self.inputs = None
        self.device = device

    def preprocess_img(self, img, crop_params, pad_params):
        '''Any necessary preprocessing steps'''

        img = crop_im(img, crop_params) 
        img = pad_im(img, pad_params)

        mask = img > 0
        mean, std = img[mask].mean(), img[mask].std()
        # standardize_func = tio.ZNormalization(masking_method=lambda x: x > 0)
        # img2 = np.array(standardize_func(torch.from_numpy(img).unsqueeze(0))).squeeze(0) # Gives a different result; investigate
        img = (img-mean)/std 

        img = torch.from_numpy(img).unsqueeze(0).unsqueeze(0).to(device) # add channel and batch dimensions
        return(img)
    
    def preprocess_points(self, points, crop_params, pad_params): # We only make use of point prompts
        coords, labs = points.value['coords'], points.value['labels']

        coords = crop_pad_coords(coords, crop_params, pad_params)

        coords = torch.from_numpy(coords).unsqueeze(0).to(self.device) # Must have shape B, N, 3
        labs = torch.tensor(labs).unsqueeze(0).to(self.device)
        return(coords, labs)
    
    def postprocess_logits(self, mask, cropping_params, padding_params):
        mask = torch.sigmoid(mask)
        mask = (mask>0.5).numpy().astype(np.uint8)
        mask = invert_crop_or_pad(mask, cropping_params, padding_params)

        mask = mask.transpose(2,1,0)
        return(mask)
 
    def predict(self, img, prompt, store_inference_data, use_prev_mask, use_prev_embedding, device = 'cuda',):
        # Last preprocessing steps for prompts and image
        if not isinstance(prompt, Points):
            raise RuntimeError('Currently only points are supported')

        self.crop_pad_center = get_crop_pad_center_from_points(prompt)
        self.crop_params, self.pad_params = get_crop_pad_params(img, self.crop_pad_center, self.required_shape)
        img = self.preprocess_img(img, self.crop_params, self.pad_params)
        pts, labs = self.preprocess_points(prompt, self.crop_params, self.pad_params)

        # Obtain necessaries for inference and then get output mask
        ## Handle mask and embeddings
        low_res_spatial_shape = [dim//4 for dim in img.shape[-3:]] #batch and channel dimensions remain the same, spatial dimensions are quartered 
        
        if use_prev_mask:
            if self.prev_mask is not None:
                print('Using previous mask as an input!')
                low_res_mask = F.interpolate(self.prev_mask, size = low_res_spatial_shape, mode='trilinear', align_corners=False)
            else:
                raise RuntimeError('Tried to use previous mask, but no mask is stored.')
        else:
            low_res_mask = torch.zeros([1,1] + low_res_spatial_shape).to(self.device) # [1,1] is batch and channel dimensions

        if use_prev_embedding:
            if self.image_embedding is not None:
                image_embedding = self.image_embedding
            else:
                raise RuntimeError('Tried to use previous embedding, but no embedding is stored.')
        else:
            with torch.no_grad():
                image_embedding = self.model.image_encoder(img) # (1, 384, 16, 16, 16)

        sparse_embeddings, dense_embeddings = self.model.prompt_encoder(
                points = [pts, labs],
                boxes = None,
                masks = low_res_mask.to(self.device),
            )
        
        ## Decode
        mask_out, _ = self.model.mask_decoder(
            image_embeddings = image_embedding.to(self.device), # (B, 384, 64, 64, 64)
            image_pe = self.model.prompt_encoder.get_dense_pe(), # (1, 384, 64, 64, 64)
            sparse_prompt_embeddings = sparse_embeddings, # (B, 2, 384)
            dense_prompt_embeddings = dense_embeddings, # (B, 384, 64, 64, 64)
            multimask_output = False,
            )
        
        # If repeating inference on same image, store the mask and the embedding
        
        if store_inference_data == True:
            self.prev_mask = mask_out
            self.image_embedding = image_embedding
        
        logits = F.interpolate(mask_out, size=img.shape[-3:], mode = 'trilinear', align_corners = False).detach().cpu().squeeze()
        segmentation = self.postprocess_logits(logits, self.crop_params, self.pad_params)

        return(segmentation)
    
    def clear_storage(self):
        self.prev_mask = None
        self.image_embedding = None


In [57]:
img_path = '/home/t722s/Desktop/Datasets/BratsMini/imagesTs/BraTS2021_01646.nii.gz'
label_path = '/home/t722s/Desktop/Datasets/BratsMini/labelsTs/BraTS2021_01646.nii.gz'
prompts_path = '/home/t722s/Desktop/Sam-Med3DTest/BratsMini/prompts.pkl'
metadata_path = '/home/t722s/Desktop/Datasets/BratsMini/dataset.json'

with open(prompts_path, 'rb') as f:
    prompts_dict = pickle.load(f)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

img_name = os.path.basename(img_path)
img = sitk.GetArrayFromImage(sitk.ReadImage(img_path)).transpose(2,1,0)
gt = sitk.GetArrayFromImage(sitk.ReadImage(label_path))


pts = prompts_dict[img_name][1]['3D']
pts_prompt = Points({'coords': pts[0][:,::-1], 'labels': pts[1]})
segmenter = SAMMed3DSegmenter(sam_model_tune, (128,128,128))
# segmentation = segmenter.predict(img, pts_prompt, store_inference_data=False, use_prev_mask=False, use_prev_embedding=False)
# compute_dice(segmentation, gt)

In [68]:
prompts = []
for i in range(5):
    prompts.append(Points({'coords': pts[0][i:i+1,::-1], 'labels': pts[1][i:i+1]}))

segmenter.clear_storage()
segmentations = []
segmentation = segmenter.predict(img, prompts[0], store_inference_data=False, use_prev_mask=False, use_prev_embedding=False)
print(compute_dice(segmentation, gt))
segmentations.append(segmentation)


storing
0.7855765042623359


In [None]:

for i in range(0,5):
    segmentation = segmenter.predict(img, prompts[:i], store_inference_data=False, use_prev_mask=False, use_prev_embedding=False)
    print(compute_dice(segmentation, gt))
    segmentations.append(segmentation)

In [59]:
with open('/home/t722s/Desktop/test/embed1.pkl', 'rb') as f:
    t1 = pickle.load(f)

with open('/home/t722s/Desktop/test/embed2.pkl', 'rb') as f:
    t2 = pickle.load(f)

In [67]:
segmenter.pad_params

array([ 4,  0,  7,  0,  0, 24])

In [None]:
crop_pad_center
pad_params, crop_params = get_crop_pad_params(img, crop_pad_center, (128,128,128))

img2 = crop_im(img, crop_params) 
img2 = pad_im(img2, pad_params)

mask = img2 > 0
mean, std = img2[mask].mean(), img2[mask].std()
# standardize_func = tio.ZNormalization(masking_method=lambda x: x > 0)
# img2 = np.array(standardize_func(torch.from_numpy(img).unsqueeze(0))).squeeze(0) # Gives a different result; investigate
img2 = (img2-mean)/std 

img2 = torch.from_numpy(img2).unsqueeze(0).unsqueeze(0).to(device) # add channel and batch dimensions