In [None]:
import os
import pandas as pd
import cv2
import numpy as np
from ultralytics import YOLO
import matplotlib.pyplot as plt
from matplotlib.patches import Circle

# Ensure inline plotting in Jupyter Notebook
%matplotlib inline

# Load YOLO model
model_path = '/work/hdd/bczm/sjafarisheshtamad/COMS571/runs/detect/train59/weights/best.pt'
model = YOLO(model_path)

# Paths to images and ground truth folders
images_folder = '/work/hdd/bczm/sjafarisheshtamad/COMS571/IDCIA v2/images'
ground_truth_folder = '/work/hdd/bczm/sjafarisheshtamad/COMS571/IDCIA v2/ground_truth'

# Fixed confidence rate and distance threshold
confidence_rate = 0.25
distance_threshold = 15  # Threshold in pixels

# Sharpening kernel
kernel = np.array([[0, -1, 0],
                   [-1, 5, -1],
                   [0, -1, 0]])

# Initialize variables for overall accuracy calculation
total_correct_predictions = 0
total_ground_truth_cells = 0

# Get the list of all image files in the dataset
image_files = sorted(os.listdir(images_folder))

# Process each image
for image_file in image_files:
    # Skip hidden files or invalid files
    if image_file.startswith('.'):
        continue

    # Load the image
    image_path = os.path.join(images_folder, image_file)
    img = cv2.imread(image_path)
    
    # Apply sharpening filter only
    enhanced_img = cv2.filter2D(img, -1, kernel)
    
    # Corresponding CSV file (assuming same base name as the image file)
    base_name = os.path.splitext(image_file)[0]
    csv_file = os.path.join(ground_truth_folder, f"{base_name}.csv")
    
    # Check if the ground truth CSV exists
    if not os.path.exists(csv_file):
        print(f"Ground truth file not found for image: {image_file}")
        continue
    
    # Load ground truth data (X, Y coordinates)
    ground_truth_data = pd.read_csv(csv_file)
    if 'X' not in ground_truth_data.columns or 'Y' not in ground_truth_data.columns:
        print(f"Ground truth CSV for {image_file} does not have 'X' and 'Y' columns.")
        continue
    
    ground_truth_points = ground_truth_data[['X', 'Y']].values
    
    # Count the number of cells in the ground truth
    ground_truth_cells = len(ground_truth_points)
    total_ground_truth_cells += ground_truth_cells
    
    # Run YOLOv8 detection on the enhanced image with the specified confidence
    results = model(enhanced_img, conf=confidence_rate)
    
    # Extract predicted bounding box centroids
    detected_boxes = results[0].boxes
    detected_centroids = [
        [
            ((box.xyxy[0][0] + box.xyxy[0][2]) / 2).cpu().numpy(),
            ((box.xyxy[0][1] + box.xyxy[0][3]) / 2).cpu().numpy()
        ]
        for box in detected_boxes
    ]
    
    # Match predictions to ground truth based on distance
    correct_predictions = 0
    for gt_point in ground_truth_points:
        for pred_centroid in detected_centroids:
            distance = np.sqrt((gt_point[0] - pred_centroid[0]) ** 2 + (gt_point[1] - pred_centroid[1]) ** 2)
            if distance <= distance_threshold:
                correct_predictions += 1
                break  # Match each ground truth point only once

    total_correct_predictions += correct_predictions
    
    # Plot the image
    plt.figure(figsize=(10, 10))
    fig, ax = plt.subplots(figsize=(10, 10))
    ax.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    ax.scatter(ground_truth_points[:, 0], ground_truth_points[:, 1], c='blue', label='Ground Truth', s=10)
    ax.scatter([centroid[0] for centroid in detected_centroids], 
               [centroid[1] for centroid in detected_centroids], c='red', label='Detected Centroid', s=10)

    # Draw 15-pixel radius circles around ground truth points
    for gt_point in ground_truth_points:
        circle = Circle((gt_point[0], gt_point[1]), distance_threshold, color='blue', fill=False, linestyle='--')
        ax.add_patch(circle)
    
    plt.title(f"Image: {image_file} | Correct Predictions: {correct_predictions}/{ground_truth_cells}")
    plt.legend(loc='upper right', labels=['Ground Truth', 'Detected Centroid'])
    plt.axis('off')
    plt.show()

# Calculate overall accuracy for all images
overall_accuracy = (total_correct_predictions / total_ground_truth_cells) * 100 if total_ground_truth_cells > 0 else 0

# Print the final accuracy
print("\n--- Final Results ---")
print(f"Distance Threshold: {distance_threshold} pixels")
print(f"Overall Accuracy: {overall_accuracy:.2f}%")