In [1]:
from ultralytics import YOLO
import os
from PIL import Image
Image.MAX_IMAGE_PIXELS = 1000000000000  # Allow large images to load
import numpy as np
import torch
import cv2
from pyometiff import OMETIFFReader
import sys
from torchvision.ops import nms
import math
import tifffile
import time







# Settings
input_folder = "F:\\Users\\labo\\LVerschuren\\gigapixelwoodbot beech scan\\test"  # Folder containing input images
output_folder = "F:\\Users\\labo\\LVerschuren\\gigapixelwoodbot beech scan\\seg"  # Folder for output images
model_weights = 'D:\\Users\\labo\\Documents\\JanVdB\\yolo code\\segment\\train7\\weights\\best.pt' # file containing the best model weights
object_class = 1 # indices of class to segment for example 0 (RAYS), 1 (VESSELS): 0 if only one class in the model
sub_image_size = 640  # User-defined sub-image size, should be 640 for YOLOv8
overlap_percent = 0.40  # User-defined overlap, 0.1 = 10% overlap
confidence = 0.3  # (0.4 vessels, 0.3 rays) Sets the minimum confidence threshold for detections. Objects detected with confidence below this threshold will be disregarded. Adjusting this value can help reduce false positives.
IntersectionOverUnion = 0.8  # Intersection Over Union (IoU) threshold for Non-Maximum Suppression (NMS). Higher values result in fewer detections by eliminating overlapping boxes, useful for reducing duplicates.
saveExtraImages = False # if True it will save the original image with bounding boxes and masks (takes a lot of RAM for large images!) 
countBoxes = True # if True it will calculate the number of bounding boxes to count the number of detected objects


# Initialize counters
total_images_processed = 0
total_subimages_processed = 0

def process_sub_image(sub_image, yolo_model, full_mask, confidence, IntersectionOverUnion, x, y, sub_image_size, bounding_boxes):
    global total_subimages_processed
    try:
        # Run YOLOv8 prediction on the sub-image
        results = yolo_model.predict(sub_image, save=False, save_txt=False, save_conf=False, conf=confidence, iou=IntersectionOverUnion, verbose=False, retina_masks = True)

        # Check if any objects were detected
        if results and len(results) > 0:
            masks = results[0].masks.data
            boxes = results[0].boxes.data
            clss = boxes[:, 5]
            vessel_indices = torch.where(clss == object_class)
            vessel_masks = masks[vessel_indices]
            vessel_masks = torch.any(vessel_masks, dim=0).int() * 255
            vessel_mask = np.array(vessel_masks.cpu())

            if countBoxes == True:
                # Extract bounding boxes for the class
                for box in boxes[vessel_indices]:
                    x1, y1, x2, y2, conf, cls = box
                    # Check if the bounding box is not at the edge of the sub-image
                    if x1 > 1 and y1 > 1 and x2 < sub_image_size-1 and y2 < sub_image_size-1:
                        bounding_boxes.append([x + int(x1), y + int(y1), x + int(x2), y + int(y2), conf.item()])
            else: 
                for box in boxes[vessel_indices]:
                    x1, y1, x2, y2, conf, cls = box
                    bounding_boxes.append([x + int(x1), y + int(y1), x + int(x2), y + int(y2), conf.item()])
                    
        else:
            vessel_mask = np.zeros((sub_image_size, sub_image_size))

    except Exception:
        vessel_mask = np.zeros((sub_image_size, sub_image_size))

    roi = full_mask[y:y+sub_image_size, x:x+sub_image_size]
    overlay_mask = vessel_mask[:, :]
    final_image = np.maximum(roi, overlay_mask)
    full_mask[y:y+sub_image_size, x:x+sub_image_size] = final_image





