# BYU Locating Flagellar Motors - Submission Notebook

This submission notebook processes test tomograms to locate bacterial flagellar motors using an ensemble approach combining 2D YOLOv8 detection and 3D CNN validation. The approach consists of:

1. Loading pre-trained models (YOLOv8 and 3D CNN)
2. Processing each tomogram slice-by-slice
3. Finding potential motor candidates using YOLOv8
4. Validating candidates with a 3D CNN model
5. Generating the final submission file

## Setup and Dependencies

In [None]:
# Install required packages from offline sources
import kagglehub
kagglehub.dataset_download('michaelkoo21/flagellar-motor-model-2')
kagglehub.dataset_download('rachiteagles/yolo-pkg')

# Install the YOLO package from the downloaded wheel file
!pip install --no-index --no-deps /kaggle/input/yolo-pkg/yolo/ultralytics-8.3.112-py3-none-any.whl

### Import Libraries

In [None]:
# Import necessary libraries
import os
import glob
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import cv2
import torch
import yaml
import random
import time
import threading
from tqdm import tqdm
from contextlib import nullcontext
from concurrent.futures import ThreadPoolExecutor
from ultralytics import YOLO

# Set random seeds for reproducibility
np.random.seed(42)
random.seed(42)
torch.manual_seed(42)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(42)
    
# Define paths for Kaggle environment
DATA_DIR = '/kaggle/input/byu-locating-bacterial-flagellar-motors-2025'
TRAIN_DIR = os.path.join(DATA_DIR, 'train')
TEST_DIR = os.path.join(DATA_DIR, 'test')
MODEL_DIR = '/kaggle/input/michaelkoo21-flagellar-motor-model-2'
SUBMISSION_PATH = '/kaggle/working/submission.csv'
WORKING_DIR = '/kaggle/working'

## 3D CNN Model Definition

