In [4]:
#!/usr/bin/env python3
"""
Image Processor - Selects optimal RGB images and pairs with depth maps

This script:
1. Analyzes all images in the current directory
2. Detects blur using multiple methods
3. Extracts features using EfficientNet for view diversity analysis
4. Selects top images with minimal blur and diverse views
5. Copies selected images to 'Good Data' folder
6. Matches selected RGB images with depth maps from ../Depth
"""

import cv2
import numpy as np
import os
import re
import shutil
from tqdm import tqdm
from sklearn.cluster import KMeans
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from efficientnet_pytorch import EfficientNet
import multiprocessing
from concurrent.futures import ProcessPoolExecutor, as_completed

# Check if GPU is available
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Load EfficientNet for feature extraction
def load_efficientnet():
    try:
        model = EfficientNet.from_pretrained('efficientnet-b0')
        model.eval()
        model = model.to(device)
        return model
    except Exception as e:
        print(f"Failed to load EfficientNet: {e}")
        print("Falling back to OpenCV-based feature extraction")
        return None

# Preprocessing transforms for EfficientNet
transform = transforms.Compose([
    transforms.ToPILImage(),
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

def calculate_blur_metrics(image_path):
    """Calculate blur metrics using multiple methods"""
    try:
        # Read image
        img = cv2.imread(image_path)
        if img is None:
            return None, None, None
        
        # Convert to grayscale
        gray = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
        
        # Method 1: Laplacian variance (lower values indicate more blur)
        laplacian_var = cv2.Laplacian(gray, cv2.CV_64F).var()
        
        # Method 2: FFT-based blur detection
        f = np.fft.fft2(gray)
        fshift = np.fft.fftshift(f)
        magnitude_spectrum = 20 * np.log(np.abs(fshift) + 1)
        
        h, w = gray.shape
        center_y, center_x = h//2, w//2
        
        mask = np.ones((h, w), np.uint8)
        center_region = 20
        mask[center_y-center_region:center_y+center_region, center_x-center_region:center_x+center_region] = 0
        
        high_freq_content = np.sum(magnitude_spectrum * mask) / np.sum(mask)
        
        # Method 3: Sobel gradients
        sobelx = cv2.Sobel(gray, cv2.CV_64F, 1, 0, ksize=3)
        sobely = cv2.Sobel(gray, cv2.CV_64F, 0, 1, ksize=3)
        sobel_mag = np.sqrt(sobelx**2 + sobely**2)
        sobel_mean = np.mean(sobel_mag)
        
        return laplacian_var, high_freq_content, sobel_mean
    
    except Exception as e:
        print(f"Error processing {image_path}: {e}")
        return None, None, None

def extract_features_cv(image_path):
    """Extract features using OpenCV ORB (fallback method)"""
    try:
        img = cv2.imread(image_path)
        if img is None:
            return None
        
        img = cv2.resize(img, (224, 224))
        orb = cv2.ORB_create(nfeatures=500)
        keypoints, descriptors = orb.detectAndCompute(img, None)
        
        if descriptors is None or len(descriptors) == 0:
            return np.zeros((1, 32), dtype=np.float32)
        
        return np.mean(descriptors, axis=0).reshape(1, -1).astype(np.float32)
    
    except Exception as e:
        print(f"Error extracting CV features from {image_path}: {e}")
        return np.zeros((1, 32), dtype=np.float32)

def extract_features_efficientnet(image_path, model):
    """Extract features using EfficientNet"""
    try:
        img = cv2.imread(image_path)
        if img is None:
            return None
        
        img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        img_tensor = transform(img).unsqueeze(0).to(device)
        
        with torch.no_grad():
            features = model.extract_features(img_tensor)
            features = torch.nn.functional.adaptive_avg_pool2d(features, 1)
            features = features.squeeze().cpu().numpy()
        
        return features.reshape(1, -1).astype(np.float32)
    
    except Exception as e:
        print(f"Error extracting EfficientNet features from {image_path}: {e}")
        if model is not None:
            return extract_features_cv(image_path)
        else:
            return np.zeros((1, 32), dtype=np.float32)

def process_image(args):
    img_path, model = args
    
    laplacian_var, high_freq_content, sobel_mean = calculate_blur_metrics(img_path)
    
    if laplacian_var is None:
        return None
    
    if model is not None:
        features = extract_features_efficientnet(img_path, model)
    else:
        features = extract_features_cv(img_path)
    
    if features is not None:
        return {
            'path': img_path,
            'laplacian_var': laplacian_var,
            'high_freq_content': high_freq_content,
            'sobel_mean': sobel_mean,
            'features': features
        }
    
    return None

# Alternative approach without using ProcessPoolExecutor
def process_single_image(img_path, model=None):
    return process_image((img_path, model))

def process_image_directory_sequential(directory_path='.', model=None):
    """Process images sequentially without multiprocessing"""
    image_data = []
    
    image_files = [f for f in os.listdir(directory_path) 
                  if os.path.isfile(os.path.join(directory_path, f)) and 
                  f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.webp'))]
    
    print(f"Found {len(image_files)} images in {directory_path}")
    print("Processing images sequentially...")
    
    for img_file in tqdm(image_files, desc="Processing images"):
        img_path = os.path.join(directory_path, img_file)
        result = process_image((img_path, model))
        if result is not None:
            image_data.append(result)
    
    print(f"Successfully processed {len(image_data)} images")
    return image_data

def process_image_directory(directory_path='.', model=None):
    """Process images with multiprocessing if possible, otherwise fall back to sequential"""
    try:
        # First try using spawn method
        multiprocessing.set_start_method('spawn', force=True)
        
        image_data = []
        
        image_files = [f for f in os.listdir(directory_path) 
                      if os.path.isfile(os.path.join(directory_path, f)) and 
                      f.lower().endswith(('.png', '.jpg', '.jpeg', '.bmp', '.tiff', '.webp'))]
        
        print(f"Found {len(image_files)} images in {directory_path}")
        
        # For GPU processing, use sequential processing to avoid CUDA issues
        if torch.cuda.is_available():
            print("CUDA detected. Using sequential processing to avoid CUDA initialization issues.")
            return process_image_directory_sequential(directory_path, model)
        
        # For CPU processing, we can still use multiprocessing
        num_workers = min(multiprocessing.cpu_count(), 8)
        print(f"Using {num_workers} worker processes")
        
        # Use a multiprocessing.Pool directly instead of ProcessPoolExecutor
        with multiprocessing.Pool(processes=num_workers) as pool:
            args_list = [(os.path.join(directory_path, img_file), model) for img_file in image_files]
            results = list(tqdm(pool.imap(process_image, args_list), total=len(args_list), desc="Processing images"))
            
            for result in results:
                if result is not None:
                    image_data.append(result)
        
        print(f"Successfully processed {len(image_data)} images")
        return image_data
        
    except (AttributeError, RuntimeError) as e:
        print(f"Multiprocessing error: {e}")
        print("Falling back to sequential processing...")
        return process_image_directory_sequential(directory_path, model)

def select_top_images(image_data, num_to_select=300):
    if len(image_data) <= num_to_select:
        return [data['path'] for data in image_data]
    
    laplacian_vars = np.array([data['laplacian_var'] for data in image_data])
    high_freq_contents = np.array([data['high_freq_content'] for data in image_data])
    sobel_means = np.array([data['sobel_mean'] for data in image_data])
    
    def normalize(x):
        return (x - x.min()) / (x.max() - x.min() + 1e-10)
    
    norm_laplacian = normalize(laplacian_vars)
    norm_high_freq = normalize(high_freq_contents)
    norm_sobel = normalize(sobel_means)
    
    blur_scores = (norm_laplacian + norm_high_freq + norm_sobel) / 3
    
    feature_matrix = np.vstack([data['features'] for data in image_data])
    
    n_clusters = min(num_to_select, len(image_data))
    print(f"Clustering images into {n_clusters} groups...")
    
    kmeans = KMeans(n_clusters=n_clusters, random_state=42, n_init=10)
    clusters = kmeans.fit_predict(feature_matrix)
    unique_clusters = np.unique(clusters)
    
    selected_images = []
    images_per_cluster = {cluster_id: 1 for cluster_id in unique_clusters}
    
    if len(images_per_cluster) < num_to_select:
        print(f"Warning: Only {len(images_per_cluster)} unique clusters found")
    
    for cluster_id in unique_clusters:
        cluster_indices = np.where(clusters == cluster_id)[0]
        cluster_blur_scores = blur_scores[cluster_indices]
        sorted_indices = cluster_indices[np.argsort(-cluster_blur_scores)]
        
        top_n = images_per_cluster[cluster_id]
        for idx in sorted_indices[:top_n]:
            selected_images.append(image_data[idx]['path'])
    
    return selected_images

def main():
    print("Loading EfficientNet for feature extraction...")
    model = load_efficientnet()
    
    image_directory = "/home/aaronmcafee/Documents/bigVid/RGB"
    num_to_select = 200
    output_dir = "/home/aaronmcafee/Documents/bigVid/good"
    
    print("Starting image analysis...")
    image_data = process_image_directory(image_directory, model)
    
    if len(image_data) == 0:
        print("No images were successfully processed. Please check the images in the directory.")
        return
    
    print(f"Selecting top {num_to_select} images...")
    selected_images = select_top_images(image_data, num_to_select)
    
    print(f"Selected {len(selected_images)} out of {len(image_data)} images")
    
    if not os.path.exists(output_dir):
        os.makedirs(output_dir)
        print(f"Created output directory: {output_dir}")
    
    print(f"Copying selected images to {output_dir}...")
    for i, img_path in enumerate(tqdm(selected_images, desc="Copying files")):
        filename = os.path.basename(img_path)
        dst = os.path.join(output_dir, f"{i:04d}_{filename}")
        shutil.copy(img_path, dst)
    
    # Depth map matching logic
    print("\nMatching depth maps...")
    depth_folder = os.path.abspath(os.path.join(image_directory, '../Depth'))
    if not os.path.exists(depth_folder):
        print(f"Depth folder not found: {depth_folder}")
    else:
        paired_count = 0
        missing_depths = []
        
        for img_path in tqdm(selected_images, desc="Matching depth maps"):
            filename = os.path.basename(img_path)
            match = re.search(r'v2_(\d+)\.png$', filename)
            if not match:
                print(f"Warning: Could not extract serial number from {filename}, skipping")
                continue
            serial = match.group(1)
            depth_file = f"depth_{serial}.png"
            depth_path = os.path.join(depth_folder, depth_file)
            
            if not os.path.exists(depth_path):
                missing_depths.append(filename)
                continue
            
            depth_dest = os.path.join(output_dir, depth_file)
            shutil.copy2(depth_path, depth_dest)
            paired_count += 1
        
        print(f"Successfully paired {paired_count} RGB images with depth maps")
        if missing_depths:
            print(f"Warning: {len(missing_depths)} depth maps missing")
            missing_file = os.path.join(output_dir, "missing_depth_maps.txt")
            with open(missing_file, 'w') as f:
                for name in missing_depths:
                    f.write(f"{name}\n")
            print(f"Missing depth maps listed in {missing_file}")
    
    print(f"Done! {len(selected_images)} images have been copied to {output_dir}")
    print(f"Selected {len(selected_images)/len(image_data):.1%} of the original images")

if __name__ == "__main__":
    try:
        main()
    except KeyboardInterrupt:
        print("\nProcess interrupted by user")
    except Exception as e:
        print(f"Error: {e}")

Loading EfficientNet for feature extraction...
Loaded pretrained weights for efficientnet-b0
Starting image analysis...
Found 2934 images in /home/aaronmcafee/Documents/bigVid/RGB
CUDA detected. Using sequential processing to avoid CUDA initialization issues.
Found 2934 images in /home/aaronmcafee/Documents/bigVid/RGB
Processing images sequentially...


Processing images: 100%|████████████████████| 2934/2934 [01:35<00:00, 30.87it/s]


Successfully processed 2934 images
Selecting top 200 images...
Clustering images into 200 groups...
Selected 200 out of 2934 images
Copying selected images to /home/aaronmcafee/Documents/bigVid/good...


Copying files: 100%|████████████████████████| 200/200 [00:00<00:00, 2490.95it/s]



Matching depth maps...


Matching depth maps: 100%|██████████████████| 200/200 [00:00<00:00, 2258.23it/s]

Successfully paired 200 RGB images with depth maps
Done! 200 images have been copied to /home/aaronmcafee/Documents/bigVid/good
Selected 6.8% of the original images



