In [None]:
import torch
import cv2
import numpy as np
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
import plotly.graph_objects as go
from transformers import pipeline
import os

def load_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint = "depth-anything/Depth-Anything-V2-base-hf"
    pipe = pipeline("depth-estimation", model=checkpoint, device=device)
    return pipe

def estimate_depth(model, img):
    if isinstance(img, np.ndarray):
        img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    else:
        img_pil = img

    predictions = model(img_pil)
    depth_map = predictions["depth"]
    depth_map_np = np.array(depth_map).squeeze()
    depth_map_resized = cv2.resize(depth_map_np, (img_pil.size[0], img_pil.size[1]))
    
    return depth_map_resized

def process_image(model, image_path):
    image = cv2.imread(image_path)
    if image is None:
        return None

    depth_map = estimate_depth(model, image)
    return image, depth_map

def combine_depth_maps(depth_maps):
    # Simple averaging of depth maps
    return np.mean(depth_maps, axis=0)

def create_point_cloud(images, depth_maps, sample_rate=10):
    points = []
    colors = []
    
    for img, depth in zip(images, depth_maps):
        h, w = depth.shape
        for y in range(0, h, sample_rate):
            for x in range(0, w, sample_rate):
                z = depth[y, x]
                if z > 0:
                    points.append([x, y, z])
                    colors.append(img[y, x])
    
    return np.array(points), np.array(colors)

def plot_3d_point_cloud(points, colors):
    fig = go.Figure(data=[go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=['rgb({},{},{})'.format(r, g, b) for r, g, b in colors],
            opacity=0.8
        )
    )])

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='data'
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )

    return fig

def process_multiple_images(image_paths):
    model = load_model()
    images = []
    depth_maps = []

    for path in image_paths:
        result = process_image(model, path)
        if result is not None:
            image, depth_map = result
            images.append(image)
            depth_maps.append(depth_map)

    combined_depth = combine_depth_maps(depth_maps)
    points, colors = create_point_cloud(images, depth_maps)
    fig = plot_3d_point_cloud(points, colors)
    
    return fig

# Example usage
image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg', 'path/to/image3.jpg']
fig = process_multiple_images(image_paths)
fig.show()

### Load Dataset

In [None]:
import pandas as pd
import matplotlib.pyplot as plt
from collections import defaultdict
import os
from PIL import Image

def process_csv(csv_path):
    # Read the CSV file
    df = pd.read_csv(csv_path)
    
    # Create a dictionary to store the dataset
    dataset = defaultdict(lambda: {'main_image': None, 'other_images': []})
    
    # Process each row in the dataframe
    for _, row in df.iterrows():
        product_id = row['meta'].split(':')[0]
        image_path = os.path.join("../data/small", row['path'])
        
        if row['meta'].endswith(':main'):
            dataset[product_id]['main_image'] = image_path
        else:
            dataset[product_id]['other_images'].append(image_path)
    
    return dataset

def plot_images(product_id, image_paths):
    num_images = len(image_paths)
    cols = 3
    rows = (num_images + cols - 1) // cols
    
    fig, axes = plt.subplots(rows, cols, figsize=(15, 5 * rows))
    fig.suptitle(f"Images for Product ID: {product_id}", fontsize=16)
    
    for i, path in enumerate(image_paths):
        img = Image.open(path)
        ax = axes[i // cols, i % cols] if rows > 1 else axes[i % cols]
        ax.imshow(img)
        ax.axis('off')
        ax.set_title('Main' if path == image_paths[0] else f'Other {i}')
    
    # Remove empty subplots
    for i in range(num_images, rows * cols):
        ax = axes[i // cols, i % cols] if rows > 1 else axes[i % cols]
        ax.axis('off')
    
    plt.tight_layout()
    plt.show()

# Main execution
csv_path = '../data/metadata/abo-mvr.csv'  # Replace with your actual CSV file path
dataset = process_csv(csv_path)

# Print the first few items in the dataset
print("Dataset structure:")
for product_id, data in list(dataset.items())[:3]:  # Print first 3 items
    print(f"Product ID: {product_id}")
    print(f"  Main image: {data['main_image']}")
    print(f"  Other images: {data['other_images'][:2]}...")  # Print first 2 other images
    print()

# Plot images for one product ID (for testing)
test_product_id = next(iter(dataset))  # Get the first product ID
test_image_paths = [dataset[test_product_id]['main_image']] + dataset[test_product_id]['other_images']
plot_images(test_product_id, test_image_paths)