# ControlNet Based Pipeline

In [None]:
import os
import urllib.request
from io import BytesIO
import requests
import json

import torch

from PIL import Image
import numpy as np
import cv2
import matplotlib.pyplot as plt
import scipy
import h5py

from diffusers import (ControlNetModel, DiffusionPipeline,
                       StableDiffusionControlNetPipeline,
                       UniPCMultistepScheduler)

import open3d as o3d

import matplotlib.image as mpimg
import re
from csv import writer



In [None]:
torch.cuda.empty_cache()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")

## Data Utils

In [None]:
def prepare_nyu_data(rgb_img, depth_map):

    # reshape
    img_ = np.empty([rgb_img.shape[2], rgb_img.shape[1], 3])
    img_[:,:,0] = rgb_img[0,:,:].T
    img_[:,:,1] = rgb_img[1,:,:].T
    img_[:,:,2] = rgb_img[2,:,:].T
    
    depth_np = np.asarray(depth_map.T, dtype=np.float32, order="C" )
    return img_.astype(np.uint8), depth_np.astype(np.float32)

In [None]:
def align_midas(midas_pred, ground_truth):
    ground_truth_invert = 1 / (ground_truth + 10e-6) # invert absolute depth with meters
    x = midas_pred.copy().flatten()  # Midas Depth
    y = ground_truth_invert.copy().flatten()  # Realsense invert Depth
    A = np.vstack([x, np.ones(len(x))]).T
    s, t = np.linalg.lstsq(A, y, rcond=None)[0]
    midas_aligned_invert = midas_pred * s + t
    midas_aligned = 1 /  (midas_aligned_invert + 10e-6)

    return midas_aligned

In [None]:
# for nicer looking plot titles 
def break_up_string(text, line_limit = 50):

    char_count = 0
    new_text = ""
    for word in text.split():
        if not new_text:
            new_text = word
            char_count = len(word)
        elif len(word) + char_count < line_limit:
            new_text = " ".join([new_text,word])
            char_count += len(word)
        else:
            new_text = "\n".join([new_text,word])
            char_count = len(word)
    return new_text
        
        

In [None]:
new_text

## Loading ControlNet

In [None]:
CONTROLNET_MODEL_ID = 'lllyasviel/sd-controlnet-depth'
#'runwayml/stable-diffusion-v1-5'

BASE_MODEL_ID = "runwayml/stable-diffusion-v1-5"

torch.cuda.empty_cache()
controlnet = ControlNetModel.from_pretrained(CONTROLNET_MODEL_ID,
                                                     torch_dtype=torch.float32)
coltrolnet_pipe = StableDiffusionControlNetPipeline.from_pretrained(
            BASE_MODEL_ID,
            safety_checker=None,
            controlnet=controlnet,
            torch_dtype=torch.float32)
coltrolnet_pipe.scheduler = UniPCMultistepScheduler.from_config(
            coltrolnet_pipe.scheduler.config)
#coltrolnet_pipe.enable_xformers_memory_efficient_attention()
coltrolnet_pipe.to(device)
torch.cuda.empty_cache()

## Inference on ControlNet

### Constants

In [None]:
# Defaults ------------
image_resolution = 512
depth_resolution = 512

additional_prompt = 'best quality, extremely detailed'
negative_prompt   = 'longbody, lowres, bad anatomy, bad hands, missing fingers, extra digit, fewer digits, cropped, worst quality, low quality'
num_steps         = 20
guidance_scale    = 9
seed              = 1825989188

### Prompt Ideas

In [None]:
interior_design_prompt_1 = "Intricate, Ornate, Embellished, Elaborate, Detailed, Decorative, Intricately-crafted, Luxurious, Ornamented, and Artistic cloak, open book, sparks, cozy library in background, furniture, fire place, food, wine, pet, chandelier, High Definition, Night time, Photorealism, realistic"
interior_design_prompt_2 = "Residential home high end futuristic interior, olson kundig, Interior Design by Dorothy Draper, maison de verre, axel vervoordt, award winning photography of an indoor-outdoor living library space, minimalist modern designs, high end indoor/outdoor residential living space, rendered in vray, rendered in octane, rendered in unreal engine, architectural photography, photorealism, featured in dezeen, cristobal palma. 5 chaparral landscape outside, black surfaces/textures for furnishings in outdoor space"
#interior_design_prompt_3 = 

