## import libraries

In [None]:
import sys, subprocess

# reinstall scikit-learn so it matches the current numpy ABI
# remove the problem packages completely
subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "numpy", "matplotlib"])

# install versions that work together on Kaggle
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q",
                       "numpy==2.0.2",
                       "matplotlib==3.10.0",
                       "protobuf<6"])

subprocess.check_call([sys.executable, "-m", "pip", "uninstall", "-y", "scikit-learn"])
subprocess.check_call([sys.executable, "-m", "pip", "install", "-q", "--no-cache-dir",
                       "scikit-learn==1.5.2"])
!pip install google-cloud-bigquery-storage>=2.30.0,<3.0.0
# or to install a specific version
!pip install google-cloud-bigquery-storage==2.30.0 


print("âœ… scikit-learn reinstalled. Now restart the session.")


In [None]:
# 1. Extract the offline packages
!tar xfvz /kaggle/input/ultralytics-for-offline-install/archive.tar.gz

# 2. CRITICAL: Remove the numpy wheel so it doesn't overwrite the system version
!rm packages/numpy*
!rm packages/scipy*

# 3. Install only ultralytics (using the system's safe numpy/scipy)
!pip -q install --no-index --find-links=./packages ultralytics

# 4. Clean up
!rm -rf ./packages

In [None]:
# Data Processing & Analysis
import numpy as np  # Numerical computations
import pandas as pd  # Data manipulation & analysis
from tqdm.notebook import tqdm  # Progress bar for loops

# Machine Learning & Deep Learning
import torch  # PyTorch framework
import torch.nn as nn  # Neural network layers
import torch.nn.functional as F  # Activation functions
import torch.optim as optim  # Optimizers for training
from torch.utils.data import Dataset, DataLoader  # Custom dataset handling
from torch.optim.lr_scheduler import ReduceLROnPlateau  # Learning rate scheduler

# Computer Vision & Image Processing
import cv2  # OpenCV for image processing
from PIL import Image, ImageDraw  # Pillow for handling images
import seaborn as sns  # Data visualization
import matplotlib.pyplot as plt  # Plotting library
from matplotlib.patches import Rectangle  # Drawing bounding boxes

# Dataset & File Handling
import yaml  # YAML file handling
import json  # JSON file handling
import os  # Operating system interactions
import glob  # File searching

# Threading & Parallel Processing
import threading  # Multi-threading for performance improvement
import time  # Timing operations
from contextlib import nullcontext  # Handling context management
from concurrent.futures import ThreadPoolExecutor  # Thread-based parallelism

# Deep Learning Models
from ultralytics import YOLO  # YOLO model for object detection

# Machine Learning Utilities
from sklearn.model_selection import train_test_split  # Data splitting for training/validation

# Visualization
import plotly.express as px  # Interactive visualizations

# Mathematical Operations
import math  # Mathematical functions
import random  # Random number generation


## directories

In [None]:
Train_dir = "/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/train"
Test_dir = "/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/test"

### Check device

In [None]:
# Create output directories if they don't exist
os.makedirs('./', exist_ok=True)
os.makedirs('./', exist_ok=True)

# Set device: Use GPU if available; otherwise, fall back to CPU
DEVICE = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {DEVICE}")

# Set random seeds for reproducibility
RANDOM_SEED = 42
random.seed(RANDOM_SEED)
np.random.seed(RANDOM_SEED)
torch.manual_seed(RANDOM_SEED)
if torch.cuda.is_available():
    torch.cuda.manual_seed(RANDOM_SEED)
    torch.backends.cudnn.deterministic = True


## load the dataset

In [None]:
# Load the training labels CSV into a pandas DataFrame
train_df = pd.read_csv("/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/train_labels.csv")
train_df.head()

## Motor Stat

In [None]:
# Create the bar plot
ax = train_df.groupby('tomo_id')['Number of motors'].first().value_counts().plot(kind="bar", figsize=(10, 6), color="skyblue")

# Annotate each bar with its count value
for p in ax.patches:
    ax.annotate(f'{p.get_height()}',
                (p.get_x() + p.get_width() / 2, p.get_height()),  # Position (x, y)
                ha='center', va='bottom',  # Center alignment
                fontsize=12, fontweight='bold', color='black')  # Text styling

# Labels and title
ax.set_xlabel("Number of Motors", fontsize=14)
ax.set_ylabel("Count", fontsize=14)
ax.set_title("Distribution of Motors per Tomogram", fontsize=16)
plt.xticks(rotation=0)  # Keep x-axis labels horizontal

# Show plot
plt.show()


