In [3]:
from ultralytics import YOLO

model = YOLO('best.pt') 
model.info() 


YOLOv10m summary: 288 layers, 16,487,602 parameters, 0 gradients, 64.0 GFLOPs


(288, 16487602, 0, 63.9804672)

In [6]:
import time
from ultralytics import YOLO
import cv2
import numpy as np

model = YOLO("best.pt") 

img = cv2.imread("selected/images/1-11_frame_0.jpg") 
if img is None:
    img = 255 * np.ones((640, 640, 3), dtype=np.uint8)  

model.predict(img)

n = 50
start = time.time()

for _ in range(n):
    results = model.predict(img, verbose=False)

end = time.time()
total_time = end - start
fps = n / total_time

print(f"Inference FPS: {fps:.2f}")


0: 768x1024 44 normal sperms, 945.6ms
Speed: 9.6ms preprocess, 945.6ms inference, 3.2ms postprocess per image at shape (1, 3, 768, 1024)
Inference FPS: 1.04


## First atttempt, detects, did not count

In [10]:
from ultralytics import YOLO
import cv2
from PIL import Image
import numpy as np

image_path = "selected/images/11_frame_0.jpg"
model = YOLO("best.pt")  

results = model(image_path)

for result in results:
    boxes = result.boxes 
    
    img = cv2.imread(image_path)
    
    for box in boxes:
        x1, y1, x2, y2 = map(int, box.xyxy[0])
        
        cls_id = int(box.cls[0])
        conf = float(box.conf[0])
        
        cls_name = model.names[cls_id]
        cv2.rectangle(img, (x1, y1), (x2, y2), (0, 255, 0), 2)
        
        label = f"{cls_name}: {conf:.2f}"
        cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, (0, 255, 0), 2)
    
    p = f"{image_path}_result.jpg"
    cv2.imwrite(p, img)


image 1/1 /home/sofi/Documents/GitThes/Code/selected/images/11_frame_0.jpg: 768x1024 44 normal sperms, 1065.1ms
Speed: 7.4ms preprocess, 1065.1ms inference, 0.4ms postprocess per image at shape (1, 3, 768, 1024)


## Second attempt, detects, counts, for single image or a directory of images

In [13]:
from ultralytics import YOLO
import cv2
import numpy as np
import os
from collections import Counter

def process_sperm_image(image_path, model_path="best.pt", conf_threshold=0.25, save_result=True):
    model = YOLO(model_path)
    results = model(image_path, conf=conf_threshold)
    counts = Counter()
    img = cv2.imread(image_path)
    if img is None:
        raise ValueError(f"Could not read image at {image_path}")
    
    colors = {
        'normal sperm': (0, 255, 0),     # Green
        'small sperm': (0, 255, 255),    # Yellow
        'pinhead': (0, 165, 255),        # Orange
        'cluster': (0, 0, 255)           # Red
    }
    
    for result in results:
        boxes = result.boxes  # Bounding boxes
        
        for box in boxes:
            x1, y1, x2, y2 = map(int, box.xyxy[0])
            cls_id = int(box.cls[0])
            conf = float(box.conf[0])
            cls_name = model.names[cls_id]
            counts[cls_name] += 1
            
            color = colors.get(cls_name.lower(), (0, 255, 0))
            cv2.rectangle(img, (x1, y1), (x2, y2), color, 2)
            
            label = f"{cls_name}: {conf:.2f}"
            cv2.putText(img, label, (x1, y1 - 10), cv2.FONT_HERSHEY_SIMPLEX, 0.5, color, 2)
    
    summary_text = []
    y_pos = 30
    for i, (cls_name, count) in enumerate(counts.items()):
        color = colors.get(cls_name.lower(), (0, 255, 0))
        summary_text.append(f"{cls_name}: {count}")
        cv2.putText(img, f"{cls_name}: {count}", (10, y_pos + i*30), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
    
    total_count = sum(counts.values())
    if total_count > 0:
        y_pos += len(counts) * 30 + 10
        cv2.putText(img, f"Total count: {total_count}", (10, y_pos), 
                    cv2.FONT_HERSHEY_SIMPLEX, 0.7, (255, 255, 255), 2)
        
        y_pos += 30
        for cls_name, count in counts.items():
            percentage = (count / total_count) * 100
            color = colors.get(cls_name.lower(), (0, 255, 0))
            cv2.putText(img, f"{cls_name}: {percentage:.1f}%", (10, y_pos), 
                        cv2.FONT_HERSHEY_SIMPLEX, 0.7, color, 2)
            y_pos += 30
    
    if save_result:
        output_path = f"{os.path.splitext(image_path)[0]}_analyzed.jpg"
        cv2.imwrite(output_path, img)
        print(f"Saved result to {output_path}")
    
    return counts, img

def process_directory(directory_path, model_path="best.pt", conf_threshold=0.25):
    if not os.path.isdir(directory_path):
        raise ValueError(f"{directory_path} is not a valid directory")
    
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tif', '.tiff']
    image_files = [f for f in os.listdir(directory_path) 
                  if os.path.isfile(os.path.join(directory_path, f)) and 
                  os.path.splitext(f)[1].lower() in image_extensions]
    
    if not image_files:
        print(f"No image files found in {directory_path}")
        return
    
    all_counts = Counter()
    total_images = len(image_files)
    
    for i, image_file in enumerate(image_files):
        image_path = os.path.join(directory_path, image_file)
        print(f"Processing image {i+1}/{total_images}: {image_file}")
        
        try:
            counts, _ = process_sperm_image(image_path, model_path, conf_threshold)
            all_counts += counts
            print(f"Counts for {image_file}: {dict(counts)}")
        except Exception as e:
            print(f"Error processing {image_file}: {e}")
    
    print("\nSummary Statistics:")
    print("-" * 50)
    total_cells = sum(all_counts.values())
    
    if total_cells > 0:
        print(f"Total cells detected across all images: {total_cells}")
        for cls_name, count in all_counts.items():
            percentage = (count / total_cells) * 100
            print(f"{cls_name}: {count} ({percentage:.1f}%)")
    else:
        print("No cells detected in any images.")
    
    csv_path = os.path.join(directory_path, "sperm_analysis_summary.csv")
    with open(csv_path, 'w') as f:
        f.write("Class,Count,Percentage\n")
        for cls_name, count in all_counts.items():
            percentage = (count / total_cells) * 100 if total_cells > 0 else 0
            f.write(f"{cls_name},{count},{percentage:.1f}\n")
        f.write(f"Total,{total_cells},100.0\n")
    
    print(f"Summary saved to {csv_path}")

if __name__ == "__main__":
    counts, _ = process_sperm_image("selected/images/11_frame_0.jpg")
    print(counts)
  


image 1/1 /home/sofi/Documents/GitThes/Code/selected/images/11_frame_0.jpg: 768x1024 44 normal sperms, 1243.2ms
Speed: 9.2ms preprocess, 1243.2ms inference, 0.2ms postprocess per image at shape (1, 3, 768, 1024)
Saved result to selected/images/11_frame_0_analyzed.jpg
Counter({'normal sperm': 44})