### Inference functions

In [None]:
def HWC3(x):
    #assert x.dtype == np.uint8
    if x.ndim == 2:
        x = x[:, :, None]
    assert x.ndim == 3
    H, W, C = x.shape
    assert C == 1 or C == 3 or C == 4
    if C == 3:
        return x
    if C == 1:
        return np.concatenate([x, x, x], axis=2)
    if C == 4:
        color = x[:, :, 0:3].astype(np.float32)
        alpha = x[:, :, 3:4].astype(np.float32) / 255.0
        y = color * alpha + 255.0 * (1.0 - alpha)
        y = y.clip(0, 255).astype(np.uint8)
        return y
    

def resize_image(input_image, resolution):
    
    if len(input_image.shape) == 3:
        H, W, C = input_image.shape
    else:
         H, W = input_image.shape
    H = float(H)
    W = float(W)
    k = float(resolution) / min(H, W)
    H *= k
    W *= k
    H = int(np.round(H / 64.0)) * 64
    W = int(np.round(W / 64.0)) * 64
    img = cv2.resize(input_image, (W, H), interpolation=cv2.INTER_LANCZOS4 if k > 1 else cv2.INTER_AREA)
    return img

def preprocess_for_controlnet(ground_depth_map):
    input_image = HWC3(src_img_np)
    input_depth = HWC3(ground_depth)
    
    input_image = resize_image(input_image, image_resolution)
    input_depth = resize_image(input_depth, image_resolution)
    return PIL.Image.fromarray(control_image), PIL.Image.fromarray(input_depth)
   
#src_img_np, ground_depth_map 
#plt.imshow(ground_depth_map, cmap='RdBu')
#plt.colorbar()

In [None]:
def prepare_nyu_controlnet(x, is_nyu_ground=True, num_samples=1):

    if is_nyu_ground:
        a = x* 25.5
        depth_pt = 1 / (a + 10e-6)
    else:
        depth_pt = x
        
    depth_pt -= np.min(depth_pt)
    depth_pt /= np.max(depth_pt)
    b = (depth_pt * 255.0).clip(0, 255).astype(np.uint8)

    c = HWC3(b)

    temp_img = resize_image(c, image_resolution)
    H, W, C = temp_img.shape

    detected_map = cv2.resize(c, (W, H), interpolation=cv2.INTER_LINEAR)
    
    #detected_map = np.moveaxis(detected_map, -1, 0)

    ##control = torch.from_numpy(detected_map.copy()).float() / 255.0
    
    #result = Image.fromarray(detected_map)

    #result = torch.stack([control for _ in range(num_samples)], dim=0)
    control_image = Image.fromarray(np.uint8(detected_map))
    
    return control_image, H, W

In [None]:
def infer_controlnet(source_image, prompt, H, W, guidance_scale = 7.5,  depth_image= None, 
                     num_inference_steps=50, save_name= '', comparison_save_name= '', ): 
    
    generator = torch.Generator().manual_seed(seed)
    
    prompt = f'{prompt}, {additional_prompt}'
    
    #source_image = Image.fromarray(source_image)
    #depth_tensor = torch.from_numpy(np.expand_dims(depth_image,axis=0))
    
    #depth_image = Image.fromarray(depth_image)
    
    
    results = coltrolnet_pipe(prompt=prompt, negative_prompt=negative_prompt, 
                                  image=depth_image,
                                  #image=source_image, 
                                  height = H,
                                  width = W,
                                  guidance_scale=guidance_scale, 
                                  num_inference_steps=num_inference_steps,
                                  generator=generator)
    
    title = break_up_string(prompt)
    
    fontsize=12
    if len(title) / 50 > 3:
        fontsize=10

    fig,ax = plt.subplots(1,2)
    ax[0].imshow(source_image)
    ax[1].imshow(results.images[0])
    fig.suptitle(title, fontsize=fontsize,  y=0.9)
    for a in ax:
        a.set_xticks([])
        a.set_yticks([])
        
    fig.tight_layout()
        
    if save_name:       
        results.images[0].save(save_name)
    if comparison_save_name:
        fig.savefig(comparison_save_name)
        
    plt.show()
    
    return results.images[0]

