In [1]:
from PIL import Image, ImageDraw
import random
import seaborn as sns
from matplotlib.patches import Rectangle
import yaml
import json
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from tqdm.notebook import tqdm
from sklearn.model_selection import train_test_split
from sklearn.metrics import f1_score, precision_score, recall_score, fbeta_score
import cv2
import threading
import time
from contextlib import nullcontext
from concurrent.futures import ThreadPoolExecutor
import math

In [2]:
# Setting
DEBUG_MODE = False
PATCH_SIZE = 640
TRUST = 5
BOX_SIZE = 30
TRAIN_SPLIT = 0.8
CONCENTRATION = 3

In [3]:
def normalize_slice(slice_data):
    """
    Normalize slice data using the 2nd and 98th percentiles.
    
    Args:
        slice_data (numpy.array): Input image slice.
    
    Returns:
        np.uint8: Normalized image in the range [0, 255].
    """
    p2 = np.percentile(slice_data, 2)
    p98 = np.percentile(slice_data, 98)
    clipped_data = np.clip(slice_data, p2, p98)
    normalized = 255 * (clipped_data - p2) / (p98 - p2)
    return np.uint8(normalized)

# Define the preprocessing function to extract slices, normalize, and generate YOLO annotations.
def prepare_yolo_dataset(trust=TRUST, train_split=TRAIN_SPLIT, dir_info=None):
    """
    Extract slices containing motors and save images with corresponding YOLO annotations.
    
    Steps:
    - Load the motor labels.
    - Perform a train/validation split by tomogram.
    - For each motor, extract slices in a range (± trust parameter).
    - Normalize each slice and save it.
    - Generate YOLO format bounding box annotations with a fixed box size.
    - Create a YAML configuration file for YOLO training.
    
    Returns:
        dict: A summary containing dataset statistics and file paths.
    """
    # Load the labels CSV
    labels_df = pd.read_csv(os.path.join(dir_info['data_path'], "train_labels.csv"))
    
    total_motors = labels_df['Number of motors'].sum()
    print(f"Total number of motors in the dataset: {total_motors}")
    
    # Consider only tomograms with at least one motor
    tomo_df = labels_df[labels_df['Number of motors'] > 0].copy()
    unique_tomos = tomo_df['tomo_id'].unique()
    print(f"Found {len(unique_tomos)} unique tomograms with motors")
    
    # Shuffle and split tomograms into train and validation sets
    np.random.shuffle(unique_tomos)
    split_idx = int(len(unique_tomos) * train_split)
    train_tomos = unique_tomos[:split_idx]
    val_tomos = unique_tomos[split_idx:]
    print(f"Split: {len(train_tomos)} tomograms for training, {len(val_tomos)} tomograms for validation")
    
    # Helper function to process a list of tomograms
    def process_tomogram_set(tomogram_ids, images_dir, labels_dir, set_name):
        motor_counts = []
        for tomo_id in tomogram_ids:
            # Get motor annotations for the current tomogram
            tomo_motors = labels_df[labels_df['tomo_id'] == tomo_id]
            for _, motor in tomo_motors.iterrows():
                if pd.isna(motor['Motor axis 0']):
                    continue
                motor_counts.append(
                    (tomo_id, 
                     int(motor['Motor axis 0']), 
                     int(motor['Motor axis 1']), 
                     int(motor['Motor axis 2']),
                     int(motor['Array shape (axis 0)']))
                )
        
        print(f"Will process approximately {len(motor_counts) * (2 * trust + 1)} slices for {set_name}")
        processed_slices = 0
        
        # Loop over each motor annotation
        for tomo_id, z_center, y_center, x_center, z_max in tqdm(motor_counts, desc=f"Processing {set_name} motors"):
            z_min = max(0, z_center - trust)
            z_max_bound = min(z_max - 1, z_center + trust)
            for z in range(z_min, z_max_bound + 1):
                if z % 3 == 0:
                    continue
                # Create the slice filename and source path
                slice_filename = f"slice_{z:04d}.jpg"
                src_path = os.path.join(dir_info['train_dir'], tomo_id, slice_filename)
                if not os.path.exists(src_path):
                    print(f"Warning: {src_path} does not exist, skipping.")
                    continue
                
                # Load, normalize, and save the image slice
                img = Image.open(src_path)
                img_array = np.array(img)
                normalized_img = normalize_slice(img_array)
                dest_filename = f"{tomo_id}_z{z:04d}_y{y_center:04d}_x{x_center:04d}.jpg"
                dest_path = os.path.join(images_dir, dest_filename)
                Image.fromarray(normalized_img).save(dest_path)
                
                # Prepare YOLO bounding box annotation (normalized values)
                img_width, img_height = img.size
                x_center_norm = x_center / img_width
                y_center_norm = y_center / img_height
                box_width_norm = BOX_SIZE / img_width
                box_height_norm = BOX_SIZE / img_height
                label_path = os.path.join(labels_dir, dest_filename.replace('.jpg', '.txt'))
                with open(label_path, 'w') as f:
                    f.write(f"0 {x_center_norm} {y_center_norm} {box_width_norm} {box_height_norm}\n")
                
                processed_slices += 1
        
        return processed_slices, len(motor_counts)
    
    # Process training tomograms
    train_slices, train_motors = process_tomogram_set(train_tomos, dir_info['yolo_images_train'], dir_info['yolo_labels_train'], "training")
    # Process validation tomograms
    val_slices, val_motors = process_tomogram_set(val_tomos, dir_info['yolo_images_val'], dir_info['yolo_labels_val'], "validation")
    
    # Generate YAML configuration for YOLO training
    yaml_content = {
        'path': dir_info['yolo_dataset_dir'],
        'train': 'images/train',
        'val': 'images/val',
        'names': {0: 'motor'}
    }
    with open(os.path.join(dir_info['yolo_dataset_dir'], 'dataset.yaml'), 'w') as f:
        yaml.dump(yaml_content, f, default_flow_style=False)
    
    print(f"\nProcessing Summary:")
    print(f"- Train set: {len(train_tomos)} tomograms, {train_motors} motors, {train_slices} slices")
    print(f"- Validation set: {len(val_tomos)} tomograms, {val_motors} motors, {val_slices} slices")
    print(f"- Total: {len(train_tomos) + len(val_tomos)} tomograms, {train_motors + val_motors} motors, {train_slices + val_slices} slices")
    
    return {
        "dataset_dir": dir_info['yolo_dataset_dir'],
        "yaml_path": os.path.join(dir_info['yolo_dataset_dir'], 'dataset.yaml'),
        "train_tomograms": len(train_tomos),
        "val_tomograms": len(val_tomos),
        "train_motors": train_motors,
        "val_motors": val_motors,
        "train_slices": train_slices,
        "val_slices": val_slices
    }

