In [15]:
import pandas as pd
import numpy as np
import cv2
from PIL import Image
import plotly.graph_objects as go
from plotly.subplots import make_subplots
from transformers import pipeline
import torch
import os

def load_processed_csv(csv_path):
    df = pd.read_csv(csv_path)
    df['other_images'] = df['other_images'].apply(lambda x: x.split('|') if pd.notna(x) else [])
    return df

def load_model():
    device = "cuda" if torch.cuda.is_available() else "cpu"
    checkpoint = "depth-anything/Depth-Anything-V2-base-hf"
    pipe = pipeline("depth-estimation", model=checkpoint, device=device)
    return pipe

def estimate_depth(model, img):
    if isinstance(img, np.ndarray):
        img_pil = Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))
    else:
        img_pil = img

    predictions = model(img_pil)
    depth_map = predictions["depth"]
    depth_map_np = np.array(depth_map).squeeze()
    depth_map_resized = cv2.resize(depth_map_np, (img_pil.size[0], img_pil.size[1]))
    
    return depth_map_resized

def process_image(model, image_path):
    image = cv2.imread(image_path)
    if image is None:
        return None

    depth_map = estimate_depth(model, image)
    return image, depth_map

def create_point_cloud(images, depth_maps, sample_rate=10):
    points = []
    colors = []
    
    for img, depth in zip(images, depth_maps):
        h, w = depth.shape
        for y in range(0, h, sample_rate):
            for x in range(0, w, sample_rate):
                z = depth[y, x]
                if z > 0:
                    points.append([x, y, z])
                    colors.append(img[y, x])
    
    return np.array(points), np.array(colors)

def plot_3d_point_cloud_and_images(points, colors, images, product_id):
    fig = make_subplots(
        rows=1, cols=2,
        column_widths=[0.7, 0.3],
        specs=[[{"type": "scatter3d"}, {"type": "image"}]],
        subplot_titles=("3D Point Cloud", "Original Images")
    )

    # 3D Point Cloud
    fig.add_trace(
        go.Scatter3d(
            x=points[:, 0],
            y=points[:, 1],
            z=points[:, 2],
            mode='markers',
            marker=dict(
                size=2,
                color=['rgb({},{},{})'.format(r, g, b) for r, g, b in colors],
                opacity=0.8
            )
        ),
        row=1, col=1
    )

    # Original Images
    for i, img in enumerate(images):
        img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
        fig.add_trace(
            go.Image(z=img_rgb),
            row=1, col=2
        )

    fig.update_layout(
        title=f"3D Point Cloud and Original Images for Product ID: {product_id}",
        scene=dict(
            xaxis_title='X',
            yaxis_title='Y',
            zaxis_title='Z',
            aspectmode='data'
        ),
        height=800,
        width=1400
    )

    return fig

def process_product_images(row, model):
    product_id = row['product_id']
    image_paths = ([row['main_image']] if pd.notna(row['main_image']) else []) + row['other_images']
    
    images = []
    depth_maps = []

    for path in image_paths:
        result = process_image(model, path)
        if result is not None:
            image, depth_map = result
            images.append(image)
            depth_maps.append(depth_map)

    if images and depth_maps:
        points, colors = create_point_cloud(images, depth_maps)
        fig = plot_3d_point_cloud_and_images(points, colors, images, product_id)
        return fig
    else:
        return None

# Main execution
csv_path = '../metadata.csv'  # Path to your processed CSV file
dataset = load_processed_csv(csv_path)

# Load the model once to be reused for all products
model = load_model()

# Process the first product (or you can loop through all products)
first_product = dataset.iloc[3]
fig = process_product_images(first_product, model)

if fig is not None:
    fig.show()
else:
    print("No valid images found for the first product.")

# Uncomment the following lines to process all products
# for index, row in dataset.iterrows():
#     fig = process_product_images(row, model)
#     if fig is not None:
#         fig.show()
#     else: