In [130]:
import torch
import numpy as np
import SimpleITK as sitk
import pickle
import os
import cv2
import matplotlib.pyplot as plt
import SimpleITK as sitk
import json
import torchio as tio
import torch.nn.functional as F

In [133]:
img_path = '/home/t722s/Desktop/Datasets/BratsMini/imagesTs/BraTS2021_01646.nii.gz'
label_path = '/home/t722s/Desktop/Datasets/BratsMini/labelsTs/BraTS2021_01646.nii.gz'
prompts_path = '/home/t722s/Desktop/Sam-Med3DTest/BratsMini/prompts.pkl'
metadata_path = '/home/t722s/Desktop/Datasets/BratsMini/dataset.json'

with open(prompts_path, 'rb') as f:
    prompts_dict = pickle.load(f)

with open(metadata_path, 'r') as f:
    metadata = json.load(f)

img_name = os.path.basename(img_path)
    

In [134]:
img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
gt = sitk.GetArrayFromImage(sitk.ReadImage(label_path))

In [135]:
## comparison
def binarize(gt, organ, metadata): # TODO: Raise errors for metadata not having labels, labels in the wrong k,v order, and organ not being present
    organ_label = int(metadata['labels'][organ])
    gt_binary = torch.where(gt == organ_label, 1, torch.zeros_like(gt))
    return(gt_binary)


class SegmenterInputs():
    def __init__(self, img_path, dim, prompt, cropped_shape = None, crop_pad_center = None, xyz = True, standardize = True,): # xyz: if prompts and images are in xyz: need to be converted to zyx
        self.img_path = img_path
        self.img_name = os.path.basename(self.img_path)
        self.prompt = prompt
        self.crop_pad_center = crop_pad_center
        self.img = sitk.GetArrayFromImage(sitk.ReadImage(img_path))
        self.prompt_types = prompt.keys()

        # Points generated from generate_points.py are in xyz since the image is read in through sitk.GetArrayFromImage on sitk.ReadImage output, and here self.img is also in xyz. Transform both to zyx
        if xyz:
            self.img = self.img.transpose(2,1,0)
            if 'points' in self.prompt_types:
                self.prompt['points'] = (self.prompt['points'][0][:,::-1], self.prompt['points'][1])

        if isinstance(cropped_shape, int):
            cropped_shape = (cropped_shape, cropped_shape, cropped_shape)

        if crop_pad_center == 'from_points': 
            if 'points' not in self.prompt_types:
                return RuntimeError('Inferring the crop/pad center from points was requested, but no points were supplied')
            
            bbox_min = self.prompt['points'][0].T.min(axis = 1) # Get an array of two points: the minimal and maximal vertices of the minimal cube parallel to the axes bounding the points
            bbox_max = self.prompt['points'][0].T.max(axis = 1) + 1 # Add 1 since we'll be using this for indexing # TESTING: Remove 'self's here

            point_center = np.mean((bbox_min, bbox_max), axis = 0)

            self.crop_pad_center = point_center

        if self.crop_pad_center is not None:
            self.crop_pad_params = self._get_crop_pad_params(self.img, self.crop_pad_center, cropped_shape)
            # Crop/pad image
            self.img = self._crop(self.img, self.crop_pad_params[0])
            self.img = self._pad(self.img, self.crop_pad_params[1])

            # Crop/pad prompts
            if 'points' in self.prompt_types:
                self.points_transformed = self._transform_points(self.prompt['points'][0], *self.crop_pad_params)
                #self.points_transformed = self._transform_points(self.prompt['points'][0], self.crop_pad_params[1],self.crop_pad_params[0])

        if standardize == True:
            mask = self.img > 0
            mean, std = self.img[mask].mean(), self.img[mask].std()
            # standardize_func = tio.ZNormalization(masking_method=lambda x: x > 0)
            # self.img2 = np.array(standardize_func(torch.from_numpy(self.img).unsqueeze(0))).squeeze(0) # Gives a different result??
            self.img = (self.img-mean)/std 

    def _crop(self, image, cropping_params): # Modified from TorchIO cropOrPad
        low = cropping_params[::2]
        high = cropping_params[1::2]
        index_ini = low
        index_fin = np.array(image.shape) - high 
        i0, j0, k0 = index_ini
        i1, j1, k1 = index_fin
        image_cropped = image[i0:i1, j0:j1, k0:k1]

        return(image_cropped)

    def _pad(self, image, padding_params): # Modified from TorchIO cropOrPad
        paddings = padding_params[:2], padding_params[2:4], padding_params[4:]
        image_padded = np.pad(image, paddings, mode = 'constant', constant_values = 0)  

        return(image_padded)

    def _get_crop_pad_params(self, img, crop_pad_center, target_shape): # Modified from TorchIO cropOrPad
        subject_shape = img.shape
        padding = []
        cropping = []

        for dim in range(3):
            target_dim = target_shape[dim]
            center_dim = crop_pad_center[dim]
            subject_dim = subject_shape[dim]

            center_on_index = not (center_dim % 1)
            target_even = not (target_dim % 2)

            # Approximation when the center cannot be computed exactly
            # The output will be off by half a voxel, but this is just an
            # implementation detail
            if target_even ^ center_on_index:
                center_dim -= 0.5

            begin = center_dim - target_dim / 2
            if begin >= 0:
                crop_ini = begin
                pad_ini = 0
            else:
                crop_ini = 0
                pad_ini = -begin

            end = center_dim + target_dim / 2
            if end <= subject_dim:
                crop_fin = subject_dim - end
                pad_fin = 0
            else:
                crop_fin = 0
                pad_fin = end - subject_dim

            padding.extend([pad_ini, pad_fin])
            cropping.extend([crop_ini, crop_fin])

        padding_params = np.asarray(padding, dtype=int)
        cropping_params = np.asarray(cropping, dtype=int)

        return cropping_params, padding_params  # type: ignore[return-value]

    def _transform_points(self, pts, cropping_params, padding_params):

        # if type(pts) is torch.Tensor:
        #     pts = pts.numpy()

        axis_add, axis_sub = padding_params[::2], cropping_params[::2] 
        pts = pts + axis_add - axis_sub # same as pts_trans[:,i] = pts_trans[:,i] + axis_add[i] - axis_sub[i] iterating over i

        return(pts)
    

class Points():
    def __init__(self, points_list, points_labels):
        self.pts = points_list
        self.labs = points_labels

    def __str__(self):
        return '\n'.join([f'{pt}: {"fg" if lab else "bg"}' for pt, lab in zip(self.pts, self.labs)])
    
class Box():
    def __init__(self, min_coord, max_coord):
        self.min_coord = min_coord
        self.max_coord = max_coord

    def __str__(self):
        return '\n'.join(f'min coord: {self.min_coord}\n max coord: {self.max_coord}')


In [136]:
pts = prompts_dict[img_name][1]['3D']
t = SegmenterInputs(img_path, 3, {'points': pts}, 128, 'from_points')