In [None]:
tomogram_ranges = {
    "Z-axis (slices)": (train_df["Array shape (axis 0)"].min(), train_df["Array shape (axis 0)"].max()),
    "X-axis (width)":  (train_df["Array shape (axis 1)"].min(), train_df["Array shape (axis 1)"].max()),
    "Y-axis (height)": (train_df["Array shape (axis 2)"].min(), train_df["Array shape (axis 2)"].max())
}

# Print formatted output
print("\n **Tomogram Size Ranges**:")
for axis, (min_val, max_val) in tomogram_ranges.items():
    print(f"{axis}: {min_val} to {max_val}")

### Relationships between motor axes and tomogram shapes

In [None]:
# Extract data
x = train_df['Motor axis 0']
y = train_df['Motor axis 1']
z = train_df['Motor axis 2']
colors = train_df['Number of motors']

# Create 3D plot
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
sc = ax.scatter(x, y, z, c=colors, cmap="viridis", s=50, alpha=0.85)

# Color bar
cbar = plt.colorbar(sc, ax=ax)
cbar.set_label('Number of motors')

# Labels and title
ax.set_xlabel("Motor axis 0")
ax.set_ylabel("Motor axis 1")
ax.set_zlabel("Motor axis 2")
ax.set_title("ðŸš€ 3D Scatter Plot: Motor Axes")

# Show plot
plt.show()


In [None]:
# Extract data
x = train_df['Array shape (axis 0)']
y = train_df['Array shape (axis 1)']
z = train_df['Array shape (axis 2)']
colors = train_df['Number of motors']

# Create 3D plot
fig = plt.figure(figsize=(10, 7))
ax = fig.add_subplot(111, projection='3d')

# Scatter plot
sc = ax.scatter(x, y, z, c=colors, cmap="magma", s=50, alpha=0.85)

# Color bar
cbar = plt.colorbar(sc, ax=ax)
cbar.set_label('Number of motors')

# Labels and title
ax.set_xlabel("Array shape (axis 0)")
ax.set_ylabel("Array shape (axis 1)")
ax.set_zlabel("Array shape (axis 2)")
ax.set_title("ðŸ§¬ 3D Scatter Plot: Tomogram Shapes")

# Show plot
plt.show()


### # Show descriptive statistics

In [None]:
display(train_df.describe())

###  Analyze Correlation


In [None]:
# Set figure size and background color
plt.figure(figsize=(9, 5), facecolor="white")

# Generate the heatmap
sns.heatmap(
    data=train_df.corr(numeric_only=True),
    annot=True,
    fmt=".2f"
)

# Add a title
plt.title("Correlation Heatmap")

# Display the plot
plt.show()

## look on images in the dataset

In [None]:
# Define parameters
path = "/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/train/tomo_00e463"
n_images = 9
is_random = True
figsize = (12, 12)

# Load image names
image_names = os.listdir(path)

# Handle case where directory is empty
if not image_names:
    print("No images found in the directory.")
else:
    # Select images (random or sequential)
    if is_random:
        image_names = random.sample(image_names, min(len(image_names), n_images))
    else:
        image_names = image_names[:n_images]

    # Define grid size
    w = int(math.sqrt(n_images))
    h = math.ceil(n_images / w)

    # Create figure
    plt.figure(figsize=figsize)

    for ind, image_name in enumerate(image_names):
        img_path = os.path.join(path, image_name)
        img = cv2.imread(img_path)

        if img is None:
            continue  # Skip if image is unreadable

        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

        plt.subplot(h, w, ind + 1)
        plt.imshow(img)
        plt.xticks([])
        plt.yticks([])

    # Add title
    plt.suptitle("Sample Tomogram Images")
    plt.show()


## Model used YOLO8


In [None]:
from tqdm import tqdm

# Define dataset paths
data_path = "/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/"
train_dir = os.path.join(data_path, "train")

# YOLO dataset structure
yolo_dir = "/kaggle/working/yolo_dataset"
yolo_img_train, yolo_img_val = os.path.join(yolo_dir, "images/train"), os.path.join(yolo_dir, "images/val")
yolo_lbl_train, yolo_lbl_val = os.path.join(yolo_dir, "labels/train"), os.path.join(yolo_dir, "labels/val")

# Create required directories
for path in [yolo_img_train, yolo_img_val, yolo_lbl_train, yolo_lbl_val]:
    os.makedirs(path, exist_ok=True)

# Constants
TRUST = 6  # Number of slices above and below center slice
BOX_SIZE = 28  # Bounding box size in pixels
TRAIN_SPLIT = 0.8  # Train-validation split ratio

def normalize_slice(slice_data):
    """
    Normalize image slice using percentile-based contrast enhancement.
    """
    p2, p98 = np.percentile(slice_data, [2, 98])
    return np.uint8(255 * np.clip((slice_data - p2) / (p98 - p2), 0, 1))