## Loading Midas

In [None]:
model_type = "MiDaS"
midas = torch.hub.load("intel-isl/MiDaS", model_type)

device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
midas.to(device)
midas.eval()

midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")

if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
    transform = midas_transforms.dpt_transform
elif model_type == "MiDaS":
    transform = midas_transforms.default_transform
else:
    transform = midas_transforms.small_transform

## Midas Inference

In [None]:
def infer_depth_map(image, save_name= ''):
    
    input_batch = transform(image).to(device)

    with torch.no_grad():
        prediction = midas(input_batch)
        prediction = torch.nn.functional.interpolate(
            prediction.unsqueeze(1),
            size=image.shape[:2],
            mode="bicubic",
             align_corners=False,
        ).squeeze()

        output = prediction.cpu().numpy() 
        
        with open(save_name+'.npy', 'wb') as f:
            np.save(f, output, allow_pickle=True, fix_imports=True)
        
        plt.imshow(output)
        plt.show()
        
        if save_name:
            im = Image.fromarray(output).convert('RGB')
            im.save(save_name+".png")
        
        print("infer depth map done")
        
        return output

## Depth Heatmap Visualisation

In [None]:
def get_depth_heat_map(ground_depth_map, predict_depth_map, img_id=None, save_name=''):
    
    depth_diff = ground_depth_map - predict_depth_map
    
    _min, _max = np.amin(ground_depth_map), np.amax(ground_depth_map)
    
    fig,ax = plt.subplots(1,3, figsize=(16, 6), layout='constrained')
    ax[0].imshow(ground_depth_map, vmin = _min, vmax = _max)
    ax[1].imshow(predict_depth_map, vmin = _min, vmax = _max)
    
    heat_min = -4 
    heat_max = 4
    
    diff = ax[2].imshow(depth_diff, cmap='RdBu_r', vmin = heat_min, vmax = heat_max)

    cbar = fig.colorbar(diff, ax=ax[2], shrink=0.6)
    cbar.set_label('Ground truth - Predicted', rotation=90, labelpad=5)
    cbar.ax.set_yticklabels(["{:.2}".format(i) + " m" for i in cbar.get_ticks()]) # set ticks of your format
    
    ax[0].set_title('Ground Truth', fontsize=16)
    ax[1].set_title('Generated', fontsize=16)
    ax[2].set_title('Difference Heat Map', fontsize=16)
    
    for a in ax:
        a.set_xticks([])
        a.set_yticks([])
    
    if img_id is not None:
        fig.suptitle(f"Depth Maps for Image - {img_id}", fontsize=18,  y=0.95)
    else:
        fig.suptitle(f"Depth Maps", fontsize=18,  y=0.95)
    

    if save_name:
        fig.savefig(save_name)
        
    plt.show()
        

## Point Cloud and Mesh

