Inference Test set

In [None]:
import torch
from PIL import Image, ImageDraw
import os
import numpy as np
import matplotlib.pyplot as plt
from transformers import SamProcessor, SamModel, SamConfig
from ultralytics import YOLO

def extract_results_info(results):
    info_dict = {}
    info_dict['path'] = results[0].path
    info_dict['shape'] = results[0].orig_shape
    boxes = results[0].boxes
    bounding_boxes = boxes.xyxy.tolist()
    converted_boxes = []
    for box in bounding_boxes:
        x_min, y_min, x_max, y_max = box
        converted_boxes.append([x_min, y_min, x_max, y_max])
    info_dict['bounding_boxes'] = converted_boxes
    return info_dict

def expand_bounding_boxes(bounding_boxes, image_shape, expansion_ratio=0.1):
    expanded_boxes = []
    for box in bounding_boxes:
        x_min, y_min, x_max, y_max = box
        width = x_max - x_min
        height = y_max - y_min
        new_x_min = max(0, x_min - width * expansion_ratio)
        new_y_min = max(0, y_min - height * expansion_ratio)
        new_x_max = min(image_shape[1], x_max + width * expansion_ratio)
        new_y_max = min(image_shape[0], y_max + height * expansion_ratio)
        expanded_boxes.append([new_x_min, new_y_min, new_x_max, new_y_max])
    return expanded_boxes

