In [None]:
import os
import cv2
import numpy as np
from pathlib import Path
import matplotlib.pyplot as plt
from matplotlib.patches import Rectangle
from PIL import Image

def extract_sub_images(image_path):
    """
    Extract sub-images from a sorted image where organisms are placed on a white background.
    
    Args:
        image_path: Path to the sorted image
        
    Returns:
        List of lists of sub-images (numpy arrays)
    """
    # Read the image
    img = Image.open(image_path)
   
    # Convert to numpy array
    img_array = np.array(img)
   
    # Find rows where all pixels are white
    white_rows = np.all(img_array == 255, axis=(1, 2))
   
    # Get indices of separator rows
    row_separators = np.where(white_rows)[0]
   
    # List to store all sub-images
    all_sub_images = []
   
    # Process each row section
    for i in range(len(row_separators) - 1):
        start_row = row_separators[i] + 1
        end_row = row_separators[i + 1]
       
        # Extract the row section
        row_section = img_array[start_row:end_row]
       
        if row_section.shape[0] > 0:
            # Find white columns in this section
            white_cols = np.all(row_section >= 250, axis=(0, 2))
           
            # Get indices of separator columns
            col_separators = np.where(white_cols)[0]
           
            # If no column separators found, add the entire row section
            if len(col_separators) == 0:
                sub_images_row = [row_section]
                all_sub_images.append(sub_images_row)
            else:
                # Extract sub-images between column separators
                row_sub_images = []
               
                # Add section before first separator
                if col_separators[0] > 0:
                    sub_img = row_section[:, 0:col_separators[0]]
                    if sub_img.size > 0:
                        row_sub_images.append(sub_img)
               
                # Add sections between separators
                for j in range(len(col_separators) - 1):
                    start_col = col_separators[j] + 1
                    end_col = col_separators[j + 1]
                    sub_img = row_section[:, start_col:end_col]
                    if sub_img.size > 0:
                        row_sub_images.append(sub_img)
               
                # Add section after last separator
                if col_separators[-1] < row_section.shape[1] - 1:
                    sub_img = row_section[:, col_separators[-1] + 1:]
                    if sub_img.size > 0:
                        row_sub_images.append(sub_img)
               
                all_sub_images.append(row_sub_images)
   
    return all_sub_images

def squeeze_white_pixels(sub_image):
    """
    Remove white rows and columns around an organism to crop it tightly.
    
    Args:
        sub_image: Numpy array of the sub-image
        
    Returns:
        Cropped numpy array
    """
    # Find non-white rows and columns
    non_white_rows = ~np.all(sub_image == 255, axis=(1, 2))
    non_white_cols = ~np.all(sub_image == 255, axis=(0, 2))
    
    # If no non-white pixels are found, return the original image
    if not np.any(non_white_rows) or not np.any(non_white_cols):
        return sub_image
    
    # Extract only non-white rows and columns
    return sub_image[non_white_rows][:, non_white_cols]

def extract_organisms(sorted_image_path, min_size):
    """
    Extract organisms from a sorted image and keep them in memory.
    
    Args:
        sorted_image_path: Path to the sorted image
        min_size: Minimum size (width/height) for an organism to be included
        
    Returns:
        List of (identifier, image_array) tuples of extracted organisms
    """
    # Get base name of input file without extension
    base_name = os.path.splitext(os.path.basename(sorted_image_path))[0]
   
    # Extract all sub-images
    sub_images = extract_sub_images(sorted_image_path)
   
    # Track extracted organisms
    extracted_organisms = []
   
    # Process each row of sub-images
    for i, row in enumerate(sub_images):
        for j, sub_image in enumerate(row):
            # Remove white rows and columns
            squeezed_image = squeeze_white_pixels(sub_image)
           
            # Get dimensions
            h, w = squeezed_image.shape[:2]
           
            # Check if image meets minimum size requirements
            if h > min_size and w > min_size:
                # Generate an identifier for this organism
                identifier = f"{base_name}_organism_{i}_{j}_{w}_{h}"
                
                # Ensure the image is in RGB format (in case it's RGBA)
                if squeezed_image.shape[2] == 4:
                    # Convert RGBA to RGB
                    rgb_image = Image.fromarray(squeezed_image).convert('RGB')
                    squeezed_image = np.array(rgb_image)
                
                # Add to extracted organisms
                extracted_organisms.append((identifier, squeezed_image))
                print(f"Extracted: organism at position ({i},{j}) with size {w}x{h}")
   
    print(f"\nTotal organisms extracted: {len(extracted_organisms)}")
    return extracted_organisms