In [None]:
def get_point_cloud(rgb_image, depth_image, pcd_path, display=False):
    
    new_depth_image = o3d.geometry.Image(depth_image)

    rgbd_image = o3d.geometry.RGBDImage.create_from_color_and_depth(
        o3d.geometry.Image(rgb_image), o3d.geometry.Image(new_depth_image),
    depth_scale=10, convert_rgb_to_intensity=False) 
    
    pcd = o3d.geometry.PointCloud.create_from_rgbd_image(
        rgbd_image,
        o3d.camera.PinholeCameraIntrinsic(
            o3d.camera.PinholeCameraIntrinsicParameters.PrimeSenseDefault))
    pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.0104,max_nn=12))
    
    # Flip it, otherwise the pointcloud will be upside down
    pcd.transform([[1, 0, 0, 0], [0, -1, 0, 0], [0, 0, -1, 0], [0, 0, 0, 1]])
    
    #writing point cloud to file
    o3d.io.write_point_cloud(pcd_path, pcd, write_ascii=False, compressed=False, print_progress=False)
    
    if display:
        o3d.visualization.draw_geometries([pcd])
        
    return pcd 

def capture_pcd_with_view_params(pcd, pcd_path, view_setting_path):
    
    vis = o3d.visualization.Visualizer()
    
    vis.create_window(visible=False)
    vis.add_geometry(pcd)
    
    with open(view_setting_path, "r") as f: 
        js = json.load(f)
    
    vc=vis.get_view_control()
    vc.change_field_of_view(js['trajectory'][0]['field_of_view'])
    vc.set_front(js['trajectory'][0]['front'])
    vc.set_lookat(js['trajectory'][0]['lookat'])
    vc.set_up(js['trajectory'][0]['up'])
    vc.set_zoom(js['trajectory'][0]['zoom'])
    
    vis.poll_events()
    vis.update_renderer()
    vis.capture_screen_image(pcd_path)
    vis.destroy_window()

def get_mesh_from_pcd(pcd, method="ball_rolling"):

    with o3d.utility.VerbosityContextManager(o3d.utility.VerbosityLevel.Debug) as cm:
        
        if method == "poisson":    
            mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd,
                depth=10, width=0, scale=1.1, linear_fit=True)
            print(mesh)
            o3d.visualization.draw_geometries([mesh])
            
            return mesh, densities
        
        if method == "ball_pivoting":
            radii = [0.005, 0.01, 0.02, 0.04]
            mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
                pcd, o3d.utility.DoubleVector(radii))
            o3d.visualization.draw_geometries([pcd, rec_mesh])
            
            return mesh
        
        if method == "alpha":
            tetra_mesh, pt_map = o3d.geometry.TetraMesh.create_from_point_cloud(pcd)
            
            for alpha in np.logspace(np.log10(0.5), np.log10(0.1), num=4):
                print(f"alpha={alpha:.3f}")
                mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_alpha_shape(
                    pcd, alpha, tetra_mesh, pt_map)
                mesh.compute_vertex_normals()
                o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True)
                
        if method == "ball_rolling":
            distances = pcd.compute_nearest_neighbor_distance()
            avg_dist = np.mean(distances)
            radius = 3 * avg_dist
            o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(
                pcd, o3d.utility.DoubleVector([radius, radius * 2]))
            
            dec_mesh = mesh.simplify_quadric_decimation(100000)
            dec_mesh.remove_degenerate_triangles()
            dec_mesh.remove_duplicated_triangles()
            dec_mesh.remove_duplicated_vertices()
            dec_mesh.remove_non_manifold_edges()
            
            o3d.visualization.draw_geometries([mesh], mesh_show_back_face=True)

            
            
    

In [None]:
imgs, H, W = prepare_nyu_controlnet(ground_depth_map)

# Test Code

In [None]:
nyu_path = 'C:/Users/User/Documents/Data_Science/ar_stable_diffusion/data/nyu_depth_v2_labeled.mat'
nyu_result_root = './results/NYU/'
eval_results_path = nyu_result_root + "eval_logs.csv"

save_folder = nyu_result_root

guidance_scale = 7.5
strength=0.5
num_inference_steps=30

prompt= "baroque style palace room with ornate marble decorations and statues, landscape paintings on the walls, warm ambient lighting and painted ceiling"

prompt=interior_design_prompt_1

# read mat file
f = h5py.File(nyu_path)

rgb_images = f['images']
depth_maps = f['depths']

In [None]:


#for i in range(0,rgb_images.shape[0]):