def process_tomograms(tomo_ids, img_dir, lbl_dir, labels_df, trust, set_name):
    """
    Process tomograms to extract image slices and generate YOLO labels.
    """
    motor_data = labels_df[labels_df['tomo_id'].isin(tomo_ids)]
    motor_count = []

    for _, motor in motor_data.iterrows():
        if pd.isna(motor['Motor axis 0']):
            continue
        motor_count.append(
            (motor['tomo_id'], *map(int, [motor['Motor axis 0'], motor['Motor axis 1'], motor['Motor axis 2'], motor['Array shape (axis 0)']]))
        )

    print(f"Processing {len(motor_count)} motors for {set_name}...")
    processed = 0

    for tomo_id, z_center, y_center, x_center, z_max in tqdm(motor_count, desc=f"{set_name} data"):
        for z in range(max(0, z_center - trust), min(z_max - 1, z_center + trust) + 1):
            slice_file = f"slice_{z:04d}.jpg"
            src_path = os.path.join(train_dir, tomo_id, slice_file)

            if not os.path.exists(src_path):
                continue  # Skip missing slices

            img = np.array(Image.open(src_path))
            normalized = normalize_slice(img)
            dest_name = f"{tomo_id}_z{z:04d}_y{y_center:04d}_x{x_center:04d}.jpg"
            Image.fromarray(normalized).save(os.path.join(img_dir, dest_name))

            # Generate YOLO bounding box annotation
            img_w, img_h = img.shape[1], img.shape[0]
            label_path = os.path.join(lbl_dir, dest_name.replace('.jpg', '.txt'))
            with open(label_path, 'w') as f:
                f.write(f"0 {x_center/img_w} {y_center/img_h} {BOX_SIZE/img_w} {BOX_SIZE/img_h}\n")

            processed += 1
    return processed, len(motor_count)

def prepare_yolo_dataset():
    """
    Prepare the YOLO dataset by extracting slices and generating labels.
    """
    labels_df = pd.read_csv(os.path.join(data_path, "train_labels.csv"))
    tomo_ids = labels_df[labels_df['Number of motors'] > 0]['tomo_id'].unique()

    # Train-validation split
    np.random.shuffle(tomo_ids)
    split_idx = int(len(tomo_ids) * TRAIN_SPLIT)
    train_tomos, val_tomos = tomo_ids[:split_idx], tomo_ids[split_idx:]

    train_slices, train_motors = process_tomograms(train_tomos, yolo_img_train, yolo_lbl_train, labels_df, TRUST, "Train")
    val_slices, val_motors = process_tomograms(val_tomos, yolo_img_val, yolo_lbl_val, labels_df, TRUST, "Validation")

    # Generate YAML config for YOLO
    with open(os.path.join(yolo_dir, 'dataset.yaml'), 'w') as f:
        yaml.dump({'path': yolo_dir, 'train': 'images/train', 'val': 'images/val', 'names': {0: 'motor'}}, f)

    print("\nDataset Preparation Complete!")
    print(f"- Train: {len(train_tomos)} tomograms, {train_motors} motors, {train_slices} slices")
    print(f"- Validation: {len(val_tomos)} tomograms, {val_motors} motors, {val_slices} slices")
    print(f"- Dataset directory: {yolo_dir}")

# Run preprocessing
prepare_yolo_dataset()


## Visualize yolo data

In [None]:
# Directories for training images and labels
yolo_dataset_dir="/kaggle/working/yolo_dataset"
train_images_dir = os.path.join(yolo_dataset_dir, "images", "train")
train_labels_dir = os.path.join(yolo_dataset_dir, "labels", "train")

