# Imports

In [1]:
import torch
import torch.nn.functional as F
import logging
import os
import os.path as osp
import gc

import cupy

import sys

try:
    from mmcv.utils import Config, DictAction
except:
    from mmengine import Config, DictAction
from mono.utils.logger import setup_logger
import glob
from mono.utils.comm import init_env
from mono.model.monodepth_model import get_configured_monodepth_model
from mono.utils.running import load_ckpt
from mono.utils.do_test import transform_test_data_scalecano, get_prediction
from mono.utils.custom_data import load_from_annos, load_data

from mono.utils.avg_meter import MetricAverageMeter
from mono.utils.visualization import save_val_imgs, create_html, save_raw_imgs, save_normal_val_imgs
import cv2
from tqdm import tqdm
import numpy as np
from PIL import Image, ExifTags
import matplotlib.pyplot as plt

from mono.utils.unproj_pcd import reconstruct_pcd, save_point_cloud, ply_to_obj
from mono.utils.transform import gray_to_colormap
from mono.utils.visualization import vis_surface_normal
import gradio as gr
import plotly.graph_objects as go

# Functions

In [1]:
def predict_depth_normal(img, model_selection="vit-large", fx=1000.0, fy=1000.0, state_cache={}):
    cfg_large = Config.fromfile('/home/hydra/workspace/tamp_warm_start/notebooks/metric3d/mono/configs/HourglassDecoder/vit.raft5.large.py')
    model_large = get_configured_monodepth_model(cfg_large, )
    model_large, _,  _, _ = load_ckpt('/home/hydra/workspace/tamp_warm_start/notebooks/metric3d/weight/metric_depth_vit_large_800k.pth', model_large, strict_match=False)
    model_large.eval()
    
    cfg_small = Config.fromfile('/home/hydra/workspace/tamp_warm_start/notebooks/metric3d/mono/configs/HourglassDecoder/vit.raft5.small.py')
    model_small = get_configured_monodepth_model(cfg_small, )
    model_small, _,  _, _ = load_ckpt('/home/hydra/workspace/tamp_warm_start/notebooks/metric3d/weight/metric_depth_vit_small_800k.pth', model_small, strict_match=False)
    model_small.eval()
    
    device = "cuda"
    model_large.to(device)
    model_small.to(device)

    if model_selection == "vit-small":
        model = model_small
        cfg = cfg_small
    elif model_selection == "vit-large":
        model = model_large
        cfg = cfg_large
    else:
        return None, None, None, None, state_cache, "Not implemented model."
    
    if img is None:
        return None, None, None, None, state_cache, "Please upload an image and wait for the upload to complete."

    cv_image = np.array(img) 
    img = cv2.cvtColor(cv_image, cv2.COLOR_BGR2RGB)
    intrinsic = [fx, fy, img.shape[1]/2, img.shape[0]/2]
    rgb_input, cam_models_stacks, pad, label_scale_factor = transform_test_data_scalecano(img, intrinsic, cfg.data_basic)

    with torch.no_grad():
        pred_depth, pred_depth_scale, scale, output, confidence = get_prediction(
                    model = model,
                    input = rgb_input,
                    cam_model = cam_models_stacks,
                    pad_info = pad,
                    scale_info = label_scale_factor,
                    gt_depth = None,
                    normalize_scale = cfg.data_basic.depth_range[1],
                    ori_shape=[img.shape[0], img.shape[1]],
                )

        pred_normal = output['normal_out_list'][0][:, :3, :, :] 
        H, W = pred_normal.shape[2:]
        pred_normal = pred_normal[:, :, pad[0]:H-pad[1], pad[2]:W-pad[3]]

    pred_depth = pred_depth.squeeze().cpu().numpy()
    pred_depth[pred_depth<0] = 0
    pred_color = gray_to_colormap(pred_depth)

    pred_normal = torch.nn.functional.interpolate(pred_normal, [img.shape[0], img.shape[1]], mode='bilinear').squeeze()
    pred_normal = pred_normal.permute(1,2,0)
    pred_color_normal = vis_surface_normal(pred_normal)
    pred_normal = pred_normal.cpu().numpy()
    
    # Storing depth and normal map in state for potential 3D reconstruction
    state_cache['depth'] = pred_depth
    state_cache['normal'] = pred_normal
    state_cache['img'] = img
    state_cache['intrinsic'] = intrinsic
    state_cache['confidence'] = confidence 

    # save depth and normal map to .npy file
    if 'save_dir' not in state_cache:
        cache_id = np.random.randint(0, 100000000000)
        while osp.exists(f'recon_cache/{cache_id:08d}'):
            cache_id = np.random.randint(0, 100000000000)
        state_cache['save_dir'] = f'recon_cache/{cache_id:08d}'
        os.makedirs(state_cache['save_dir'], exist_ok=True)
    depth_file = f"{state_cache['save_dir']}/depth.npy"
    normal_file = f"{state_cache['save_dir']}/normal.npy"
    np.save(depth_file, pred_depth)
    np.save(normal_file, pred_normal)

    ##formatted = (output * 255 / np.max(output)).astype('uint8')
    img = Image.fromarray(pred_color)
    img_normal = Image.fromarray(pred_color_normal)

    del model_large, model_small
    gc.collect()
    torch.cuda.empty_cache()
    
    return img, depth_file, img_normal, normal_file, state_cache, "Success!"