In [3]:
class AttentionBlock(torch.nn.Module):
    """
    3D attention block for focusing on motor-specific features
    """
    def __init__(self, in_channels):
        """
        Initialize the attention block
        
        Args:
            in_channels (int): Number of input channels
        """
        super().__init__()
        self.query = torch.nn.Conv3d(in_channels, in_channels // 8, kernel_size=1)
        self.key = torch.nn.Conv3d(in_channels, in_channels // 8, kernel_size=1)
        self.value = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1)
        self.gamma = torch.nn.Parameter(torch.zeros(1))
        self.softmax = torch.nn.Softmax(dim=-1)
        
    def forward(self, x):
        """
        Forward pass through the attention block
        
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            torch.Tensor: Output tensor with attention applied
        """
        batch_size, channels, depth, height, width = x.size()
        
        # Reshape for attention calculation
        query = self.query(x).view(batch_size, -1, depth * height * width).permute(0, 2, 1)
        key = self.key(x).view(batch_size, -1, depth * height * width)
        value = self.value(x).view(batch_size, -1, depth * height * width)
        
        # Calculate attention
        attention = torch.bmm(query, key)
        attention = self.softmax(attention)
        
        # Apply attention
        out = torch.bmm(value, attention.permute(0, 2, 1))
        out = out.view(batch_size, channels, depth, height, width)
        
        # Residual connection
        out = self.gamma * out + x
        return out

class Motor3DCNN(torch.nn.Module):
    """
    3D CNN for detecting flagellar motors in tomogram volumes
    """
    def __init__(self, input_channels=1, dropout_rate=0.3):
        """
        Initialize the 3D CNN model
        
        Args:
            input_channels (int): Number of input channels
            dropout_rate (float): Dropout rate
        """
        super().__init__()
        
        # Initial convolution layers
        self.conv1 = torch.nn.Conv3d(input_channels, 16, kernel_size=3, padding=1)
        self.bn1 = torch.nn.BatchNorm3d(16)
        self.conv2 = torch.nn.Conv3d(16, 32, kernel_size=3, padding=1, stride=2)
        self.bn2 = torch.nn.BatchNorm3d(32)
        self.conv3 = torch.nn.Conv3d(32, 64, kernel_size=3, padding=1, stride=2)
        self.bn3 = torch.nn.BatchNorm3d(64)
        self.conv4 = torch.nn.Conv3d(64, 128, kernel_size=3, padding=1, stride=2)
        self.bn4 = torch.nn.BatchNorm3d(128)
        
        # Attention block
        self.attention = AttentionBlock(128)
        
        # Global average pooling
        self.global_avg_pool = torch.nn.AdaptiveAvgPool3d(1)
        
        # Fully connected layers
        self.fc1 = torch.nn.Linear(128, 64)
        self.fc2 = torch.nn.Linear(64, 32)
        self.fc3 = torch.nn.Linear(32, 1)  # Binary classification: motor presence
        
        # Regression head for 3D coordinates
        self.reg = torch.nn.Linear(32, 3)
        
        # Activation and dropout
        self.relu = torch.nn.ReLU()
        self.dropout = torch.nn.Dropout(dropout_rate)
        self.sigmoid = torch.nn.Sigmoid()
        
    def forward(self, x):
        """
        Forward pass through the 3D CNN
        
        Args:
            x (torch.Tensor): Input tensor
            
        Returns:
            tuple: (Classification output, Regression output)
        """
        # Initial convolutions
        x = self.relu(self.bn1(self.conv1(x)))
        x = self.relu(self.bn2(self.conv2(x)))
        x = self.relu(self.bn3(self.conv3(x)))
        x = self.relu(self.bn4(self.conv4(x)))
        
        # Apply attention
        x = self.attention(x)
        
        # Global average pooling
        x = self.global_avg_pool(x)
        x = x.view(x.size(0), -1)
        
        # Fully connected layers
        x = self.relu(self.fc1(x))
        x = self.dropout(x)
        x = self.relu(self.fc2(x))
        x = self.dropout(x)
        
        # Classification output
        cls_output = self.sigmoid(self.fc3(x))
        
        # Regression output (normalized coordinates)
        reg_output = self.reg(x)
        
        return cls_output, reg_output

## Image Processing and Helper Functions

In [4]:
class ImageProcessor:
    """
    Class for processing and normalizing tomogram slices.
    """
    @staticmethod
    def normalize_slice(img_array, method='adaptive'):
        """
        Normalize slice using different normalization methods
        
        Args:
            img_array (np.ndarray): Input image array
            method (str): Normalization method ('percentile', 'histogram', 'adaptive')
            
        Returns:
            np.ndarray: Normalized image
        """
        if method == 'percentile':
            # Percentile-based normalization (2nd to 98th percentile)
            p2 = np.percentile(img_array, 2)
            p98 = np.percentile(img_array, 98)
            normalized = np.clip(img_array, p2, p98)
            normalized = 255 * (normalized - p2) / (p98 - p2)
            return np.uint8(normalized)
        
        elif method == 'histogram':
            # Histogram equalization
            return cv2.equalizeHist(img_array)
        
        elif method == 'adaptive':
            # Combination of methods
            # First apply percentile normalization
            p2 = np.percentile(img_array, 2)
            p98 = np.percentile(img_array, 98)
            normalized = np.clip(img_array, p2, p98)
            normalized = 255 * (normalized - p2) / (p98 - p2)
            normalized = np.uint8(normalized)
            
            # Then apply adaptive histogram equalization
            clahe = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8, 8))
            return clahe.apply(normalized)
        
        else:
            return img_array

# GPU profiling context manager
class GPUProfiler:
    def __init__(self, name):
        self.name = name
        self.start_time = None
        
    def __enter__(self):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        self.start_time = time.time()
        return self
        
    def __exit__(self, *args):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        elapsed = time.time() - self.start_time
        print(f"[PROFILE] {self.name}: {elapsed:.3f}s")

## Ensemble Motor Detector Implementation