def add_noize(image, label, noize_level=0.05):
    """
    Add Gaussian noise to an image.
    
    Args:
        image (numpy.array): Input image.
        noize_level (float): Standard deviation of the Gaussian noise.
    
    Returns:
        numpy.array: Noisy image.
    """
    # Validation data
    if not isinstance(image, np.ndarray):
        raise ValueError("Image must be a numpy array.")
    if image.ndim != 3:
        raise ValueError("Image must be a 3D array (height, width, channels).")
    if image.shape[2] != 3:
        raise ValueError("Image must have 3 channels (RGB).")
    if not isinstance(label, list):
        raise ValueError("Label must be a list.")
    noise = np.random.normal(0, noize_level * 255, image.shape).astype(np.uint8)
    noisy_image = cv2.add(image, noise)
    return noisy_image, label

def add_blur(image, label, blur_level=5):
    """
    Add Gaussian blur to an image.
    
    Args:
        image (numpy.array): Input image.
        blur_level (int): Size of the Gaussian kernel.
    
    Returns:
        numpy.array: Blurred image.
    """
    # Validation data
    if not isinstance(image, np.ndarray):
        raise ValueError("Image must be a numpy array.")
    if image.ndim != 3:
        raise ValueError("Image must be a 3D array (height, width, channels).")
    if image.shape[2] != 3:
        raise ValueError("Image must have 3 channels (RGB).")
    if not isinstance(label, list):
        raise ValueError("Label must be a list.")
    if blur_level % 2 == 0:
        blur_level += 1
    blurred_image = cv2.GaussianBlur(image, (blur_level, blur_level), 0)
    return blurred_image, label

def add_contrast(image, label, contrast_level=1.5):
    """
    Adjust the contrast of an image.
    
    Args:
        image (numpy.array): Input image.
        contrast_level (float): Contrast adjustment factor.
    
    Returns:
        numpy.array: Contrast-adjusted image.
    """
    # Validation data
    if not isinstance(image, np.ndarray):
        raise ValueError("Image must be a numpy array.")
    if image.shape[2] != 3:
        raise ValueError("Image must have 3 channels (RGB).")
    if not isinstance(label, list):
        raise ValueError("Label must be a list.")
    # Validation for contrast level
    if contrast_level < 0:
        raise ValueError("Contrast level must be non-negative.")
    if contrast_level == 0:
        return image, label
    
    lab = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
    l, a, b = cv2.split(lab)
    l = np.clip(contrast_level * l, 0, 255).astype(np.uint8)
    lab = cv2.merge((l, a, b))
    contrast_image = cv2.cvtColor(lab, cv2.COLOR_LAB2BGR)
    return contrast_image, label