def find_organisms_in_raw_image(raw_image_path, organism_images):
    """
    Find extracted organisms in the raw image and return their bounding box coordinates.
    
    Args:
        raw_image_path: Path to the raw plankton image
        organism_images: List of (filename, image_array) tuples
        
    Returns:
        List of tuples (organism_filename, x, y, width, height) where x,y is the top-left corner
    """
    # Load the raw image
    raw_img = cv2.imread(raw_image_path)
    raw_img_rgb = cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB)
    
    if raw_img is None:
        print(f"Error: Could not read raw image at {raw_image_path}")
        return []
    
    if not organism_images:
        print("No organism images provided")
        return []
    
    # Results list to store: (organism_filename, x, y, width, height)
    detections = []
    
    # For each organism image
    for filename, org_img_rgb in organism_images:
        # Convert from PIL RGB to OpenCV BGR if needed
        if isinstance(org_img_rgb, np.ndarray) and org_img_rgb.shape[2] == 3:
            org_img_cv = cv2.cvtColor(org_img_rgb, cv2.COLOR_RGB2BGR)
        else:
            print(f"Warning: Unexpected image format for {filename}")
            continue
        
        # Use template matching to find the organism in the raw image
        result = cv2.matchTemplate(raw_img, org_img_cv, cv2.TM_CCOEFF_NORMED)
        min_val, max_val, min_loc, max_loc = cv2.minMaxLoc(result)
        
        # If good match found (adjust threshold as needed)
        if max_val > 0.7:  # Adjust this threshold based on your images
            # Get organism dimensions
            h, w = org_img_rgb.shape[:2]
            
            # Store detection data: (filename, x, y, width, height)
            detections.append((filename, max_loc[0], max_loc[1], w, h))
            print(f"Found {filename} at position {max_loc} with confidence {max_val:.2f}")
        else:
            print(f"Could not find a good match for {filename}. Best match: {max_val:.2f}")
            
            # Try alternative matching with a lower threshold if needed
            if max_val > 0.5:  # Lower threshold for backup detection
                h, w = org_img_rgb.shape[:2]
                detections.append((filename, max_loc[0], max_loc[1], w, h))
                print(f"Using best available match for {filename}")
    
    return detections

def visualize_detections(raw_image_path, detections, output_path):
    """
    Create a visualization of the raw image with bounding boxes around detected organisms.
    
    Args:
        raw_image_path: Path to the raw plankton image
        detections: List of tuples (organism_filename, x, y, width, height)
        output_path: Path to save the visualization
    """
    # Load the raw image for visualization
    raw_img = cv2.imread(raw_image_path)
    if raw_img is None:
        print(f"Error: Could not read raw image at {raw_image_path}")
        return
    
    raw_img_rgb = cv2.cvtColor(raw_img, cv2.COLOR_BGR2RGB)
    
    # Create figure and axis
    fig, ax = plt.subplots(1, figsize=(16, 10))
    
    # Display the raw image
    ax.imshow(raw_img_rgb)
    
    # Add bounding boxes for each detection
    for i, (filename, x, y, w, h) in enumerate(detections):
        # Create a rectangle patch
        rect = Rectangle((x, y), w, h, linewidth=2, edgecolor='r', facecolor='none')
        
        # Add the rectangle to the plot
        ax.add_patch(rect)
        
        # Add label
        ax.text(x, y-5, f"{i+1}: {filename.split('_')[2]}", color='red', fontsize=8, 
                bbox=dict(facecolor='white', alpha=0.7))
    
    # Remove axis ticks
    ax.set_xticks([])
    ax.set_yticks([])
    
    # Set title
    ax.set_title(f"Detected Organisms: {len(detections)}")
    
    # Save the visualization
    plt.tight_layout()
    plt.savefig(output_path, dpi=300, bbox_inches='tight')
    plt.close()
    
    print(f"Visualization saved to {output_path}")
    
    # Also create a version with CV2 for direct image manipulation
    for filename, x, y, w, h in detections:
        cv2.rectangle(raw_img, (x, y), (x+w, y+h), (0, 0, 255), 2)
        # Add a simple label with the organism index
        cv2.putText(raw_img, f"{filename.split('_')[2]}", (x, y-5), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 0, 255), 1)
    
    cv2.imwrite(output_path.replace('.png', '_cv2.jpg'), raw_img)
    
