In [None]:
#import necessary libraries
from ultralytics import YOLO
import torch
from PIL import Image
import cv2
from segment_anything import SamPredictor, sam_model_registry
import numpy as np
import matplotlib.pyplot as plt
import os
import pandas as pd
import glob
import json

In [None]:
# save the SAM masks for each image to determine accuracy

# set the path to the folder containing the images
folder_path = #path to folder containing images
output_folder = #path to folder to save segmetnation masks

# load the custom YOLO model
weight = #path to custom YOLO object detection model
model = YOLO(weight)

data = []

# loop over all the files in the folder
for filename in os.listdir(folder_path):
    # check if the file is an image
    if filename.endswith('.png') or filename.endswith('.jpg'):
         # extract isolate, cultivar, rep, and DAI information from the filename
        isolate, cultivar, rep, head, DAI = filename.split('_')[:5]
        # run the model on the image
        image_path = os.path.join(folder_path, filename)
        results = model.predict(image_path, conf=0.5)

        # extract bounding box coordinates from the detection
        for result in results:
            boxes = result.boxes
            if len(boxes) > 0:
                bbox = boxes.xyxy.tolist()[0]
                device = "cuda"

        # run the segmentation model on the bounding box
        image = cv2.cvtColor(cv2.imread(image_path), cv2.COLOR_BGR2RGB)
        sam_checkpoint = "sam_vit_h_4b8939.pth"
        model_type = "vit_h"
        sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
        sam.to(device=device)
        predictor = SamPredictor(sam)
        predictor.set_image(image)
        input_box = np.array(bbox)
        masks, _, _ = predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_box[None, :],
            multimask_output=False,
        )
        # Save the mask to file along with the original file name and the coordinates of the bounding box
        mask_filename = f'{filename}.png'
        mask_path = os.path.join(output_folder, mask_filename)

        # Append the image info to the data list
        image_info = {
            'filename': filename,
            'x1': bbox[0],
            'y1': bbox[1],
            'x2': bbox[2],
            'y2': bbox[3],
            'mask_path': mask_path
        }
        data.append(image_info)

        # Convert the array to a PIL image
        mask_image = Image.fromarray((masks[0] * 255).astype(np.uint8))

        # Save the mask to file
        mask_image.save(mask_path)

        # Convert the PIL image to a numpy array
        mask = np.array(mask_image)

        # Convert the image to a supported data type
        mask = mask.astype(np.uint8)

        # Save the mask to file
        cv2.imwrite(mask_path, mask)

# Create a DataFrame from the collected data
df = pd.DataFrame(data)

In [None]:
#save ground truth annotations to individual image files

# Set the paths to the annotation and output mask folders
image_folder = #path to folder containing images
annotation_file = #path to json file containing hand annotations for the wheat heads
output_mask_folder =  #path to folder to save binary masks

# Loop through all the annotation files in the folder
with open(annotation_file, 'r') as f:
    annotations = json.load(f)

# Loop over the images
for image in annotations['images']:
    # Load the image
    image_path = os.path.join(image_folder, image['file_name'])
    img = cv2.imread(image_path)

    # Create a binary mask from the annotations
    mask = np.zeros_like(img[:, :, 0], dtype=np.uint8)
    for annotation in annotations['annotations']:
        if annotation['image_id'] == image['id']:
            vertices = np.array(annotation['segmentation'], dtype=np.int32).reshape((-1, 2))
            cv2.fillPoly(mask, [vertices], 255)

    # Overlay the mask on the image
    overlay = cv2.addWeighted(img, 0.5, cv2.cvtColor(mask, cv2.COLOR_GRAY2BGR), 0.5, 0)

    # Convert the mask to binary
    ret, binary_mask = cv2.threshold(mask, 127, 255, cv2.THRESH_BINARY)

    # Save the binary mask to file
    binary_mask_path = os.path.join(output_mask_folder, image['file_name'].replace('.jpg', '_binary_mask.png'))
    cv2.imwrite(binary_mask_path, binary_mask)

In [None]:
# Set the paths to the ground truth and predicted mask folders
gt_mask_folder = # path to folder containing ground truth masks
pred_mask_folder = # path to folder containing predicted SAM masks

# Get the list of image files
gt_mask_files = os.listdir(gt_mask_folder)
pred_mask_files = os.listdir(pred_mask_folder)

# Create a table to store the IoU data
data = []
for i, gt_mask_file in enumerate(gt_mask_files):
    pred_mask_file = pred_mask_files[i]
    assert gt_mask_file == pred_mask_file, f"Ground truth mask file {gt_mask_file} does not match predicted mask file {pred_mask_file}"
    image_name = os.path.splitext(gt_mask_file)[0]
    gt_mask_path = os.path.join(gt_mask_folder, gt_mask_file)
    pred_mask_path = os.path.join(pred_mask_folder, pred_mask_file)
    gt_mask = cv2.imread(gt_mask_path, cv2.IMREAD_GRAYSCALE)
    pred_mask = cv2.imread(pred_mask_path, cv2.IMREAD_GRAYSCALE)
    intersection = cv2.bitwise_and(gt_mask, pred_mask)
    union = cv2.bitwise_or(gt_mask, pred_mask)
    iou = cv2.countNonZero(intersection) / cv2.countNonZero(union)
    data.append({'image_name': image_name, 'iou': iou})

# Calculate the overall IoU
iou_values = [d['iou'] for d in data]
overall_iou = sum(iou_values) / len(iou_values)

# Create a DataFrame from the collected data
df = pd.DataFrame(data)
df

# save the dataframe to a csv file
df.to_csv(#path to save csv file)