In [1]:
import os
import time
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from PIL import Image
import open3d as o3d
from pathlib import Path
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
import multiprocessing as mp
from typing import List, Dict, Tuple, Optional
import json

warnings.filterwarnings('ignore')

# Define the directory where the images are stored
image_directory = "public"  # Replace with your actual directory path

# List of the image filenames
image_filenames = [
    "numbers-stonks1.png",
    "dovenest.png",
    "meme-snap.png",
    "doveart1.png"
]

# Create the full paths to the images
image_paths = [os.path.join(image_directory, filename) for filename in image_filenames]

class BatchFastDepthPipeline:
    def __init__(self, model_type='midas_small', input_size=256, use_metal=True,
                 max_batch_size=8, prefetch_buffer=16):
        """
        Initialize the batch depth estimation pipeline

        Args:
            model_type: 'midas_small' or 'zoe_nano'
            input_size: Input image size (256, 384, or 512)
            use_metal: Use Apple Metal backend for acceleration
            max_batch_size: Maximum images to process simultaneously
            prefetch_buffer: Number of images to prefetch for processing
        """
        self.model_type = model_type
        self.input_size = input_size
        self.max_batch_size = max_batch_size
        self.prefetch_buffer = prefetch_buffer
        self.device = self._setup_device(use_metal)
        self.model = self._load_model()

        # Setup parallel processing
        self.num_workers = min(8, mp.cpu_count())

        print(f"🚀 Batch Pipeline initialized:")
        print(f"   Model: {model_type}")
        print(f"   Input size: {input_size}x{input_size}")
        print(f"   Device: {self.device}")
        print(f"   Max batch size: {max_batch_size}")
        print(f"   Workers: {self.num_workers}")

    def _setup_device(self, use_metal):
        """Setup optimal device for M2 Mac"""
        if use_metal and torch.backends.mps.is_available():
            return torch.device('mps')
        elif torch.cuda.is_available():
            return torch.device('cuda')
        else:
            return torch.device('cpu')

    def _load_model(self):
        """Load lightweight depth estimation model"""
        if self.model_type == 'midas_small':
            model = torch.hub.load('intel-isl/MiDaS', 'DPT_Hybrid', pretrained=True)
            model.to(self.device)
            model.eval()
            return model
        else:
            raise ValueError(f"Unknown model type: {self.model_type}")

    def preprocess_batch(self, image_paths: List[str]) -> Tuple[torch.Tensor, List[Tuple], List[float]]:
        """Preprocess multiple images in parallel"""
        def process_single_image(image_path):
            try:
                img = cv2.imread(str(image_path))
                if img is None:
                    return None, None, None

                # Convert BGR to RGB
                img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)

                # Store original dimensions
                h, w = img.shape[:2]
                scale = self.input_size / max(h, w)
                new_h, new_w = int(h * scale), int(w * scale)
                img = cv2.resize(img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)

                # Pad to square
                pad_h = self.input_size - new_h
                pad_w = self.input_size - new_w
                img = np.pad(img, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')

                # Convert to tensor
                img = img.astype(np.float32) / 255.0
                img = torch.from_numpy(img).permute(2, 0, 1)

                return img, (h, w), scale
            except Exception as e:
                print(f"Error processing {image_path}: {e}")
                return None, None, None

        # Process images in parallel
        with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
            results = list(executor.map(process_single_image, image_paths))

        # Filter out failed images and stack tensors
        valid_results = [(img, dims, scale) for img, dims, scale in results if img is not None]

        if not valid_results:
            return torch.empty(0), [], []

        images = torch.stack([img for img, _, _ in valid_results])
        dimensions = [dims for _, dims, _ in valid_results]
        scales = [scale for _, _, scale in valid_results]

        return images.to(self.device), dimensions, scales

    def estimate_depth_batch(self, image_batch: torch.Tensor) -> List[np.ndarray]:
        """Estimate depth for a batch of images"""
        if image_batch.size(0) == 0:
            return []

        with torch.no_grad():
            depth_batch = self.model(image_batch)
            depth_maps = []

            for i in range(depth_batch.size(0)):
                depth = depth_batch[i].squeeze().cpu().numpy()
                depth = (depth - depth.min()) / (depth.max() - depth.min())
                depth_maps.append(depth)

        return depth_maps

    def create_sparse_point_cloud_batch(self, depth_maps: List[np.ndarray],
                                      rgb_images: List[np.ndarray],
                                      sample_rate: int = 4) -> List[Tuple[np.ndarray, np.ndarray]]:
        """Create sparse point clouds for multiple images in parallel"""
        def process_single_pointcloud(args):
            depth_map, rgb_image = args
            h, w = depth_map.shape

            # Sample every Nth pixel for speed
            y_coords, x_coords = np.meshgrid(
                np.arange(0, h, sample_rate),
                np.arange(0, w, sample_rate),
                indexing='ij'
            )

            # Get depth and color values
            depths = depth_map[y_coords, x_coords]
            colors = rgb_image[y_coords, x_coords]

            # Create 3D coordinates
            focal_length = max(h, w)
            cx, cy = w // 2, h // 2

            z = depths * 10
            x = (x_coords - cx) * z / focal_length
            y = (y_coords - cy) * z / focal_length

            points = np.stack([x.flatten(), y.flatten(), z.flatten()], axis=1)
            colors = colors.reshape(-1, 3) / 255.0

            # Remove invalid points
            valid_mask = ~np.isnan(points).any(axis=1)
            points = points[valid_mask]
            colors = colors[valid_mask]

            return points, colors

        # Process point clouds in parallel
        with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
            results = list(executor.map(process_single_pointcloud, zip(depth_maps, rgb_images)))

        return results

    def analyze_3d_shape_batch(self, point_cloud_batch: List[Tuple[np.ndarray, np.ndarray]]) -> List[Dict]:
        """Analyze 3D shapes for multiple point clouds in parallel"""
        def analyze_single_shape(args):
            points, colors = args
            if len(points) < 10:
                return {"type": "unknown", "confidence": 0.0}

            # Calculate basic statistics
            bbox_min = np.min(points, axis=0)
            bbox_max = np.max(points, axis=0)
            bbox_size = bbox_max - bbox_min

            z_values = points[:, 2]
            z_std = np.std(z_values)

            aspect_ratio_xy = bbox_size[0] / bbox_size[1] if bbox_size[1] > 0 else 1.0
            aspect_ratio_z = bbox_size[2] / max(bbox_size[0], bbox_size[1]) if max(bbox_size[0], bbox_size[1]) > 0 else 1.0

            # Simple classification
            if z_std < 0.5 and aspect_ratio_z < 0.1:
                return {"type": "flat_surface", "confidence": 0.8}
            elif aspect_ratio_xy > 3.0 or aspect_ratio_xy < 0.3:
                return {"type": "elongated_object", "confidence": 0.7}
            elif z_std > 1.0:
                return {"type": "3d_object", "confidence": 0.6}
            else:
                return {"type": "unknown", "confidence": 0.3}

        # Analyze shapes in parallel
        with ThreadPoolExecutor(max_workers=self.num_workers) as executor:
            results = list(executor.map(analyze_single_shape, point_cloud_batch))

        return results

    def process_batch(self, image_paths: List[str], sample_rate: int = 4,
                     save_results: bool = False, output_dir: str = "batch_results") -> List[Dict]:
        """Process a batch of images efficiently"""
        start_time = time.time()

        # Create output directory if saving results
        if save_results:
            os.makedirs(output_dir, exist_ok=True)

        # Split into manageable batches
        all_results = []

        for i in range(0, len(image_paths), self.max_batch_size):
            batch_paths = image_paths[i:i + self.max_batch_size]
            batch_start = time.time()

            print(f"📦 Processing batch {i//self.max_batch_size + 1}/{(len(image_paths) + self.max_batch_size - 1)//self.max_batch_size}")

            # Preprocess batch
            image_batch, dimensions, scales = self.preprocess_batch(batch_paths)

            if image_batch.size(0) == 0:
                print("⚠️  No valid images in batch")
                continue

            # Estimate depth for batch
            depth_maps = self.estimate_depth_batch(image_batch)

            # Load RGB images for point cloud generation
            rgb_images = []
            for path in batch_paths:
                try:
                    img = cv2.imread(str(path))
                    img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
                    img = cv2.resize(img, (self.input_size, self.input_size))
                    rgb_images.append(img)
                except:
                    rgb_images.append(np.zeros((self.input_size, self.input_size, 3), dtype=np.uint8))

            # Create point clouds
            point_clouds = self.create_sparse_point_cloud_batch(depth_maps, rgb_images, sample_rate)

            # Analyze shapes
            analyses = self.analyze_3d_shape_batch(point_clouds)

            # Compile results
            for j, (path, analysis, (points, colors), depth_map) in enumerate(zip(batch_paths, analyses, point_clouds, depth_maps)):
                result = {
                    'image_path': str(path),
                    'analysis': analysis,
                    'points': points,
                    'colors': colors,
                    'depth_map': depth_map,
                    'processing_time': time.time() - batch_start,
                    'point_count': len(points)
                }

                # Save individual results if requested
                if save_results:
                    result_path = os.path.join(output_dir, f"{Path(path).stem}_result.json")
                    self._save_result(result, result_path)

                all_results.append(result)

            batch_time = time.time() - batch_start
            print(f"   ✅ Batch completed in {batch_time:.2f}s ({len(batch_paths)} images)")

        total_time = time.time() - start_time
        print(f"\n🎯 Batch processing complete:")
        print(f"   Total images: {len(image_paths)}")
        print(f"   Successfully processed: {len(all_results)}")
        print(f"   Total time: {total_time:.2f}s")
        print(f"   Average per image: {total_time/len(all_results):.2f}s")

        return all_results

    def _save_result(self, result: Dict, output_path: str):
        """Save result to JSON file (without numpy arrays)"""
        # Create a serializable version
        serializable_result = {
            'image_path': result['image_path'],
            'analysis': result['analysis'],
            'processing_time': result['processing_time'],
            'point_count': result['point_count'],
            'depth_map_shape': result['depth_map'].shape,
            'points_shape': result['points'].shape if len(result['points']) > 0 else (0, 0)
        }

        with open(output_path, 'w') as f:
            json.dump(serializable_result, f, indent=2)

    def process_directory(self, input_dir: str, extensions: List[str] = None,
                         sample_rate: int = 4, save_results: bool = False,
                         output_dir: str = "batch_results") -> List[Dict]:
        """Process all images in a directory"""
        if extensions is None:
            extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']

        # Find all image files
        image_paths = []
        for ext in extensions:
            image_paths.extend(Path(input_dir).glob(f"*{ext}"))
            image_paths.extend(Path(input_dir).glob(f"*{ext.upper()}"))

        print(f"📁 Found {len(image_paths)} images in {input_dir}")

        if not image_paths:
            print("⚠️  No images found")
            return []

        return self.process_batch(image_paths, sample_rate, save_results, output_dir)

    def get_batch_summary(self, results: List[Dict]) -> Dict:
        """Generate summary statistics for batch processing"""
        if not results:
            return {}

        # Aggregate statistics
        shape_types = {}
        total_points = 0
        processing_times = []

        for result in results:
            shape_type = result['analysis']['type']
            shape_types[shape_type] = shape_types.get(shape_type, 0) + 1
            total_points += result['point_count']
            processing_times.append(result['processing_time'])

        return {
            'total_images': len(results),
            'shape_distribution': shape_types,
            'total_points_generated': total_points,
            'average_points_per_image': total_points / len(results),
            'average_processing_time': np.mean(processing_times),
            'total_processing_time': sum(processing_times)
        }

# Utility functions for batch processing
def quick_batch_test(num_images: int = 5):
    """Quick test with multiple synthetic images"""
    test_dir = "test_batch"
    os.makedirs(test_dir, exist_ok=True)

    # Create test images
    for i in range(num_images):
        test_img = np.random.randint(0, 255, (400, 400, 3), dtype=np.uint8)
        # Add different shapes
        if i % 3 == 0:
            cv2.circle(test_img, (200, 200), 80, (255, 0, 0), -1)
        elif i % 3 == 1:
            cv2.rectangle(test_img, (50, 50), (350, 150), (0, 255, 0), -1)
        else:
            cv2.ellipse(test_img, (200, 200), (100, 50), 45, 0, 360, (0, 0, 255), -1)

        cv2.imwrite(f"{test_dir}/test_{i:03d}.jpg", test_img)

    # Run batch pipeline
    pipeline = BatchFastDepthPipeline(input_size=256, max_batch_size=4)
    results = pipeline.process_directory(test_dir, sample_rate=8, save_results=True)

    # Get summary
    summary = pipeline.get_batch_summary(results)

    # Cleanup
    import shutil
    shutil.rmtree(test_dir)

    print("\n📊 Batch Test Summary:")
    for key, value in summary.items():
        print(f"   {key}: {value}")

    return results, summary

# Memory-efficient streaming processor for very large batches
class StreamingDepthProcessor:
    def __init__(self, pipeline: BatchFastDepthPipeline, chunk_size: int = 100):
        self.pipeline = pipeline
        self.chunk_size = chunk_size

    def process_large_dataset(self, image_paths: List[str], output_file: str = "streaming_results.jsonl"):
        """Process very large datasets with streaming output"""
        total_images = len(image_paths)
        processed = 0

        with open(output_file, 'w') as f:
            for i in range(0, total_images, self.chunk_size):
                chunk_paths = image_paths[i:i + self.chunk_size]
                results = self.pipeline.process_batch(chunk_paths, save_results=False)

                # Write results incrementally
                for result in results:
                    # Convert numpy arrays to lists for JSON serialization
                    json_result = {
                        'image_path': result['image_path'],
                        'analysis': result['analysis'],
                        'processing_time': result['processing_time'],
                        'point_count': result['point_count']
                    }
                    f.write(json.dumps(json_result) + '\n')
                    f.flush()

                processed += len(results)
                print(f"💾 Streamed {processed}/{total_images} results")

        print(f"✅ Streaming complete. Results saved to {output_file}")

print("🎯 Batch-optimized setup complete!")
print("📝 Run 'quick_batch_test()' to test batch processing.")
print("📁 Use 'process_directory()' for real datasets.")
print("💾 Use 'StreamingDepthProcessor' for very large datasets.")

# Initialize the BatchFastDepthPipeline
pipeline = BatchFastDepthPipeline(input_size=256, max_batch_size=4)

# Process the batch of images
results = pipeline.process_batch(image_paths, save_results=True)

# Print the results
for result in results:
    print(result)


🎯 Batch-optimized setup complete!
📝 Run 'quick_batch_test()' to test batch processing.
📁 Use 'process_directory()' for real datasets.
💾 Use 'StreamingDepthProcessor' for very large datasets.
Downloading: "https://github.com/intel-isl/MiDaS/zipball/master" to /Users/adamaslan/.cache/torch/hub/master.zip
Downloading: "https://github.com/isl-org/MiDaS/releases/download/v3/dpt_hybrid_384.pt" to /Users/adamaslan/.cache/torch/hub/checkpoints/dpt_hybrid_384.pt


100%|██████████| 470M/470M [00:32<00:00, 15.2MB/s] 


🚀 Batch Pipeline initialized:
   Model: midas_small
   Input size: 256x256
   Device: mps
   Max batch size: 4
   Workers: 8
📦 Processing batch 1/1
   ✅ Batch completed in 7.44s (4 images)

🎯 Batch processing complete:
   Total images: 4
   Successfully processed: 4
   Total time: 7.44s
   Average per image: 1.86s
{'image_path': 'public/numbers-stonks1.png', 'analysis': {'type': '3d_object', 'confidence': 0.6}, 'points': array([[-0.19432224, -0.19432224,  0.38864449],
       [-0.22272596, -0.22991067,  0.45982134],
       [-0.26137806, -0.27880326,  0.55760652],
       ...,
       [ 4.3046149 ,  4.60148489,  9.49983978],
       [ 4.48104948,  4.63041779,  9.55957222],
       [ 4.67940703,  4.67940703,  9.66071129]], shape=(4096, 3)), 'colors': array([[1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        ],
       [1.        , 1.        , 1.        ],
       ...,
       [0.9372549 , 0.9372549 , 0.9372549 ],
       [0.93333333, 0.93333333, 0.93333333],
       [0