def display_random_samples(num_samples=4):
    """
    Displays a specified number of random training images with YOLO annotations.

    Args:
        num_samples (int): The number of random images to display.
    """
    # Collect all image files (supports multiple formats)
    image_files = []
    for ext in ['*.jpg', '*.jpeg', '*.png']:
        image_files.extend(glob.glob(os.path.join(train_images_dir, "**", ext), recursive=True))

    if not image_files:
        print("No images found in the training directory.")
        return

    num_samples = min(num_samples, len(image_files))
    selected_images = random.sample(image_files, num_samples)

    # Create subplots
    rows = int(np.ceil(num_samples / 2))
    cols = min(num_samples, 2)
    fig, axes = plt.subplots(rows, cols, figsize=(14, 5 * rows))

    axes = axes.flatten() if num_samples > 1 else np.array([axes])

    for idx, img_path in enumerate(selected_images):
        try:
            # Get the corresponding label file
            relative_path = os.path.relpath(img_path, train_images_dir)
            label_path = os.path.join(train_labels_dir, os.path.splitext(relative_path)[0] + '.txt')

            # Load and normalize the image
            img = Image.open(img_path)
            img_width, img_height = img.size
            img_array = np.array(img)
            p2, p98 = np.percentile(img_array, 2), np.percentile(img_array, 98)
            img_normalized = np.clip(img_array, p2, p98)
            img_normalized = 255 * (img_normalized - p2) / (p98 - p2)
            img_normalized = Image.fromarray(np.uint8(img_normalized))

            # Prepare image for annotation
            img_rgb = img_normalized.convert('RGB')
            overlay = Image.new('RGBA', img_rgb.size, (0, 0, 0, 0))
            draw = ImageDraw.Draw(overlay)

            # Load YOLO annotations if they exist
            annotations = []
            if os.path.exists(label_path):
                with open(label_path, 'r') as label_file:
                    for line in label_file:
                        class_id, x_center, y_center, width, height = map(float, line.strip().split())
                        x_center, y_center = x_center * img_width, y_center * img_height
                        width, height = width * img_width, height * img_height
                        annotations.append({'class_id': int(class_id), 'x_center': x_center, 'y_center': y_center, 'width': width, 'height': height})

            # Draw annotations
            for ann in annotations:
                x1 = max(0, int(ann['x_center'] - ann['width'] / 2))
                y1 = max(0, int(ann['y_center'] - ann['height'] / 2))
                x2 = min(img_width, int(ann['x_center'] + ann['width'] / 2))
                y2 = min(img_height, int(ann['y_center'] + ann['height'] / 2))
                draw.rectangle([x1, y1, x2, y2], fill=(255, 0, 0, 64), outline=(255, 0, 0, 200))
                draw.text((x1, y1 - 10), f"Class {ann['class_id']}", fill=(255, 0, 0, 255))

            # If no annotations, indicate it on the image
            if not annotations:
                draw.text((10, 10), "No annotations found", fill=(255, 0, 0, 255))

            # Composite and show image
            img_final = Image.alpha_composite(img_rgb.convert('RGBA'), overlay).convert('RGB')
            axes[idx].imshow(np.array(img_final))
            axes[idx].set_title(f"Image: {os.path.basename(img_path)}\nAnnotations: {len(annotations)}")
            axes[idx].axis('off')

        except Exception as e:
            print(f"Error processing image {img_path}: {e}")
            axes[idx].text(0.5, 0.5, f"Error loading: {os.path.basename(img_path)}", ha='center', va='center')
            axes[idx].axis('off')

    # Hide any extra axes
    for i in range(idx + 1, len(axes)):
        axes[i].axis('off')

    plt.tight_layout()
    plt.show()
    print(f"Displayed {num_samples} random images with YOLO annotations.")

# Call the function to visualize random training samples
display_random_samples(4)


##  Training


In [None]:
# yaml fixing
def fix_yaml_paths(yaml_path):
    """Fix YAML paths to match the Kaggle directories."""
    with open(yaml_path, 'r') as f:
        yaml_data = yaml.safe_load(f)

    yaml_data['path'] = yolo_dataset_dir  # Update path
    fixed_yaml_path = "/kaggle/working/fixed_dataset.yaml"

    with open(fixed_yaml_path, 'w') as f:
        yaml.dump(yaml_data, f)

    print(f"Fixed YAML created at {fixed_yaml_path}")
    return fixed_yaml_path


In [None]:
# Set random seeds for reproducibility
np.random.seed(00)
random.seed(42)
torch.manual_seed(42)

# Define paths for the Kaggle environment
yolo_dataset_dir = "/kaggle/working/yolo_dataset"
yolo_weights_dir = "/kaggle/working/yolo_weights"
yolo_pretrained_weights = "/kaggle/input/ultralytics-for-offline-install/yolov8n.pt"  # Pre-downloaded weights

# Create the weights directory if it does not exist
os.makedirs(yolo_weights_dir, exist_ok=True)

