#### Dataset functions

In [1]:
import os
import pandas as pd
import SimpleITK as sitk

class Abus23DataLoader():
    def __init__(self, dataset_path, labels_csv):
        self.data = self.load_abus23(dataset_path, labels_csv)
        self.dataset_path = dataset_path
        self.used_data = self.data
        self.cidx = 0
        
    def load_abus23(self, dataset_path, label_file):
        dataset = pd.read_csv(os.path.join(dataset_path, label_file))
        print("Dataset columns:", dataset.columns)
        return dataset
    
    def set_subset_ids(self, list_id = [], id_label = 'case_id'):
        if list_id:
            self.used_data = self.data[self.data[id_label].isin(list_id)]
        
    def get_data_entry(self, idx):    
        return self.used_data.iloc[idx]
    
    def get_item(self, idx):
        entry = self.get_data_entry(idx).to_dict()
        output = {}
        output["id"] = entry['case_id']
        if 'label' in entry:
            output["class"] =entry['label']
        image_full_path = os.path.join(self.dataset_path, entry['data_path'].replace('\\','/'))
        output["image"] = sitk.ReadImage(image_full_path)
        output["image_path"] = image_full_path
        if 'mask_path' in entry:
            mask_full_path = os.path.join(self.dataset_path, entry['mask_path'].replace('\\','/'))
            output["mask"] = sitk.ReadImage(mask_full_path)
            output["mask_path"] = mask_full_path
            
        return output

    def get_keys(self):
        return self.used_data.columns.tolist()
        
    def __getitem__(self, idx):
       return self.get_item(idx)
   
    def __len__(self):
        return len(self.used_data)
   
def get_validation_ids(val_file):
    with open(val_file) as fp:
        lines = fp.readlines()
        #print([os.path.basename(i)for i in lines])
        patients = set([int(os.path.basename(case).split('_')[0]) for case in lines])
    return list(patients)


#### Data processing functions

In [2]:

import SimpleITK as sitk
import numpy as np
import cv2

def normalize_8bits(image: np.ndarray):
    return (255.0 *(image - image.min()) / (image.max() - image.min())).astype(np.uint8)

def get_slices(data, norm_fn = normalize_8bits):
    data_array = sitk.GetArrayFromImage(data)
    return [norm_fn(data_array[i, ...]) for i in range(len(data_array))]


# Create the volume from slices

def volume_from_slice(slices):
    mask_3d = np.stack(slices)
    output_mask = sitk.GetImageFromArray(mask_3d)
    #castImageFilter = sitk.CastImageFilter()
    #castImageFilter.SetOutputPixelType(sitk.sitkFloat32)
    #img = castImageFilter.Execute(img)
    
    return output_mask

#### Prediction functions

In [3]:
from ultralytics import YOLO
import numpy as np
import cv2

# Load a model
class YOLOPredictor:
    def __init__(self, model_file, conf_th = 0.5):
        self.model = YOLO(model_file)  # pretrained YOLOv8n model
        self.conf_th = conf_th
        
    def set_conf_th(self, conf_th = 0.5):
        self.conf_th = conf_th
        
    def __call__(self, slice, conf_th=None):
        return self.predict(slice, conf_th)
        
    def predict(self, slice, conf_th=None):
        assert len(slice.shape) == 2
        
        if conf_th is None:
            conf_th = self.conf_th
        
        cv2.imwrite("temp.png", slice)
        results = self.model("temp.png", verbose=False)[0].cpu().numpy()
        
        slice_mask = np.zeros(slice.shape)
        if results.masks is not None:
            
            pred_mask_data = results.masks.data
            for i in range(len(pred_mask_data)):
                
                pred_box_conf = results.boxes[i].conf  # confidence score, (N, )
                if pred_box_conf < conf_th:
                    continue

                m = cv2.resize(pred_mask_data[i, ...], dsize=(slice.shape[1], slice.shape[0])) # interpolation=cv2.INTER_CUBIC)
                slice_mask = np.logical_or(slice_mask, m).astype("float32")
                
        return slice_mask       

#### Run inference

In [4]:
from ultralytics import YOLO
from tqdm import tqdm

# ABUS 23
dataset_path = "datasets/Train"
label_file = "labels.csv"
validation_file = "datasets/abus23_25_png/val_seg.txt"

# Yolo model
yolo_weights = "/home/joel/abus23/runs/segment/train10/weights/best.pt"

# Output folder
output_folder = os.path.join("results_masks", "abus23_25", "raw_stack_2")


# Create output folder
os.makedirs(output_folder, exist_ok=True)

# Load dataset
dataset = Abus23DataLoader(dataset_path, label_file)

# Get validation cases
val_ids = get_validation_ids(validation_file)
dataset.set_subset_ids(val_ids)

# Load yolo predictor
yolo_predictor = YOLOPredictor(yolo_weights)

