In [1]:
# running as docker?
docker_running = False

# define repo path and add it to the path
from pathlib import Path
import sys, os
if not docker_running: # if we are running locally
    repo_path= Path.cwd().resolve()
    while '.gitignore' not in os.listdir(repo_path): # while not in the root of the repo
        repo_path = repo_path.parent #go up one level
else: # if running in the container
    repo_path = Path('opt/usuari')
sys.path.insert(0,str(repo_path)) if str(repo_path) not in sys.path else None

import numpy as np
import SimpleITK as sitk
import pandas as pd
from torchvision.transforms import (
    Compose,
    Resize,
    InterpolationMode,
)
from PIL import Image
from tqdm import tqdm
import shutil
import SimpleITK as sitk
import torch

# special imports
from segmentation import USSegmentation
from datasets_utils.datasets import ABUS_test
from torch.utils.data import DataLoader
import torchvision

The first thing to understand is that each image will need to store 64 MB of slices.<br>
Thus images will be stored along other cache information in the cached_data folder.<br>

In [2]:
class lesion_seg:
    def __init__(self):
        self.input_dir = Path('./input/') if docker_running else repo_path / 'input'
        self.output_dir = Path('./predict') / 'Segmentation' if docker_running else Path(repo_path / 'predict' / 'Segmentation')
        self.output_dir.mkdir(parents=True, exist_ok=True) # make sure the output dir exists
        self.checkpoint_dir = repo_path / 'checkpoints' / 'sam_vit_b_01ec64.pth'
        self.cached_dir = repo_path / 'cached_data'
        self.cached_dir.mkdir(parents=True, exist_ok=True) # create cached dir in root
        self.slices_dir = self.cached_dir / 'slices'
        self.probs_dir = self.cached_dir / 'probs'
        # load all folds models
        self.md = USSegmentation(self.checkpoint_dir)
        load_success = self.md.load_model()
        if load_success:
            print("Successfully loaded models")

    def save_slices(self, image_path:Path):
        """given an nrrd image path, the slices are saved in the cached_dir/slices folder

        Args:
            image_path (Path): Path to the nrrd image
        """
        # Expansion HP
        x_expansion = 865
        y_expansion = 865
        x_resizing = 512
        y_resizing = 512
        file_format = 'mha'


        # remove folder if exists, always starts from scratch
        if self.slices_dir.exists():
            shutil.rmtree(self.slices_dir)
        self.slices_dir.mkdir(exist_ok=True, parents=True)

        # transforms
        preprocess_im = Compose(
                [
                    Resize((x_resizing, y_resizing), interpolation= InterpolationMode.BILINEAR),
                ]
        )

        # get image
        im_sitk = sitk.ReadImage(image_path)
        shape = im_sitk.GetSize()
        im = sitk.GetArrayFromImage(im_sitk)
        # now, we complete the images and labels to the expansion variables
        if im.shape[2]<x_expansion:
            # print('Expanding x dimension')
            im = np.concatenate((im, np.zeros((im.shape[0], im.shape[1], x_expansion-im.shape[2]), dtype=np.int8)), axis=2)

        if im.shape[1]<y_expansion:
            # print('Expanding y dimension')
            im = np.concatenate((im, np.zeros((im.shape[0], y_expansion-im.shape[1], im.shape[2]), dtype=np.int8)), axis=1)

        # all z values available
        z_values = np.array(range(im.shape[0]))
        for z in tqdm(z_values):
            # preprocess image
            im_slice = Image.fromarray(im[z])
            im_slice = preprocess_im(im_slice)
            im_slice = np.asarray(im_slice)
            # put channel first and repeat in RGB
            im_slice = np.repeat(np.expand_dims(im_slice, axis=0), 3, axis=0)

            # saving path
            save_name = f'slice_{z}.{file_format}'
            # save image
            sitk.WriteImage(sitk.GetImageFromArray(im_slice), str(self.slices_dir / save_name))
        
        return shape

    def prob_map(self, image_path:Path):
        """create a probability map for a given image path

        Args:
            image_path (Path): path of the nrrd original image
        """
        original_shape = self.save_slices(image_path) # save slices and get original shape
        prob_map = self.md.process_image(slices_dir=self.slices_dir, original_shape=original_shape)
        # save the prob map as numpy array
        np.save(self.probs_dir / 'prob_map.npy', prob_map)

    # def seed_definition():



In [3]:
segmenter = lesion_seg()

Models loaded on CUDA
Successfully loaded models


In [4]:
image_paths = list(segmenter.input_dir.glob("*"))
image_path = image_paths[0]
image_path

PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/input/DATA_101.nrrd')

In [5]:
segmenter.prob_map(image_path=image_path)

100%|██████████| 348/348 [00:02<00:00, 129.36it/s]
Processing slices: 100%|██████████| 11/11 [00:13<00:00,  1.20s/it]
Processing slices: 100%|██████████| 11/11 [00:10<00:00,  1.10it/s]
Processing slices: 100%|██████████| 11/11 [00:10<00:00,  1.07it/s]
Processing slices: 100%|██████████| 11/11 [00:10<00:00,  1.08it/s]
Processing slices: 100%|██████████| 11/11 [00:10<00:00,  1.08it/s]


The shape of the accumulated mask is torch.Size([348, 2, 512, 512])


100%|██████████| 348/348 [00:01<00:00, 281.60it/s]


The shape of the final output is (348, 682, 865)


Explore examples