In [5]:
class EnsembleMotorDetector:
    """
    Ensemble detector combining 2D YOLOv8 and 3D CNN for improved motor detection
    """
    def __init__(self, yolo_model_path, cnn3d_model_path, test_dir, submission_path, 
                 device='auto', yolo_confidence_threshold=0.30, cnn_confidence_threshold=0.45):
        """
        Initialize the detector with both YOLO and 3D CNN models
        
        Args:
            yolo_model_path (str): Path to trained YOLO model weights
            cnn3d_model_path (str): Path to trained 3D CNN model weights
            test_dir (str): Path to test tomograms directory
            submission_path (str): Path to save submission CSV
            device (str): Device to use ('cpu', 'cuda', or 'auto')
            yolo_confidence_threshold (float): Confidence threshold for YOLO detections
            cnn_confidence_threshold (float): Confidence threshold for 3D CNN detections
        """
        self.yolo_model_path = yolo_model_path
        self.cnn3d_model_path = cnn3d_model_path
        self.test_dir = test_dir
        self.submission_path = submission_path
        self.yolo_confidence_threshold = yolo_confidence_threshold
        self.cnn_confidence_threshold = cnn_confidence_threshold
        
        # Set device
        if device == 'auto':
            self.device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
        else:
            self.device = device
            
        print(f"Using device: {self.device} for ensemble detection")
        
        # Detection parameters
        self.nms_iou_threshold = 0.2  # Non-maximum suppression threshold
        self.batch_size = 8  # Default batch size, will be adjusted dynamically
        self.concentration = 1.0  # Process all slices (can be reduced for faster testing)
        self.subvolume_size = 64  # Size of 3D CNN input subvolumes
    
    def load_models(self):
        """
        Load both YOLO and 3D CNN models
        
        Returns:
            tuple: (YOLO model, 3D CNN model)
        """
        # Load YOLO model
        print(f"Loading YOLO model from {self.yolo_model_path}")
        yolo_model = YOLO(self.yolo_model_path)
        yolo_model.to(self.device)
        
        # Fuse layers for faster inference if using GPU
        if self.device.startswith('cuda'):
            yolo_model.fuse()
            
            # Use half precision if on compatible GPU
            if torch.cuda.get_device_capability(0)[0] >= 7:  # Volta or newer
                yolo_model.model.half()
                print("Using half precision (FP16) for YOLO inference")
                
            # Get available GPU memory and set batch size accordingly
            free_mem = torch.cuda.get_device_properties(0).total_memory / 1e9 - torch.cuda.memory_allocated(0) / 1e9
            self.batch_size = max(4, min(32, int(free_mem * 4)))  # 4 images per GB as rough estimate
            print(f"Dynamic batch size set to {self.batch_size} based on available GPU memory")
        
        # Load 3D CNN model
        print(f"Loading 3D CNN model from {self.cnn3d_model_path}")
        cnn3d_model = Motor3DCNN(input_channels=1, dropout_rate=0.0)  # No dropout for inference
        cnn3d_model.load_state_dict(torch.load(self.cnn3d_model_path, map_location=self.device))
        cnn3d_model.to(self.device)
        cnn3d_model.eval()  # Set to evaluation mode
        
        return yolo_model, cnn3d_model
    
    def preload_image_batch(self, file_paths):
        """
        Preload a batch of images to CPU memory
        
        Args:
            file_paths (list): List of file paths to load
            
        Returns:
            list: List of loaded images
        """
        images = []
        for path in file_paths:
            img = cv2.imread(path)
            if img is None:
                # Try with PIL as fallback
                img = np.array(Image.open(path))
            images.append(img)
        return images
    
    def perform_3d_nms(self, detections, iou_threshold):
        """
        Perform 3D Non-Maximum Suppression on detections
        
        Args:
            detections (list): List of detection dictionaries
            iou_threshold (float): IoU threshold for suppression
            
        Returns:
            list: Filtered detections after NMS
        """
        if not detections:
            return []
        
        # Sort by confidence (highest first)
        detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
        
        # List to store final detections after NMS
        final_detections = []
        
        # Define 3D distance function
        def distance_3d(d1, d2):
            return np.sqrt((d1['z'] - d2['z'])**2 + 
                          (d1['y'] - d2['y'])**2 + 
                          (d1['x'] - d2['x'])**2)
        
        # Maximum distance threshold (based on box size)
        box_size = 24  # Same as annotation box size
        distance_threshold = box_size * iou_threshold
        
        # Process each detection
        while detections:
            # Take the detection with highest confidence
            best_detection = detections.pop(0)
            final_detections.append(best_detection)
            
            # Filter out detections that are too close to the best detection
            detections = [d for d in detections if distance_3d(d, best_detection) > distance_threshold]
        
        return final_detections
    
    def extract_subvolume(self, tomo_id, z, y, x):
        """
        Extract a subvolume centered at the given coordinates
        
        Args:
            tomo_id (str): Tomogram ID
            z (int): Z coordinate
            y (int): Y coordinate
            x (int): X coordinate
            
        Returns:
            torch.Tensor: Extracted subvolume as a tensor with shape [1, 1, D, H, W]
        """
        half_size = self.subvolume_size // 2
        subvolume = np.zeros((self.subvolume_size, self.subvolume_size, self.subvolume_size), dtype=np.float32)
        tomo_dir = os.path.join(self.test_dir, tomo_id)
        
        # Load the slices
        for i, z_pos in enumerate(range(z - half_size, z + half_size)):
            if z_pos < 0:
                continue
                
            slice_path = os.path.join(tomo_dir, f"slice_{z_pos:04d}.jpg")
            if not os.path.exists(slice_path):
                continue
                
            # Load and normalize slice
            img = np.array(Image.open(slice_path))
            
            # Normalize using our adaptive method
            p2 = np.percentile(img, 2)
            p98 = np.percentile(img, 98)
            normalized = np.clip(img, p2, p98)
            normalized = (normalized - p2) / (p98 - p2)
            
            # Convert to float32 for processing
            normalized = normalized.astype(np.float32)
            
            # Extract region around center
            y_start = max(0, y - half_size)
            y_end = min(img.shape[0], y + half_size)
            x_start = max(0, x - half_size)
            x_end = min(img.shape[1], x + half_size)
            
            # Calculate target indices in subvolume
            target_i = i
            target_y_start = max(0, half_size - (y - y_start))
            target_x_start = max(0, half_size - (x - x_start))
            
            # Calculate amount to copy
            height = min(y_end - y_start, self.subvolume_size - target_y_start)
            width = min(x_end - x_start, self.subvolume_size - target_x_start)
            
            if height <= 0 or width <= 0:
                continue
                
            # Copy data
            subvolume[target_i, 
                     target_y_start:target_y_start+height, 
                     target_x_start:target_x_start+width] = normalized[y_start:y_start+height, 
                                                                      x_start:x_start+width]
        
        # Add batch and channel dimensions to create a 5D tensor [B, C, D, H, W]
        # This is the fix for the "expected 5D input (got 4D input)" error
        subvolume_tensor = torch.tensor(subvolume).float()
        subvolume_tensor = subvolume_tensor.unsqueeze(0).unsqueeze(0)  # Add batch and channel dims
        
        return subvolume_tensor
    
    def process_tomogram(self, tomo_id, yolo_model, cnn3d_model, index=0, total=1):
        """
        Process a single tomogram using both YOLO and 3D CNN models
        
        Args:
            tomo_id (str): Tomogram ID
            yolo_model: YOLO model
            cnn3d_model: 3D CNN model
            index (int): Current tomogram index
            total (int): Total number of tomograms
            
        Returns:
            dict: Detection result with coordinates
        """
        print(f"Processing tomogram {tomo_id} ({index}/{total})")
        
        # Get all slice files for this tomogram
        tomo_dir = os.path.join(self.test_dir, tomo_id)
        slice_files = sorted([f for f in os.listdir(tomo_dir) if f.endswith('.jpg')])
        
        # Apply concentration if needed
        if self.concentration < 1.0:
            selected_indices = np.linspace(0, len(slice_files)-1, int(len(slice_files) * self.concentration))
            selected_indices = np.round(selected_indices).astype(int)
            slice_files = [slice_files[i] for i in selected_indices]
            print(f"Processing {len(slice_files)} out of {len(os.listdir(tomo_dir))} slices")
        
        # STEP 1: Run YOLO on 2D slices to get candidates
        print("STEP 1: Running YOLO detection on 2D slices...")
        yolo_detections = []
        
        # Create CUDA streams for parallel processing
        if self.device.startswith('cuda'):
            streams = [torch.cuda.Stream() for _ in range(min(4, self.batch_size))]
        else:
            streams = [None]
        
        # Variables for preloading
        next_batch_thread = None
        next_batch_images = None
        
        # Process slices in batches
        for batch_start in range(0, len(slice_files), self.batch_size):
            # Wait for previous preload thread if it exists
            if next_batch_thread is not None:
                next_batch_thread.join()
                next_batch_images = None
                
            batch_end = min(batch_start + self.batch_size, len(slice_files))
            batch_files = slice_files[batch_start:batch_end]
            
            # Start preloading next batch
            next_batch_start = batch_end
            next_batch_end = min(next_batch_start + self.batch_size, len(slice_files))
            next_batch_files = slice_files[next_batch_start:next_batch_end] if next_batch_start < len(slice_files) else []
            
            if next_batch_files:
                next_batch_paths = [os.path.join(tomo_dir, f) for f in next_batch_files]
                next_batch_thread = threading.Thread(target=self.preload_image_batch, args=(next_batch_paths,))
                next_batch_thread.start()
            else:
                next_batch_thread = None
            
            # Split batch across streams for parallel processing
            sub_batches = np.array_split(batch_files, len(streams))
            
            for i, sub_batch in enumerate(sub_batches):
                if len(sub_batch) == 0:
                    continue
                    
                stream = streams[i % len(streams)]
                with torch.cuda.stream(stream) if stream and self.device.startswith('cuda') else nullcontext():
                    # Process sub-batch
                    sub_batch_paths = [os.path.join(tomo_dir, slice_file) for slice_file in sub_batch]
                    sub_batch_slice_nums = [int(slice_file.split('_')[1].split('.')[0]) for slice_file in sub_batch]
                    
                    # Run inference
                    sub_results = yolo_model(sub_batch_paths, verbose=False)
                    
                    # Process each result in this sub-batch
                    for j, result in enumerate(sub_results):
                        if len(result.boxes) > 0:
                            boxes = result.boxes
                            for box_idx, confidence in enumerate(boxes.conf):
                                if confidence >= self.yolo_confidence_threshold:
                                    # Get bounding box coordinates
                                    x1, y1, x2, y2 = boxes.xyxy[box_idx].cpu().numpy()
                                    
                                    # Calculate center coordinates
                                    x_center = (x1 + x2) / 2
                                    y_center = (y1 + y2) / 2
                                    
                                    # Store detection with 3D coordinates
                                    yolo_detections.append({
                                        'z': round(sub_batch_slice_nums[j]),
                                        'y': round(y_center),
                                        'x': round(x_center),
                                        'confidence': float(confidence),
                                        'model': 'yolo'
                                    })
            
            # Synchronize streams
            if self.device.startswith('cuda'):
                torch.cuda.synchronize()
        
        # Clean up thread if still running
        if next_batch_thread is not None:
            next_batch_thread.join()
        
        # Apply 3D NMS to consolidate YOLO detections
        yolo_detections = self.perform_3d_nms(yolo_detections, self.nms_iou_threshold)
        print(f"YOLO found {len(yolo_detections)} potential candidates")
        
        # If no YOLO detections, return early
        if not yolo_detections:
            return {
                'tomo_id': tomo_id,
                'Motor axis 0': -1,
                'Motor axis 1': -1,
                'Motor axis 2': -1
            }
        
        # STEP 2: Validate candidates with 3D CNN
        print("STEP 2: Validating candidates with 3D CNN...")
        validated_detections = []
        
        with torch.no_grad():
            for detection in yolo_detections[:min(10, len(yolo_detections))]:  # Process top candidates
                z, y, x = detection['z'], detection['y'], detection['x']
                
                # Skip if too close to edge for 3D CNN processing
                half_size = self.subvolume_size // 2
                slice_sample = os.path.join(tomo_dir, slice_files[0])
                try:
                    tomo_shape = np.array(Image.open(slice_sample)).shape
                    if (z < half_size or z >= len(slice_files) - half_size or
                        y < half_size or y >= tomo_shape[0] - half_size or
                        x < half_size or x >= tomo_shape[1] - half_size):
                        continue
                    
                    # Extract subvolume with proper 5D shape [B, C, D, H, W]
                    subvolume = self.extract_subvolume(tomo_id, z, y, x)
                    subvolume = subvolume.to(self.device)
                    
                    # Run 3D CNN inference
                    cls_output, reg_output = cnn3d_model(subvolume)
                    
                    # Check if 3D CNN confirms the detection
                    if cls_output.item() >= self.cnn_confidence_threshold:
                        # Add to validated detections
                        validated_detections.append({
                            'z': z,
                            'y': y,
                            'x': x,
                            'yolo_confidence': detection['confidence'],
                            'cnn_confidence': cls_output.item(),
                            'ensemble_confidence': (detection['confidence'] + cls_output.item()) / 2,
                            'model': 'ensemble'
                        })
                except Exception as e:
                    print(f"Error processing candidate at z={z}, y={y}, x={x}: {e}")
                    continue
        
        print(f"3D CNN validated {len(validated_detections)} candidates")
        
        # STEP 3: Make final prediction using ensemble
        if validated_detections:
            # Sort by ensemble confidence
            validated_detections.sort(key=lambda x: x['ensemble_confidence'], reverse=True)
            best_detection = validated_detections[0]
            
            # Return the best ensemble detection
            return {
                'tomo_id': tomo_id,
                'Motor axis 0': round(best_detection['z']),
                'Motor axis 1': round(best_detection['y']),
                'Motor axis 2': round(best_detection['x'])
            }
        else:
            # If no validated detections, fall back to best YOLO detection
            best_yolo = max(yolo_detections, key=lambda x: x['confidence'])
            
            # Only use YOLO if it's very confident
            if best_yolo['confidence'] > 0.65:
                return {
                    'tomo_id': tomo_id,
                    'Motor axis 0': round(best_yolo['z']),
                    'Motor axis 1': round(best_yolo['y']),
                    'Motor axis 2': round(best_yolo['x'])
                }
            else:
                # Otherwise, say no motor found
                return {
                    'tomo_id': tomo_id,
                    'Motor axis 0': -1,
                    'Motor axis 1': -1,
                    'Motor axis 2': -1
                }
    
    def generate_submission(self):
        """
        Process all test tomograms using the ensemble approach and generate submission
        
        Returns:
            pd.DataFrame: Submission dataframe
        """
        # Get list of test tomograms
        test_tomos = sorted([d for d in os.listdir(self.test_dir) if os.path.isdir(os.path.join(self.test_dir, d))])
        total_tomos = len(test_tomos)
        
        print(f"Found {total_tomos} tomograms in test directory")
        
        # Clear GPU cache before starting
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
        
        # Load models
        yolo_model, cnn3d_model = self.load_models()
        
        # Process tomograms with parallelization
        results = []
        motors_found = 0
        
        # Using ThreadPoolExecutor with max_workers=1 since each worker uses the GPU
        # and we're parallelizing within each tomogram processing
        with ThreadPoolExecutor(max_workers=1) as executor:
            future_to_tomo = {}
            
            # Submit all tomograms for processing
            for i, tomo_id in enumerate(test_tomos, 1):
                future = executor.submit(self.process_tomogram, tomo_id, yolo_model, cnn3d_model, i, total_tomos)
                future_to_tomo[future] = tomo_id
            
            # Process completed futures as they complete
            for future in future_to_tomo:
                tomo_id = future_to_tomo[future]
                try:
                    # Clear CUDA cache between tomograms
                    if torch.cuda.is_available():
                        torch.cuda.empty_cache()
                        
                    result = future.result()
                    results.append(result)
                    
                    # Update motors found count
                    motor_found = result['Motor axis 0'] != -1
                    if motor_found:
                        motors_found += 1
                        print(f"Motor found in {tomo_id} at position: "
                              f"z={result['Motor axis 0']}, y={result['Motor axis 1']}, x={result['Motor axis 2']}")
                    else:
                        print(f"No motor detected in {tomo_id}")
                        
                    print(f"Current detection rate: {motors_found}/{len(results)} ({motors_found/len(results)*100:.1f}%)")
                
                except Exception as e:
                    print(f"Error processing {tomo_id}: {e}")
                    # Create a default entry for failed tomograms
                    results.append({
                        'tomo_id': tomo_id,
                        'Motor axis 0': -1,
                        'Motor axis 1': -1,
                        'Motor axis 2': -1
                    })
        
        # Create submission dataframe
        submission_df = pd.DataFrame(results)
        
        # Ensure proper column order
        submission_df = submission_df[['tomo_id', 'Motor axis 0', 'Motor axis 1', 'Motor axis 2']]
        
        # Save the submission file
        submission_df.to_csv(self.submission_path, index=False)
        
        print(f"\nSubmission complete!")
        print(f"Motors detected: {motors_found}/{total_tomos} ({motors_found/total_tomos*100:.1f}%)")
        print(f"Submission saved to: {self.submission_path}")
        
        # Display first few rows of submission
        print("\nSubmission preview:")
        print(submission_df.head())
        
        return submission_df

