In [64]:
from PIL import Image
from torchvision.transforms.functional import pil_to_tensor
import matplotlib.pyplot as plt
import torch
from glob import glob
from collections import defaultdict
import json

In [51]:
def get_bounding_box(binary_mask):
    # Find indices where the mask is non-zero
    non_zero_indices = torch.nonzero(binary_mask)
    
    if non_zero_indices.numel() == 0:
        return None
    
    # Calculate minimum and maximum coordinates
    y_min, x_min = non_zero_indices.min(dim=0).values
    y_max, x_max = non_zero_indices.max(dim=0).values
    
    return (x_min.item(), y_min.item(), x_max.item(), y_max.item())


def visualize_bounding_box(image, bounding_box, ax=None):
    if torch.is_tensor(image):
        image = image.squeeze().cpu().numpy()
    
    if ax is None:
        fig, ax = plt.subplots(figsize=(10, 10))
    
    # Display the image
    ax.imshow(image, cmap='gray' if image.ndim == 2 else None)
    
    # Extract bounding box coordinates
    x_min, y_min, x_max, y_max = bounding_box
    
    # Calculate width and height of the box
    width = x_max - x_min
    height = y_max - y_min
    
    # Create a Rectangle patch
    rect = plt.Rectangle(
        (x_min, y_min),  # lower left corner
        width, 
        height, 
        fill=False,  # don't fill the rectangle
        edgecolor='red',  # color of the rectangle border
        linewidth=2  # thickness of the rectangle border
    )
    
    # Add the rectangle to the Axes
    ax.add_patch(rect)
    
    # Set title and adjust layout
    ax.set_title('Image with Bounding Box')
    plt.tight_layout()
    
    return ax

In [61]:
mask_paths = glob("./processed_data/*/*_mask.jpg")
bbox_annotations = defaultdict(list)

for path in mask_paths:
    mask = Image.open(path).convert('L')
    mask = pil_to_tensor(mask).squeeze(0)
    mask = (mask > 100).to(torch.uint8)
    
    bbox = get_bounding_box(mask)

    # Save results for bbox
    label = path.split("/")[2]
    eval_id = path.split("/")[-1].split("_")[0]

    bbox_annotations[label].append((eval_id, list(bbox)))

In [68]:
for label, annotations in bbox_annotations.items():
    label_annotations = {
        annotation[0]: annotation[1]
        for annotation in annotations
    }

    # Save annotation
    path = f"./processed_data/{label}/bbox_annotations.json"

    with open(path, "w") as f:
        json.dump(label_annotations, f, indent=4)