def export_to_yolo_format(raw_image_path, detections, output_folder, class_mapping=None):
    """
    Export detections to YOLO format.
    
    YOLO format: <class> <x_center> <y_center> <width> <height>
    Where all values are normalized between 0 and 1.
    
    Args:
        raw_image_path: Path to the raw plankton image
        detections: List of tuples (organism_filename, x, y, width, height)
        output_folder: Folder to save YOLO annotations
        class_mapping: Dictionary mapping organism filenames to class indices
    """
    # Create output folder if it doesn't exist
    os.makedirs(output_folder, exist_ok=True)
    
    # Get raw image dimensions
    raw_img = cv2.imread(raw_image_path)
    if raw_img is None:
        print(f"Error: Could not read raw image at {raw_image_path}")
        return
    
    img_height, img_width = raw_img.shape[:2]
    
    # If no class mapping provided, create a simple one
    if class_mapping is None:
        # Extract organism types from filenames
        organism_types = set()
        for filename, _, _, _, _ in detections:
            # Try to get organism position from filename (e.g., "base_organism_i_j_w_h.jpg")
            parts = filename.split('_')
            if len(parts) > 2:
                organism_type = parts[2]  # Use the row index as the class
            else:
                organism_type = "organism"  # Default if can't extract from filename
            organism_types.add(organism_type)
        
        # Create mapping from organism types to class indices
        class_mapping = {organism_type: i for i, organism_type in enumerate(sorted(organism_types))}
        
        # Create class names file
        with open(os.path.join(output_folder, 'classes.txt'), 'w') as f:
            for organism_type in sorted(organism_types):
                f.write(f"plankton_class_{organism_type}\n")
    
    # Create YOLO annotation file
    base_name = os.path.splitext(os.path.basename(raw_image_path))[0]
    annotation_path = os.path.join(output_folder, f"{base_name}.txt")
    
    with open(annotation_path, 'w') as f:
        for filename, x, y, width, height in detections:
            # Determine class from filename
            parts = filename.split('_')
            if len(parts) > 2:
                organism_type = parts[2]  # Use the row index as the class
            else:
                organism_type = "organism"
            
            # Get class index
            if organism_type in class_mapping:
                class_idx = class_mapping[organism_type]
            else:
                print(f"Warning: No class mapping for {organism_type}, using 0")
                class_idx = 0
            
            # Convert to YOLO format (normalized)
            x_center = (x + width / 2) / img_width
            y_center = (y + height / 2) / img_height
            w_normalized = width / img_width
            h_normalized = height / img_height
            
            # Write to file
            f.write(f"{class_idx} {x_center:.6f} {y_center:.6f} {w_normalized:.6f} {h_normalized:.6f}\n")
    
    print(f"YOLO annotations saved to {annotation_path}")
    print(f"Class mapping: {class_mapping}")
    
    # Create YOLO dataset configuration
    dataset_config = f"""
# YOLO Dataset Configuration
train: {output_folder}/train
val: {output_folder}/val
test: {output_folder}/test

# number of classes
nc: {len(class_mapping)}

# class names
names: {[f"plankton_class_{idx}" for idx in sorted(class_mapping.keys())]}
"""
    
    with open(os.path.join(output_folder, 'dataset.yaml'), 'w') as f:
        f.write(dataset_config)
    
    print(f"Dataset configuration saved to {os.path.join(output_folder, 'dataset.yaml')}")
    
    # Copy the raw image to the YOLO images folder
    images_folder = os.path.join(output_folder, 'images')
    os.makedirs(images_folder, exist_ok=True)
    
    image_output_path = os.path.join(images_folder, os.path.basename(raw_image_path))
    cv2.imwrite(image_output_path, raw_img)
    
    # Create directory structure for YOLO dataset
    for split in ['train', 'val', 'test']:
        os.makedirs(os.path.join(output_folder, split, 'images'), exist_ok=True)
        os.makedirs(os.path.join(output_folder, split, 'labels'), exist_ok=True)
    
    # Copy the raw image to the train folder as an example
    train_image_path = os.path.join(output_folder, 'train', 'images', os.path.basename(raw_image_path))
    cv2.imwrite(train_image_path, raw_img)
    
    # Copy the annotation to the train folder
    train_annotation_path = os.path.join(output_folder, 'train', 'labels', f"{base_name}.txt")
    with open(annotation_path, 'r') as f_src, open(train_annotation_path, 'w') as f_dst:
        f_dst.write(f_src.read())
    
    print(f"YOLO dataset structure created in {output_folder}")