In [3]:
def unprojection_pcd(state_cache, name):
    depth_map = state_cache.get('depth', None)
    normal_map = state_cache.get('normal', None)
    img = state_cache.get('img', None)
    intrinsic = state_cache.get('intrinsic', None)

    if depth_map is None or img is None:
        return None, "Please predict depth and normal first."
    
    # # downsample/upsample the depth map to confidence map size
    # confidence = state_cache.get('confidence', None)
    # if confidence is not None:
    #     H, W = confidence.shape
    #     # intrinsic[0] *= W / depth_map.shape[1]
    #     # intrinsic[1] *= H / depth_map.shape[0]
    #     # intrinsic[2] *= W / depth_map.shape[1]
    #     # intrinsic[3] *= H / depth_map.shape[0]
    #     depth_map = cv2.resize(depth_map, (W, H), interpolation=cv2.INTER_LINEAR)
    #     img = cv2.resize(img, (W, H), interpolation=cv2.INTER_LINEAR)
    
    #     # filter out depth map by confidence
    #     mask = confidence.cpu().numpy() > 0

    # downsample the depth map if too large
    if depth_map.shape[0] > 1080:
        scale = 1080 / depth_map.shape[0]
        depth_map = cv2.resize(depth_map, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
        img = cv2.resize(img, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_LINEAR)
        intrinsic = [intrinsic[0]*scale, intrinsic[1]*scale, intrinsic[2]*scale, intrinsic[3]*scale]
    
    if 'save_dir' not in state_cache:
        cache_id = np.random.randint(0, 100000000000)
        while osp.exists(f'recon_cache/{cache_id:08d}'):
            cache_id = np.random.randint(0, 100000000000)
        state_cache['save_dir'] = f'recon_cache/{cache_id:08d}'
        os.makedirs(state_cache['save_dir'], exist_ok=True)

    pcd_ply = f"{state_cache['save_dir']}/output_{name}.ply"
    pcd_obj = pcd_ply.replace(".ply", ".obj")

    pcd = reconstruct_pcd(depth_map, intrinsic[0], intrinsic[1], intrinsic[2], intrinsic[3])
    # if mask is not None:
    #     pcd_filtered = pcd[mask]
    #     img_filtered = img[mask]
    pcd_filtered = pcd.reshape(-1, 3)
    img_filtered = img.reshape(-1, 3)

    save_point_cloud(pcd_filtered, img_filtered, pcd_ply, binary=False)
    # ply_to_obj(pcd_ply, pcd_obj)

    # downsample the point cloud for visualization
    num_samples = 250000
    if pcd_filtered.shape[0] > num_samples:
        indices = np.random.choice(pcd_filtered.shape[0], num_samples, replace=False)
        pcd_downsampled = pcd_filtered[indices]
        img_downsampled = img_filtered[indices]
    else:
        pcd_downsampled = pcd_filtered
        img_downsampled = img_filtered

    # plotly show
    color_str = np.array([f"rgb({r},{g},{b})" for b,g,r in img_downsampled])
    data=[go.Scatter3d(
        x=pcd_downsampled[:,0],
        y=pcd_downsampled[:,1],
        z=pcd_downsampled[:,2],
        mode='markers',
        marker=dict(
            size=1,
            color=color_str,
            opacity=0.8,
        )
    )]
    layout = go.Layout(
        margin=dict(l=0, r=0, b=0, t=0),
        scene=dict(
            camera = dict(
                eye=dict(x=0, y=0, z=-1),
                up=dict(x=0, y=-1, z=0)
            ),
            xaxis=dict(showgrid=False, showticklabels=False, visible=False),
            yaxis=dict(showgrid=False, showticklabels=False, visible=False),
            zaxis=dict(showgrid=False, showticklabels=False, visible=False),
        )
    )
    fig = go.Figure(data=data, layout=layout)

    return fig, pcd_ply, "Success!"

In [1]:
def generate_pointcloud(img, name):    
    depth_output, depth_output_scale, normal_output, normal_file, state_cache, message_box = predict_depth_normal(img)
    pcd_output, pcd_ply, message_box = unprojection_pcd(state_cache, name)
    return state_cache.get('depth', None), pcd_output, pcd_ply, depth_output

# Tests

In [25]:
# img = Image.open("../flux/generated_images/grasp/floor/image0.png")
# depth_output_match, pcd_output_match, pcd_ply_match = generate_pointcloud(img)

In [26]:
# depth_output_match

In [24]:
# pcd_output_match

In [23]:
# img = Image.open("/home/ilyass/Pictures/Screenshots/reference.png")
# pcd_output_ref, pcd_ply_ref = generate_pointcloud(img)

In [22]:
# pcd_output_ref