In [1]:
#Add repo path to the system path
from pathlib import Path
import os, sys
repo_path= Path.cwd().resolve()
while '.gitignore' not in os.listdir(repo_path): # while not in the root of the repo
    repo_path = repo_path.parent #go up one level
sys.path.insert(0,str(repo_path)) if str(repo_path) not in sys.path else None

import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="0"
device = 1

from importlib import import_module
from monai.transforms import (
    Compose,
    ScaleIntensityd,
    EnsureTyped,
    Resized,
)
import numpy as np
import torch
from torch.utils.data import DataLoader
from sklearn.metrics import jaccard_score
import SimpleITK as sitk
from PIL import Image
import torchvision
import pandas as pd

# special imports
from datasets_utils.datasets import ABUS_test
sys.path.append(str(repo_path / 'SAMed')) if str(repo_path / 'SAMed') not in sys.path else None
from SAMed.segment_anything import sam_model_registry

In [3]:
import re
# Define a custom sorting key function
def slice_number(filename):
    """order images by slice number

    Args:
        filename (str): file name in string

    Returns:
        int: match group int
    """
    match = re.search(r'slice_(\d+)\.mha', filename)
    if match:
        return int(match.group(1))
    return -1  # Default value if the pattern is not found

# Single model inference

In [4]:
# HP
batch_size = 8
num_classes = 1
image_size = 512

# get SAM model
checkpoint_dir = repo_path / 'checkpoints'
sam, _ = sam_model_registry['vit_b'](image_size=image_size,
                                    num_classes=num_classes,
                                    checkpoint=str(checkpoint_dir / 'sam_vit_b_01ec64.pth'),
                                    pixel_mean=[0, 0, 0],
                                    pixel_std=[1, 1, 1])
# load lora model
pkg = import_module('sam_lora_image_encoder')
model = pkg.LoRA_Sam(sam, 4)

optimum_weights = [
    'experiments/SAMed_ABUS/results/full-slice-lesion/fold0/weights/epoch_19.pth', #3220
]

val_transform = Compose(
            [
                ScaleIntensityd(keys=["image"]),
                Resized(keys=["image"], spatial_size=(image_size, image_size),mode=['area']),
                EnsureTyped(keys=["image"])
            ])

In [5]:
metadata_path = repo_path / 'data/challange_2023/Val/metadata.csv'
metadata = pd.read_csv(metadata_path)


# for pat_id in range(100,130,1): # each val id
pat_id = 100
patient_meta = metadata[metadata['case_id'] == pat_id]
original_shape = patient_meta['shape'].apply(lambda x: tuple(map(int, x[1:-1].split(',')))).values[0]

# get data
root_path = repo_path / 'data/challange_2023/Val/full-slice_512x512_all'
path_images = (root_path / "image_mha")
# get all files in the folder in a list, only mha files
image_files = [file for file in os.listdir(path_images) if file.endswith('.mha')] # unordered files
# # now, we will check if the path has at least one of the ids in the train_ids list
val_files = [file for file in image_files if f'id_{pat_id}_' in file]
val_files = sorted(val_files, key=slice_number) # sort them
# # create final paths
image_files = np.array([path_images / i for i in val_files])


In [6]:
image_files

array([PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/data/challange_2023/Val/full-slice_512x512_all/image_mha/id_100_slice_0.mha'),
       PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/data/challange_2023/Val/full-slice_512x512_all/image_mha/id_100_slice_1.mha'),
       PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/data/challange_2023/Val/full-slice_512x512_all/image_mha/id_100_slice_2.mha'),
       PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/data/challange_2023/Val/full-slice_512x512_all/image_mha/id_100_slice_3.mha'),
       PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/data/challange_2023/Val/full-slice_512x512_all/image_mha/id_100_slice_4.mha'),
       PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/data/challange_2023/Val/full-slice_512x512_all/image_mha/id_100_slice_5.mha'),
       PosixPath('/home/ricardo/ABUS2023_documents/tdsc_abus23/data/challange_2023/Val/full-slice_512x512_all/image_mha/id_100_slice_6.mha'),
      