In [1]:
import os
import cv2
import torch
import shutil
import numpy as np
import torchvision.transforms as transforms
import torchvision.models.detection as detection
from torchvision.io import read_image
from torchvision.transforms.functional import to_pil_image
from sklearn.cluster import DBSCAN
from torchvision.models import inception_v3, Inception_V3_Weights
from torchvision.models.detection import MaskRCNN_ResNet50_FPN_Weights

In [2]:
# Paths
INPUT_FOLDER = r"C:\Users\nguye\OneDrive\Desktop\2024\Spring 24\ResreachWithGollum\input_images"  # Folder containing input images
FEATURE_FOLDER = r"C:\Users\nguye\OneDrive\Desktop\2024\Spring 24\ResreachWithGollum\features"  # Output folder for clustered segments
DEBUG_FOLDER = r"C:\Users\nguye\OneDrive\Desktop\2024\Spring 24\ResreachWithGollum\debug"  # Folder for debugging output

# Ensure output directory exists
if os.path.exists(FEATURE_FOLDER):
    shutil.rmtree(FEATURE_FOLDER)  # Clear previous results
os.makedirs(FEATURE_FOLDER)

# Ensure output directory exists
if os.path.exists(DEBUG_FOLDER):
    shutil.rmtree(DEBUG_FOLDER)  # Clear previous results
os.makedirs(DEBUG_FOLDER)

In [3]:
# Load Pretrained Mask R-CNN (for segmentation)
device = "cuda" if torch.cuda.is_available() else "cpu"
mask_rcnn_weights = MaskRCNN_ResNet50_FPN_Weights.DEFAULT
mask_rcnn = detection.maskrcnn_resnet50_fpn(weights=mask_rcnn_weights).to(device)
mask_rcnn.eval()

# Load Pretrained CNN (Inception-V3 for feature extraction)
cnn_weights = Inception_V3_Weights.DEFAULT
cnn_model = inception_v3(weights=cnn_weights).to(device)
cnn_model.eval()

# Image Transform for Feature Extraction
transform = transforms.Compose([
    transforms.Resize((299, 299)),  # Resize for Inception-V3
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

In [4]:
def extract_features(image):
    """Extract features from an image segment using Inception-V3."""
    image = transform(image).unsqueeze(0).to(device)
    with torch.no_grad():
        features = cnn_model(image).cpu().numpy().flatten()  # Feature vector
    return features

def segment_image(image_path):
    """Segment the image using Mask R-CNN."""
    image = read_image(image_path).to(device) / 255.0  # Convert to tensor
    with torch.no_grad():
        predictions = mask_rcnn([image])  # Get segmentation predictions

    segments = []
    for i, score in enumerate(predictions[0]['scores']):
        if score > 0.25:  # Only keep high-confidence segments
            mask = predictions[0]['masks'][i, 0].cpu().numpy()  # Get mask
            bbox = predictions[0]['boxes'][i].cpu().numpy().astype(int)  # Get bounding box
            x1, y1, x2, y2 = bbox

            # Extract the segmented part
            img = cv2.imread(image_path)
            segment = img[y1:y2, x1:x2]

            # Ignore very small segments
            if segment.shape[0] > 50 and segment.shape[1] > 50:
                segments.append((segment, bbox))

                # DEBUG: Save detected segments
                debug_path = os.path.join(DEBUG_FOLDER, f"debug_{i}.jpg")
                cv2.imwrite(debug_path, segment)

    return segments

In [5]:
# Step 1: Process all images and extract segments
all_features = []
segment_images = []

for filename in os.listdir(INPUT_FOLDER):
    if filename.lower().endswith((".jpg", ".png", ".jpeg")):
        img_path = os.path.join(INPUT_FOLDER, filename)
        segments = segment_image(img_path)
        
        if not segments:
            print(f"[WARNING] No segments found in {filename}. Try lowering score threshold.")

        for segment, bbox in segments:
            segment_images.append(segment)
            pil_image = to_pil_image(cv2.cvtColor(segment, cv2.COLOR_BGR2RGB))  # Convert OpenCV to PIL
            features = extract_features(pil_image)
            all_features.append(features)

In [6]:
# Step 2: Cluster similar feature segments using DBSCAN
all_features = np.array(all_features)
dbscan = DBSCAN(eps=10, min_samples=1)  # Adjust `eps` based on dataset
labels = dbscan.fit_predict(all_features)

In [7]:
# Step 3: Save segments into corresponding feature folders
unique_labels = set(labels)
for label in unique_labels:
    if label == -1:
        continue  # Skip noise points
    feature_dir = os.path.join(FEATURE_FOLDER, f"feature{label+1}")
    os.makedirs(feature_dir, exist_ok=True)

    for idx, (segment, lbl) in enumerate(zip(segment_images, labels)):
        if lbl == label:
            segment_path = os.path.join(feature_dir, f"segment_{idx}.jpg")
            cv2.imwrite(segment_path, segment)

print(f"Segmentation and clustering complete! Results saved in '{FEATURE_FOLDER}/'.")


Segmentation and clustering complete! Results saved in 'C:\Users\nguye\OneDrive\Desktop\2024\Spring 24\ResreachWithGollum\features/'.