def apply_nms_in_windows(bounding_boxes, img_width, img_height, window_size, IntersectionOverUnion, overlap_percent):
    window_size = int(window_size * 10.5)  # bigger window size
    
    numx = math.ceil((img_width+50)/(window_size*(1-overlap_percent*1.5)))
    stridex = (img_width+50)/numx
    numy = math.ceil((img_height+50)/(window_size*(1-overlap_percent*1.5)))
    stridey = (img_height+50)/numy

    counterAreas = 0
    totalAreas = numx * numy
    sys.stdout.write(f"\nSubareas filtered: {counterAreas}/{totalAreas}")
    sys.stdout.flush()
    
    for y in np.linspace(-25, img_height+25, num = numy):
        for x in np.linspace(-25, img_width+25, num = numx):
            # Define window bounds
            x_min, y_min = x, y
            x_max, y_max = x + window_size, y + window_size
            
            # Collect boxes in the window (including boxes that cross the window boundary)
            window_boxes = [box for box in bounding_boxes if not (box[2] < x_min or box[0] > x_max or box[3] < y_min or box[1] > y_max)]

            # Remove these boxes from the original list
            bounding_boxes = [box for box in bounding_boxes if (box[2] < x_min or box[0] > x_max or box[3] < y_min or box[1] > y_max)]
            
            if window_boxes:
                bbox_array = np.array(window_boxes)
                bbox_coords = bbox_array[:, :4]
                bbox_scores = bbox_array[:, 4]

                # Apply Non-Maximum Suppression
                keep_indices = nms(torch.tensor(bbox_coords), torch.tensor(bbox_scores), 1 - IntersectionOverUnion)
                bounding_boxes.extend([window_boxes[i] for i in keep_indices])

            # counter
            counterAreas += 1
            sys.stdout.write(f"\rSubareas filtered: {counterAreas}/{totalAreas}")
            sys.stdout.flush()

    # Remove duplicates
    bounding_boxes = np.array(bounding_boxes)
    _, unique_indices = np.unique(bounding_boxes[:, :4], axis=0, return_index=True)
    bounding_boxes = bounding_boxes[unique_indices]
    
    return bounding_boxes.tolist()




# Start timing
start_time = time.time()



if __name__ == "__main__":
    # Load YOLOv8 model
    model = YOLO(model_weights)

    # List of files to process
    file_list = [f for f in os.listdir(input_folder) if f.lower().endswith((".jpg", ".png", ".jpeg", ".tif", ".tiff", ".ome.tif"))]
    total_images = len(file_list)

    # Process multiple images
    for idx, filename in enumerate(file_list):
        print(f"\nProcessing {filename} ({idx+1}/{total_images})")
        input_image_path = os.path.join(input_folder, filename)
        if filename.lower().endswith((".ome.tif")):
            reader = OMETIFFReader(input_image_path)
            img, metadata, xml_metadata = reader.read()
