In [1]:
import torch
import torchvision.models as models
from torchvision import transforms

# Load a pre-trained model
model = models.resnet50(pretrained=True)
model.eval()

# Transformation for the input image
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
])

def extract_features(image, model, transform):
    image = transform(image).unsqueeze(0)  # Add batch dimension
    with torch.no_grad():
        features = model(image)
    return features.squeeze(0)


Downloading: "https://download.pytorch.org/models/resnet50-0676ba61.pth" to /home/junior/.cache/torch/hub/checkpoints/resnet50-0676ba61.pth
100%|██████████| 97.8M/97.8M [00:00<00:00, 113MB/s] 


In [2]:
import cv2
from skimage.segmentation import slic
from skimage.color import label2rgb

def superpixel_segmentation(image, num_segments=100):
    segments = slic(image, n_segments=num_segments, compactness=10, sigma=1)
    return segments


In [3]:
import numpy as np
from sklearn.cluster import AgglomerativeClustering

def construct_affinity_graph(features, segments):
    num_segments = len(np.unique(segments))
    affinity_matrix = np.zeros((num_segments, num_segments))
    
    for i in range(num_segments):
        for j in range(num_segments):
            if i != j:
                feature_i = np.mean(features[segments == i], axis=0)
                feature_j = np.mean(features[segments == j], axis=0)
                affinity_matrix[i, j] = np.linalg.norm(feature_i - feature_j)
    
    return affinity_matrix

def graph_based_clustering(affinity_matrix, num_clusters):
    clustering = AgglomerativeClustering(n_clusters=num_clusters, affinity='precomputed', linkage='average')
    labels = clustering.fit_predict(affinity_matrix)
    return labels


In [None]:
! pip install --use-pep517 pydensecrf
import pydensecrf.densecrf as dcrf
from pydensecrf.utils import unary_from_labels, create_pairwise_gaussian, create_pairwise_bilateral

def refine_segmentation(image, segments, labels):
    d = dcrf.DenseCRF2D(image.shape[1], image.shape[0], len(np.unique(labels)))
    unary = unary_from_labels(labels, len(np.unique(labels)), gt_prob=0.7)
    d.setUnaryEnergy(unary)
    
    d.addPairwiseGaussian(sxy=3, compat=3)
    d.addPairwiseBilateral(sxy=80, srgb=13, rgbim=image, compat=10)
    
    refined_labels = d.inference(10)
    refined_labels = np.argmax(refined_labels, axis=0).reshape(image.shape[:2])
    return refined_labels


In [None]:
def get_bounding_boxes(segmented_image):
    bounding_boxes = []
    unique_labels = np.unique(segmented_image)
    
    for label in unique_labels:
        if label == 0:  # Skip the background
            continue
        mask = segmented_image == label
        coords = np.argwhere(mask)
        y0, x0 = coords.min(axis=0)
        y1, x1 = coords.max(axis=0)
        bounding_boxes.append((x0, y0, x1, y1))
    
    return bounding_boxes


In [None]:
from PIL import Image
import matplotlib.pyplot as plt
import matplotlib.patches as patches

# Load and preprocess image
image_path = 'images/0_0.jpeg'
image = Image.open(image_path)
image_np = np.array(image)

# Step 1: Feature Extraction
features = extract_features(image, model, transform).numpy()

# Step 2: Superpixel Segmentation
segments = superpixel_segmentation(image_np)

# Step 3: Affinity Graph Construction and Clustering
affinity_matrix = construct_affinity_graph(features, segments)
num_clusters = 5  # Adjust based on the complexity of the scene
labels = graph_based_clustering(affinity_matrix, num_clusters)

# Step 4: Instance Segmentation Refinement
refined_labels = refine_segmentation(image_np, segments, labels)

# Step 5: Bounding Box Extraction
bounding_boxes = get_bounding_boxes(refined_labels)

# Display the segmented image with bounding boxes
fig, ax = plt.subplots(1, figsize=(12, 12))
ax.imshow(image_np)

for box in bounding_boxes:
    x0, y0, x1, y1 = box
    rect = patches.Rectangle((x0, y0), x1 - x0, y1 - y0, linewidth=2, edgecolor='r', facecolor='none')
    ax.add_patch(rect)

plt.show()
