### 2-1-LiveBeeLabelCropper.ipynb



In [22]:
import matplotlib.pyplot as plt
from pathlib import Path
from tqdm import tqdm
import seaborn as sns
import numpy as np
import shutil
import torch
import os
import cv2

from PIL import Image
from scipy.ndimage import binary_fill_holes

from sam2.build_sam import build_sam2
from sam2.sam2_image_predictor import SAM2ImagePredictor

In [3]:
# Define directories
data_dir = Path("/mnt/g/Projects/Master/Data/")

input_dir = data_dir / "Raw" / "LiveBees" 
output_dir = data_dir / "Processed" / "LiveBees" / "1-LiveWingLabelCrops" 

# select the segment anything model
sam2_checkpoint = "/home/wsl/bin/segment-anything-2/checkpoints/sam2_hiera_large.pt"
model_cfg = "sam2_hiera_l.yaml"

DEBUG = True

In [17]:
def find_white_area(image, y_coord, window_size, step_size, density_threshold):
    h, w = image.shape
    max_density = -1
    best_coords = (0, 0)

    # Start searching from the center x-coordinate
    center_x = w // 2

    # Radius-based search around the center x-coordinate
    for radius in range(0, w // 2, step_size):
        # Check positions to the left and right of the center within the current radius
        for dx in range(-radius, radius + 1, step_size):
            for direction in [-1, 1]:  # -1 for left, 1 for right
                x = center_x + dx * direction

                # Ensure the window is within bounds horizontally
                if 0 <= x <= w - window_size and 0 <= y_coord <= h - window_size:
                    # Extract a square window from the image
                    window = image[y_coord:y_coord + window_size, x:x + window_size]

                    # Count the number of white pixels
                    white_pixel_count = np.sum(window >= 120)

                    # Calculate density (fraction of white pixels in the window)
                    density = white_pixel_count / (window_size * window_size)

                    # Track the window with the maximum density of white pixels
                    if density > max_density:
                        max_density = density
                        best_coords = (x, y_coord)

                    # Early termination if a good enough density is found
                    if density >= density_threshold:
                        return best_coords

    return best_coords
    

def identify_label(image, sampling_coords):
    input_point = np.array(sampling_coords)
    input_label = np.array([1] * len(sampling_coords))
    
    predictor.set_image(image)

    masks, scores, _ = predictor.predict(
    point_coords=input_point,
    point_labels=input_label,
    multimask_output=False,
    )
    mask = masks[0]

    # Fill holes in the mask
    mask = binary_fill_holes(mask).astype(int)
    
    return mask

def crop_from_mask(mask, image):
    # Identification of the label
    # Convert mask to 8-bit single channel
    mask = mask.astype(np.uint8)
    
    # Find contours
    mask_contours, _ = cv2.findContours(mask, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    
    # Select the largest contour 
    mask_contour = max(mask_contours, key=cv2.contourArea)
    
    # Calculate the minimum area bounding box
    mask_rect = cv2.minAreaRect(mask_contour)
    
    # Get the box points and convert them to integer coordinates
    mask_box_points = cv2.boxPoints(mask_rect)
    mask_box_points = np.intp(mask_box_points)

    # Swap width and height if necessary to make the longer side horizontal
    center, size, angle = mask_rect
    if size[0] < size[1]:
        angle += 90
        size = (size[1], size[0])
    
    # Get the rotation matrix
    rotation_matrix = cv2.getRotationMatrix2D(center, angle, 1.0)
    
    # Rotate the entire image to align the rectangle horizontally
    height, width = mask.shape[:2]
    rotated_image = cv2.warpAffine(image, rotation_matrix, (width, height), flags=cv2.INTER_LINEAR, borderValue=(255, 255, 255))

    # Calculate the bounding box of the rotated rectangle in the rotated image
    x, y, w, h = cv2.boundingRect(np.intp(cv2.transform(np.array([mask_box_points]), rotation_matrix))[0])

    # Crop the aligned rectangle with white padding for any areas outside the original image
    cropped_image = rotated_image[y:y+h, x:x+w]
    
    return mask_box_points, cropped_image

    

def crop_label(image, output_dir, jpg_file, coords=None):
    gray = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    blurred = cv2.medianBlur(gray, 5)

    # Take manually chosen coorinates or automaticly identify coorinates
    if coords:
        coords_list = coords
    else:
        label_coords = find_white_area(blurred, y_coord=1000, window_size=200, step_size=200, density_threshold=0.99)
        modified_coord = (label_coords[0] + 100, label_coords[1] + 100)
        coords_list = [label_coords, modified_coord]
        
    mask = identify_label(image, coords_list)
    
    mask_box_points, cropped_image = crop_from_mask(mask, image)

    if DEBUG: 
        # Save cropped label
        label_dir = output_dir / "Labels"
        os.makedirs(label_dir, exist_ok=True)
        label = Image.fromarray(cropped_image)
        label.save(label_dir / jpg_file)

        # Create an image directory
        image_dir = output_dir / "Process"
        os.makedirs(image_dir, exist_ok=True)
        
        # New 4 channel image (RGBA)
        png_image = np.zeros((image.shape[0], image.shape[1], 4), dtype=np.uint8)
        
        # Apply the color to each channel (R, G, B)
        for c in range(3):
            png_image[:, :, c] = (mask * (1, 0, 0)[c] * 255).astype(np.uint8)
        
        # Set the alpha channel: 255 where the mask is present, 0 elsewhere
        png_image[:, :, 3] = (mask * 255).astype(np.uint8)
    
        # Draw contours on the image for visualization
        label_image = image.copy()
        cv2.drawContours(label_image, [mask_box_points], 0, (255, 0, 0), 40)

        # Create a 1x3 grid of images
        fig, axes = plt.subplots(nrows=1, ncols=3, figsize=(30, 20))
        # Show mask
        x_coords, y_coords = zip(*coords_list)
        axes[0].imshow(image)
        axes[0].imshow(png_image, alpha=0.6)
        axes[0].scatter(x_coords, y_coords, c="red", s=20, edgecolor='black')
        axes[0].axis("off")
        # Show rectangle
        axes[1].imshow(label_image)
        axes[1].axis("off")
        # Show cropped image  
        axes[2].imshow(cropped_image)
        axes[2].axis("off")
        plt.savefig(image_dir / jpg_file)
        plt.close()
    else:
        # Save cropped label
        os.makedirs(output_dir, exist_ok=True)
        label = Image.fromarray(cropped_image)
        label.save(output_dir / jpg_file)

In [11]:
# select the device for computation
if torch.cuda.is_available():
    device = torch.device("cuda")
elif torch.backends.mps.is_available():
    device = torch.device("mps")
else:
    device = torch.device("cpu")
print(f"using device: {device}")

# Select sam model
sam2_model = build_sam2(model_cfg, sam2_checkpoint, device=device)
predictor = SAM2ImagePredictor(sam2_model)

# Color palette
sns_colors = sns.color_palette("hls", 8)

using device: cpu


In [4]:
try:
    # Ensure the input directory exists
    if not os.path.exists(input_dir):
        raise FileNotFoundError(f"Input directory '{input_dir}' was not found.")
    
    # Create the output directories
    if os.path.exists(output_dir):
        print("WARNING: Output directory already exists.") 
    os.makedirs(output_dir, exist_ok=True)
    
    # Find all jpg files
    jpg_files = []
    for root, _, files in os.walk(input_dir):
        for file in files:
            if file.endswith(".JPG") or file.endswith(".jpg"):
                jpg_files.append(os.path.join(root, file))
    
    # Loop through every file
    for jpg_file_path in tqdm(jpg_files, desc="Processing files", ncols=145):
        jpg_basename = os.path.basename(jpg_file_path)
        relative_jpg_path = str(jpg_file_path).removeprefix(str(input_dir)).lstrip("/")
        new_jpg_basename = relative_jpg_path.replace("/", "-")
        
        # Skip if output file exists
        if os.path.exists(output_dir / "Labels" / new_jpg_basename):
            continue   
            
        # Process the file
        image = cv2.imread(jpg_file_path)
        image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
        crop_label(image, output_dir, new_jpg_basename)   
        cv2.destroyAllWindows()
        
# Handle exceptions
except FileNotFoundError as e:
    print(e)
    
except KeyboardInterrupt:
    pass

using device: cpu


Processing files: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1194/1194 [03:41<00:00,  5.39it/s]


In [19]:
# Define directory for images with manual selected coordinates
manual_dir = data_dir / "Processed" / "LiveBees" / "2-LiveWingLabelCropsManuallyImproved"

In [25]:
coords = {"Round01-Hive01-2024_06_05-h01bee36": ((4324, 1399), (4869, 1814)),
          "Round02-hive14-2024_06_20-h14b42": ((3229, 889), (3559, 1304)),
          "Round02-hive14-2024_06_27-h14b23": ((3064, 1079), (3374, 1304)),
          "Round02-hive14-2024_06_27-h14b30": ((2999, 1194), (3359, 1324)),
          "Round04-hive34-2024_07_19-h34b25": ((3124, 634), (3514, 919)),
          "Round04-hive35-2024_07_23-h35b27": ((2809, 1034), (3339, 1409))}
            
try:
    # Create the new output directories
    if os.path.exists(manual_dir):
        print("WARNING: Output directory already exists.") 
    os.makedirs(manual_dir, exist_ok=True)
    
    # Find all jpg files
    jpg_files = []
    for root, _, files in os.walk(input_dir):
        for file in files:
            if file.endswith(".JPG") or file.endswith(".jpg"):
                jpg_files.append(os.path.join(root, file))
    
    # Loop through every file
    for jpg_file_path in tqdm(jpg_files, desc="Processing files", ncols=145):
        jpg_basename = os.path.basename(jpg_file_path)
        relative_jpg_path = str(jpg_file_path).removeprefix(str(input_dir)).lstrip("/")
        new_jpg_basename = relative_jpg_path.replace("/", "-")
        
        # Skip if output file exists
        if os.path.exists(manual_dir / "Labels" / new_jpg_basename):
            continue

        if new_jpg_basename.removesuffix(".JPG") in coords.keys():
            # Process the file using manual coordinates
            coord = coords[new_jpg_basename.removesuffix(".JPG")]
            image = cv2.imread(jpg_file_path)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            crop_label(image, manual_dir, new_jpg_basename, coord)   
            cv2.destroyAllWindows()
    
        else:
            # Copy the image
            source_jpg = output_dir / "Labels" / new_jpg_basename
            destination_jpg = manual_dir / "Labels" / new_jpg_basename
            shutil.copy(source_jpg, destination_jpg)

# Handle exceptions
except FileNotFoundError as e:
    print(e)
    
except KeyboardInterrupt:
    pass



Processing files: 100%|██████████████████████████████████████████████████████████████████████████████████████| 1194/1194 [00:32<00:00, 37.00it/s]