# For each item in the dataset
gt_files = []
pred_files = []
for item in tqdm(dataset):
    
        # Get image slice
        image_slices = get_slices(item['image'])
    
        # Get predicted slices
        yolo_masks_slices = [yolo_predictor(slice, conf_th=0.6) for slice in image_slices]

        # Create 3D volum
        mask_volum = volume_from_slice(yolo_masks_slices)
        
        # Copy metadata from predited image
        mask_volum.CopyInformation(item['image'])
        
        # Save NRRD mask prediction
        mask_file = os.path.join(output_folder, f"{item['id']}.nrrd")
        sitk.WriteImage(mask_volum, mask_file, useCompression=True )
        
        # Save file names for evaluation
        pred_files.append(mask_file)
        gt_files.append(item['mask_path'])
        

Dataset columns: Index(['case_id', 'label', 'data_path', 'mask_path'], dtype='object')


100%|██████████| 20/20 [05:44<00:00, 17.21s/it]


#### Validation

In [None]:
from TDSCABUS2023.Metrics import segmentation


def Validate(pred_list, gt_list, cvs_pred_file = None, csv_gt_file = None):
    
    print("Segmentation:")
    print("------------------------------------------")
    
    scores = {'DiceCoefficient': [], 'HDCoefficient': [], 'score': []}
    for pred, gt in zip(pred_list, gt_list):
        try:
            result = segmentation.score_case(gt, pred)
        except Exception as e:
           result = {'DiceCoefficient': 0, 'HDCoefficient': 0, 'score': 0} #HD coefficient if fail?
        print("Case:", os.path.basename(pred), "  Results:",  result)
        
        for k, v in result.items():
            scores[k].append(v)
        
    for k, values in scores.items():
        values = np.array(values)
        print(f"\n{k}:")
        print(f"   - Min: {values.min():0.4f}")
        print(f"   - Max: {values.max():0.4f}")
        print(f"   - Mean: {values.mean():0.4f}")
    

In [9]:
gt_files = [g.replace("DATA", "MASK") for g in gt_files]
print(pred_files)
print(gt_files)


['results_masks/abus23_25/raw_stack_2/7.nrrd', 'results_masks/abus23_25/raw_stack_2/9.nrrd', 'results_masks/abus23_25/raw_stack_2/10.nrrd', 'results_masks/abus23_25/raw_stack_2/14.nrrd', 'results_masks/abus23_25/raw_stack_2/22.nrrd', 'results_masks/abus23_25/raw_stack_2/25.nrrd', 'results_masks/abus23_25/raw_stack_2/31.nrrd', 'results_masks/abus23_25/raw_stack_2/39.nrrd', 'results_masks/abus23_25/raw_stack_2/53.nrrd', 'results_masks/abus23_25/raw_stack_2/61.nrrd', 'results_masks/abus23_25/raw_stack_2/62.nrrd', 'results_masks/abus23_25/raw_stack_2/66.nrrd', 'results_masks/abus23_25/raw_stack_2/67.nrrd', 'results_masks/abus23_25/raw_stack_2/74.nrrd', 'results_masks/abus23_25/raw_stack_2/75.nrrd', 'results_masks/abus23_25/raw_stack_2/80.nrrd', 'results_masks/abus23_25/raw_stack_2/87.nrrd', 'results_masks/abus23_25/raw_stack_2/89.nrrd', 'results_masks/abus23_25/raw_stack_2/90.nrrd', 'results_masks/abus23_25/raw_stack_2/99.nrrd']
['datasets/Train/MASK/MASK_007.nrrd', 'datasets/Train/MASK/MA

In [14]:
Validate(pred_files, gt_files)


Segmentation:
------------------------------------------
Case: 7.nrrd   Results: {'DiceCoefficient': 0.4108192408071463, 'HDCoefficient': 293.4177908716511, 'score': -293.00697163084396}
Case: 9.nrrd   Results: {'DiceCoefficient': 0.0, 'HDCoefficient': 230.59488285736091, 'score': -230.59488285736091}
Case: 10.nrrd   Results: {'DiceCoefficient': 0.5383357066362257, 'HDCoefficient': 226.57890457851542, 'score': -226.0405688718792}
Case: 14.nrrd   Results: {'DiceCoefficient': 0.13659776114277866, 'HDCoefficient': 340.74917461382057, 'score': -340.6125768526778}
Case: 22.nrrd   Results: {'DiceCoefficient': 0.5968216750111011, 'HDCoefficient': 57.706152185014034, 'score': -57.109330510002934}
Case: 25.nrrd   Results: {'DiceCoefficient': 0.6108740813545691, 'HDCoefficient': 45.221676218380054, 'score': -44.61080213702549}
Case: 31.nrrd   Results: {'DiceCoefficient': 0, 'HDCoefficient': 0, 'score': 0}
Case: 39.nrrd   Results: {'DiceCoefficient': 0.5201721162144022, 'HDCoefficient': 232.13358