In [None]:
def plot_curves(run_dir):
    """Plot DFL loss curves and highlight the best model."""
    results_csv = os.path.join(run_dir, 'results.csv')
    if not os.path.exists(results_csv):
        print(f"Results file not found: {results_csv}")
        return

    df = pd.read_csv(results_csv)
    train_dfl_col = next((col for col in df.columns if 'train/dfl_loss' in col), None)
    val_dfl_col = next((col for col in df.columns if 'val/dfl_loss' in col), None)

    if not train_dfl_col or not val_dfl_col:
        print("DFL loss columns not found.")
        return

    best_epoch = df[val_dfl_col].idxmin()
    best_val_loss = df.loc[best_epoch, val_dfl_col]

    plt.figure(figsize=(10, 6))
    plt.plot(df['epoch'], df[train_dfl_col], label='Train DFL Loss')
    plt.plot(df['epoch'], df[val_dfl_col], label='Validation DFL Loss')
    plt.axvline(x=df.loc[best_epoch, 'epoch'], color='r', linestyle='--',
                label=f'Best Model (Epoch {df.loc[best_epoch, "epoch"]}, Loss: {best_val_loss:.4f})')
    plt.xlabel('Epoch')
    plt.ylabel('DFL Loss')
    plt.title('Training and Validation DFL Loss')
    plt.legend()
    plt.grid(True, linestyle='--', alpha=0.7)

    plot_path = os.path.join(run_dir, 'dfl_loss_curve.png')
    plt.savefig(plot_path)
    plt.savefig('/kaggle/working/dfl_loss_curve.png')
    print(f"Loss curve saved to {plot_path}")
    plt.close()

    return best_epoch, best_val_loss

In [None]:
import os
def train_yolo_model(yaml_path, pretrained_weights_path, epochs=30, batch_size=16, img_size=640):
    """
    Train a YOLO model on the prepared dataset.

    Args:
        yaml_path (str): Path to the dataset YAML file.
        pretrained_weights_path (str): Path to pre-downloaded weights file.
        epochs (int): Number of training epochs.
        batch_size (int): Batch size for training.
        img_size (int): Image size for training.
    """
    print(f"Loading pre-trained weights from: {pretrained_weights_path}")
    assert os.path.exists(pretrained_weights_path), f"Weights not found: {pretrained_weights_path}"
    assert pretrained_weights_path.endswith(".pt"), f"Expected a .pt file, got: {pretrained_weights_path}"
    model = YOLO(pretrained_weights_path)

    results = model.train(
        data=yaml_path,
        epochs=epochs,
        batch=batch_size,
        imgsz=img_size,
        project=yolo_weights_dir,
        name='motor_detector',
        exist_ok=True,
        patience=5,
        save_period=5,
        val=True,
        verbose=True
    )

    run_dir = os.path.join(yolo_weights_dir, 'motor_detector')
    best_epoch_info = plot_curves(run_dir)
    if best_epoch_info:
        best_epoch, best_val_loss = best_epoch_info
        print(f"\nBest model found at epoch {best_epoch} with validation DFL loss: {best_val_loss:.4f}")

    return model, results

In [None]:

