In [1]:
!pip install ultralytics

Collecting ultralytics
  Downloading ultralytics-8.3.28-py3-none-any.whl.metadata (35 kB)
Collecting ultralytics-thop>=2.0.0 (from ultralytics)
  Downloading ultralytics_thop-2.0.11-py3-none-any.whl.metadata (9.4 kB)
Downloading ultralytics-8.3.28-py3-none-any.whl (881 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m881.2/881.2 kB[0m [31m19.5 MB/s[0m eta [36m0:00:00[0m00:01[0m
[?25hDownloading ultralytics_thop-2.0.11-py3-none-any.whl (26 kB)
Installing collected packages: ultralytics-thop, ultralytics
Successfully installed ultralytics-8.3.28 ultralytics-thop-2.0.11


In [2]:
def load_and_predict_image(image, model_path, device):
    """
    Load an image, apply transformations, and predict using a pretrained model.
    
    Parameters:
    - img_path (str): Path to the image file.
    - model_path (str): Path to the pretrained model file.
    - device (torch.device): The device (GPU/CPU) to use for computation.
    
    Returns:
    - float: The predicted probability of the image having Glaucoma.
    - int: The binary classification result (1 for Glaucoma, 0 for No Glaucoma).
    """
    # Define transformation for the image
    def test_transform():
        return transforms.Compose([
            transforms.Resize((384, 384)),
            transforms.ToTensor(),
            transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        ])

    # Load and transform the image
#     image = Image.open(img_path).convert('RGB')
    image = test_transform()(image).unsqueeze(0)  # Add batch dimension

    # Load the model
    model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
    num_features = model.heads.head.in_features
    model.heads.head = nn.Linear(num_features, 1)
    model.load_state_dict(torch.load(model_path, map_location=device))
    model.to(device)
    model.eval()

    # Predict the image
    with torch.no_grad():
        image = image.to(device)
        outputs = model(image)
        probs = torch.sigmoid(outputs).squeeze().item()  # Squeeze to remove batch dimensions and get the scalar probability

    # Determine binary classification
    binary_pred = int(probs > 0.5)

    return probs, binary_pred

In [4]:
import os
import cv2
import numpy as np
import torch
import torch.nn as nn
from ultralytics import YOLO
from PIL import Image
from torchvision import transforms
from torchvision.models import vit_b_16, ViT_B_16_Weights

# Define the transformation pipeline
preprocess = transforms.Compose([
    transforms.Resize((384, 384), interpolation=transforms.InterpolationMode.BICUBIC),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def apply_clahe(img, clip_limit=3.0, tile_grid_size=(8, 8)):
    """Apply CLAHE contrast enhancement on each color channel separately."""
    img_cv = cv2.cvtColor(np.array(img), cv2.COLOR_RGB2BGR)
    r, g, b = cv2.split(img_cv)
    clahe = cv2.createCLAHE(clipLimit=clip_limit, tileGridSize=tile_grid_size)
    r_clahe = clahe.apply(r)
    g_clahe = clahe.apply(g)
    b_clahe = clahe.apply(b)
    clahe_img_cv = cv2.merge([r_clahe, g_clahe, b_clahe])
    clahe_img = cv2.cvtColor(clahe_img_cv, cv2.COLOR_BGR2RGB)
    return Image.fromarray(clahe_img)

def trim_and_resize(im, output_size):
    """Trim margins, maintain aspect ratio, and resize to the specified output size."""
    percentage = 0.02
    img = np.array(im)
    img_gray = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
    im_binary = img_gray > 0.1 * np.mean(img_gray[img_gray != 0])
    row_sums = np.sum(im_binary, axis=1)
    col_sums = np.sum(im_binary, axis=0)
    rows = np.where(row_sums > img.shape[1] * percentage)[0]
    cols = np.where(col_sums > img.shape[0] * percentage)[0]
    if rows.size and cols.size:
        min_row, min_col = np.min(rows), np.min(cols)
        max_row, max_col = np.max(rows), np.max(cols)
        img = img[min_row:max_row+1, min_col:max_col+1]
    im_pil = Image.fromarray(img)
    old_size = im_pil.size
    ratio = float(output_size) / max(old_size)
    new_size = tuple([int(x * ratio) for x in old_size])
    im_resized = im_pil.resize(new_size, Image.LANCZOS)
    new_im = Image.new("RGB", (output_size, output_size))
    new_im.paste(im_resized, ((output_size - new_size[0]) // 2, (output_size - new_size[1]) // 2))
    return new_im

# Load the YOLO model
yolo_model = YOLO('/kaggle/input/od_segmentation_model/pytorch/default/1/best.pt')

# Load the ViT model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# vit_model = vit_b_16(weights=ViT_B_16_Weights.IMAGENET1K_SWAG_E2E_V1)
# num_features = vit_model.heads.head.in_features
# vit_model.heads.head = nn.Linear(num_features, 1)  # Assuming binary classification
best_model_path = '/kaggle/input/new_vit_roi_best_model/pytorch/default/1/ViT_RG_ROI_best (1).pth'
# vit_model.load_state_dict(torch.load(best_model_path, map_location=device))
# vit_model.to(device)
# vit_model.eval()

def find_encompassing_bbox(bboxes):
    """Find the smallest bounding box that encompasses all given bounding boxes."""
    if not bboxes:
        return None
    x_min = min(box[0] for box in bboxes)
    y_min = min(box[1] for box in bboxes)
    x_max = max(box[2] for box in bboxes)
    y_max = max(box[3] for box in bboxes)
    return [x_min, y_min, x_max, y_max]


def crop_and_resize_image(img, bbox, target_size=(518, 518)):
    """
    Attempts to crop a 518x518 square from the center of a given bounding box. If the square
    exceeds image dimensions, it crops the largest possible square and resizes it to 518x518.
    If a given bounding box exceeds 518x518, segmentation is considered as invalid.

    Parameters:
    - img: The preprocessed image read by CV2.
    - bbox: The bounding box coordinates as [x1, y1, x2, y2].
    - target_size: The target size for cropping and resizing as (width, height).

    Returns:
    - Resized image if cropping is possible, otherwise None.
    """
    x1, y1, x2, y2 = map(int, bbox)
    
    # Calculate the width and height of the bounding box
    bbox_width = x2 - x1
    bbox_height = y2 - y1

    # Check if the bounding box doesn't exceed 518x518
    if bbox_width > target_size[0] or bbox_height > target_size[1]:
        return None 

    # Calculate the center of the bounding box
    center_x = (x1 + x2) // 2
    center_y = (y1 + y2) // 2

    # Calculate half the target size
    half_target = target_size[0] // 2

    # Define initial maximum square crop that can fit within the image boundaries
    start_x = max(0, center_x - half_target)
    start_y = max(0, center_y - half_target)
    end_x = min(img.shape[1], center_x + half_target)
    end_y = min(img.shape[0], center_y + half_target)

    # Validate crop dimensions
    crop_width = end_x - start_x
    crop_height = end_y - start_y

    # Adjust crop dimensions to the largest possible square within the boundary
    if crop_width < target_size[0] or crop_height < target_size[1]:
        # Calculate the largest possible dimension that can be squared within the limits
        max_possible_square = min(crop_width, crop_height)
        start_x = center_x - max_possible_square // 2
        start_y = center_y - max_possible_square // 2
        end_x = start_x + max_possible_square
        end_y = start_y + max_possible_square
        # Re-validate boundaries (important in cases where center is near the image edge)
        start_x = max(0, start_x)
        start_y = max(0, start_y)
        end_x = min(img.shape[1], end_x)
        end_y = min(img.shape[0], end_y)

    # Crop the image
    cropped_img = img[start_y:end_y, start_x:end_x]
    if cropped_img.size == 0:
        return None  # Return None if the cropped image is empty

    # Resize to the desired target size
    resized_img = cv2.resize(cropped_img, target_size, interpolation=cv2.INTER_CUBIC)

    return resized_img

def inference(img_path):
    """
    Performs inference on the given image.

    Parameters:
    - img_path (str): Path to the input image.

    Returns:
    - The predicted label (0 or 1) and the prediction probability.
    """
    try:
        # Load and preprocess the image
        img = Image.open(img_path).convert("RGB")
        output_size = 2000
        img = trim_and_resize(img, output_size)  # Resize to 384 to match the transform
        img_cl = apply_clahe(img)
        img_array = np.array(img_cl)
        if len(img_array.shape) == 3 and img_array.shape[2] == 3:  # If RGB
            img_array = cv2.cvtColor(img_array, cv2.COLOR_RGB2BGR)
#         img_array = cv2.imread(img_cl)
        
        # Detect ROI using YOLO
        results = yolo_model(img_array)
        boxes = results[0].boxes
        if len(boxes) == 0:
            print("ROI not detected in the image.")
            return None, None

        # Find the encompassing bounding box
        bboxes = [box.cpu().numpy().tolist() for box in boxes.xyxy]
        encompassing_bbox = find_encompassing_bbox(bboxes)
        if encompassing_bbox is None:
            print("No bounding boxes found.")
            return None, None

        # Crop and resize the image
        cropped_resized_img = crop_and_resize_image(img_array, encompassing_bbox, target_size=(518, 518))
#         ropped_resized_img = crop_and_resize_image(img_array, encompassing_bbox, target_size=size)
        if cropped_resized_img is None:
            print("Bounding box was too large to crop.")
            return None, None

        # Convert to PIL Image for transformations
        cropped_pil = Image.fromarray(cropped_resized_img)
        probs,binary_pred= load_and_predict_image(cropped_pil,best_model_path,device)

        return binary_pred,probs

    except Exception as e:
        print(f"An error occurred during inference: {e}")
        return None, None

# Example usage
img_path = '/kaggle/input/jraigs-dataset/justRAIGS/1/TRAIN024379.JPG'
pred, prob = inference(img_path)
if pred is not None:
    if pred==1:
        print(f"Predicted label: Referable glaucoma")
    if pred==0:
        print(f"Predicted label: Non Referable glaucoma")
        
    print(f"Prediction probability: {prob:.4f}")
else:
    print("Inference could not be completed.")



0: 640x640 2 disks, 149.7ms
Speed: 5.6ms preprocess, 149.7ms inference, 1.1ms postprocess per image at shape (1, 3, 640, 640)
Predicted label: Referable glaucoma
Prediction probability: 0.9067