for i in range(16,rgb_images.shape[0]):
    img_id = i
    
    src_img_np, ground_depth_map = prepare_nyu_data(rgb_images[i], depth_maps[i])
    ground_depth, H, W = prepare_nyu_controlnet(ground_depth_map)

    gen_img_save_name = save_folder + f"2d_images/{img_id}_generated.png"
    comparison_save_name = save_folder + f"2d_images/{img_id}_comparison.png"
        
        
    gen_img = infer_controlnet(source_image=src_img_np, prompt=prompt, depth_image=ground_depth,
                         H=H, W=W, 
                         guidance_scale = guidance_scale, 
                         num_inference_steps=num_inference_steps,
                         save_name= gen_img_save_name, 
                         comparison_save_name= comparison_save_name)
    
    print( type(gen_img))

    gen_img_np = np.array(gen_img)
    
    print(type(gen_img_np))

    #ground_depth_map = infer_depth_map(source_img_np)
    predict_depth_path = nyu_result_root + f"depth_maps/{i}_gen_depth"
    predict_depth_map = infer_depth_map(gen_img_np, save_name=predict_depth_path)
    
    print(type(predict_depth_map))
    
    groundmap_for_heatmap = resize_image(ground_depth_map, image_resolution)
    
    predict_depth_map_aligned = align_midas(predict_depth_map, groundmap_for_heatmap)
    
    print(type(predict_depth_map_aligned))
    
    heatmap_path = nyu_result_root + f"depth_map_heat_maps/{i}_depth_heatmap.png"
    heatmap = get_depth_heat_map(groundmap_for_heatmap, predict_depth_map_aligned,
                                img_id=i, save_name=heatmap_path)
    
    ground_pcd_path = nyu_result_root + f"point_clouds/{i}_ground_pcd"
    gen_pcd_path = nyu_result_root + f"point_clouds/{i}_gen_pcd"
    
    original_pcd = get_point_cloud(src_img_np, ground_depth_map,  pcd_path=ground_pcd_path+".pcd")
    generated_pcd = get_point_cloud(gen_img_np, predict_depth_map_aligned, pcd_path=gen_pcd_path+".pcd")
    
    view_setting_path = nyu_result_root +"view_setting.json"
    
    capture_pcd_with_view_params(pcd=original_pcd, pcd_path=ground_pcd_path+".png", view_setting_path=view_setting_path)
    capture_pcd_with_view_params(pcd=generated_pcd, pcd_path=gen_pcd_path+".png", view_setting_path=view_setting_path)
    
    
    #eval_list = get_eval_results(src_img_np, gen_img_np, ground_depth_map, predict_depth_map_aligned)
    
    
    #with open(eval_results_path, 'a') as f_object:
 
     #   writer_object = writer(f_object)
 
        # Pass the list as an argument into
        # the writerow()
      #  writer_object.writerow(eval_list)
 
        # Close the file object
       # f_object.close()
    
    
    

# Rebuilding Point Clouds for Visualisation

In [None]:
def rebuild_point_clouds(i, display=True):


    predict_depth_path = nyu_result_root + f"depth_maps/{i}_gen_depth.npy"
    predict_depth_map = np.load(predict_depth_path)
    src_img_np, ground_depth_map = prepare_nyu_data(rgb_images[i], depth_maps[i])
    ground_depth, H, W = prepare_nyu_controlnet(ground_depth_map)
    gen_img = Image.open(save_folder + f"2d_images/{i}_generated.png") 
    gen_img.show()
    gen_img_np = np.array(gen_img)
    groundmap_for_heatmap = resize_image(ground_depth_map, image_resolution)
    predict_depth_map_aligned = align_midas(predict_depth_map, groundmap_for_heatmap)
    
    ground_pcd_path = nyu_result_root + f"point_clouds/{i}_ground_pcd"
    gen_pcd_path = nyu_result_root + f"point_clouds/{i}_gen_pcd"

    original_pcd = get_point_cloud(src_img_np, ground_depth_map,  pcd_path=ground_pcd_path+".pcd", display=display)
    generated_pcd = get_point_cloud(gen_img_np, predict_depth_map_aligned, pcd_path=gen_pcd_path+".pcd", display=display)
    
    #pcd = o3d.io.read_point_cloud( nyu_result_root + f"point_clouds/{i}_ground_pcd.pcd")
    #o3d.visualization.draw_geometries([pcd])
    
    return original_pcd, generated_pcd
    