def add_brightness(image, label, brightness_level=50):
    """
    Adjust the brightness of an image.
    
    Args:
        image (numpy.array): Input image.
        brightness_level (int): Brightness adjustment value.
    
    Returns:
        numpy.array: Brightness-adjusted image.
    """
    hsv = cv2.cvtColor(image, cv2.COLOR_BGR2HSV)
    h, s, v = cv2.split(hsv)
    v = np.clip(v + brightness_level, 0, 255).astype(np.uint8)
    hsv = cv2.merge((h, s, v))
    brightness_image = cv2.cvtColor(hsv, cv2.COLOR_HSV2BGR)
    return brightness_image, label

def add_flip(image, label):
    """
    Flip an image horizontally.
    
    Args:
        image (numpy.array): Input image.
        label (list): List of labels.
            line: [x_center, y_center, width, height]
    
    Returns:
        numpy.array: Flipped image.
    """
    flipped_image = cv2.flip(image, 0)
    if label is None:
        return flipped_image, label
    if len(label) == 0:
        return flipped_image, label

    for i, line in enumerate(label):
        if len(line) != 5:
            raise ValueError(f"Invalid label format: {line}. Expected 5 values.")
        _, _, y_center, _, _ = line
        y_center = 1 - y_center
        line[2] = y_center
        label[i] = line
    return flipped_image, label

def apply_augmentation(augmentations:dict, dir_info=None):
    """
    Apply a series of augmentations to an image.
    
    Args:
        augmentations (dict): List of augmentation functions to apply.
        dir_info (dict): Directory information for saving images and labels.
    
    Returns:
        numpy.array: Augmented image.
    """
    
    def _process_tomogram_set(_tomo_paths, _label_paths, images_dir, labels_dir, aug, _aug_name): #TODO: add augmentations
        """
        Process a set of tomograms and apply augmentations.
        """

        assert os.path.exists(images_dir), f"Images directory {images_dir} does not exist."
        assert os.path.exists(labels_dir), f"Labels directory {labels_dir} does not exist."
        assert len(_tomo_paths) == len(_label_paths), f"Number of tomograms and labels do not match."
        
        images = []
        labels = []
        # Placeholder for processing logic
        for tomo_path, label_path in tqdm(zip(_tomo_paths, _label_paths), desc=f"Processing {_aug_name} tomograms"):
            # Load the image and labels
            if not os.path.exists(label_path):
                print(f"Warning: {label_path} does not exist, skipping.")
                continue
            if not os.path.exists(tomo_path):
                print(f"Warning: {tomo_path} does not exist, skipping.")
                continue
            image = Image.open(tomo_path)
            image = image.convert("RGB")
            image = np.array(image)
            with open(label_path, 'r') as f:
                label = f.readlines()
            # split the label into a list of lists
            label = [list(map(float, line.split())) for line in label]
            
            _data_name = os.path.basename(tomo_path).replace('.jpg', '')
            _data_name = f"{_data_name}_{_aug_name}.jpg"
            image_dest_path = os.path.join(images_dir, _data_name)
            label_path = os.path.join(labels_dir, _data_name.replace('.jpg', '.txt'))
            label_dest_path = os.path.join(labels_dir, _data_name.replace('.jpg', '.txt'))
            _data_name = _data_name + f"_{_aug_name}"
            image, label = aug(image, label)
            image = Image.fromarray(image)
            if image.mode != 'RGB':
                image = image.convert('RGB')
            image.save(image_dest_path)
            images.append(image)
            # Save the label
            labels.append(label)
            with open(label_dest_path, 'w') as f:
                for line in label:
                    line = f"{int(line[0])} {line[1]} {line[2]} {line[3]} {line[4]}"
                    f.write(line + '\n')
        return images, labels

    # Process training tomograms
    train_tomos = os.listdir(dir_info['yolo_images_train'])
    train_tomos = [item for item in train_tomos if not item.startswith('.')]
    train_tomos = [os.path.join(dir_info['yolo_images_train'], item) for item in train_tomos]
    train_labels = os.listdir(dir_info['yolo_labels_train'])
    train_labels = [item for item in train_labels if not item.startswith('.')]
    train_labels = [os.path.join(dir_info['yolo_labels_train'], item) for item in train_labels]
    for aug_name, aug in augmentations.items():
        train_slices, train_motors = _process_tomogram_set(train_tomos, train_labels,
                                                           dir_info['yolo_images_train'],
                                                           dir_info['yolo_labels_train'], aug, aug_name)
    
    # Process validation tomograms
    val_tomos = os.listdir(dir_info['yolo_images_val'])
    val_tomos = [item for item in val_tomos if not item.startswith('.')]
    val_tomos = [os.path.join(dir_info['yolo_images_val'], item) for item in val_tomos]
    val_labels = os.listdir(dir_info['yolo_labels_val'])
    val_labels = [item for item in val_labels if not item.startswith('.')]
    val_labels = [os.path.join(dir_info['yolo_labels_val'], item) for item in val_labels]
    for aug_name, aug in augmentations.items():
        val_slices, val_motors = _process_tomogram_set(val_tomos,val_labels,
                                                       dir_info['yolo_images_val'],
                                                       dir_info['yolo_labels_val'], aug, aug_name)
    
    # Generate YAML configuration for YOLO training
    yaml_content = {
        'path': dir_info['yolo_dataset_dir'],
        'train': 'images/train',
        'val': 'images/val',
        'names': {0: 'motor'}
    }
    
    with open(os.path.join(dir_info['yolo_dataset_dir'], 'dataset.yaml'), 'w') as f:
        yaml.dump(yaml_content, f, default_flow_style=False)
    
    print(f"\nProcessing Summary:")
    print(f"- Train set: {len(train_tomos)} tomograms, {len(train_motors)} motors, {len(train_slices)} slices")
    print(f"- Validation set: {len(val_tomos)} tomograms, {len(val_motors)} motors, {len(val_slices)} slices")
    print(f"- Total: {len(train_tomos) + len(val_tomos)} tomograms, {len(train_motors) + len(val_motors)} motors, {len(train_slices) + len(val_slices)} slices")
    
    return {
        "dataset_dir": dir_info['yolo_dataset_dir'],
        "yaml_path": os.path.join(dir_info['yolo_dataset_dir'], 'dataset.yaml'),
        "train_tomograms": len(train_tomos),
        "val_tomograms": len(val_tomos),
        "train_motors": train_motors,
        "val_motors": val_motors,
        "train_slices": train_slices,
        "val_slices": val_slices
    }