def perform_segmentation_with_box(image_path, box):
    test_image = Image.open(image_path).convert("RGB")
    inputs = processor(test_image, input_boxes=[[box]], return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model_sam(**inputs, multimask_output=False)
    medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
    medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
    return medsam_seg

def merge_masks(predicted_masks):
    if predicted_masks:
        merged_mask = np.zeros_like(predicted_masks[0], dtype=np.uint8)
        for mask in predicted_masks:
            merged_mask[mask != 0] = 1
    else:
        merged_mask = np.zeros((256, 256), dtype=np.uint8)
    return merged_mask

def process_image(image_path, yolo_model, sam_model, sam_processor, device):
    results = yolo_model.predict(source=[image_path], imgsz=512, conf=0.3, verbose=False)
    results_info = extract_results_info(results)

    image = Image.open(results_info['path'])
    image_shape = image.size[::-1]

    expanded_boxes = expand_bounding_boxes(results_info['bounding_boxes'], image_shape)

    predicted_masks = []
    for box in expanded_boxes:
        seg_mask = perform_segmentation_with_box(results_info['path'], box)
        predicted_masks.append(seg_mask)

    merged_mask = merge_masks(predicted_masks)

    return results_info, merged_mask, expanded_boxes

yolo_model_path = './trained_yolo/best.pt'
sam_model_checkpoint_path = "./SAM_model_checkpoint_test_5epochs.pth"
test_images_folder = "./test/test_images/"
test_masks_folder = "./test/test_masks/"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

yolo_model = YOLO(yolo_model_path).to(device)

model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")

model_sam = SamModel(config=model_config).to(device)
model_sam.load_state_dict(torch.load(sam_model_checkpoint_path))

total_iou = 0
total_predictions = 0

for filename in os.listdir(test_images_folder):
    if filename.endswith(".png"):
        image_path = os.path.join(test_images_folder, filename)
        results_info, merged_mask, expanded_boxes = process_image(image_path, yolo_model, model_sam, processor, device)

        true_mask_path = os.path.join(test_masks_folder, filename)  # Assuming the mask filename matches the image filename
        true_mask = np.array(Image.open(true_mask_path).convert("L"))
        true_mask = (true_mask > 0).astype(np.uint8)  # Binarize the mask
        true_mask_resized = np.array(Image.fromarray(true_mask).resize((256, 256), Image.NEAREST))

        iou = calculate_iou(merged_mask, true_mask_resized)

        total_iou += iou
        total_predictions += 1

        print("IoU for", filename, ":", "{:.3f}".format(iou))

        image_expanded = Image.open(results_info['path'])
        draw_expanded = ImageDraw.Draw(image_expanded)
        for box in expanded_boxes:
            x_min, y_min, x_max, y_max = box
            draw_expanded.rectangle([x_min, y_min, x_max, y_max], outline="red")

        plt.figure(figsize=(12, 6))
        plt.subplot(1, 2, 1)
        plt.imshow(image_expanded)
        plt.title("Image with expanded bounding boxes")
        plt.axis('off')

        plt.subplot(1, 2, 2)
        plt.imshow(merged_mask, cmap='gray')
        plt.title("Merged Prediction")
        plt.axis('off')

        plt.show()


average_iou = total_iou / total_predictions
print("Average IoU: {:.3f}".format(average_iou))


Inference on one image

In [None]:
import torch
from PIL import Image, ImageDraw
import os
import numpy as np
import matplotlib.pyplot as plt
from transformers import SamProcessor, SamModel, SamConfig
from ultralytics import YOLO

def extract_results_info(results):
    info_dict = {}
    info_dict['path'] = results[0].path
    info_dict['shape'] = results[0].orig_shape
    boxes = results[0].boxes
    bounding_boxes = boxes.xyxy.tolist()
    converted_boxes = []
    for box in bounding_boxes:
        x_min, y_min, x_max, y_max = box
        converted_boxes.append([x_min, y_min, x_max, y_max])
    info_dict['bounding_boxes'] = converted_boxes
    return info_dict

def expand_bounding_boxes(bounding_boxes, image_shape, expansion_ratio=0.1):
    expanded_boxes = []
    for box in bounding_boxes:
        x_min, y_min, x_max, y_max = box
        width = x_max - x_min
        height = y_max - y_min
        new_x_min = max(0, x_min - width * expansion_ratio)
        new_y_min = max(0, y_min - height * expansion_ratio)
        new_x_max = min(image_shape[1], x_max + width * expansion_ratio)
        new_y_max = min(image_shape[0], y_max + height * expansion_ratio)
        expanded_boxes.append([new_x_min, new_y_min, new_x_max, new_y_max])
    return expanded_boxes

def perform_segmentation_with_box(image_path, box):
    test_image = Image.open(image_path).convert("RGB")
    inputs = processor(test_image, input_boxes=[[box]], return_tensors="pt")
    inputs = {k: v.to(device) for k, v in inputs.items()}
    with torch.no_grad():
        outputs = model_sam(**inputs, multimask_output=False)
    medsam_seg_prob = torch.sigmoid(outputs.pred_masks.squeeze(1))
    medsam_seg_prob = medsam_seg_prob.cpu().numpy().squeeze()
    medsam_seg = (medsam_seg_prob > 0.5).astype(np.uint8)
    return medsam_seg

def merge_masks(predicted_masks):
    merged_mask = np.zeros_like(predicted_masks[0], dtype=np.uint8)
    for mask in predicted_masks:
        merged_mask = np.maximum(merged_mask, mask)
    merged_mask[merged_mask != 0] = 1
    return merged_mask

def process_image(image_path, yolo_model, sam_model, sam_processor, device):
    results = yolo_model.predict(source=[image_path], imgsz=512, conf=0.3, verbose=False)
    results_info = extract_results_info(results)

    image = Image.open(results_info['path'])
    image_shape = image.size[::-1]

    expanded_boxes = expand_bounding_boxes(results_info['bounding_boxes'], image_shape)

    predicted_masks = []
    for box in expanded_boxes:
        seg_mask = perform_segmentation_with_box(results_info['path'], box)
        predicted_masks.append(seg_mask)

    merged_mask = merge_masks(predicted_masks)

    return results_info, merged_mask, expanded_boxes

yolo_model_path = './trained_yolo/best.pt'
sam_model_checkpoint_path = "./SAM_model_checkpoint_test_5epochs.pth"
image_path = "./test/test_images/fusc_0035.png"

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")


yolo_model = YOLO(yolo_model_path).to(device)


model_config = SamConfig.from_pretrained("facebook/sam-vit-base")
processor = SamProcessor.from_pretrained("facebook/sam-vit-base")


model_sam = SamModel(config=model_config).to(device)

model_sam.load_state_dict(torch.load(sam_model_checkpoint_path))

results_info, merged_mask, expanded_boxes = process_image(image_path, yolo_model, model_sam, processor, device)

image = Image.open(results_info['path'])
draw = ImageDraw.Draw(image)
for box in expanded_boxes:
    x_min, y_min, x_max, y_max = box
    draw.rectangle([x_min, y_min, x_max, y_max], outline="red")

plt.figure(figsize=(12, 6))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.title("Image with expanded bounding boxes")
plt.axis('off')

plt.subplot(1, 2, 2)
plt.imshow(merged_mask, cmap='gray')
plt.title("Merged Prediction")
plt.axis('off')