In [None]:
def rebuild_point_clouds_ground_depth(i, display=True):


    predict_depth_path = nyu_result_root + f"depth_maps/{i}_gen_depth.npy"
    predict_depth_map = np.load(predict_depth_path)
    src_img_np, ground_depth_map = prepare_nyu_data(rgb_images[i], depth_maps[i])
    ground_depth, H, W = prepare_nyu_controlnet(ground_depth_map)
    gen_img = Image.open(save_folder + f"2d_images/{i}_generated.png") 
    gen_img.show()
    gen_img_np = np.array(gen_img)
    groundmap_for_heatmap = resize_image(ground_depth_map, image_resolution)
    predict_depth_map_aligned = align_midas(predict_depth_map, groundmap_for_heatmap)
    
    ground_pcd_path = nyu_result_root + f"point_clouds/{i}_ground_pcd"
    gen_pcd_path = nyu_result_root + f"point_clouds/{i}_gen_pcd"

    original_pcd = get_point_cloud(src_img_np, ground_depth_map,  pcd_path=ground_pcd_path+".pcd", display=display)
    generated_pcd = get_point_cloud(gen_img_np, groundmap_for_heatmap, pcd_path=gen_pcd_path+"_ground_depth.pcd", display=display)
    
    #pcd = o3d.io.read_point_cloud( nyu_result_root + f"point_clouds/{i}_ground_pcd.pcd")
    #o3d.visualization.draw_geometries([pcd])
    
    return original_pcd, generated_pcd

In [None]:
import matplotlib as mpl
import matplotlib.cm as cm

def rebuild_point_clouds_heatmap(i, display=True):


    predict_depth_path = nyu_result_root + f"depth_maps/{i}_gen_depth.npy"
    predict_depth_map = np.load(predict_depth_path)
    src_img_np, ground_depth_map = prepare_nyu_data(rgb_images[i], depth_maps[i])
    ground_depth, H, W = prepare_nyu_controlnet(ground_depth_map)
    gen_img = Image.open(save_folder + f"2d_images/{i}_generated.png") 
    gen_img.show()
    gen_img_np = np.array(gen_img)
    groundmap_for_heatmap = resize_image(ground_depth_map, image_resolution)

    predict_depth_map_aligned = align_midas(predict_depth_map, groundmap_for_heatmap)
    
    
    difference_map = groundmap_for_heatmap - predict_depth_map_aligned
    
    norm = mpl.colors.Normalize(vmin=-4, vmax=4)
    cmap = plt.get_cmap('RdBu_r')

    m = cm.ScalarMappable(norm=norm, cmap=cmap)
    heatmap = (m.to_rgba(difference_map)[:,:,:3] * 256).astype(np.uint8)
    
    gen_pcd_path = nyu_result_root +  f"point_clouds/{i}_heatmap_pcd.pcd"

    pcd = get_point_cloud(heatmap, predict_depth_map_aligned, pcd_path=gen_pcd_path, display=display)
    
    #pcd = o3d.io.read_point_cloud( nyu_result_root + f"point_clouds/{i}_heatmap_pcd.pcd")
    #o3d.visualization.draw_geometries([pcd])
    return pcd

In [None]:
rebuild_point_clouds(28)

In [None]:
rebuild_point_clouds_ground_depth(28)

In [None]:
rebuild_point_clouds_heatmap(28)

# Creating Meshes

In [None]:
original_pcd, generated_pcd = rebuild_point_clouds(28)

In [None]:
get_mesh_from_pcd(original_pcd, method="ball_rolling")