#            img = Image.fromarray(img)   
            img_width, img_height = img.shape[1], img.shape[0]
        else:
            img = Image.open(input_image_path)
            img_width, img_height = img.size
            
        full_mask = np.zeros((img_height, img_width), dtype=np.uint8) # Initialize full binary mask
        bounding_boxes = []

        overlap = int(sub_image_size * overlap_percent)
        stride = sub_image_size - overlap

        # Adjust stride to fit within image dimensions
        stride_x = min(stride, img_width - sub_image_size)
        stride_y = min(stride, img_height - sub_image_size)

        total_subimages = ((img_height // stride_y) * (img_width // stride_x)) + (img_width // stride_x) + (img_height // stride_y) + 1
        subimages_done = 0

        for y in range(0, img_height - sub_image_size + 1, stride_y):
            for x in range(0, img_width - sub_image_size + 1, stride_x):
                box = (x, y, x + sub_image_size, y + sub_image_size)
                left, upper, right, lower = box
                if filename.lower().endswith((".ome.tif")):
                    sub_img = img[upper:lower, left:right] # img[upper:lower, left:right]
                    sub_img = Image.fromarray(sub_img)
                else:
                    sub_img = img.crop(box) # img.crop((left, upper, right, lower))
                process_sub_image(sub_img, model, full_mask, confidence, IntersectionOverUnion, x, y, sub_image_size, bounding_boxes)
                subimages_done += 1
                sys.stdout.write(f"\rSubimages processed: {subimages_done}/{total_subimages}")
                sys.stdout.flush()

        # pass over bottom row
        for x in range(0, img_width - sub_image_size + 1, stride_x):
            box = (x, img_height - sub_image_size, x + sub_image_size, img_height)
            left, upper, right, lower = box
            if filename.lower().endswith((".ome.tif")):
                sub_img = img[upper:lower, left:right] # img[upper:lower, left:right]
                sub_img = Image.fromarray(sub_img)
            else:
                sub_img = img.crop(box)
            process_sub_image(sub_img, model, full_mask, confidence, IntersectionOverUnion, x, img_height - sub_image_size, sub_image_size, bounding_boxes)
            subimages_done += 1
            sys.stdout.write(f"\rSubimages processed: {subimages_done}/{total_subimages}")
            sys.stdout.flush()

        # pass over right row 
        for y in range(0, img_height - sub_image_size + 1, stride_y):
            box = (img_width - sub_image_size , y, img_width, y + sub_image_size)
            left, upper, right, lower = box
            if filename.lower().endswith((".ome.tif")):
                sub_img = img[upper:lower, left:right] # img[upper:lower, left:right]
                sub_img = Image.fromarray(sub_img)
            else:
                sub_img = img.crop(box)
            process_sub_image(sub_img, model, full_mask, confidence, IntersectionOverUnion, img_width - sub_image_size , y, sub_image_size, bounding_boxes)
            subimages_done += 1
            sys.stdout.write(f"\rSubimages processed: {subimages_done}/{total_subimages}")
            sys.stdout.flush()

        # pass over right bottom corner
        box = (img_width - sub_image_size , img_height - sub_image_size, img_width, img_height)
        left, upper, right, lower = box
        if filename.lower().endswith((".ome.tif")):
            sub_img = img[upper:lower, left:right] # img[upper:lower, left:right]
            sub_img = Image.fromarray(sub_img)
        else:
            sub_img = img.crop(box)
        process_sub_image(sub_img, model, full_mask, confidence, IntersectionOverUnion, img_width - sub_image_size, img_height - sub_image_size, sub_image_size, bounding_boxes)
        subimages_done += 1
        sys.stdout.write(f"\rSubimages processed: {subimages_done}/{total_subimages}")
        sys.stdout.flush()

        # write mask
        mask_output_path = os.path.join(output_folder, f"mask_{filename}")
        cv2.imwrite(mask_output_path, full_mask)

        
        # After processing all sub-images Convert the bounding boxes to a format suitable for NMS

        if countBoxes == True:
            if len(bounding_boxes) > 0:
                
                # Save unfiltered bounding boxes
                bbox_output_path = os.path.join(output_folder, f"unfiltered_bboxes_{filename}.txt")
                with open(bbox_output_path, 'w') as f:
                    for bbox in bounding_boxes:
                        f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{bbox[4]}\n")
            
                # Apply Non-Maximum Suppression as a moving window
                filtered_bounding_boxes = apply_nms_in_windows(bounding_boxes, img_width, img_height, sub_image_size, IntersectionOverUnion, overlap_percent)
    
                # print the ammount of bounding boxes
                print(f"\nNumber of bounding boxes: {len(filtered_bounding_boxes)}")
            
                # Save filtered bounding boxes
                bbox_output_path = os.path.join(output_folder, f"filtered_bboxes_{filename}.txt")
                with open(bbox_output_path, 'w') as f:
                    for bbox in filtered_bounding_boxes:
                        f.write(f"{bbox[0]},{bbox[1]},{bbox[2]},{bbox[3]},{bbox[4]}\n")
                
            else: 
                print("\nNo bounding boxes found.")

        if saveExtraImages == True:
            # convert to np
            if not filename.lower().endswith((".ome.tif")):
                img = np.array(img)
            full_mask = np.array(full_mask)

            # Ensure both img and full_mask are of type uint8
            img = img.astype(np.uint8)
            full_mask = full_mask.astype(np.uint8)
            
            # Convert the mask to 3 channels
            background = np.zeros_like(full_mask)
            full_mask = cv2.merge([full_mask, background, background])

            # Discard the alpha channel in original image if present
            if img.shape[2] == 4:
                img = img[:, :, :3]

            # Overlay the mask on the image
            overlay = cv2.addWeighted(img, 1, full_mask, 0.3, 0)
            # Write image with mask
            ImageAndMask_output_path = os.path.join(output_folder, f"imageAndMask_{filename}")
            cv2.imwrite(ImageAndMask_output_path, cv2.cvtColor(overlay, cv2.COLOR_RGB2BGR)) 

            if countBoxes == True: 
                if len(bounding_boxes) > 0:
                    # Ensure img is contiguous
                    if not img.flags['C_CONTIGUOUS']:
                        img = np.ascontiguousarray(img)
                
                    # Draw bounding boxes on the original image
                    for bbox in filtered_bounding_boxes:
                        x1, y1, x2, y2 = map(int, bbox[:4])
                        
                        # Clip the bounding box coordinates to be within the image dimensions
                        x1 = np.clip(x1, 0, img.shape[1] - 1)
                        y1 = np.clip(y1, 0, img.shape[0] - 1)
                        x2 = np.clip(x2, 0, img.shape[1] - 1)
                        y2 = np.clip(y2, 0, img.shape[0] - 1)
                        
                        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 1)
            
                    # Write image with bounding boxes
                    if countBoxes == True:
                        ImageAndBoxes_output_path = os.path.join(output_folder, f"imageAndBoxes_{filename}")
                        cv2.imwrite(ImageAndBoxes_output_path, cv2.cvtColor(img, cv2.COLOR_RGB2BGR))
                

        # Update the image counter
        total_images_processed += 1

    
    # End timing
    end_time = time.time()
    # Calculate elapsed time in hours
    elapsed_time_in_hours = (end_time - start_time) / 3600


    
    print("\nJob is done")
    print(f"Total images processed: {total_images_processed}/{total_images}")
    print(f"Elapsed Time: {elapsed_time_in_hours:.3f} hours.")



Processing corrected_tile_4_56.tiff (1/3)
Subimages processed: 81/81
No bounding boxes found.

Processing corrected_tile_7_22.tiff (2/3)
Subimages processed: 81/81
Subareas filtered: 4/4
Number of bounding boxes: 7164

Processing corrected_tile_7_222.tiff (3/3)
Subimages processed: 384/425
Subareas filtered: 12/12
Number of bounding boxes: 7374

Job is done
Total images processed: 3/3
This took 0.0120 hours.


In [None]:
# merge two masks

from PIL import Image
Image.MAX_IMAGE_PIXELS = 1000000000000  # Allow large images to load
import numpy as np
import cv2
import tifffile as tiff

def load_mask(file_path):
    # Load the binary mask as a grayscale image
    mask = Image.open(file_path).convert('L')
    mask_np = np.array(mask)
    # Ensure the mask is binary
    _, binary_mask = cv2.threshold(mask_np, 127, 255, cv2.THRESH_BINARY)
    return binary_mask

def merge_masks(mask1, mask2):
    # Create an empty color image
    combined_image = np.zeros((mask1.shape[0], mask1.shape[1], 3), dtype=np.uint8)
    
    # Add cyan color to the first mask
    combined_image[mask1 > 0] = [0, 255, 255]  # Cyan color (R, G, B order in Pillow)
    
    # Add red color to the second mask
    combined_image[mask2 > 0] = [255, 0, 0]  # Red color
    
    # Where both masks overlap, make it purple (red + cyan)
    overlap = np.bitwise_and(mask1, mask2)
    combined_image[overlap > 0] = [255, 0, 255]  # Purple color
    
    return combined_image

def save_image(image, file_path):
    # Save the image using tifffile
    tiff.imwrite(file_path, image)
    print(f"Image saved to {file_path}")

def main(mask1_path, mask2_path, output_path):
    mask1 = load_mask(mask1_path)
    mask2 = load_mask(mask2_path)
    
    combined_image = merge_masks(mask1, mask2)
    
    save_image(combined_image, output_path)

# Example usage
if __name__ == "__main__":
    mask1_path = 'F:\\Users\\labo\\LVerschuren\\gigapixelwoodbot beech scan\\segmented image\\vessels_mask_crop.tif'
    mask2_path = 'F:\\Users\\labo\\LVerschuren\\gigapixelwoodbot beech scan\\segmented image\\rays_mask_crop.tif'
    output_path = 'F:\\Users\\labo\\LVerschuren\\gigapixelwoodbot beech scan\\segmented image\\merged_mask_crop.ome.tif'
    
    main(mask1_path, mask2_path, output_path)