In [4]:
yolo_dataset_dir = os.path.join("../input","yolo_dataset")
os.makedirs(yolo_dataset_dir, exist_ok=True)
yolo_images_train = os.path.join(yolo_dataset_dir, "images/train")
yolo_images_val = os.path.join(yolo_dataset_dir, "images/val")
yolo_labels_train = os.path.join(yolo_dataset_dir, "labels/train")
yolo_labels_val = os.path.join(yolo_dataset_dir, "labels/val")
os.makedirs(yolo_images_train, exist_ok=True)
os.makedirs(yolo_images_val, exist_ok=True)
os.makedirs(yolo_labels_train, exist_ok=True)
os.makedirs(yolo_labels_val, exist_ok=True)

dir_info = {
    "data_path": "../input/full_data",
    "train_dir": "../input/full_data/train",
    # "data_path": "../input/data",
    # "train_dir": "../input/data/train",
    "yolo_dataset_dir": yolo_dataset_dir,
    "yolo_images_train": yolo_images_train,
    "yolo_labels_train": yolo_labels_train,
    "yolo_images_val": yolo_images_val,
    "yolo_labels_val": yolo_labels_val
}

# Prepare the YOLO dataset
summary = prepare_yolo_dataset(trust=TRUST, train_split=TRAIN_SPLIT, dir_info=dir_info)

Total number of motors in the dataset: 831
Found 362 unique tomograms with motors
Split: 289 tomograms for training, 73 tomograms for validation
Will process approximately 4004 slices for training


Processing training motors:   0%|          | 0/364 [00:00<?, ?it/s]

Will process approximately 957 slices for validation


Processing validation motors:   0%|          | 0/87 [00:00<?, ?it/s]


Processing Summary:
- Train set: 289 tomograms, 364 motors, 2664 slices
- Validation set: 73 tomograms, 87 motors, 633 slices
- Total: 362 tomograms, 451 motors, 3297 slices


In [5]:
augmentations = {
    "contrast": add_contrast,
    "brightness": add_brightness,
    "blur": add_blur,
    "flip": add_flip,
    "noize": add_noize
}

summary = apply_augmentation(augmentations=augmentations, dir_info=dir_info)

Processing contrast tomograms: 0it [00:00, ?it/s]

Processing brightness tomograms: 0it [00:00, ?it/s]

Processing blur tomograms: 0it [00:00, ?it/s]

Processing flip tomograms: 0it [00:00, ?it/s]

Processing noize tomograms: 0it [00:00, ?it/s]

Processing contrast tomograms: 0it [00:00, ?it/s]

Processing brightness tomograms: 0it [00:00, ?it/s]

Processing blur tomograms: 0it [00:00, ?it/s]

Processing flip tomograms: 0it [00:00, ?it/s]

Processing noize tomograms: 0it [00:00, ?it/s]


Processing Summary:
- Train set: 2596 tomograms, 2596 motors, 2596 slices
- Validation set: 701 tomograms, 701 motors, 701 slices
- Total: 3297 tomograms, 3297 motors, 3297 slices
