In [12]:
# running as docker?
docker_running = False

# define repo path and add it to the path
from pathlib import Path
import sys, os
if not docker_running: # if we are running locally
    repo_path= Path.cwd().resolve()
    while '.gitignore' not in os.listdir(repo_path): # while not in the root of the repo
        repo_path = repo_path.parent #go up one level
else: # if running in the container
    repo_path = Path('opt/usuari')
sys.path.insert(0,str(repo_path)) if str(repo_path) not in sys.path else None

import numpy as np
import SimpleITK as sitk
import pandas as pd
from torchvision.transforms import (
    Compose,
    Resize,
    InterpolationMode,
)
from PIL import Image
from tqdm import tqdm
import shutil
import SimpleITK as sitk
import torch

# special imports
from segmentation import USSegmentation
from datasets_utils.datasets import ABUS_test
from torch.utils.data import DataLoader
import torchvision

The first thing to understand is that each image will need to store 64 MB of slices.<br>
Thus images will be stored along other cache information in the cached_data folder.<br>

In [13]:
class lesion_seg:
    def __init__(self):
        self.input_dir = Path('./input/') if docker_running else repo_path / 'input'
        self.output_dir = Path('./predict') / 'Segmentation' if docker_running else Path(repo_path / 'predict' / 'Segmentation')
        self.output_dir.mkdir(parents=True, exist_ok=True) # make sure the output dir exists
        self.checkpoint_dir = repo_path / 'checkpoints' / 'sam_vit_b_01ec64.pth'
        self.cached_dir = repo_path / 'cached_data'
        self.cached_dir.mkdir(parents=True, exist_ok=True) # create cached dir in root
        self.slices_dir = self.cached_dir / 'slices'
        # load all folds models
        self.md = USSegmentation(self.checkpoint_dir)
        load_success = self.md.load_model()
        if load_success:
            print("Successfully loaded models")

    def save_slices(self, image_path:Path):
        """given an nrrd image path, the slices are saved in the cached_dir/slices folder

        Args:
            image_path (Path): Path to the nrrd image
        """
        # Expansion HP
        x_expansion = 865
        y_expansion = 865
        x_resizing = 512
        y_resizing = 512
        file_format = 'mha'


        # remove folder if exists, always starts from scratch
        if self.slices_dir.exists():
            shutil.rmtree(self.slices_dir)
        self.slices_dir.mkdir(exist_ok=True, parents=True)

        # transforms
        preprocess_im = Compose(
                [
                    Resize((x_resizing, y_resizing), interpolation= InterpolationMode.BILINEAR),
                ]
        )

        # get image
        im_sitk = sitk.ReadImage(image_path)
        shape = im_sitk.GetSize()
        im = sitk.GetArrayFromImage(im_sitk)
        # now, we complete the images and labels to the expansion variables
        if im.shape[2]<x_expansion:
            # print('Expanding x dimension')
            im = np.concatenate((im, np.zeros((im.shape[0], im.shape[1], x_expansion-im.shape[2]), dtype=np.int8)), axis=2)

        if im.shape[1]<y_expansion:
            # print('Expanding y dimension')
            im = np.concatenate((im, np.zeros((im.shape[0], y_expansion-im.shape[1], im.shape[2]), dtype=np.int8)), axis=1)

        # all z values available
        z_values = np.array(range(im.shape[0]))
        for z in tqdm(z_values):
            # preprocess image
            im_slice = Image.fromarray(im[z])
            im_slice = preprocess_im(im_slice)
            im_slice = np.asarray(im_slice)
            # put channel first and repeat in RGB
            im_slice = np.repeat(np.expand_dims(im_slice, axis=0), 3, axis=0)

            # saving path
            save_name = f'slice_{z}.{file_format}'
            # save image
            sitk.WriteImage(sitk.GetImageFromArray(im_slice), str(self.slices_dir / save_name))
        
        return shape

    # def prob_map():

In [15]:
segmenter = lesion_seg()

Models loaded on CUDA
Successfully loaded models


Explore examples

In [17]:
image_paths = list(segmenter.input_dir.glob("*"))
image_path = image_paths[0]
image_path

PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/input/DATA_101.nrrd')

In [19]:
original_shape = segmenter.save_slices(image_path)

