In [None]:
import torch
import cv2
import numpy as np
from torchvision.transforms import Compose, Resize, ToTensor, Normalize
from PIL import Image
import plotly.graph_objects as go
from transformers import pipeline
import os

def load_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint = "depth-anything/Depth-Anything-V2-base-hf"
    pipe = pipeline("depth-estimation", model=checkpoint, device=device)
    return pipe

def estimate_depth(model, img):
    if isinstance(img, np.ndarray):
        img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    else:
        img_pil = img

    predictions = model(img_pil)
    depth_map = predictions["depth"]
    depth_map_np = np.array(depth_map).squeeze()
    depth_map_resized = cv2.resize(depth_map_np, (img_pil.size[0], img_pil.size[1]))
    
    return depth_map_resized

def process_image(model, image_path):
    image = cv2.imread(image_path)
    if image is None:
        return None

    depth_map = estimate_depth(model, image)
    return image, depth_map

def combine_depth_maps(depth_maps):
    # Simple averaging of depth maps
    return np.mean(depth_maps, axis=0)

def create_point_cloud(images, depth_maps, sample_rate=10):
    points = []
    colors = []
    
    for img, depth in zip(images, depth_maps):
        h, w = depth.shape
        for y in range(0, h, sample_rate):
            for x in range(0, w, sample_rate):
                z = depth[y, x]
                if z > 0:
                    points.append([x, y, z])
                    colors.append(img[y, x])
    
    return np.array(points), np.array(colors)

def plot_3d_point_cloud(points, colors):
    fig = go.Figure(data=[go.Scatter3d(
        x=points[:, 0],
        y=points[:, 1],
        z=points[:, 2],
        mode='markers',
        marker=dict(
            size=2,
            color=['rgb({},{},{})'.format(r, g, b) for r, g, b in colors],
            opacity=0.8
        )
    )])

    fig.update_layout(
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='data'
        ),
        margin=dict(l=0, r=0, b=0, t=0)
    )

    return fig

def process_multiple_images(image_paths):
    model = load_model()
    images = []
    depth_maps = []

    for path in image_paths:
        result = process_image(model, path)
        if result is not None:
            image, depth_map = result
            images.append(image)
            depth_maps.append(depth_map)

    combined_depth = combine_depth_maps(depth_maps)
    points, colors = create_point_cloud(images, depth_maps)
    fig = plot_3d_point_cloud(points, colors)
    
    return fig

# Example usage
image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg', 'path/to/image3.jpg']
fig = process_multiple_images(image_paths)
fig.show()

### Load Dataset

In [10]:
import pandas as pd
import matplotlib.pyplot as plt
from PIL import Image
import os

def process_and_save_csv(input_csv_path, output_csv_path, max_products=None):
    # Read the CSV file
    df = pd.read_csv(input_csv_path)
    
    # Extract product_id from meta
    df['product_id'] = df['meta'].str.split(':').str[0]
    
    # Create is_main column
    df['is_main'] = df['meta'].str.endswith(':main')
    
    # Add full path to images
    df['full_path'] = "../data/small/" + df['path']
    
    # Group by product_id and aggregate
    grouped = df.groupby('product_id').agg({
        'full_path': lambda x: '|'.join(x),
        'is_main': lambda x: '|'.join(x.astype(str))
    }).reset_index()
    
    # Separate main image and other images
    def split_images(row):
        paths = row['full_path'].split('|')
        is_mains = row['is_main'].split('|')
        main_image = next((path for path, is_main in zip(paths, is_mains) if is_main == 'True'), None)
        other_images = [path for path, is_main in zip(paths, is_mains) if is_main == 'False']
        return pd.Series({'main_image': main_image, 'other_images': '|'.join(other_images)})

    grouped[['main_image', 'other_images']] = grouped.apply(split_images, axis=1)
    
    # Drop unnecessary columns
    grouped = grouped.drop(columns=['full_path', 'is_main'])
    
    if max_products:
        grouped = grouped.head(max_products)
    
    # Save to CSV
    grouped.to_csv(output_csv_path, index=False)
    
    return grouped

def load_processed_csv(csv_path):
    df = pd.read_csv(csv_path)
    df['other_images'] = df['other_images'].apply(lambda x: x.split('|') if pd.notna(x) else [])
    return df

# Main execution
input_csv_path = '../data/metadata/abo-mvr.csv'  # Replace with your actual input CSV file path
processed_csv_path = './metadata.csv'  # Replace with desired output CSV file path
max_products = None  # Set to a number if you want to limit the products processed

# Check if processed CSV exists
if os.path.exists(processed_csv_path):
    print("Loading pre-processed data...")
    dataset = load_processed_csv(processed_csv_path)
else:
    print("Processing and saving data...")
    dataset = process_and_save_csv(input_csv_path, processed_csv_path, max_products)

# Print information about the dataset
print("Dataset structure:")
print(dataset.head(10))
print("\nDataset info:")
print(dataset.info())

Processing and saving data...
Dataset structure:
   product_id                     main_image                   other_images
0  0615916821  ../data/small/20/208bd166.jpg                               
1  B00004SD6V  ../data/small/44/448ff753.jpg  ../data/small/88/88887e1c.jpg
2  B00004TBMT  ../data/small/6d/6d27b089.jpg  ../data/small/aa/aa52d129.jpg
3  B00004W42A  ../data/small/cf/cf93a1e1.jpg                               
4  B000050415  ../data/small/1c/1c4e48da.jpg                               
5  B000050419  ../data/small/a9/a96e15d9.jpg  ../data/small/fc/fc97fc19.jpg
6  B00005041B  ../data/small/d1/d1aed412.jpg                               
7  B00005041G  ../data/small/b8/b8539919.jpg                               
8  B00005041J  ../data/small/0f/0f1a37c5.jpg                               
9  B00005041K  ../data/small/91/913cba1f.jpg                               

Dataset info:
<class 'pandas.core.frame.DataFrame'>
RangeIndex: 52401 entries, 0 to 52400
Data columns (total 3 co