In [None]:
# Simple 5-Image Depth Scanner - Optimized for Small Batches
# Perfect for scanning 5 images quickly and efficiently

import os
import time
import numpy as np
import cv2
import torch
import torch.nn.functional as F
from PIL import Image
import matplotlib.pyplot as plt
from pathlib import Path
import warnings
warnings.filterwarnings('ignore')

class Simple5ImageScanner:
    def __init__(self, input_size=384, use_metal=True):
        """
        Simple scanner optimized for ~5 images
        
        Args:
            input_size: 384 for good quality/speed balance
            use_metal: Use Apple Metal backend
        """
        self.input_size = input_size
        self.device = self._setup_device(use_metal)
        self.model = self._load_model()
        
        print(f"🚀 Simple Scanner Ready:")
        print(f"   Input size: {input_size}x{input_size}")
        print(f"   Device: {self.device}")
    
    def _setup_device(self, use_metal):
        """Setup optimal device for M2 Mac"""
        if use_metal and torch.backends.mps.is_available():
            return torch.device('mps')
        else:
            return torch.device('cpu')
    
    def _load_model(self):
        """Load MiDaS model"""
        model = torch.hub.load('intel-isl/MiDaS', 'DPT_Hybrid', pretrained=True)
        model.to(self.device)
        model.eval()
        return model
    
    def scan_images(self, image_paths, show_preview=True, save_results=False):
        """
        Scan 5 images and return combined results
        
        Args:
            image_paths: List of image file paths
            show_preview: Show matplotlib preview of results
            save_results: Save depth maps and point clouds
        """
        if len(image_paths) > 10:
            print("⚠️  Warning: This scanner is optimized for ~5 images. Consider using batch scanner for more.")
        
        print(f"📸 Scanning {len(image_paths)} images...")
        start_time = time.time()
        
        # Process all images
        results = []
        all_points = []
        all_colors = []
        
        for i, img_path in enumerate(image_paths):
            print(f"   Processing image {i+1}/{len(image_paths)}: {Path(img_path).name}")
            
            # Load and preprocess
            img_tensor, rgb_img = self._preprocess_image(img_path)
            
            # Get depth
            depth_map = self._estimate_depth(img_tensor)
            
            # Create point cloud
            points, colors = self._create_point_cloud(depth_map, rgb_img)
            
            # Analyze shape
            analysis = self._analyze_shape(points)
            
            # Store result
            result = {
                'image_path': str(img_path),
                'analysis': analysis,
                'points': points,
                'colors': colors,
                'depth_map': depth_map,
                'point_count': len(points)
            }
            results.append(result)
            
            # Combine for overall scan
            if len(points) > 0:
                # Offset points for each image to avoid overlap
                offset_points = points.copy()
                offset_points[:, 0] += i * 15  # Spread images apart
                all_points.append(offset_points)
                all_colors.append(colors)
        
        # Combine all point clouds
        if all_points:
            combined_points = np.vstack(all_points)
            combined_colors = np.vstack(all_colors)
        else:
            combined_points = np.array([])
            combined_colors = np.array([])
        
        total_time = time.time() - start_time
        
        # Show preview if requested
        if show_preview:
            self._show_results(results, combined_points, combined_colors)
        
        # Save results if requested
        if save_results:
            self._save_results(results, combined_points, combined_colors)
        
        # Print summary
        print(f"\n🎯 Scan Complete:")
        print(f"   Images processed: {len(results)}")
        print(f"   Total processing time: {total_time:.2f}s")
        print(f"   Average per image: {total_time/len(results):.2f}s")
        print(f"   Total points: {len(combined_points)}")
        
        return {
            'individual_results': results,
            'combined_points': combined_points,
            'combined_colors': combined_colors,
            'processing_time': total_time,
            'summary': self._get_summary(results)
        }
    
    def _preprocess_image(self, image_path):
        """Preprocess single image"""
        # Load image
        img = cv2.imread(str(image_path))
        if img is None:
            raise ValueError(f"Could not load image: {image_path}")
        
        # Convert to RGB
        rgb_img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        
        # Resize maintaining aspect ratio
        h, w = rgb_img.shape[:2]
        scale = self.input_size / max(h, w)
        new_h, new_w = int(h * scale), int(w * scale)
        rgb_img = cv2.resize(rgb_img, (new_w, new_h))
        
        # Pad to square
        pad_h = self.input_size - new_h
        pad_w = self.input_size - new_w
        rgb_img = np.pad(rgb_img, ((0, pad_h), (0, pad_w), (0, 0)), mode='constant')
        
        # Convert to tensor
        img_tensor = rgb_img.astype(np.float32) / 255.0
        img_tensor = torch.from_numpy(img_tensor).permute(2, 0, 1).unsqueeze(0)
        
        return img_tensor.to(self.device), rgb_img
    
    def _estimate_depth(self, img_tensor):
        """Estimate depth for single image"""
        with torch.no_grad():
            depth = self.model(img_tensor)
            depth = depth.squeeze().cpu().numpy()
            depth = (depth - depth.min()) / (depth.max() - depth.min())
        return depth
    
    def _create_point_cloud(self, depth_map, rgb_img, sample_rate=3):
        """Create point cloud from depth map"""
        h, w = depth_map.shape
        
        # Sample points
        y_coords, x_coords = np.meshgrid(
            np.arange(0, h, sample_rate),
            np.arange(0, w, sample_rate),
            indexing='ij'
        )
        
        # Get depth and color values
        depths = depth_map[y_coords, x_coords]
        colors = rgb_img[y_coords, x_coords]
        
        # Create 3D coordinates
        focal_length = max(h, w)
        cx, cy = w // 2, h // 2
        
        z = depths * 10
        x = (x_coords - cx) * z / focal_length
        y = (y_coords - cy) * z / focal_length
        
        points = np.stack([x.flatten(), y.flatten(), z.flatten()], axis=1)
        colors = colors.reshape(-1, 3) / 255.0
        
        # Remove invalid points
        valid_mask = ~np.isnan(points).any(axis=1)
        points = points[valid_mask]
        colors = colors[valid_mask]
        
        return points, colors
    
    def _analyze_shape(self, points):
        """Analyze 3D shape"""
        if len(points) < 10:
            return {"type": "unknown", "confidence": 0.0}
        
        # Basic shape analysis
        bbox_min = np.min(points, axis=0)
        bbox_max = np.max(points, axis=0)
        bbox_size = bbox_max - bbox_min
        
        z_std = np.std(points[:, 2])
        aspect_ratio_xy = bbox_size[0] / bbox_size[1] if bbox_size[1] > 0 else 1.0
        aspect_ratio_z = bbox_size[2] / max(bbox_size[0], bbox_size[1]) if max(bbox_size[0], bbox_size[1]) > 0 else 1.0
        
        # Simple classification
        if z_std < 0.5 and aspect_ratio_z < 0.1:
            return {"type": "flat_surface", "confidence": 0.8}
        elif aspect_ratio_xy > 3.0 or aspect_ratio_xy < 0.3:
            return {"type": "elongated_object", "confidence": 0.7}
        elif z_std > 1.0:
            return {"type": "3d_object", "confidence": 0.6}
        else:
            return {"type": "rounded_object", "confidence": 0.5}
    
    def _show_results(self, results, combined_points, combined_colors):
        """Show matplotlib preview of results"""
        n_images = len(results)
        fig = plt.figure(figsize=(20, 4 * n_images))
        
        for i, result in enumerate(results):
            # Original image
            plt.subplot(n_images, 4, i*4 + 1)
            orig_img = cv2.imread(result['image_path'])
            orig_img = cv2.cvtColor(orig_img, cv2.COLOR_BGR2RGB)
            plt.imshow(orig_img)
            plt.title(f"Image {i+1}: {Path(result['image_path']).name}")
            plt.axis('off')
            
            # Depth map
            plt.subplot(n_images, 4, i*4 + 2)
            plt.imshow(result['depth_map'], cmap='plasma')
            plt.title(f"Depth Map")
            plt.axis('off')
            
            # Point cloud top view
            plt.subplot(n_images, 4, i*4 + 3)
            if len(result['points']) > 0:
                plt.scatter(result['points'][:, 0], result['points'][:, 1], 
                           c=result['points'][:, 2], cmap='viridis', s=1)
                plt.title(f"Point Cloud\n{result['analysis']['type']}")
            plt.axis('equal')
            
            # Side view
            plt.subplot(n_images, 4, i*4 + 4)
            if len(result['points']) > 0:
                plt.scatter(result['points'][:, 0], result['points'][:, 2], 
                           c=result['points'][:, 1], cmap='viridis', s=1)
                plt.title(f"Side View\n{result['point_count']} points")
            plt.axis('equal')
        
        plt.tight_layout()
        plt.show()
        
        # Show combined point cloud
        if len(combined_points) > 0:
            fig = plt.figure(figsize=(15, 5))
            
            # Combined top view
            plt.subplot(1, 3, 1)
            plt.scatter(combined_points[:, 0], combined_points[:, 1], 
                       c=combined_colors, s=1)
            plt.title('Combined Scan - Top View')
            plt.axis('equal')
            
            # Combined side view
            plt.subplot(1, 3, 2)
            plt.scatter(combined_points[:, 0], combined_points[:, 2], 
                       c=combined_points[:, 1], cmap='viridis', s=1)
            plt.title('Combined Scan - Side View')
            plt.axis('equal')
            
            # 3D depth visualization
            plt.subplot(1, 3, 3)
            plt.scatter(combined_points[:, 1], combined_points[:, 2], 
                       c=combined_points[:, 0], cmap='plasma', s=1)
            plt.title('Combined Scan - Front View')
            plt.axis('equal')
            
            plt.tight_layout()
            plt.show()
    
    def _save_results(self, results, combined_points, combined_colors):
        """Save results to files"""
        os.makedirs("scan_results", exist_ok=True)
        
        # Save individual depth maps
        for i, result in enumerate(results):
            # Save depth map
            depth_path = f"scan_results/depth_map_{i+1}.png"
            depth_normalized = (result['depth_map'] * 255).astype(np.uint8)
            cv2.imwrite(depth_path, depth_normalized)
            
            # Save point cloud as PLY
            if len(result['points']) > 0:
                ply_path = f"scan_results/pointcloud_{i+1}.ply"
                self._save_ply(result['points'], result['colors'], ply_path)
        
        # Save combined point cloud
        if len(combined_points) > 0:
            combined_ply_path = "scan_results/combined_scan.ply"
            self._save_ply(combined_points, combined_colors, combined_ply_path)
        
        print(f"💾 Results saved to 'scan_results/' folder")
    
    def _save_ply(self, points, colors, filename):
        """Save point cloud as PLY file"""
        with open(filename, 'w') as f:
            f.write("ply\n")
            f.write("format ascii 1.0\n")
            f.write(f"element vertex {len(points)}\n")
            f.write("property float x\n")
            f.write("property float y\n")
            f.write("property float z\n")
            f.write("property uchar red\n")
            f.write("property uchar green\n")
            f.write("property uchar blue\n")
            f.write("end_header\n")
            
            for i in range(len(points)):
                x, y, z = points[i]
                r, g, b = (colors[i] * 255).astype(int)
                f.write(f"{x} {y} {z} {r} {g} {b}\n")
    
    def _get_summary(self, results):
        """Get scan summary"""
        shape_types = {}
        total_points = 0
        
        for result in results:
            shape_type = result['analysis']['type']
            shape_types[shape_type] = shape_types.get(shape_type, 0) + 1
            total_points += result['point_count']
        
        return {
            'total_images': len(results),
            'shape_distribution': shape_types,
            'total_points': total_points,
            'average_points_per_image': total_points / len(results) if results else 0
        }

