# Flagellar Motor Detection in Bacterial Tomograms

## Introduction

This notebook implements a solution for the BYU Locating Bacterial Flagellar Motors 2025 Kaggle competition. We'll develop an algorithm to automatically identify the presence and 3D coordinates of flagellar motors in cryogenic electron tomography (cryo-ET) data of bacteria.

### The Challenge

Flagellar motors are molecular machines that enable bacterial movement. While cryo-ET imaging allows visualization in near-native conditions, identifying these structures in tomograms is challenging due to:

- Poor signal-to-noise ratio (negative SNR values)
- Variable tomogram sizes (especially z-axis)
- Motors appearing as darker regions than surroundings
- Variable motor orientations
- Proximity to cell boundaries

### Our Approach

We'll implement a Ensemble Approach consiting of a 2D YOLOv8-based solution with 3D post-processing to find potential candidates and a 3D CNN to validate candidates, consisting of:
1. Robust preprocessing with adaptive normalization
2. YOLOv8 training with hyperparameter optimization
3. Efficient slice processing during inference
4. Advanced 3D clustering for final motor localization
5. Two-stage detection pipeline:
    - YOLO efficiently processes all 2D slices to find potential candidates
    - 3D CNN validates candidates by analyzing the full 3D context around each potential motor

In [1]:
# Install dependencies
import kagglehub
kagglehub.dataset_download('rachiteagles/yolo-pkg')
kagglehub.dataset_download('rachiteagles/yolo-model')

# Install the YOLO package from the downloaded wheel file
!pip install --no-index --no-deps /kaggle/input/yolo-pkg/yolo/ultralytics-8.3.112-py3-none-any.whl

Processing /kaggle/input/yolo-pkg/yolo/ultralytics-8.3.112-py3-none-any.whl
Installing collected packages: ultralytics
Successfully installed ultralytics-8.3.112


## Setup and Imports

In [None]:
# Import necessary libraries
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image, ImageDraw
import cv2
import torch
import yaml
import random
import time
import threading
from tqdm.notebook import tqdm
from contextlib import nullcontext
from concurrent.futures import ThreadPoolExecutor
from sklearn.model_selection import train_test_split, KFold
from ultralytics import YOLO
import shutil
import optuna

# Set random seeds for reproducibility
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    
# Define paths for Kaggle environment
DATA_DIR = '/kaggle/input/byu-locating-bacterial-flagellar-motors-2025'
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
TEST_DIR = os.path.join(DATA_DIR, 'test')
TRAIN_CSV = os.path.join(DATA_DIR, 'train_labels.csv')
WORKING_DIR = '/kaggle/working'

## 1. Data Exploration and Analysis

Exploring the dataset to understand its characteristics before building our detection model.

In [None]:
def explore_dataset():
    """
    Explore and analyze the dataset, returning key statistics and insights.
    
    Returns:
        dict: Dictionary containing dataset statistics
    """
    # Load the training labels
    train_labels = pd.read_csv(TRAIN_CSV)
    
    # Display the first few rows
    display(train_labels.head())
    
    # Calculate basic statistics
    total_records = len(train_labels)
    unique_tomos = train_labels['tomo_id'].nunique()
    tomos_with_motors = train_labels[train_labels['Number of motors'] > 0]['tomo_id'].nunique()
    tomos_without_motors = unique_tomos - tomos_with_motors
    percent_with_motors = (tomos_with_motors / unique_tomos) * 100
    
    # Calculate distribution of motors per tomogram
    motors_per_tomo = train_labels.groupby('tomo_id')['Number of motors'].first()
    motor_distribution = motors_per_tomo.value_counts().sort_index()
    
    # Tomogram shape statistics
    axis_stats = {}
    for axis in [0, 1, 2]:
        axis_sizes = train_labels[f'Array shape (axis {axis})'].unique()
        axis_stats[axis] = {
            'sizes': sorted(axis_sizes),
            'min': min(axis_sizes),
            'max': max(axis_sizes),
            'mean': np.mean(axis_sizes)
        }
    
    # Voxel spacing statistics
    voxel_spacing = train_labels['Voxel spacing'].unique()
    voxel_stats = {
        'values': sorted(voxel_spacing),
        'min': min(voxel_spacing),
        'max': max(voxel_spacing),
        'mean': np.mean(voxel_spacing)
    }
    
    # Print summary statistics
    print(f"Total number of records: {total_records}")
    print(f"Number of unique tomograms: {unique_tomos}")
    print(f"Number of tomograms with motors: {tomos_with_motors}")
    print(f"Number of tomograms without motors: {tomos_without_motors}")
    print(f"Percentage of tomograms with motors: {percent_with_motors:.2f}%")
    
    print("\nDistribution of motors per tomogram:")
    display(motor_distribution)
    
    print("\nTomogram size statistics:")
    for axis, stats in axis_stats.items():
        print(f"Axis {axis} sizes: {stats['sizes']}")
        print(f"Min: {stats['min']}, Max: {stats['max']}, Mean: {stats['mean']:.2f}")
    
    print("\nVoxel spacing values (in angstroms per voxel):")
    print(voxel_stats['values'])
    print(f"Min: {voxel_stats['min']}, Max: {voxel_stats['max']}, Mean: {voxel_stats['mean']:.2f}")
    
    return {
        'train_labels': train_labels,
        'total_records': total_records,
        'unique_tomos': unique_tomos,
        'tomos_with_motors': tomos_with_motors,
        'tomos_without_motors': tomos_without_motors,
        'percent_with_motors': percent_with_motors,
        'motor_distribution': motor_distribution,
        'axis_stats': axis_stats,
        'voxel_stats': voxel_stats
    }

# Run data exploration
dataset_stats = explore_dataset()
train_labels = dataset_stats['train_labels']  # Store for later use

### Data Visualization

Visualize sample tomogram slices with motors to better understand what we're looking for.

In [None]:
class ImageProcessor:
    """
    Class for processing and normalizing tomogram slices.
    """
    @staticmethod
    def normalize_slice(img_array, method='adaptive'):
        """
        Normalize slice using different normalization methods
        
        Args:
            img_array (np.ndarray): Input image array
            method (str): Normalization method ('percentile', 'histogram', 'adaptive')
            
        Returns:
            np.ndarray: Normalized image
        """
        if method == 'percentile':
            # Percentile-based normalization (2nd to 98th percentile)
            p2 = np.percentile(img_array, 2)
            p98 = np.percentile(img_array, 98)
            normalized = np.clip(img_array, p2, p98)
            normalized = 255 * (normalized - p2) / (p98 - p2)
            return np.uint8(normalized)
        
        elif method == 'histogram':
            # Histogram equalization
            return cv2.equalizeHist(img_array)
        
        elif method == 'adaptive':
            # Combination of methods
            # First apply percentile normalization
            p2 = np.percentile(img_array, 2)
            p98 = np.percentile(img_array, 98)
            normalized = np.clip(img_array, p2, p98)
            normalized = 255 * (normalized - p2) / (p98 - p2)
            normalized = np.uint8(normalized)
            
            # Then apply adaptive histogram equalization
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            return clahe.apply(normalized)
        
        else:
            return img_array

