In [2]:
#Add repo path to the system path
from pathlib import Path
import os, sys
repo_path= Path.cwd().resolve()
while '.gitignore' not in os.listdir(repo_path): # while not in the root of the repo
    repo_path = repo_path.parent #go up one level
sys.path.insert(0,str(repo_path)) if str(repo_path) not in sys.path else None

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="1"
device = 0

from importlib import import_module
from monai.transforms import (
    Compose,
    ScaleIntensityd,
    EnsureTyped,
    Resized,
)
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import jaccard_score
import SimpleITK as sitk
from PIL import Image
import torchvision
import pandas as pd

# special imports
from datasets_utils.datasets import ABUS_test
sys.path.append(str(repo_path / 'SAMed')) if str(repo_path / 'SAMed') not in sys.path else None
from SAMed.segment_anything import sam_model_registry

In [3]:
import re
# Define a custom sorting key function
def slice_number(filename):
    """order images by slice number

    Args:
        filename (str): file name in string

    Returns:
        int: match group int
    """
    match = re.search(r'slice_(\d+)\.mha', filename)
    if match:
        return int(match.group(1))
    return -1  # Default value if the pattern is not found

# Single model inference

In [4]:
# HP
batch_size = 8
num_classes = 1
image_size = 512

# get SAM model
checkpoint_dir = repo_path / 'checkpoints'
sam, _ = sam_model_registry['vit_b'](image_size=image_size,
                                    num_classes=num_classes,
                                    checkpoint=str(checkpoint_dir / 'sam_vit_b_01ec64.pth'),
                                    pixel_mean=[0, 0, 0],
                                    pixel_std=[1, 1, 1])
# load lora model
pkg = import_module('sam_lora_image_encoder')
model = pkg.LoRA_Sam(sam, 4)

optimum_weights = [
    'experiments/SAMed_ABUS/results/full-slice-lesion/fold0/weights/epoch_19.pth', #3220
]

val_transform = Compose(
            [
                ScaleIntensityd(keys=["image"]),
                Resized(keys=["image"], spatial_size=(image_size, image_size),mode=['area']),
                EnsureTyped(keys=["image"])
            ])

In [5]:
metadata_path = repo_path / 'data/challange_2023/Val/metadata.csv'
metadata = pd.read_csv(metadata_path)


# for pat_id in range(100,130,1): # each val id
pat_id = 100
patient_meta = metadata[metadata['case_id'] == pat_id]
original_shape = patient_meta['shape'].apply(lambda x: tuple(map(int, x[1:-1].split(',')))).values[0]

# get data
root_path = repo_path / 'data/challange_2023/Val/full-slice_512x512_all'
path_images = (root_path / "image_mha")
# get all files in the folder in a list, only mha files
image_files = [file for file in os.listdir(path_images) if file.endswith('.mha')] # unordered files
# # now, we will check if the path has at least one of the ids in the train_ids list
val_files = [file for file in image_files if f'id_{pat_id}_' in file]
val_files = sorted(val_files, key=slice_number) # sort them
# # create final paths
image_files = np.array([path_images / i for i in val_files])
db_val = ABUS_test(transform=val_transform,list_dir=image_files)   
valloader = DataLoader(db_val, batch_size=batch_size, shuffle=False, num_workers=8, pin_memory=True)
print(f'The patient id is {pat_id}')
print(f'The number of slices is {len(db_val)}')
# store final mask per patient
accumulated_mask = torch.zeros((len(db_val),num_classes+1,image_size,image_size))
# for model_path in optimum_weights: # for each model learned
model_path = optimum_weights[0]
# load weighs
load_path = repo_path / model_path
model.load_lora_parameters(str(load_path))
model.eval()
model.to(device);

model_mask = []
for sample_batch in valloader: # get some slides
    with torch.no_grad():
        # get data
        image_batch = sample_batch["image"].to(device)
        # forward and losses computing
        outputs = model(image_batch, True, image_size)
        # stack the masks
        model_mask.append(outputs['masks'].detach().cpu())
# stack tensors in a single one
model_mask = torch.cat(model_mask, dim=0)
print(f'The shape of the independent model mask is {model_mask.shape}')
accumulated_mask += model_mask

# get the mean and argmax
accumulated_mask /= len(optimum_weights)
# accumulated_mask = torch.argmax(torch.softmax(accumulated_mask, dim=1), dim=1, keepdim=False)
# print(f'The shape of the accumulated output is {accumulated_mask.shape}')

The patient id is 100
The number of slices is 353
The shape of the independent model mask is torch.Size([353, 2, 512, 512])


Lesion probability saving

In [10]:
softed_mask = torch.softmax(accumulated_mask, dim=1)[:,1].numpy() # with lesion probability
softed_mask = softed_mask.astype(np.float32)

# reshape each slice
x_expansion = 865
y_expansion = 865
resized_mask = []
for slice_num in range(softed_mask.shape[0]):
    im_slice = softed_mask[slice_num,:,:]
    im_slice = Image.fromarray(im_slice)
    im_slice_comeback = torchvision.transforms.Resize(
        (x_expansion, y_expansion),
        interpolation= torchvision.transforms.InterpolationMode.NEAREST,
        )(im_slice)
    resized_mask.append(im_slice_comeback)
# stack all slices
resized_mask = np.stack(resized_mask, axis=0)
print(f'The shape of the resized mask is {resized_mask.shape}')

# get original size
final_mask = resized_mask[:,:original_shape[1],:original_shape[0]]
print(f'The shape of the final output is {final_mask.shape}')
print(f'The dtype of the final output is {final_mask.dtype}')

saving_dir = repo_path / 'experiments/inference/segmentation/data/predictions' / 'full-size' / f'fold0_probmap'
saving_dir.mkdir(parents=True, exist_ok=True)
saving_path = saving_dir  / f'MASK_{pat_id}.nii.gz'

# save the mask as nii.gz
sitk.WriteImage(sitk.GetImageFromArray(final_mask), str(saving_path))

The shape of the resized mask is (353, 865, 865)
The shape of the final output is (353, 682, 865)
The dtype of the final output is float32


Binary mask saving

In [None]:
# change to numpy and int8
accumulated_mask = accumulated_mask.cpu().numpy()
accumulated_mask = accumulated_mask.astype(np.uint8)

# reshape each slice
x_expansion = 865
y_expansion = 865
resized_mask = []
for slice_num in range(accumulated_mask.shape[0]):
    im_slice = accumulated_mask[slice_num,:,:]
    im_slice = Image.fromarray(im_slice)
    im_slice_comeback = torchvision.transforms.Resize(
        (x_expansion, y_expansion),
        interpolation= torchvision.transforms.InterpolationMode.NEAREST,
        )(im_slice)
    resized_mask.append(im_slice_comeback)
# stack all slices
resized_mask = np.stack(resized_mask, axis=0)
print(f'The shape of the resized mask is {resized_mask.shape}')

The shape of the resized mask is (353, 865, 865)


In [None]:
# get original size
final_mask = resized_mask[:,:original_shape[1],:original_shape[0]]
print(f'The shape of the final output is {final_mask.shape}')

saving_dir = repo_path / 'experiments/inference/segmentation/data/predictions' / 'full-size' / f'fold0'
saving_dir.mkdir(parents=True, exist_ok=True)
saving_path = saving_dir  / f'MASK_{pat_id}.nii.gz'

# save the mask as nii.gz
sitk.WriteImage(sitk.GetImageFromArray(final_mask), str(saving_path))

The shape of the final output is (353, 682, 865)