# Quick usage functions
def scan_5_images(image_paths, show_preview=True, save_results=False):
    """
    Quick function to scan 5 images
    
    Usage:
        # Scan images and show results
        results = scan_5_images(['img1.jpg', 'img2.jpg', 'img3.jpg', 'img4.jpg', 'img5.jpg'])
        
        # Scan and save results
        results = scan_5_images(image_paths, save_results=True)
    """
    scanner = Simple5ImageScanner(input_size=384)
    return scanner.scan_images(image_paths, show_preview=show_preview, save_results=save_results)

def scan_folder(folder_path, max_images=5, show_preview=True, save_results=False):
    """
    Scan the first 5 images from a folder
    
    Usage:
        results = scan_folder('my_images_folder/')
    """
    # Find image files
    image_extensions = ['.jpg', '.jpeg', '.png', '.bmp', '.tiff']
    image_paths = []
    
    for ext in image_extensions:
        image_paths.extend(Path(folder_path).glob(f"*{ext}"))
        image_paths.extend(Path(folder_path).glob(f"*{ext.upper()}"))
    
    # Take first max_images
    image_paths = sorted(image_paths)[:max_images]
    
    if not image_paths:
        print(f"❌ No images found in {folder_path}")
        return None
    
    print(f"📁 Found {len(image_paths)} images in {folder_path}")
    
    scanner = Simple5ImageScanner(input_size=384)
    return scanner.scan_images(image_paths, show_preview=show_preview, save_results=save_results)