100%|██████████| 348/348 [00:02<00:00, 128.11it/s]


Now the probability map must be created

In [6]:
def process_image(self, input_image):
    image = self.test_transform({"image": input_image})["image"]
    image = image.to(device=self.device).unsqueeze(0)
    input_h_flipped = self.h_flip(image)
    final_output = torch.zeros((1, 4, 256, 256), dtype=torch.float32).to(self.device)
    for i in range(5):
        self.models[i].eval()
        outputs = self.models[i](image, True, 256)
        outputs_h_flip = self.models[i](input_h_flipped, True, 256)

        output_masks_t = (outputs['masks'] + self.h_flip(outputs_h_flip['masks'])) / 2
        final_output += output_masks_t

    output_masks = torch.argmax(torch.softmax(final_output / 5, dim=1), dim=1, keepdim=True)

    return output_masks

In [9]:
# get all files in the folder in a list, only mha files
slice_files = [file for file in os.listdir(segmenter.slices_dir) if file.endswith('.mha')] # unordered files
slice_files = sorted(slice_files, key=lambda x: int(x.split('.')[0].split('_')[1]))

# create final paths
image_files = np.array([segmenter.slices_dir / i for i in slice_files])
db_val = ABUS_test(transform=segmenter.md.test_transform,list_dir=image_files)   
valloader = DataLoader(db_val, batch_size=32, shuffle=False, num_workers=12, pin_memory=True)
print(f'The number of slices is {len(db_val)}')

The number of slices is 350


In [24]:
num_classes = segmenter.md.num_classes
image_size = segmenter.md.image_size
device = segmenter.md.device

# 2. Create probability volume
accumulated_mask = torch.zeros((len(db_val),num_classes+1,image_size,image_size)) # store final mask per patient

for model in segmenter.md.models: # for each model learned

    model_mask = [] # for appending slices of same model
    for sample_batch in tqdm(valloader):
        with torch.no_grad():
            # get data
            image_batch = sample_batch["image"].to(device)
            # forward and losses computing
            outputs = model(image_batch, True, image_size)
            # stack the masks
            model_mask.append(outputs['masks'].detach().cpu())
    # stack tensors in a single one
    model_mask = torch.cat(model_mask, dim=0)
    accumulated_mask += model_mask
print(f'The shape of the accumulated mask is {accumulated_mask.shape}')

# get the mean
accumulated_mask /= len(segmenter.md.models)
accumulated_mask = torch.softmax(accumulated_mask, dim=1)[:,1] # get lesion probability
accumulated_mask = accumulated_mask.cpu().numpy()

# reshape each slice
x_expansion = 865
y_expansion = 865
resized_mask = []
for slice_num in tqdm(range(accumulated_mask.shape[0])):
    im_slice = accumulated_mask[slice_num,:,:]
    im_slice = Image.fromarray(im_slice)
    im_slice_comeback = torchvision.transforms.Resize(
        (x_expansion, y_expansion),
        interpolation= torchvision.transforms.InterpolationMode.BILINEAR, # bilineal or nearest? probs bilineal
        )(im_slice)
    resized_mask.append(im_slice_comeback)
# stack all slices
resized_mask = np.stack(resized_mask, axis=0)
# get original size and save
final_mask = resized_mask[:,:original_shape[1],:original_shape[0]]
print(f'The shape of the final output is {final_mask.shape}')

100%|██████████| 11/11 [00:10<00:00,  1.03it/s]
100%|██████████| 11/11 [00:10<00:00,  1.03it/s]
100%|██████████| 11/11 [00:11<00:00,  1.00s/it]
100%|██████████| 11/11 [00:10<00:00,  1.02it/s]
100%|██████████| 11/11 [00:11<00:00,  1.01s/it]


The shape of the accumulated mask is torch.Size([350, 2, 512, 512])


In [25]:
saving_path = saving_dir  / f'MASK_{pat_id}.nii.gz'

# save the mask as nii.gz
sitk.WriteImage(sitk.GetImageFromArray(final_mask), str(saving_path), True, 0)

100%|██████████| 350/350 [00:01<00:00, 316.10it/s]


The shape of the resized mask is (350, 865, 865)
The shape of the final output is (350, 682, 865)