def predict_on_samples(model, num_samples=4):
    """
    Run predictions on random validation samples and display results.

    Args:
        model: Trained YOLO model.
        num_samples (int): Number of random samples to test.
    """
    val_dir = os.path.join(yolo_dataset_dir, 'images', 'val')
    if not os.path.exists(val_dir):
        print(f"Validation directory not found at {val_dir}")
        val_dir = os.path.join(yolo_dataset_dir, 'images', 'train')
        print(f"Using train directory for predictions instead: {val_dir}")

    if not os.path.exists(val_dir):
        print("No images directory found for predictions")
        return

    val_images = os.listdir(val_dir)
    if len(val_images) == 0:
        print("No images found for prediction")
        return

    num_samples = min(num_samples, len(val_images))
    samples = random.sample(val_images, num_samples)

    fig, axes = plt.subplots(2, 2, figsize=(12, 12))
    axes = axes.flatten()

    for i, img_file in enumerate(samples):
        if i >= len(axes):
            break

        img_path = os.path.join(val_dir, img_file)
        results = model.predict(img_path, conf=0.25)[0]
        img = Image.open(img_path)
        axes[i].imshow(np.array(img), cmap='gray')

        # Draw ground truth box if available (extracted from filename)
        try:
            parts = img_file.split('_')
            y_part = [p for p in parts if p.startswith('y')]
            x_part = [p for p in parts if p.startswith('x')]
            if y_part and x_part:
                y_gt = int(y_part[0][1:])
                x_gt = int(x_part[0][1:].split('.')[0])
                box_size = 28
                rect_gt = Rectangle((x_gt - box_size//2, y_gt - box_size//2), box_size, box_size,
                                      linewidth=1, edgecolor='g', facecolor='none')
                axes[i].add_patch(rect_gt)
        except:
            pass

        if len(results.boxes) > 0:
            boxes = results.boxes.xyxy.cpu().numpy()
            confs = results.boxes.conf.cpu().numpy()
            for box, conf in zip(boxes, confs):
                x1, y1, x2, y2 = box
                rect_pred = Rectangle((x1, y1), x2-x1, y2-y1, linewidth=1, edgecolor='r', facecolor='none')
                axes[i].add_patch(rect_pred)
                axes[i].text(x1, y1-5, f'{conf:.2f}', color='red')

        axes[i].set_title(f"Image: {img_file}\nGT (green) vs Pred (red)")

    plt.tight_layout()
    plt.savefig(os.path.join('/kaggle/working', 'predictions.png'))
    plt.show()

In [None]:
def prepare_dataset():
    """
    Check if the dataset exists and create/fix a proper YAML file for training.

    Returns:
        str: Path to the YAML file to use for training.
    """
    train_images_dir = os.path.join(yolo_dataset_dir, 'images', 'train')
    val_images_dir = os.path.join(yolo_dataset_dir, 'images', 'val')
    train_labels_dir = os.path.join(yolo_dataset_dir, 'labels', 'train')
    val_labels_dir = os.path.join(yolo_dataset_dir, 'labels', 'val')

    print(f"Directory status:")
    print(f"- Train images exists: {os.path.exists(train_images_dir)}")
    print(f"- Val images exists: {os.path.exists(val_images_dir)}")
    print(f"- Train labels exists: {os.path.exists(train_labels_dir)}")
    print(f"- Val labels exists: {os.path.exists(val_labels_dir)}")

    original_yaml_path = os.path.join(yolo_dataset_dir, 'dataset.yaml')
    if os.path.exists(original_yaml_path):
        print(f"Found original dataset.yaml at {original_yaml_path}")
        return fix_yaml_paths(original_yaml_path)
    else:
        print("Original dataset.yaml not found, creating a new one")
        yaml_data = {
            'path': yolo_dataset_dir,
            'train': 'images/train',
            'val': 'images/train' if not os.path.exists(val_images_dir) else 'images/val',
            'names': {0: 'motor'}
        }
        new_yaml_path = "/kaggle/working/dataset.yaml"
        with open(new_yaml_path, 'w') as f:
            yaml.dump(yaml_data, f)
        print(f"Created new YAML at {new_yaml_path}")
        return new_yaml_path


In [None]:
import shutil, os

src = "/kaggle/input/ultralytics-for-offline-install/yolov8n.pt"
dst = "/kaggle/working/yolov8n.pt"

assert os.path.exists(src), f"Missing weights at {src}"
shutil.copy(src, dst)
print("Copied to:", dst, "size:", os.path.getsize(dst))


In [None]:
yaml_path = prepare_dataset()
print(f"Using YAML file: {yaml_path}")
with open(yaml_path, 'r') as f:
    print(f"YAML contents:\n{f.read()}")

model, results = train_yolo_model(
    yaml_path,
    pretrained_weights_path=yolo_pretrained_weights,
    epochs=30
)

print("predictions")
predict_on_samples(model, num_samples=4)


## create submission

In [None]:
# random seed
np.random.seed(42)
torch.manual_seed(42)

# Define paths for the test data and submission
data_path = "/kaggle/input/byu-locating-bacterial-flagellar-motors-2025/"
test_dir = os.path.join(data_path, "test")
submission_path = "/kaggle/working/submission.csv"

# Path to the best trained model (adjust if necessary)
model_path = "/kaggle/working/yolo_weights/motor_detector/weights/best.pt"

# Define detection and processing parameters
CONFIDENCE_THRESHOLD = 0.45
MAX_DETECTIONS_PER_TOMO = 3
NMS_IOU_THRESHOLD = 0.2
CONCENTRATION = 1  # Process a fraction of slices for fast submission

# GPU profiling context manager for timing
class GPUProfiler:
    def __init__(self, name):
        self.name = name
        self.start_time = None

    def __enter__(self):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        self.start_time = time.time()
        return self

    def __exit__(self, *args):
        if torch.cuda.is_available():
            torch.cuda.synchronize()
        elapsed = time.time() - self.start_time
        print(f"[PROFILE] {self.name}: {elapsed:.3f}s")

# Set device and dynamic batch size
device = 'cuda:0' if torch.cuda.is_available() else 'cpu'
BATCH_SIZE = 8
if device.startswith('cuda'):
    torch.backends.cudnn.benchmark = True
    torch.backends.cudnn.deterministic = False
    torch.backends.cuda.matmul.allow_tf32 = True
    torch.backends.cudnn.allow_tf32 = True
    gpu_name = torch.cuda.get_device_name(0)
    gpu_mem = torch.cuda.get_device_properties(0).total_memory / 1e9
    print(f"Using GPU: {gpu_name} with {gpu_mem:.2f} GB memory")
    free_mem = gpu_mem - torch.cuda.memory_allocated(0) / 1e9
    BATCH_SIZE = max(8, min(32, int(free_mem * 4)))
    print(f"Dynamic batch size set to {BATCH_SIZE} based on {free_mem:.2f}GB free memory")
else:
    print("GPU not available, using CPU")
    BATCH_SIZE = 4

## Inference Creation

In [None]:
def normalize_slice(slice_data):
    """
    Normalize slice data using the 2nd and 98th percentiles.
    """
    p2 = np.percentile(slice_data, 2)
    p98 = np.percentile(slice_data, 98)
    clipped_data = np.clip(slice_data, p2, p98)
    normalized = 255 * (clipped_data - p2) / (p98 - p2)
    return np.uint8(normalized)

def preload_image_batch(file_paths):
    """Preload a batch of images to CPU memory."""
    images = []
    for path in file_paths:
        img = cv2.imread(path)
        if img is None:
            img = np.array(Image.open(path))
        images.append(img)
    return images

def perform_3d_nms(detections, iou_threshold):
    """
    Perform 3D Non-Maximum Suppression on detections to merge nearby motors.
    """
    if not detections:
        return []

    detections = sorted(detections, key=lambda x: x['confidence'], reverse=True)
    final_detections = []
    def distance_3d(d1, d2):
        return np.sqrt((d1['z'] - d2['z'])**2 + (d1['y'] - d2['y'])**2 + (d1['x'] - d2['x'])**2)

    box_size = 28
    distance_threshold = box_size * iou_threshold

    while detections:
        best_detection = detections.pop(0)
        final_detections.append(best_detection)
        detections = [d for d in detections if distance_3d(d, best_detection) > distance_threshold]

    return final_detections

def process_tomogram(tomo_id, model, index=0, total=1):
    """
    Process a single tomogram and return the most confident motor detection.
    """
    print(f"Processing tomogram {tomo_id} ({index}/{total})")
    tomo_dir = os.path.join(test_dir, tomo_id)
    slice_files = sorted([f for f in os.listdir(tomo_dir) if f.endswith('.jpg')])

    selected_indices = np.linspace(0, len(slice_files)-1, int(len(slice_files) * CONCENTRATION))
    selected_indices = np.round(selected_indices).astype(int)
    slice_files = [slice_files[i] for i in selected_indices]

    print(f"Processing {len(slice_files)} out of {len(os.listdir(tomo_dir))} slices (CONCENTRATION={CONCENTRATION})")
    all_detections = []

    if device.startswith('cuda'):
        streams = [torch.cuda.Stream() for _ in range(min(4, BATCH_SIZE))]
    else:
        streams = [None]

    next_batch_thread = None
    next_batch_images = None

    for batch_start in range(0, len(slice_files), BATCH_SIZE):
        if next_batch_thread is not None:
            next_batch_thread.join()
            next_batch_images = None

        batch_end = min(batch_start + BATCH_SIZE, len(slice_files))
        batch_files = slice_files[batch_start:batch_end]

        next_batch_start = batch_end
        next_batch_end = min(next_batch_start + BATCH_SIZE, len(slice_files))
        next_batch_files = slice_files[next_batch_start:next_batch_end] if next_batch_start < len(slice_files) else []
        if next_batch_files:
            next_batch_paths = [os.path.join(tomo_dir, f) for f in next_batch_files]
            next_batch_thread = threading.Thread(target=preload_image_batch, args=(next_batch_paths,))
            next_batch_thread.start()
        else:
            next_batch_thread = None

        sub_batches = np.array_split(batch_files, len(streams))
        for i, sub_batch in enumerate(sub_batches):
            if len(sub_batch) == 0:
                continue
            stream = streams[i % len(streams)]
            with torch.cuda.stream(stream) if stream and device.startswith('cuda') else nullcontext():
                sub_batch_paths = [os.path.join(tomo_dir, slice_file) for slice_file in sub_batch]
                sub_batch_slice_nums = [int(slice_file.split('_')[1].split('.')[0]) for slice_file in sub_batch]
                with GPUProfiler(f"Inference batch {i+1}/{len(sub_batches)}"):
                    sub_results = model(sub_batch_paths, verbose=False)
                for j, result in enumerate(sub_results):
                    if len(result.boxes) > 0:
                        for box_idx, confidence in enumerate(result.boxes.conf):
                            if confidence >= CONFIDENCE_THRESHOLD:
                                x1, y1, x2, y2 = result.boxes.xyxy[box_idx].cpu().numpy()
                                x_center = (x1 + x2) / 2
                                y_center = (y1 + y2) / 2
                                all_detections.append({
                                    'z': round(sub_batch_slice_nums[j]),
                                    'y': round(y_center),
                                    'x': round(x_center),
                                    'confidence': float(confidence)
                                })
        if device.startswith('cuda'):
            torch.cuda.synchronize()

    if next_batch_thread is not None:
        next_batch_thread.join()

    final_detections = perform_3d_nms(all_detections, NMS_IOU_THRESHOLD)
    final_detections.sort(key=lambda x: x['confidence'], reverse=True)

    if not final_detections:
        return {'tomo_id': tomo_id, 'Motor axis 0': -1, 'Motor axis 1': -1, 'Motor axis 2': -1}

    best_detection = final_detections[0]
    return {
        'tomo_id': tomo_id,
        'Motor axis 0': round(best_detection['z']),
        'Motor axis 1': round(best_detection['y']),
        'Motor axis 2': round(best_detection['x'])
    }

def debug_image_loading(tomo_id):
    """
    Debug function to test image loading methods.
    """
    tomo_dir = os.path.join(test_dir, tomo_id)
    slice_files = sorted([f for f in os.listdir(tomo_dir) if f.endswith('.jpg')])
    if not slice_files:
        print(f"No image files found in {tomo_dir}")
        return

    print(f"Found {len(slice_files)} image files in {tomo_dir}")
    sample_file = slice_files[len(slice_files)//2]
    img_path = os.path.join(tomo_dir, sample_file)

    try:
        img_pil = Image.open(img_path)
        print(f"PIL Image shape: {np.array(img_pil).shape}, dtype: {np.array(img_pil).dtype}")
        img_cv2 = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
        print(f"OpenCV Image shape: {img_cv2.shape}, dtype: {img_cv2.dtype}")
        img_rgb = cv2.cvtColor(cv2.imread(img_path), cv2.COLOR_BGR2RGB)
        print(f"OpenCV RGB Image shape: {img_rgb.shape}, dtype: {img_rgb.dtype}")
        print("Image loading successful!")
    except Exception as e:
        print(f"Error loading image {img_path}: {e}")

    try:
        test_model = YOLO(model_path)
        test_results = test_model([img_path], verbose=False)
        print("YOLO model successfully processed the test image")
    except Exception as e:
        print(f"Error with YOLO processing: {e}")

### Generate CSV

In [None]:
test_tomos = sorted([d for d in os.listdir(test_dir) if os.path.isdir(os.path.join(test_dir, d))])
total_tomos = len(test_tomos)
print(f"Total tomograms in test:{total_tomos}")

if test_tomos:
    debug_image_loading(test_tomos[0])

if torch.cuda.is_available():
    torch.cuda.empty_cache()

print(f"Loading YOLO model from {model_path}")
model = YOLO(model_path)
model.to(device)
if device.startswith('cuda'):
    model.fuse()
    if torch.cuda.get_device_capability(0)[0] >= 7:
        model.model.half()
        print("Using half precision (FP16) for inference")

results = []
motors_found = 0

with ThreadPoolExecutor(max_workers=1) as executor:
    future_to_tomo = {}
    for i, tomo_id in enumerate(test_tomos, 1):
        future = executor.submit(process_tomogram, tomo_id, model, i, total_tomos)
        future_to_tomo[future] = tomo_id

    for future in future_to_tomo:
        tomo_id = future_to_tomo[future]
        try:
            if torch.cuda.is_available():
                torch.cuda.empty_cache()
            result = future.result()
            results.append(result)
            has_motor = not pd.isna(result['Motor axis 0'])
            if has_motor:
                motors_found += 1
                print(f"Motor found in {tomo_id} at position: z={result['Motor axis 0']}, y={result['Motor axis 1']}, x={result['Motor axis 2']}")
            else:
                print(f"No motor detected in {tomo_id}")
            print(f"Current detection rate: {motors_found}/{len(results)} ({motors_found/len(results)*100:.1f}%)")
        except Exception as e:
            print(f"Error processing {tomo_id}: {e}")
            results.append({'tomo_id': tomo_id, 'Motor axis 0': -1, 'Motor axis 1': -1, 'Motor axis 2': -1})

submission_df = pd.DataFrame(results)
submission_df = submission_df[['tomo_id', 'Motor axis 0', 'Motor axis 1', 'Motor axis 2']]
submission_df.to_csv(submission_path, index=False)


In [None]:
## display submission
import os

# Kaggle expects this exact filename in the notebook outputs
submission_path = "/kaggle/working/submission.csv"

submission_df = pd.DataFrame(results)
submission_df = submission_df[['tomo_id', 'Motor axis 0', 'Motor axis 1', 'Motor axis 2']]

submission_df.to_csv(submission_path, index=False)

print("Wrote:", submission_path)
print("Exists:", os.path.exists(submission_path))

submission_df.head()