def quick_test():
    """Create test images and scan them"""
    # Create test images
    test_images = []
    for i in range(5):
        img = np.random.randint(0, 255, (400, 400, 3), dtype=np.uint8)
        
        # Add different shapes
        if i == 0:
            cv2.circle(img, (200, 200), 80, (255, 0, 0), -1)
        elif i == 1:
            cv2.rectangle(img, (100, 100), (300, 200), (0, 255, 0), -1)
        elif i == 2:
            cv2.ellipse(img, (200, 200), (100, 50), 45, 0, 360, (0, 0, 255), -1)
        elif i == 3:
            # Triangle
            pts = np.array([[200, 50], [100, 250], [300, 250]], np.int32)
            cv2.fillPoly(img, [pts], (255, 255, 0))
        else:
            # Complex shape
            cv2.circle(img, (150, 150), 50, (255, 0, 255), -1)
            cv2.rectangle(img, (200, 200), (350, 350), (0, 255, 255), -1)
        
        filename = f"test_image_{i+1}.jpg"
        cv2.imwrite(filename, img)
        test_images.append(filename)
    
    # Scan images
    results = scan_5_images(test_images, show_preview=True, save_results=True)
    
    # Cleanup
    for img in test_images:
        os.remove(img)
    
    return results

print("🎯 Simple 5-Image Scanner Ready!")
print("📝 Use: scan_5_images(['img1.jpg', 'img2.jpg', ...]) to scan specific images")
print("📁 Use: scan_folder('folder_path/') to scan first 5 images from folder")
print("🧪 Use: quick_test() to test with synthetic images")