YOLOv8: https://github.com/ultralytics/ultralytics

Imports

In [37]:
import SimpleITK as sitk
import pandas as pd
from typing import Tuple, Dict, Any
from tqdm import tqdm
import numpy as np
import cv2
import os

Basic parameters

In [38]:
dataset_path = "datasets/Train"
labels_file = "labels.csv"

output_folder = "datasets/train_png2"
use_classes = True  #If false, only one class is used (0)
val_frac = 0.2 # Validation fraction of the data

Load dataset info

In [39]:
# Load dataset
dataset = pd.read_csv(os.path.join(dataset_path, labels_file))
print("Dataset columns:", dataset.columns)

Dataset columns: Index(['case_id', 'label', 'data_path', 'mask_path'], dtype='object')


In [40]:
# Get class index
classes = dataset['label'].unique().tolist()
print("Dataset classes:", classes)

Dataset classes: ['M', 'B']


Define function for reading NRRD files

In [41]:
def ReadNRRD(filename: str) -> Tuple[sitk.Image, Dict[str, Any]]:
    reader = sitk.ImageFileReader()
    reader.SetFileName(filename)
    reader.LoadPrivateTagsOn()
    reader.ReadImageInformation()

    image = reader.Execute()
    metadata = {}
    for key in reader.GetMetaDataKeys():
        if reader.HasMetaDataKey(key):
            metadata[key] = reader.GetMetaData(key)     
            
    return image, metadata

Define function for 8-bits normalization

In [42]:
def normalize_8bits(image: np.ndarray):
    return (255.0 *(image - image.min()) / (image.max() - image.min())).astype(np.uint8)

### Format YOLOv5/v8
https://docs.ultralytics.com/datasets/segment/
<br> `<class-index> <x1> <y1> <x2> <y2> ... <xn> <yn>`
<br>Others: https://docs.ultralytics.com/yolov5/tutorials/train_custom_data/#13-prepare-dataset-for-yolov5

In [43]:
os.makedirs(output_folder, exist_ok=True)

# Create slice from data
for _, row in tqdm(dataset.iterrows(), total=len(dataset)):
    
    # get and load data
    id = row.case_id
    label = row.label
    data, metadata = ReadNRRD(os.path.join(dataset_path, row.data_path.replace('\\','/')))
    mask, _ = ReadNRRD(os.path.join(dataset_path, row.mask_path.replace('\\','/')))
    
    # Crop slices
    assert data.GetSize() == mask.GetSize()
    data_array = sitk.GetArrayFromImage(data)
    mask_array = sitk.GetArrayFromImage(mask)
    
    image_size = data_array.shape[2], data_array.shape[1]
    
    #print(data_array.shape)
    
    assert label in classes
    
    if use_classes: # Multiples data classes:
        data_class = classes.index(label)
    else:           # Single data class:
        data_class = 0
    
    for idx in range(len(data_array)): #first dimension is z in numpy (z,y,x)
        data_slice = data_array[idx, ...]
        mask_slice = mask_array[idx, ...]
        
        if np.sum(mask_slice) == 0:
            continue
        
        data_slice = normalize_8bits(data_slice)
        
        # Get contours from the mask
        contours, _ = cv2.findContours(mask_slice, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
        contours_list = [contour.squeeze().tolist() for contour in contours]
        
        image_name = f"{id:0>3}_{idx:0>3}"
        
        # Save image
        image_out = os.path.join(output_folder, f"{image_name}.png")
        cv2.imwrite(image_out, data_slice)
    
        # save  label
        label_out = os.path.join(output_folder, f"{image_name}.txt")
        with open(label_out, "w") as fp:
            for contour in contours_list:
                contour_str = " ".join([f"{point[0]/image_size[0]:0.6f} {point[1]/image_size[1]:0.6f}" for point in contour])
                fp.write(f"{data_class} {contour_str}\n")


  2%|▏         | 2/100 [00:04<03:31,  2.16s/it]

In [None]:
# Get patients list
list_items = [item[:-4] for item in os.listdir(output_folder) if item.endswith(".png")]
patients = list(set([ item.split("_")[0] for item in list_items]))

# Get lesion type per patient
patients_type = {i: [] for i in range(len(classes))}
for item in list_items:
    p = item.split("_")[0]
    item_path = os.path.join(output_folder, f"{item}.txt")
    with open(item_path) as fp:
        class_type = int(fp.readline().split(" ")[0])
    if p not in patients_type[class_type]:
        patients_type[class_type].append(p)
    
print(patients_type)
print(len(patients_type[0]),len(patients_type[1]) )
    
# Distribute patients in Train/Val using val_frac with balanced lesion types
train_p = []
val_p = []
for i in range(len(classes)):
    num_train = int(len(patients_type[i])*(1-val_frac))
    train_type_p, val_type_p = patients_type[i][:num_train], patients_type[i][num_train:]
    train_p += train_type_p
    val_p += val_type_p
    
print(train_p, len(train_p))
print(val_p, len(val_p))
    
# Get distributed patient images for train and val
train = [ os.path.join(output_folder, f"{item}.png") for item in list_items if item.split("_")[0] in train_p]
val = [ os.path.join(output_folder, f"{item}.png") for item in list_items if item.split("_")[0] in val_p]

# Generate train.txt file
train_file = os.path.join(output_folder, "../train.txt")
with open(train_file, "w") as fp:
    fp.writelines([t + '\n' for t in train])
    
# Generate val.txt file
val_file = os.path.join(output_folder, "../val.txt")
with open(val_file, "w") as fp:
    fp.writelines([v + '\n' for v in val])
    

{0: ['066', '006', '037', '019', '086', '045', '032', '068', '012', '022', '093', '088', '033', '024', '059', '089', '058', '023', '036', '009', '015', '035', '077', '094', '087', '075', '021', '008', '056', '002', '060', '076', '092', '029', '003', '067', '020', '063', '098', '044', '038', '090', '054', '000', '070', '016', '069', '080', '034', '030', '081', '072', '047', '099', '040', '028', '085', '010'], 1: ['082', '026', '043', '079', '052', '074', '031', '097', '084', '039', '046', '096', '071', '048', '014', '061', '041', '065', '018', '017', '025', '007', '064', '095', '073', '049', '004', '005', '013', '050', '057', '062', '055', '051', '011', '053', '091', '083', '042', '078', '027', '001']}
58 42
['066', '006', '037', '019', '086', '045', '032', '068', '012', '022', '093', '088', '033', '024', '059', '089', '058', '023', '036', '009', '015', '035', '077', '094', '087', '075', '021', '008', '056', '002', '060', '076', '092', '029', '003', '067', '020', '063', '098', '044', '0