def visualize_motor_slices(train_labels, n_samples=3, context_slices=2):
    """
    Visualize motor slices with context slices above and below
    
    Args:
        train_labels (pd.DataFrame): DataFrame with motor annotations
        n_samples (int): Number of random motors to visualize
        context_slices (int): Number of slices to show above and below motor
    """
    # Get tomograms with motors
    motors_df = train_labels[(~pd.isna(train_labels['Motor axis 0'])) & 
                            (train_labels['Motor axis 0'] > 0)]
    
    if len(motors_df) == 0:
        print("No valid motors found in the dataset!")
        return
    
    # Sample random motors
    sample_motors = motors_df.sample(min(n_samples, len(motors_df)))
    
    # Create figure
    fig_height = 4 * n_samples
    fig_width = 4 * (2*context_slices + 1)
    fig, axes = plt.subplots(n_samples, 2*context_slices + 1, figsize=(fig_width, fig_height))
    
    # Handle the case of a single sample
    if n_samples == 1:
        axes = axes.reshape(1, -1)
    
    # Process each sample
    for i, (_, motor) in enumerate(sample_motors.iterrows()):
        tomo_id = motor['tomo_id']
        z_center = int(motor['Motor axis 0'])
        y_center = int(motor['Motor axis 1'])
        x_center = int(motor['Motor axis 2'])
        
        # Get slice range - ensure we have valid positive indices
        z_min = max(1, z_center - context_slices)  # Start at 1 to avoid issues with 0
        z_max = min(int(motor['Array shape (axis 0)']) - 1, z_center + context_slices)
        
        # Calculate actual number of slices to display
        n_slices = z_max - z_min + 1
        
        # Load and display slices
        for j, z in enumerate(range(z_min, z_max + 1)):
            try:
                # Create slice filename with proper formatting
                slice_filename = f"slice_{z:04d}.jpg"
                slice_path = os.path.join(TRAIN_DIR, tomo_id, slice_filename)
                
                if os.path.exists(slice_path):
                    # Load and normalize slice
                    img = np.array(Image.open(slice_path))
                    normalized_img = ImageProcessor.normalize_slice(img, 'adaptive')
                    
                    # Display the image
                    axes[i, j].imshow(normalized_img, cmap='gray')
                    
                    # Draw bounding box on the motor slice
                    if z == z_center:
                        box_size = 24  # Reasonable motor size
                        rect = plt.Rectangle((x_center - box_size//2, y_center - box_size//2), 
                                            box_size, box_size, 
                                            linewidth=2, edgecolor='r', facecolor='none')
                        axes[i, j].add_patch(rect)
                        axes[i, j].set_title(f"Motor Slice (z={z})", color='red')
                    else:
                        axes[i, j].set_title(f"Slice z={z}")
                    
                    axes[i, j].axis('on')
                    axes[i, j].set_xticks([])
                    axes[i, j].set_yticks([])
                else:
                    axes[i, j].text(0.5, 0.5, f"Slice {z} not found", 
                                  horizontalalignment='center', verticalalignment='center')
                    axes[i, j].axis('off')
            except Exception as e:
                print(f"Error processing slice z={z} for tomogram {tomo_id}: {e}")
                axes[i, j].text(0.5, 0.5, f"Error with slice {z}", 
                              horizontalalignment='center', verticalalignment='center')
                axes[i, j].axis('off')
        
        # Clear any unused subplot axes
        for j in range(n_slices, len(axes[i])):
            axes[i, j].axis('off')
        
        # Add text with motor coordinates
        plt.figtext(0.01, 0.95 - (i * 1/n_samples), 
                   f"Tomogram: {tomo_id}\nMotor at: z={z_center}, y={y_center}, x={x_center}", 
                   fontsize=9, bbox=dict(facecolor='white', alpha=0.8))
    
    plt.tight_layout()
    plt.show()

def compare_normalization_methods(train_labels):
    """
    Compare different normalization methods on the same slice
    
    Args:
        train_labels (pd.DataFrame): DataFrame with motor annotations
    """
    # Get a tomogram with a motor
    valid_motors = train_labels[(~pd.isna(train_labels['Motor axis 0'])) & 
                             (train_labels['Motor axis 0'] > 0)]
    
    if len(valid_motors) == 0:
        print("No valid motors found in the dataset!")
        return
        
    sample_motor = valid_motors.iloc[0]
    tomo_id = sample_motor['tomo_id']
    z_center = int(sample_motor['Motor axis 0'])
    y_center = int(sample_motor['Motor axis 1'])
    x_center = int(sample_motor['Motor axis 2'])
    
    # Load the slice
    slice_filename = f"slice_{z_center:04d}.jpg"
    slice_path = os.path.join(TRAIN_DIR, tomo_id, slice_filename)
    
    if not os.path.exists(slice_path):
        print(f"Slice {slice_path} not found!")
        return
    
    # Create figure
    fig, axes = plt.subplots(1, 4, figsize=(16, 4))
    
    # Original image
    original_img = np.array(Image.open(slice_path))
    axes[0].imshow(original_img, cmap='gray')
    axes[0].set_title("Original")
    
    # Percentile normalization
    percentile_img = ImageProcessor.normalize_slice(original_img, 'percentile')
    axes[1].imshow(percentile_img, cmap='gray')
    axes[1].set_title("Percentile (2-98)")
    
    # Histogram equalization
    hist_img = ImageProcessor.normalize_slice(original_img, 'histogram')
    axes[2].imshow(hist_img, cmap='gray')
    axes[2].set_title("Histogram Equalization")
    
    # Adaptive normalization
    adaptive_img = ImageProcessor.normalize_slice(original_img, 'adaptive')
    axes[3].imshow(adaptive_img, cmap='gray')
    axes[3].set_title("Adaptive (Combined)")
    
    # Draw bounding box on all images
    box_size = 24
    for ax in axes:
        rect = plt.Rectangle((x_center - box_size//2, y_center - box_size//2), 
                            box_size, box_size, 
                            linewidth=2, edgecolor='r', facecolor='none')
        ax.add_patch(rect)
        ax.axis('on')
        ax.set_xticks([])
        ax.set_yticks([])
    
    plt.suptitle(f"Normalization methods comparison for {tomo_id}, Motor at z={z_center}, y={y_center}, x={x_center}")
    plt.tight_layout()
    plt.show()

# Visualize sample motors
visualize_motor_slices(train_labels, 3, 2)

# Compare normalization methods
compare_normalization_methods(train_labels)

## 2. Data Preprocessing Pipeline

Preprocessing pipeline that prepares our data for the YOLO model.

In [None]:
class TomogramProcessor:
    """
    Comprehensive tomogram processing class for extracting, normalizing, 
    and preparing slices for YOLO training.
    """
    def __init__(self, train_dir, labels_df, working_dir, 
                 context_range=5, box_size=24, test_split=0.2, 
                 norm_method='adaptive'):
        """
        Initialize the processor
        
        Args:
            train_dir (str): Path to directory containing training tomograms
            labels_df (pd.DataFrame): DataFrame with motor annotations
            working_dir (str): Path to working directory for processed data
            context_range (int): Number of slices to include above and below motors
            box_size (int): Size of bounding box for motor annotations
            test_split (float): Fraction of tomograms to use for validation
            norm_method (str): Normalization method to use
        """
        self.train_dir = train_dir
        self.labels_df = labels_df
        self.working_dir = working_dir
        self.context_range = context_range
        self.box_size = box_size
        self.test_split = test_split
        self.norm_method = norm_method
        
        # Define YOLO dataset structure
        self.yolo_dataset_dir = os.path.join(working_dir, "yolo_dataset")
        self.yolo_images_train = os.path.join(self.yolo_dataset_dir, "images", "train")
        self.yolo_images_val = os.path.join(self.yolo_dataset_dir, "images", "val")
        self.yolo_labels_train = os.path.join(self.yolo_dataset_dir, "labels", "train")
        self.yolo_labels_val = os.path.join(self.yolo_dataset_dir, "labels", "val")
        
        # Create directories
        for dir_path in [self.yolo_images_train, self.yolo_images_val, 
                         self.yolo_labels_train, self.yolo_labels_val]:
            os.makedirs(dir_path, exist_ok=True)
    
    def normalize_slice(self, img_array):
        """
        Normalize slice using the specified method
        
        Args:
            img_array (np.ndarray): Input image array
            
        Returns:
            np.ndarray: Normalized image
        """
        return ImageProcessor.normalize_slice(img_array, self.norm_method)
    
    def prepare_dataset(self):
        """
        Prepare the YOLO dataset by extracting slices and creating annotations
        
        Returns:
            dict: Summary information about the prepared dataset
        """
        # Filter to get only tomograms with motors
        tomo_df = self.labels_df[self.labels_df['Number of motors'] > 0].copy()
        unique_tomos = tomo_df['tomo_id'].unique()
        
        print(f"Found {len(unique_tomos)} unique tomograms with motors")
        
        # Perform a train-val split at the tomogram level
        np.random.shuffle(unique_tomos)
        split_idx = int(len(unique_tomos) * (1 - self.test_split))
        train_tomos = unique_tomos[:split_idx]
        val_tomos = unique_tomos[split_idx:]
        
        print(f"Split: {len(train_tomos)} tomograms for training, {len(val_tomos)} tomograms for validation")
        
        # Process training tomograms
        train_slices, train_motors = self._process_tomogram_set(
            train_tomos, self.yolo_images_train, self.yolo_labels_train, "training")
        
        # Process validation tomograms
        val_slices, val_motors = self._process_tomogram_set(
            val_tomos, self.yolo_images_val, self.yolo_labels_val, "validation")
        
        # Create YAML configuration file for YOLO
        yaml_content = {
            'path': self.yolo_dataset_dir,
            'train': 'images/train',
            'val': 'images/val',
            'names': {0: 'motor'}
        }
        
        yaml_path = os.path.join(self.yolo_dataset_dir, 'dataset.yaml')
        with open(yaml_path, 'w') as f:
            yaml.dump(yaml_content, f, default_flow_style=False)
        
        print(f"\nProcessing Summary:")
        print(f"- Train set: {len(train_tomos)} tomograms, {train_motors} motors, {train_slices} slices")
        print(f"- Validation set: {len(val_tomos)} tomograms, {val_motors} motors, {val_slices} slices")
        print(f"- Total: {len(train_tomos) + len(val_tomos)} tomograms, {train_motors + val_motors} motors, {train_slices + val_slices} slices")
        
        # Return summary info
        return {
            "dataset_dir": self.yolo_dataset_dir,
            "yaml_path": yaml_path,
            "train_tomograms": len(train_tomos),
            "val_tomograms": len(val_tomos),
            "train_motors": train_motors,
            "val_motors": val_motors,
            "train_slices": train_slices,
            "val_slices": val_slices
        }
    
    def _process_tomogram_set(self, tomogram_ids, images_dir, labels_dir, set_name):
        """
        Process a set of tomograms to extract slices and create annotations
        
        Args:
            tomogram_ids (list): List of tomogram IDs to process
            images_dir (str): Directory to save extracted images
            labels_dir (str): Directory to save labels
            set_name (str): Name of the dataset (for logging)
            
        Returns:
            tuple: (Number of processed slices, Number of motors)
        """
        motor_counts = []
        for tomo_id in tomogram_ids:
            # Get all motors for this tomogram
            tomo_motors = self.labels_df[self.labels_df['tomo_id'] == tomo_id]
            for _, motor in tomo_motors.iterrows():
                if pd.isna(motor['Motor axis 0']):
                    continue
                motor_counts.append(
                    (tomo_id, 
                     int(motor['Motor axis 0']), 
                     int(motor['Motor axis 1']), 
                     int(motor['Motor axis 2']),
                     int(motor['Array shape (axis 0)']))
                )
        
        print(f"Will process approximately {len(motor_counts) * (2 * self.context_range + 1)} slices for {set_name}")
        
        # Process each motor
        processed_slices = 0
        
        for tomo_id, z_center, y_center, x_center, z_max in tqdm(motor_counts, desc=f"Processing {set_name} motors"):
            # Calculate range of slices to include
            z_min = max(0, z_center - self.context_range)
            z_max = min(z_max - 1, z_center + self.context_range)
            
            # Process each slice in the range
            for z in range(z_min, z_max + 1):
                # Create slice filename
                slice_filename = f"slice_{z:04d}.jpg"
                
                # Source path for the slice
                src_path = os.path.join(self.train_dir, tomo_id, slice_filename)
                
                if not os.path.exists(src_path):
                    print(f"Warning: {src_path} does not exist, skipping.")
                    continue
                
                # Load and normalize the slice
                img = Image.open(src_path)
                img_array = np.array(img)
                
                # Normalize the image
                normalized_img = self.normalize_slice(img_array)
                
                # Create destination filename (with unique identifier)
                dest_filename = f"{tomo_id}_z{z:04d}_y{y_center:04d}_x{x_center:04d}.jpg"
                dest_path = os.path.join(images_dir, dest_filename)
                
                # Save the normalized image
                Image.fromarray(normalized_img).save(dest_path)
                
                # Get image dimensions
                img_width, img_height = img.size
                
                # Create YOLO format label
                # YOLO format: <class> <x_center> <y_center> <width> <height>
                # Values are normalized to [0, 1]
                x_center_norm = x_center / img_width
                y_center_norm = y_center / img_height
                box_width_norm = self.box_size / img_width
                box_height_norm = self.box_size / img_height
                
                # Write label file
                label_path = os.path.join(labels_dir, dest_filename.replace('.jpg', '.txt'))
                with open(label_path, 'w') as f:
                    f.write(f"0 {x_center_norm} {y_center_norm} {box_width_norm} {box_height_norm}\n")
                
                processed_slices += 1
        
        return processed_slices, len(motor_counts)

# Create and run the tomogram processor
processor = TomogramProcessor(
    train_dir=TRAIN_DIR,
    labels_df=train_labels,
    working_dir=WORKING_DIR,
    context_range=5,  # Include 5 slices above and below each motor
    box_size=24,      # 24x24 bounding box for motors
    test_split=0.2,   # 20% of tomograms for validation
    norm_method='adaptive'  # Use adaptive normalization
)

# Prepare the dataset
dataset_summary = processor.prepare_dataset()

## 3. Model Training with Hyperparameter Optimization

Implement a training pipeline with hyperparameter optimization for our YOLOv8 model.

In [None]:
class YOLOTrainer:
    """
    YOLOv8 training class with hyperparameter optimization
    """
    def __init__(self, dataset_yaml, working_dir):
        """
        Initialize the trainer
        
        Args:
            dataset_yaml (str): Path to YOLO dataset YAML file
            working_dir (str): Working directory for outputs
        """
        self.dataset_yaml = dataset_yaml
        self.working_dir = working_dir
        self.weights_dir = os.path.join(working_dir, "yolo_weights")
        self.model_name = "motor_detector"
        
        # Create weights directory
        os.makedirs(self.weights_dir, exist_ok=True)
    
    def train_model(self, pretrained_weights, epochs=30, batch_size=16, img_size=640,
                   optimizer='AdamW', lr=1e-4, dropout=0.1, patience=5, box=7.5, cls=0.5, dfl=1.5):
        """
        Train the YOLO model with specified hyperparameters
        
        Args:
            pretrained_weights (str): Path to pretrained weights file
            epochs (int): Number of training epochs
            batch_size (int): Batch size for training
            img_size (int): Input image size
            optimizer (str): Optimizer to use
            lr (float): Learning rate
            dropout (float): Dropout rate
            patience (int): Early stopping patience
            box (float): Box loss gain
            cls (float): Class loss gain
            dfl (float): DFL loss gain
            
        Returns:
            tuple: (Trained model, Training results)
        """
        print(f"Training YOLO model with:")
        print(f"- Pretrained weights: {pretrained_weights}")
        print(f"- Epochs: {epochs}")
        print(f"- Batch size: {batch_size}")
        print(f"- Image size: {img_size}")
        print(f"- Optimizer: {optimizer}, LR: {lr}")
        print(f"- Loss gains: box={box}, cls={cls}, dfl={dfl}")
        
        # Load a model
        model = YOLO(pretrained_weights)
        
        # Train the model with specified hyperparameters
        results = model.train(
            data=self.dataset_yaml,
            epochs=epochs,
            batch=batch_size,
            imgsz=img_size,
            optimizer=optimizer,
            lr0=lr,
            lrf=0.01,
            momentum=0.937,
            weight_decay=0.0005,
            warmup_epochs=3.0,
            warmup_momentum=0.8,
            warmup_bias_lr=0.1,
            box=box,
            cls=cls,
            dfl=dfl,
            hsv_h=0.015,
            hsv_s=0.7,
            hsv_v=0.4,
            degrees=45.0,
            translate=0.1,
            scale=0.5,
            fliplr=0.5,
            mosaic=1.0,
            mixup=0.1,
            dropout=dropout,
            project=self.weights_dir,
            name=self.model_name,
            exist_ok=True,
            patience=patience,
            save_period=5,
            verbose=True
        )
        
        # Get run directory
        run_dir = os.path.join(self.weights_dir, self.model_name)
        
        # Plot DFL loss curve
        self.plot_dfl_loss_curve(run_dir)
        
        return model, results
    
    def plot_dfl_loss_curve(self, run_dir):
        """
        Plot the DFL loss curves for train and validation
        
        Args:
            run_dir (str): Directory where the training results are stored
            
        Returns:
            tuple: (Best epoch, Best validation loss)
        """
        # Path to the results CSV file
        results_csv = os.path.join(run_dir, 'results.csv')
        
        if not os.path.exists(results_csv):
            print(f"Results file not found at {results_csv}")
            return None, None
        
        # Read results CSV
        results_df = pd.read_csv(results_csv)
        
        # Check if DFL loss columns exist
        train_dfl_col = [col for col in results_df.columns if 'train/dfl_loss' in col]
        val_dfl_col = [col for col in results_df.columns if 'val/dfl_loss' in col]
        
        if not train_dfl_col or not val_dfl_col:
            print("DFL loss columns not found in results CSV")
            print(f"Available columns: {results_df.columns.tolist()}")
            return None, None
        
        train_dfl_col = train_dfl_col[0]
        val_dfl_col = val_dfl_col[0]
        
        # Find the epoch with the best validation loss
        best_epoch = results_df[val_dfl_col].idxmin()
        best_val_loss = results_df.loc[best_epoch, val_dfl_col]
        
        # Create the plot
        plt.figure(figsize=(10, 6))
        
        # Plot training and validation losses
        plt.plot(results_df['epoch'], results_df[train_dfl_col], label='Train DFL Loss')
        plt.plot(results_df['epoch'], results_df[val_dfl_col], label='Validation DFL Loss')
        
        # Mark the best model with a vertical line
        plt.axvline(x=results_df.loc[best_epoch, 'epoch'], color='r', linestyle='--', 
                    label=f'Best Model (Epoch {int(results_df.loc[best_epoch, "epoch"])}, Val Loss: {best_val_loss:.4f})')
        
        # Add labels and legend
        plt.xlabel('Epoch')
        plt.ylabel('DFL Loss')
        plt.title('Training and Validation DFL Loss')
        plt.legend()
        plt.grid(True, linestyle='--', alpha=0.7)
        
        # Save the plot
        plot_path = os.path.join(run_dir, 'dfl_loss_curve.png')
        plt.savefig(plot_path)
        plt.close()
        
        print(f"Loss curve saved to {plot_path}")
        
        return best_epoch, best_val_loss
    
    def optimize_hyperparameters(self, n_trials=10, max_epochs=15):
        """
        Perform hyperparameter optimization using Optuna
        
        Args:
            n_trials (int): Number of optimization trials
            max_epochs (int): Maximum epochs per trial
            
        Returns:
            tuple: (Best hyperparameters, Final model)
        """
        print(f"Starting hyperparameter optimization with {n_trials} trials")
        
        def objective(trial):
            # Define hyperparameters to optimize
            pretrained_model = trial.suggest_categorical("pretrained_model", 
                                                        ["yolov8n.pt", "yolov8s.pt", "yolov8m.pt"])
            batch_size = trial.suggest_categorical("batch_size", [8, 16, 32])
            img_size = trial.suggest_categorical("img_size", [640, 800, 960])
            optimizer = trial.suggest_categorical("optimizer", ["AdamW", "SGD"])
            lr = trial.suggest_float("lr", 1e-5, 1e-3, log=True)
            dropout = trial.suggest_float("dropout", 0.0, 0.3)
            box_gain = trial.suggest_float("box_gain", 5.0, 10.0)
            cls_gain = trial.suggest_float("cls_gain", 0.3, 1.0)
            dfl_gain = trial.suggest_float("dfl_gain", 1.0, 2.0)
            
            # Create a unique name for this trial
            trial_name = f"{self.model_name}_trial_{trial.number}"
            
            # Train with these hyperparameters
            try:
                model = YOLO(pretrained_model)
                
                # Use fewer epochs for optimization trials
                results = model.train(
                    data=self.dataset_yaml,
                    epochs=max_epochs,
                    batch=batch_size,
                    imgsz=img_size,
                    optimizer=optimizer,
                    lr0=lr,
                    lrf=0.01,
                    box=box_gain,
                    cls=cls_gain,
                    dfl=dfl_gain,
                    dropout=dropout,
                    project=self.weights_dir,
                    name=trial_name,
                    exist_ok=True,
                    patience=3,  # Early stopping for trials
                    verbose=False
                )
                
                # Get results path
                results_csv = os.path.join(self.weights_dir, trial_name, 'results.csv')
                
                if os.path.exists(results_csv):
                    # Read results CSV to get validation loss
                    results_df = pd.read_csv(results_csv)
                    val_dfl_cols = [col for col in results_df.columns if 'val/dfl_loss' in col]
                    
                    if val_dfl_cols:
                        # Get the best validation loss
                        val_dfl_col = val_dfl_cols[0]
                        best_val_loss = results_df[val_dfl_col].min()
                        
                        # Return the loss (to be minimized)
                        return best_val_loss
                
                # If something went wrong, return a high loss
                return float('inf')
                
            except Exception as e:
                print(f"Error in trial {trial.number}: {e}")
                return float('inf')
        
        # Create study and optimize
        study = optuna.create_study(direction="minimize")
        study.optimize(objective, n_trials=n_trials)
        
        # Get best parameters
        best_params = study.best_params
        print("Best hyperparameters found:")
        for param, value in best_params.items():
            print(f"- {param}: {value}")
        
        # Train final model with best parameters
        print("\nTraining final model with best hyperparameters...")
        final_model, _ = self.train_model(
            pretrained_weights=best_params["pretrained_model"],
            batch_size=best_params["batch_size"],
            img_size=best_params["img_size"],
            optimizer=best_params["optimizer"],
            lr=best_params["lr"],
            dropout=best_params["dropout"],
            box=best_params["box_gain"],
            cls=best_params["cls_gain"],
            dfl=best_params["dfl_gain"],
            epochs=30,  # Use more epochs for final model
            patience=5
        )
        
        return best_params, final_model
    
    def predict_on_samples(self, model, num_samples=4):
        """
        Run predictions on random validation samples and display results
        
        Args:
            model: Trained YOLO model
            num_samples (int): Number of random samples to display
        """
        # Get validation images directory
        val_dir = os.path.join(os.path.dirname(self.dataset_yaml), 'images', 'val')
        
        if not os.path.exists(val_dir):
            print(f"Validation directory not found at {val_dir}")
            return
        
        # Get all validation images
        val_images = [f for f in os.listdir(val_dir) if f.endswith(('.jpg', '.jpeg', '.png'))]
        
        if not val_images:
            print("No validation images found.")
            return
        
        # Select random samples
        samples = random.sample(val_images, min(num_samples, len(val_images)))
        
        # Create figure
        fig, axes = plt.subplots(2, 2, figsize=(12, 12))
        axes = axes.flatten()
        
        for i, img_file in enumerate(samples):
            if i >= len(axes):
                break
            
            # Full image path
            img_path = os.path.join(val_dir, img_file)
            
            # Run prediction
            results = model.predict(img_path, conf=0.25)[0]
            
            # Load the image
            img = Image.open(img_path)
            img_array = np.array(img)
            
            # Display image
            axes[i].imshow(img_array, cmap='gray')
            
            # Parse ground truth from filename
            # Format: tomo_id_zXXXX_yYYYY_xZZZZ.jpg
            parts = img_file.split('_')
            y_coord = None
            x_coord = None
            
            for part in parts:
                if part.startswith('y'):
                    try:
                        y_coord = int(part[1:])
                    except:
                        pass
                elif part.startswith('x'):
                    try:
                        x_coord = int(part.split('.')[0][1:])
                    except:
                        pass
            
            # Draw ground truth box if coordinates found
            if y_coord is not None and x_coord is not None:
                box_size = 24
                rect_gt = plt.Rectangle((x_coord - box_size//2, y_coord - box_size//2), 
                                     box_size, box_size, 
                                     linewidth=1, edgecolor='g', facecolor='none')
                axes[i].add_patch(rect_gt)
                axes[i].text(x_coord - box_size//2, y_coord - box_size//2 - 5, 
                          "Ground Truth", color='green', fontsize=8)
            
            # Draw predictions
            if len(results.boxes) > 0:
                boxes = results.boxes.xyxy.cpu().numpy()
                confs = results.boxes.conf.cpu().numpy()
                
                for box, conf in zip(boxes, confs):
                    x1, y1, x2, y2 = box
                    rect_pred = plt.Rectangle((x1, y1), x2-x1, y2-y1, 
                                          linewidth=1, edgecolor='r', facecolor='none')
                    axes[i].add_patch(rect_pred)
                    axes[i].text(x1, y1-5, f'Conf: {conf:.2f}', color='red', fontsize=8)
            
            axes[i].set_title(f"Image: {img_file}")
            axes[i].axis('on')
            axes[i].set_xticks([])
            axes[i].set_yticks([])
        
        plt.tight_layout()
        plt.savefig(os.path.join(self.weights_dir, 'sample_predictions.png'))
        plt.show()

# Create the YOLOTrainer
trainer = YOLOTrainer(
    dataset_yaml=dataset_summary['yaml_path'],
    working_dir=WORKING_DIR
)

# Run hyperparameter optimization and get the best model
print("Starting YOLO model training with hyperparameter optimization...")
best_params, yolo_model = trainer.optimize_hyperparameters(n_trials=5, max_epochs=10)

# Visualize predictions on validation samples
trainer.predict_on_samples(yolo_model, num_samples=4)

## 4. 3D CNN Implementation

Implement a complementary 3D CNN approach that can better capture the volumetric nature of the data.

In [None]:
class AttentionBlock(torch.nn.Module):
    """
    3D attention block for focusing on motor-specific features
    """
    def __init__(self, in_channels):
        """
        Initialize the attention block
        
        Args:
            in_channels (int): Number of input channels
        """
        super().__init__()
        self.query = torch.nn.Conv3d(in_channels, in_channels // 8, kernel_size=1)
        self.key = torch.nn.Conv3d(in_channels, in_channels // 8, kernel_size=1)
        self.value = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1)
        self.gamma = torch.nn.Parameter(torch.zeros(1))
        self.softmax = torch.nn.Softmax(dim=-1)
        
    def forward(self, x):
        """
        Forward pass through the attention block
        
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            torch.Tensor: Output tensor with attention applied
        """
        batch_size, channels, depth, height, width = x.size()
        
        # Reshape for attention calculation
        query = self.query(x).view(batch_size, -1, depth * height * width).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, depth * height * width)
        value = self.value(x).view(batch_size, -1, depth * height * width)
        
        # Calculate attention
        attention = torch.bmm(query, key)
        attention = self.softmax(attention)
        
        # Apply attention
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, depth, height, width)
        
        # Residual connection
        out = self.gamma * out + x
        return out

class Motor3DCNN(torch.nn.Module):
    """
    3D CNN for detecting flagellar motors in tomogram volumes
    """
    def __init__(self, input_channels=1, dropout_rate=0.3):
        """
        Initialize the 3D CNN model
        
        Args:
            input_channels (int): Number of input channels
            dropout_rate (float): Dropout rate
        """
        super().__init__()
        
        # Initial convolution layers
        self.conv1 = torch.nn.Conv3d(input_channels, 16, kernel_size=3, padding=1)
        self.bn1 = torch.nn.BatchNorm3d(16)
        self.conv2 = torch.nn.Conv3d(16, 32, kernel_size=3, padding=1, stride=2)
        self.bn2 = torch.nn.BatchNorm3d(32)
        self.conv3 = torch.nn.Conv3d(32, 64, kernel_size=3, padding=1, stride=2)
        self.bn3 = torch.nn.BatchNorm3d(64)
        self.conv4 = torch.nn.Conv3d(64, 128, kernel_size=3, padding=1, stride=2)
        self.bn4 = torch.nn.BatchNorm3d(128)
        
        # Attention block
        self.attention = AttentionBlock(128)
        
        # Global average pooling
        self.global_avg_pool = torch.nn.AdaptiveAvgPool3d(1)
        
        # Fully connected layers
        self.fc1 = torch.nn.Linear(128, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 1)  # Binary classification: motor presence
        
        # Regression head for 3D coordinates
        self.reg = torch.nn.Linear(32, 3)
        
        # Activation and dropout
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        """
        Forward pass through the 3D CNN
        
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            tuple: (Classification output, Regression output)
        """
        # Initial convolutions
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        
        # Apply attention
        x = self.attention(x)
        
        # Global average pooling
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        
        # Classification output
        cls_output = self.sigmoid(self.fc3(x))
        
        # Regression output (normalized coordinates)
        reg_output = self.reg(x)
        
        return cls_output, reg_output
    
class Tomogram3DDataset(torch.utils.data.Dataset):
    """
    Dataset for 3D subvolumes extracted from tomograms
    """
    def __init__(self, tomogram_dir, labels_df, subvolume_size=64, transform=None, train=True):
        """
        Initialize the dataset
        
        Args:
            tomogram_dir (str): Directory containing tomogram slice folders
            labels_df (pd.DataFrame): DataFrame with motor annotations
            subvolume_size (int): Size of cubic subvolume to extract
            transform (callable, optional): Optional transform to apply to subvolumes
            train (bool): Whether this is for training or inference
        """
        self.tomogram_dir = tomogram_dir
        self.labels_df = labels_df
        self.subvolume_size = subvolume_size
        self.transform = transform
        self.train = train
        
        # Filter to tomograms with motors for training
        if train:
            self.motor_data = []
            self.no_motor_data = []
            
            # First, collect all motors
            for _, row in labels_df.iterrows():
                if pd.isna(row['Motor axis 0']) or row['Motor axis 0'] <= 0:
                    continue
                    
                tomo_id = row['tomo_id']
                z = int(row['Motor axis 0'])
                y = int(row['Motor axis 1'])
                x = int(row['Motor axis 2'])
                
                self.motor_data.append({
                    'tomo_id': tomo_id,
                    'center': (z, y, x)
                })
            
            # Then, collect random no-motor locations
            for tomo_id in labels_df['tomo_id'].unique():
                tomo_rows = labels_df[labels_df['tomo_id'] == tomo_id]
                if tomo_rows['Number of motors'].iloc[0] == 0:
                    # Get dimensions
                    z_max = tomo_rows['Array shape (axis 0)'].iloc[0]
                    y_max = tomo_rows['Array shape (axis 1)'].iloc[0]
                    x_max = tomo_rows['Array shape (axis 2)'].iloc[0]
                    
                    # Generate random locations
                    for _ in range(3):  # Add some negative samples per empty tomogram
                        z = np.random.randint(subvolume_size//2, z_max - subvolume_size//2)
                        y = np.random.randint(subvolume_size//2, y_max - subvolume_size//2)
                        x = np.random.randint(subvolume_size//2, x_max - subvolume_size//2)
                        
                        self.no_motor_data.append({
                            'tomo_id': tomo_id,
                            'center': (z, y, x)
                        })
            
            # Balance the dataset
            if len(self.no_motor_data) > len(self.motor_data):
                self.no_motor_data = random.sample(self.no_motor_data, len(self.motor_data))
                
            # Combine motor and no-motor data
            self.samples = self.motor_data + self.no_motor_data
            random.shuffle(self.samples)
        else:
            # For inference, we'll provide this dynamically
            self.samples = []
    
    def __len__(self):
        """
        Get the number of samples in the dataset
        
        Returns:
            int: Number of samples
        """
        return len(self.samples)
    
    def __getitem__(self, idx):
        """
        Get a sample from the dataset
        
        Args:
            idx (int): Index of the sample
            
        Returns:
            tuple: (Subvolume, Motor presence, Coordinates)
        """
        sample = self.samples[idx]
        tomo_id = sample['tomo_id']
        z, y, x = sample['center']
        
        # Extract subvolume
        half_size = self.subvolume_size // 2
        subvolume = np.zeros((self.subvolume_size, self.subvolume_size, self.subvolume_size), dtype=np.float32)
        
        # Load the slices
        for i, z_pos in enumerate(range(z - half_size, z + half_size)):
            if z_pos < 0:
                continue
                
            slice_path = os.path.join(self.tomogram_dir, tomo_id, f"slice_{z_pos:04d}.jpg")
            if not os.path.exists(slice_path):
                continue
                
            # Load and normalize slice
            img = np.array(Image.open(slice_path))
            p2 = np.percentile(img, 2)
            p98 = np.percentile(img, 98)
            normalized = np.clip(img, p2, p98)
            normalized = (normalized - p2) / (p98 - p2)
            
            # Extract region around center
            y_start = max(0, y - half_size)
            y_end = min(img.shape[0], y + half_size)
            x_start = max(0, x - half_size)
            x_end = min(img.shape[1], x + half_size)
            
            # Calculate target indices in subvolume
            target_i = i
            target_y_start = max(0, half_size - y)
            target_x_start = max(0, half_size - x)
            
            # Calculate amount to copy
            height = min(y_end - y_start, self.subvolume_size - target_y_start)
            width = min(x_end - x_start, self.subvolume_size - target_x_start)
            
            if height <= 0 or width <= 0:
                continue
                
            # Copy data
            subvolume[target_i, 
                     target_y_start:target_y_start+height, 
                     target_x_start:target_x_start+width] = normalized[y_start:y_start+height, 
                                                                      x_start:x_start+width]
        
        # Prepare label
        is_motor = 1.0 if sample in self.motor_data else 0.0
        
        # Normalize coordinates to [0, 1] for regression
        if is_motor:
            # Get tomogram dimensions for normalization
            tomo_rows = self.labels_df[self.labels_df['tomo_id'] == tomo_id]
            z_max = tomo_rows['Array shape (axis 0)'].iloc[0]
            y_max = tomo_rows['Array shape (axis 1)'].iloc[0]
            x_max = tomo_rows['Array shape (axis 2)'].iloc[0]
            
            # Normalize coordinates
            coords = torch.tensor([z / z_max, y / y_max, x / x_max], dtype=torch.float32)
        else:
            coords = torch.zeros(3, dtype=torch.float32)
        
        # Apply transform if provided
        if self.transform:
            subvolume = self.transform(subvolume)
        
        # Add channel dimension
        subvolume = torch.tensor(subvolume).unsqueeze(0).float()
        
        return subvolume, torch.tensor([is_motor]).float(), coords
    
    def extract_candidate_subvolumes(self, tomo_id, slice_indices, confidence_map):
        """
        Extract candidate subvolumes based on 2D YOLO detections
        
        Args:
            tomo_id (str): Tomogram ID
            slice_indices (list): List of slice indices
            confidence_map (dict): Dictionary mapping slice index to list of (y, x, conf) detections
            
        Returns:
            list: List of (subvolume, z, y, x, conf) tuples
        """
        candidates = []
        
        # Get tomogram dimensions
        tomo_rows = self.labels_df[self.labels_df['tomo_id'] == tomo_id]
        if tomo_rows.empty:
            # For test tomograms where we don't have labels
            # We'll need to get dimensions from the images
            tomo_dir = os.path.join(self.tomogram_dir, tomo_id)
            slice_files = sorted([f for f in os.listdir(tomo_dir) if f.endswith('.jpg')])
            z_max = len(slice_files)
            if z_max == 0:
                return candidates
                
            sample_img = np.array(Image.open(os.path.join(tomo_dir, slice_files[0])))
            y_max, x_max = sample_img.shape
        else:
            z_max = tomo_rows['Array shape (axis 0)'].iloc[0]
            y_max = tomo_rows['Array shape (axis 1)'].iloc[0]
            x_max = tomo_rows['Array shape (axis 2)'].iloc[0]
        
        # Extract top candidates from each slice
        for z in slice_indices:
            if z not in confidence_map:
                continue
                
            # Get top detections for this slice
            detections = confidence_map[z]
            for y, x, conf in detections:
                # Skip if too close to edge
                half_size = self.subvolume_size // 2
                if (z < half_size or z >= z_max - half_size or
                    y < half_size or y >= y_max - half_size or
                    x < half_size or x >= x_max - half_size):
                    continue
                
                # Extract subvolume
                subvolume = np.zeros((self.subvolume_size, self.subvolume_size, self.subvolume_size), dtype=np.float32)
                
                # Load slices
                for i, z_pos in enumerate(range(z - half_size, z + half_size)):
                    slice_path = os.path.join(self.tomogram_dir, tomo_id, f"slice_{z_pos:04d}.jpg")
                    if not os.path.exists(slice_path):
                        continue
                        
                    # Load and normalize slice
                    img = np.array(Image.open(slice_path))
                    p2 = np.percentile(img, 2)
                    p98 = np.percentile(img, 98)
                    normalized = np.clip(img, p2, p98)
                    normalized = (normalized - p2) / (p98 - p2)
                    
                    # Extract region around center
                    y_start = max(0, int(y) - half_size)
                    y_end = min(img.shape[0], int(y) + half_size)
                    x_start = max(0, int(x) - half_size)
                    x_end = min(img.shape[1], int(x) + half_size)
                    
                    # Calculate target indices in subvolume
                    target_i = i
                    target_y_start = max(0, half_size - int(y))
                    target_x_start = max(0, half_size - int(x))
                    
                    # Calculate amount to copy
                    height = min(y_end - y_start, self.subvolume_size - target_y_start)
                    width = min(x_end - x_start, self.subvolume_size - target_x_start)
                    
                    if height <= 0 or width <= 0:
                        continue
                        
                    # Copy data
                    subvolume[target_i, 
                              target_y_start:target_y_start+height, 
                              target_x_start:target_x_start+width] = normalized[y_start:y_start+height, 
                                                                              x_start:x_start+width]
                
                # Add channel dimension
                subvolume = torch.tensor(subvolume).unsqueeze(0).float()
                candidates.append((subvolume, z, y, x, conf))
        
        return candidates
    
class Motor3DCNNTrainer:
    """
    Trainer for the 3D CNN model
    """
    def __init__(self, train_dir, labels_df, working_dir, subvolume_size=64, batch_size=8, device='auto'):
        """
        Initialize the trainer
        
        Args:
            train_dir (str): Directory containing training tomograms
            labels_df (pd.DataFrame): DataFrame with motor annotations
            working_dir (str): Directory for saving models and results
            subvolume_size (int): Size of cubic subvolumes
            batch_size (int): Training batch size
            device (str): Device to use ('cpu', 'cuda', or 'auto')
        """
        self.train_dir = train_dir
        self.labels_df = labels_df
        self.working_dir = working_dir
        self.subvolume_size = subvolume_size
        self.batch_size = batch_size
        
        # Set device
        if device == 'auto':
            self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
            
        print(f"Using device: {self.device} for 3D CNN")
        
        # Create model directory
        self.model_dir = os.path.join(working_dir, '3dcnn_models')
        os.makedirs(self.model_dir, exist_ok=True)
        
    def train(self, epochs=20, lr=1e-4, dropout=0.3):
        """
        Train the 3D CNN model
        
        Args:
            epochs (int): Number of training epochs
            lr (float): Learning rate
            dropout (float): Dropout rate
            
        Returns:
            torch.nn.Module: Trained model
        """
        # Create dataset
        dataset = Tomogram3DDataset(
            tomogram_dir=self.train_dir,
            labels_df=self.labels_df,
            subvolume_size=self.subvolume_size,
            train=True
        )
        
        # Split dataset into train and validation
        train_size = int(0.8 * len(dataset))
        val_size = len(dataset) - train_size
        train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
        
        # Create data loaders
        train_loader = torch.utils.data.DataLoader(
            train_dataset, 
            batch_size=self.batch_size,
            shuffle=True,
            num_workers=4
        )
        
        val_loader = torch.utils.data.DataLoader(
            val_dataset,
            batch_size=self.batch_size,
            shuffle=False,
            num_workers=4
        )
        
        # Create model
        model = Motor3DCNN(input_channels=1, dropout_rate=dropout)
        model.to(self.device)
        
        # Define loss functions
        cls_criterion = torch.nn.BCELoss()
        reg_criterion = torch.nn.MSELoss()
        
        # Create optimizer
        optimizer = torch.optim.Adam(model.parameters(), lr=lr)
        
        # Create learning rate scheduler
        scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            optimizer, 
            mode='min', 
            factor=0.5, 
            patience=3, 
            verbose=True
        )
        
        # Training loop
        best_val_loss = float('inf')
        best_model_path = os.path.join(self.model_dir, '3dcnn_best.pt')
        
        for epoch in range(epochs):
            # Training phase
            model.train()
            train_loss = 0.0
            train_cls_loss = 0.0
            train_reg_loss = 0.0
            
            for subvolumes, labels, coords in tqdm(train_loader, desc=f"Epoch {epoch+1}/{epochs} (Train)"):
                # Move data to device
                subvolumes = subvolumes.to(self.device)
                labels = labels.to(self.device)
                coords = coords.to(self.device)
                
                # Forward pass
                optimizer.zero_grad()
                cls_output, reg_output = model(subvolumes)
                
                # Compute losses
                # For regression loss, only consider positive samples
                pos_mask = (labels > 0.5).view(-1)
                cls_loss = cls_criterion(cls_output, labels)
                
                # Only compute regression loss for positive samples
                if torch.sum(pos_mask) > 0:
                    reg_loss = reg_criterion(reg_output[pos_mask], coords[pos_mask])
                else:
                    reg_loss = torch.tensor(0.0, device=self.device)
                
                # Total loss (weighted sum)
                loss = cls_loss + 0.5 * reg_loss
                
                # Backward pass
                loss.backward()
                optimizer.step()
                
                # Update metrics
                train_loss += loss.item() * subvolumes.size(0)
                train_cls_loss += cls_loss.item() * subvolumes.size(0)
                train_reg_loss += reg_loss.item() * subvolumes.size(0)
            
            # Calculate average losses
            train_loss /= len(train_loader.dataset)
            train_cls_loss /= len(train_loader.dataset)
            train_reg_loss /= len(train_loader.dataset)
            
            # Validation phase
            model.eval()
            val_loss = 0.0
            val_cls_loss = 0.0
            val_reg_loss = 0.0
            correct_preds = 0
            
            with torch.no_grad():
                for subvolumes, labels, coords in tqdm(val_loader, desc=f"Epoch {epoch+1}/{epochs} (Val)"):
                    # Move data to device
                    subvolumes = subvolumes.to(self.device)
                    labels = labels.to(self.device)
                    coords = coords.to(self.device)
                    
                    # Forward pass
                    cls_output, reg_output = model(subvolumes)
                    
                    # Compute losses
                    pos_mask = (labels > 0.5).view(-1)
                    cls_loss = cls_criterion(cls_output, labels)
                    
                    if torch.sum(pos_mask) > 0:
                        reg_loss = reg_criterion(reg_output[pos_mask], coords[pos_mask])
                    else:
                        reg_loss = torch.tensor(0.0, device=self.device)
                    
                    loss = cls_loss + 0.5 * reg_loss
                    
                    # Update metrics
                    val_loss += loss.item() * subvolumes.size(0)
                    val_cls_loss += cls_loss.item() * subvolumes.size(0)
                    val_reg_loss += reg_loss.item() * subvolumes.size(0)
                    
                    # Calculate accuracy
                    preds = (cls_output > 0.5).float()
                    correct_preds += torch.sum(preds == labels).item()
            
            # Calculate average validation losses
            val_loss /= len(val_loader.dataset)
            val_cls_loss /= len(val_loader.dataset)
            val_reg_loss /= len(val_loader.dataset)
            val_acc = correct_preds / len(val_loader.dataset)
            
            # Update learning rate
            scheduler.step(val_loss)
            
            # Print progress
            print(f"Epoch {epoch+1}/{epochs}:")
            print(f"  Train Loss: {train_loss:.4f} (Cls: {train_cls_loss:.4f}, Reg: {train_reg_loss:.4f})")
            print(f"  Val Loss: {val_loss:.4f} (Cls: {val_cls_loss:.4f}, Reg: {val_reg_loss:.4f}), Acc: {val_acc:.4f}")
            
            # Save best model
            if val_loss < best_val_loss:
                best_val_loss = val_loss
                torch.save(model.state_dict(), best_model_path)
                print(f"  New best model saved to {best_model_path}")
        
        # Load best model
        model.load_state_dict(torch.load(best_model_path))
        
        return model

# Phase 2: Train the 3D CNN model
print("\nTraining 3D CNN model...")
cnn3d_trainer = Motor3DCNNTrainer(
    train_dir=TRAIN_DIR,
    labels_df=train_labels,
    working_dir=WORKING_DIR,
    subvolume_size=64,
    batch_size=4,
    device='auto'
)

# Train the 3D CNN model
cnn3d_model = cnn3d_trainer.train(epochs=20, lr=1e-4, dropout=0.3)

## 5. Inference Pipeline for Test Data

An efficient inference pipeline to process test tomograms and generate predictions.

In [None]:
class EnsembleMotorDetector:
    """
    Ensemble detector combining 2D YOLOv8 and 3D CNN for improved motor detection
    """
    def __init__(self, yolo_model_path, cnn3d_model_path, test_dir, submission_path, 
                 device='auto', yolo_confidence_threshold=0.30, cnn_confidence_threshold=0.45):
        """
        Initialize the detector with both YOLO and 3D CNN models
        
        Args:
            yolo_model_path (str): Path to trained YOLO model weights
            cnn3d_model_path (str): Path to trained 3D CNN model weights
            test_dir (str): Path to test tomograms directory
            submission_path (str): Path to save submission CSV
            device (str): Device to use ('cpu', 'cuda', or 'auto')
            yolo_confidence_threshold (float): Confidence threshold for YOLO detections
            cnn_confidence_threshold (float): Confidence threshold for 3D CNN detections
        """
        self.yolo_model_path = yolo_model_path
        self.cnn3d_model_path = cnn3d_model_path
        self.test_dir = test_dir
        self.submission_path = submission_path
        self.yolo_confidence_threshold = yolo_confidence_threshold
        self.cnn_confidence_threshold = cnn_confidence_threshold
        
        # Set device
        if device == 'auto':
            self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
            
        print(f"Using device: {self.device} for ensemble detection")
        
        # Detection parameters
        self.nms_iou_threshold = 0.2  # Non-maximum suppression threshold
        self.batch_size = 8  # Default batch size, will be adjusted dynamically
        self.concentration = 1.0  # Process all slices (can be reduced for faster testing)
        self.subvolume_size = 64  # Size of 3D CNN input subvolumes
    
    def load_models(self):
        """
        Load both YOLO and 3D CNN models
        
        Returns:
            tuple: (YOLO model, 3D CNN model)
        """
        # Load YOLO model
        print(f"Loading YOLO model from {self.yolo_model_path}")
        yolo_model = YOLO(self.yolo_model_path)
        yolo_model.to(self.device)
        
        # Fuse layers for faster inference if using GPU
        if self.device.startswith('cuda'):
            yolo_model.fuse()
            
            # Use half precision if on compatible GPU
            if torch.cuda.get_device_capability(0)[0] >= 7:  # Volta or newer
                yolo_model.model.half()
                print("Using half precision (FP16) for YOLO inference")
                
            # Get available GPU memory and set batch size accordingly
            free_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated(0) / 1e9
            self.batch_size = max(4, min(32, int(free_mem * 4)))  # 4 images per GB as rough estimate
            print(f"Dynamic batch size set to {self.batch_size} based on available GPU memory")
        
        # Load 3D CNN model
        print(f"Loading 3D CNN model from {self.cnn3d_model_path}")
        cnn3d_model = Motor3DCNN(input_channels=1, dropout_rate=0.0)  # No dropout for inference
        cnn3d_model.load_state_dict(torch.load(self.cnn3d_model_path, map_location=self.device))
        cnn3d_model.to(self.device)
        cnn3d_model.eval()  # Set to evaluation mode
        
        return yolo_model, cnn3d_model
    
    def preload_image_batch(self, file_paths):
        """
        Preload a batch of images to CPU memory
        
        Args:
            file_paths (list): List of file paths to load
            
        Returns:
            list: List of loaded images
        """
        images = []
        for path in file_paths:
            img = cv2.imread(path)
            if img is None:
                # Try with PIL as fallback
                img = np.array(Image.open(path))
            images.append(img)
        return images
    
    def perform_3d_nms(self, detections, iou_threshold):
        """
        Perform 3D Non-Maximum Suppression on detections
        
        Args:
            detections (list): List of detection dictionaries
            iou_threshold (float): IoU threshold for suppression
            
        Returns:
            list: Filtered detections after NMS
        """
        if not detections:
            return []
        
        # Sort by confidence (highest first)
        detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
        
        # List to store final detections after NMS
        final_detections = []
        
        # Define 3D distance function
        def distance_3d(d1, d2):
            return np.sqrt((d1['z'] - d2['z'])**2 + 
                          (d1['y'] - d2['y'])**2 + 
                          (d1['x'] - d2['x'])**2)
        
        # Maximum distance threshold (based on box size)
        box_size = 24  # Same as annotation box size
        distance_threshold = box_size * iou_threshold
        
        # Process each detection
        while detections:
            # Take the detection with highest confidence
            best_detection = detections.pop(0)
            final_detections.append(best_detection)
            
            # Filter out detections that are too close to the best detection
            detections = [d for d in detections if distance_3d(d, best_detection) > distance_threshold]
        
        return final_detections
    
    def extract_subvolume(self, tomo_id, z, y, x):
        """
        Extract a subvolume centered at the given coordinates
        
        Args:
            tomo_id (str): Tomogram ID
            z (int): Z coordinate
            y (int): Y coordinate
            x (int): X coordinate
            
        Returns:
            torch.Tensor: Extracted subvolume as a tensor with shape [1, 1, D, H, W]
        """
        half_size = self.subvolume_size // 2
        subvolume = np.zeros((self.subvolume_size, self.subvolume_size, self.subvolume_size), dtype=np.float32)
        tomo_dir = os.path.join(self.test_dir, tomo_id)
        
        # Load the slices
        for i, z_pos in enumerate(range(z - half_size, z + half_size)):
            if z_pos < 0:
                continue
                
            slice_path = os.path.join(tomo_dir, f"slice_{z_pos:04d}.jpg")
            if not os.path.exists(slice_path):
                continue
                
            # Load and normalize slice
            img = np.array(Image.open(slice_path))
            
            # Normalize using our adaptive method
            p2 = np.percentile(img, 2)
            p98 = np.percentile(img, 98)
            normalized = np.clip(img, p2, p98)
            normalized = (normalized - p2) / (p98 - p2)
            
            # Convert to float32 for processing
            normalized = normalized.astype(np.float32)
            
            # Extract region around center
            y_start = max(0, y - half_size)
            y_end = min(img.shape[0], y + half_size)
            x_start = max(0, x - half_size)
            x_end = min(img.shape[1], x + half_size)
            
            # Calculate target indices in subvolume
            target_i = i
            target_y_start = max(0, half_size - (y - y_start))
            target_x_start = max(0, half_size - (x - x_start))
            
            # Calculate amount to copy
            height = min(y_end - y_start, self.subvolume_size - target_y_start)
            width = min(x_end - x_start, self.subvolume_size - target_x_start)
            
            if height <= 0 or width <= 0:
                continue
                
            # Copy data
            subvolume[target_i, 
                     target_y_start:target_y_start+height, 
                     target_x_start:target_x_start+width] = normalized[y_start:y_start+height, 
                                                                      x_start:x_start+width]
        
        # Add batch and channel dimensions to create a 5D tensor [B, C, D, H, W]
        # This is the fix for the "expected 5D input (got 4D input)" error
        subvolume_tensor = torch.tensor(subvolume).float()
        subvolume_tensor = subvolume_tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
        
        return subvolume_tensor
    
    def process_tomogram(self, tomo_id, yolo_model, cnn3d_model, index=0, total=1):
        """
        Process a single tomogram using both YOLO and 3D CNN models
        
        Args:
            tomo_id (str): Tomogram ID
            yolo_model: YOLO model
            cnn3d_model: 3D CNN model
            index (int): Current tomogram index
            total (int): Total number of tomograms
            
        Returns:
            dict: Detection result with coordinates
        """
        print(f"Processing tomogram {tomo_id} ({index}/{total})")
        
        # Get all slice files for this tomogram
        tomo_dir = os.path.join(self.test_dir, tomo_id)
        slice_files = sorted([f for f in os.listdir(tomo_dir) if f.endswith('.jpg')])
        
        # Apply concentration if needed
        if self.concentration < 1.0:
            selected_indices = np.linspace(0, len(slice_files)-1, int(len(slice_files) * self.concentration))
            selected_indices = np.round(selected_indices).astype(int)
            slice_files = [slice_files[i] for i in selected_indices]
            print(f"Processing {len(slice_files)} out of {len(os.listdir(tomo_dir))} slices")
        
        # STEP 1: Run YOLO on 2D slices to get candidates
        print("STEP 1: Running YOLO detection on 2D slices...")
        yolo_detections = []
        
        # Create CUDA streams for parallel processing
        if self.device.startswith('cuda'):
            streams = [torch.cuda.Stream() for _ in range(min(4, self.batch_size))]
        else:
            streams = [None]
        
        # Variables for preloading
        next_batch_thread = None
        next_batch_images = None
        
        # Process slices in batches
        for batch_start in range(0, len(slice_files), self.batch_size):
            # Wait for previous preload thread if it exists
            if next_batch_thread is not None:
                next_batch_thread.join()
                next_batch_images = None
                
            batch_end = min(batch_start + self.batch_size, len(slice_files))
            batch_files = slice_files[batch_start:batch_end]
            
            # Start preloading next batch
            next_batch_start = batch_end
            next_batch_end = min(next_batch_start + self.batch_size, len(slice_files))
            next_batch_files = slice_files[next_batch_start:next_batch_end] if next_batch_start < len(slice_files) else []
            
            if next_batch_files:
                next_batch_paths = [os.path.join(tomo_dir, f) for f in next_batch_files]
                next_batch_thread = threading.Thread(target=self.preload_image_batch, args=(next_batch_paths,))
                next_batch_thread.start()
            else:
                next_batch_thread = None
            
            # Split batch across streams for parallel processing
            sub_batches = np.array_split(batch_files, len(streams))
            
            for i, sub_batch in enumerate(sub_batches):
                if len(sub_batch) == 0:
                    continue
                    
                stream = streams[i % len(streams)]
                with torch.cuda.stream(stream) if stream and self.device.startswith('cuda') else nullcontext():
                    # Process sub-batch
                    sub_batch_paths = [os.path.join(tomo_dir, slice_file) for slice_file in sub_batch]
                    sub_batch_slice_nums = [int(slice_file.split('_')[1].split('.')[0]) for slice_file in sub_batch]
                    
                    # Run inference
                    sub_results = yolo_model(sub_batch_paths, verbose=False)
                    
                    # Process each result in this sub-batch
                    for j, result in enumerate(sub_results):
                        if len(result.boxes) > 0:
                            boxes = result.boxes
                            for box_idx, confidence in enumerate(boxes.conf):
                                if confidence >= self.yolo_confidence_threshold:
                                    # Get bounding box coordinates
                                    x1, y1, x2, y2 = boxes.xyxy[box_idx].cpu().numpy()
                                    
                                    # Calculate center coordinates
                                    x_center = (x1 + x2) / 2
                                    y_center = (y1 + y2) / 2
                                    
                                    # Store detection with 3D coordinates
                                    yolo_detections.append({
                                        'z': round(sub_batch_slice_nums[j]),
                                        'y': round(y_center),
                                        'x': round(x_center),
                                        'confidence': float(confidence),
                                        'model': 'yolo'
                                    })
            
            # Synchronize streams
            if self.device.startswith('cuda'):
                torch.cuda.synchronize()
        
        # Clean up thread if still running
        if next_batch_thread is not None:
            next_batch_thread.join()
        
        # Apply 3D NMS to consolidate YOLO detections
        yolo_detections = self.perform_3d_nms(yolo_detections, self.nms_iou_threshold)
        print(f"YOLO found {len(yolo_detections)} potential candidates")
        
        # If no YOLO detections, return early
        if not yolo_detections:
            return {
                'tomo_id': tomo_id,
                'Motor axis 0': -1,
                'Motor axis 1': -1,
                'Motor axis 2': -1
            }
        
        # STEP 2: Validate candidates with 3D CNN
        print("STEP 2: Validating candidates with 3D CNN...")
        validated_detections = []
        
        with torch.no_grad():
            for detection in yolo_detections[:min(10, len(yolo_detections))]:  # Process top candidates
                z, y, x = detection['z'], detection['y'], detection['x']
                
                # Skip if too close to edge for 3D CNN processing
                half_size = self.subvolume_size // 2
                slice_sample = os.path.join(tomo_dir, slice_files[0])
                try:
                    tomo_shape = np.array(Image.open(slice_sample)).shape
                    if (z < half_size or z >= len(slice_files) - half_size or
                        y < half_size or y >= tomo_shape[0] - half_size or
                        x < half_size or x >= tomo_shape[1] - half_size):
                        continue
                    
                    # Extract subvolume with proper 5D shape [B, C, D, H, W]
                    subvolume = self.extract_subvolume(tomo_id, z, y, x)
                    subvolume = subvolume.to(self.device)
                    
                    # Run 3D CNN inference
                    cls_output, reg_output = cnn3d_model(subvolume)
                    
                    # Check if 3D CNN confirms the detection
                    if cls_output.item() >= self.cnn_confidence_threshold:
                        # Add to validated detections
                        validated_detections.append({
                            'z': z,
                            'y': y,
                            'x': x,
                            'yolo_confidence': detection['confidence'],
                            'cnn_confidence': cls_output.item(),
                            'ensemble_confidence': (detection['confidence'] + cls_output.item()) / 2,
                            'model': 'ensemble'
                        })
                except Exception as e:
                    print(f"Error processing candidate at z={z}, y={y}, x={x}: {e}")
                    continue
        
        print(f"3D CNN validated {len(validated_detections)} candidates")
        
        # STEP 3: Make final prediction using ensemble
        if validated_detections:
            # Sort by ensemble confidence
            validated_detections.sort(key=lambda x: x['ensemble_confidence'], reverse=True)
            best_detection = validated_detections[0]
            
            # Return the best ensemble detection
            return {
                'tomo_id': tomo_id,
                'Motor axis 0': round(best_detection['z']),
                'Motor axis 1': round(best_detection['y']),
                'Motor axis 2': round(best_detection['x'])
            }
        else:
            # If no validated detections, fall back to best YOLO detection
            best_yolo = max(yolo_detections, key=lambda x: x['confidence'])
            
            # Only use YOLO if it's very confident
            if best_yolo['confidence'] > 0.65:
                return {
                    'tomo_id': tomo_id,
                    'Motor axis 0': round(best_yolo['z']),
                    'Motor axis 1': round(best_yolo['y']),
                    'Motor axis 2': round(best_yolo['x'])
                }
            else:
                # Otherwise, say no motor found
                return {
                    'tomo_id': tomo_id,
                    'Motor axis 0': -1,
                    'Motor axis 1': -1,
                    'Motor axis 2': -1
                }
    
    def generate_submission(self):
        """
        Process all test tomograms using the ensemble approach and generate submission
        
        Returns:
            pd.DataFrame: Submission dataframe
        """
        # Get list of test tomograms
        test_tomos = sorted([d for d in os.listdir(self.test_dir) if os.path.isdir(os.path.join(self.test_dir, d))])
        total_tomos = len(test_tomos)
        
        print(f"Found {total_tomos} tomograms in test directory")
        
        # Clear GPU cache before starting
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Load models
        yolo_model, cnn3d_model = self.load_models()
        
        # Process tomograms with parallelization
        results = []
        motors_found = 0
        
        # Using ThreadPoolExecutor with max_workers=1 since each worker uses the GPU
        # and we're parallelizing within each tomogram processing
        with ThreadPoolExecutor(max_workers=1) as executor:
            future_to_tomo = {}
            
            # Submit all tomograms for processing
            for i, tomo_id in enumerate(test_tomos, 1):
                future = executor.submit(self.process_tomogram, tomo_id, yolo_model, cnn3d_model, i, total_tomos)
                future_to_tomo[future] = tomo_id
            
            # Process completed futures as they complete
            for future in future_to_tomo:
                tomo_id = future_to_tomo[future]
                try:
                    # Clear CUDA cache between tomograms
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        
                    result = future.result()
                    results.append(result)
                    
                    # Update motors found count
                    motor_found = result['Motor axis 0'] != -1
                    if motor_found:
                        motors_found += 1
                        print(f"Motor found in {tomo_id} at position: "
                              f"z={result['Motor axis 0']}, y={result['Motor axis 1']}, x={result['Motor axis 2']}")
                    else:
                        print(f"No motor detected in {tomo_id}")
                        
                    print(f"Current detection rate: {motors_found}/{len(results)} ({motors_found/len(results)*100:.1f}%)")
                
                except Exception as e:
                    print(f"Error processing {tomo_id}: {e}")
                    # Create a default entry for failed tomograms
                    results.append({
                        'tomo_id': tomo_id,
                        'Motor axis 0': -1,
                        'Motor axis 1': -1,
                        'Motor axis 2': -1
                    })
        
        # Create submission dataframe
        submission_df = pd.DataFrame(results)
        
        # Ensure proper column order
        submission_df = submission_df[['tomo_id', 'Motor axis 0', 'Motor axis 1', 'Motor axis 2']]
        
        # Save the submission file
        submission_df.to_csv(self.submission_path, index=False)
        
        print(f"\nSubmission complete!")
        print(f"Motors detected: {motors_found}/{total_tomos} ({motors_found/total_tomos*100:.1f}%)")
        print(f"Submission saved to: {self.submission_path}")
        
        # Display first few rows of submission
        print("\nSubmission preview:")
        print(submission_df.head())
        
        return submission_df

# Paths to trained model weights
yolo_model_path = os.path.join(WORKING_DIR, "yolo_weights", "motor_detector", "weights", "best.pt")
cnn3d_model_path = os.path.join(WORKING_DIR, "3dcnn_models", "3dcnn_best.pt")

# Create the ensemble motor detector
ensemble_detector = EnsembleMotorDetector(
    yolo_model_path=yolo_model_path,
    cnn3d_model_path=cnn3d_model_path,
    test_dir=TEST_DIR,
    submission_path=os.path.join(WORKING_DIR, "ensemble_submission.csv"),
    device='auto',
    yolo_confidence_threshold=0.30,  # Slightly lower threshold to catch more candidates
    cnn_confidence_threshold=0.45    # Then filter with 3D CNN
)

# Generate the ensemble submission
ensemble_submission = ensemble_detector.generate_submission()

## 6. Results Visualization and Analysis

Visualize predictions to understand how our model is performing

In [None]:
def visualize_detection(tomo_id, coordinates, test_dir, context_slices=2):
    """
    Visualize a motor detection with context slices
    
    Args:
        tomo_id (str): Tomogram ID
        coordinates (dict): Dictionary with motor coordinates
        test_dir (str): Test directory path
        context_slices (int): Number of context slices to show
    """
    z_center = coordinates['Motor axis 0']
    y_center = coordinates['Motor axis 1']
    x_center = coordinates['Motor axis 2']
    
    # If no motor detected, show message
    if z_center == -1:
        print(f"No motor detected in tomogram {tomo_id}")
        return
    
    # Get slices
    tomo_dir = os.path.join(test_dir, tomo_id)
    
    # Calculate slice range
    z_min = max(0, z_center - context_slices)
    z_max = z_center + context_slices
    
    # Get list of slices in this range
    slices = []
    for z in range(z_min, z_max + 1):
        slice_path = os.path.join(tomo_dir, f"slice_{z:04d}.jpg")
        if os.path.exists(slice_path):
            img = np.array(Image.open(slice_path))
            # Normalize for better visibility
            p2 = np.percentile(img, 2)
            p98 = np.percentile(img, 98)
            normalized = np.clip(img, p2, p98)
            normalized = 255 * (normalized - p2) / (p98 - p2)
            normalized = np.uint8(normalized)
            slices.append((z, normalized))
    
    # Create figure
    n_slices = len(slices)
    fig, axes = plt.subplots(1, n_slices, figsize=(4*n_slices, 4))
    
    if n_slices == 1:
        axes = [axes]
    
    # Display each slice
    for i, (z, img) in enumerate(slices):
        axes[i].imshow(img, cmap='gray')
        axes[i].set_title(f"Slice z={z}")
        
        # Draw bounding box on the motor slice
        if z == z_center:
            box_size = 24
            rect = plt.Rectangle((x_center - box_size//2, y_center - box_size//2), 
                              box_size, box_size, 
                              linewidth=2, edgecolor='r', facecolor='none')
            axes[i].add_patch(rect)
            axes[i].set_title(f"Motor Slice (z={z})", color='red')
        
        axes[i].axis('on')
        axes[i].set_xticks([])
        axes[i].set_yticks([])
    
    plt.suptitle(f"Tomogram: {tomo_id}, Motor at: z={z_center}, y={y_center}, x={x_center}")
    plt.tight_layout()
    plt.show()

# Visualize a few detections
n_visualize = min(3, len(submission))
for i in range(n_visualize):
    row = submission.iloc[i]
    if row['Motor axis 0'] != -1:  # Only visualize positive detections
        visualize_detection(
            row['tomo_id'],
            {
                'Motor axis 0': row['Motor axis 0'],
                'Motor axis 1': row['Motor axis 1'],
                'Motor axis 2': row['Motor axis 2']
            },
            TEST_DIR,
            context_slices=2
        )