## Main Execution

In [None]:
def debug_image_loading(test_dir):
    """
    Debug function to check image loading for the first tomogram
    """
    test_tomos = [d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))]
    if not test_tomos:
        print("No test tomograms found!")
        return
        
    tomo_id = test_tomos[0]
    tomo_dir = os.path.join(test_dir, tomo_id)
    slice_files = sorted([f for f in os.listdir(tomo_dir) if f.endswith('.jpg')])
    
    if not slice_files:
        print(f"No image files found in {tomo_dir}")
        return
        
    print(f"Found {len(slice_files)} image files in {tomo_dir}")
    sample_file = slice_files[len(slice_files)//2]  # Middle slice
    img_path = os.path.join(tomo_dir, sample_file)
    
    # Try different loading methods
    try:
        # Method 1: PIL
        img_pil = Image.open(img_path)
        img_array_pil = np.array(img_pil)
        print(f"PIL Image shape: {img_array_pil.shape}, dtype: {img_array_pil.dtype}")
        
        # Method 2: OpenCV
        img_cv2 = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        print(f"OpenCV Image shape: {img_cv2.shape}, dtype: {img_cv2.dtype}")
        
        # Method 3: Convert to RGB
        img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        print(f"OpenCV RGB Image shape: {img_rgb.shape}, dtype: {img_rgb.dtype}")
        
        print("Image loading successful!")
    except Exception as e:
        print(f"Error loading image {img_path}: {e}")

# Main execution
def main():
    # Verify test directory and image loading
    print("Testing image loading...")
    debug_image_loading(TEST_DIR)
    
    # Define model paths
    yolo_model_path = "/kaggle/input/flagellar-motor-model-2/yolo_weights/motor_detector/weights/best.pt"
    cnn3d_model_path = "/kaggle/input/flagellar-motor-model-2/3dcnn_models/3dcnn_best.pt"
    
    # Create the ensemble motor detector
    ensemble_detector = EnsembleMotorDetector(
        yolo_model_path=yolo_model_path,
        cnn3d_model_path=cnn3d_model_path,
        test_dir=TEST_DIR,
        submission_path=SUBMISSION_PATH,
        device='auto',
        yolo_confidence_threshold=0.30,  # Slightly lower threshold to catch more candidates
        cnn_confidence_threshold=0.45    # Then filter with 3D CNN
    )
    
    # Generate the ensemble submission
    submission = ensemble_detector.generate_submission()
    
    return submission

if __name__ == "__main__":
    main()

This notebook implements a robust ensemble-based approach for detecting bacterial flagellar motors in cryo-ET tomograms. The main advantages of this approach are:

1. Efficiency: Uses GPU optimization techniques like CUDA streams and batch processing
2. Accuracy: Combines 2D YOLO detection with 3D CNN validation for better results
3. Robustness: Handles variable tomogram sizes and poor signal-to-noise ratio
4. Offline Operation: Works completely offline using pre-downloaded models

The workflow follows these steps:

1. Process each tomogram slice-by-slice with YOLO to identify potential motor locations
2. Perform 3D non-maximum suppression to merge nearby detections
3. Extract 3D subvolumes around candidate locations
4. Validate candidates using the 3D CNN model
5. Generate the final submission with motor coordinates or (-1, -1, -1) for tomograms without motors