def process_plankton_images(raw_image_path, sorted_image_path, min_size=50, output_folder="output"):
    """
    Complete workflow: extract organisms from sorted image, locate them in raw image,
    create visualizations and export to YOLO format.
    
    Args:
        raw_image_path: Path to the raw plankton image
        sorted_image_path: Path to the sorted image with organisms on white background
        min_size: Minimum size for an organism to be included
        output_folder: Folder to save all outputs
    """
    # Create base output folder
    os.makedirs(output_folder, exist_ok=True)
    
    # 1. Extract organisms from sorted image (in memory)
    print("1. Extracting organisms from sorted image...")
    extracted_organisms = extract_organisms(sorted_image_path, min_size)
    
    # 2. Find organisms in the raw image
    print("\n2. Finding organisms in raw image...")
    detections = find_organisms_in_raw_image(raw_image_path, extracted_organisms)
    
    if detections:
        # 3. Visualize detections
        print(f"\n3. Creating visualization for {len(detections)} detected organisms...")
        visualize_detections(raw_image_path, detections, 
                           os.path.join(output_folder, "detection_visualization.png"))
        
        # 4. Export to YOLO format
        print("\n4. Exporting to YOLO format...")
        yolo_folder = os.path.join(output_folder, "yolo_dataset")
        export_to_yolo_format(raw_image_path, detections, yolo_folder)
        
        print(f"\nComplete! All results saved to {output_folder}")
    else:
        print("\nNo organisms were detected in the raw image")

In [None]:
"""Main function to run the workflow."""
# Replace these paths with your actual paths
raw_image_path = r"C:\Users\acer\Desktop\Work_IGB\Georgia Zooplankton\igb-georgia\data\Zooplankton scanner raw and sorted images\raw_and_sorted\M3A_2011-08-27__45um_above200um_x1_2400dpi_1-of-3.jpg"  # Path to your raw plankton image
sorted_image_path = r"C:\Users\acer\Desktop\Work_IGB\Georgia Zooplankton\igb-georgia\data\Zooplankton scanner raw and sorted images\raw_and_sorted\M3A_2011-08-27__45um_above200um_x1_2400dpi_1-of-3_sorted.jpg"     # Path to sorted image with organisms
output_folder = r"C:\Users\acer\Desktop\Work_IGB\Georgia Zooplankton\igb-georgia\output"                           # Output folder for results
min_size = 200                                      # Minimum size for an organism

# Process the images
process_plankton_images(raw_image_path, sorted_image_path, min_size, output_folder)