In [None]:

import torch
#get torch version
print(torch.__version__)
#check if CUDA is available
print(torch.cuda.is_available())
#get cuda version
print(torch.version.cuda)

#define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

### Load dataset
Dataset of the part 3 used for testing the mesh generations

In [7]:
import pickle
def load_dataset(path):
    dataset = []
    with open(path, 'rb') as f:
        train_data = pickle.load(f)
        print("Loaded train_data")
        # print train_data
        for index,info in enumerate(train_data):
            
            temp_info = {}
            temp_info["shape_id"] = info["shape_id"]
            temp_info["semantic class"] = info["semantic class"]
            temp_info["affordance"] = info["affordance"]
            temp_info["data_info"] = info["full_shape"]
            dataset.append(temp_info)
    return dataset

## Part 2 implementation - Mesh generation

In [None]:
import torch
import kaolin
import trimesh
import trimesh.convex  

def create_mesh(point_cloud,mesh_path,smooth=True):
    
    point_cloud=torch.tensor(point_cloud).cpu()
    min_coords, _ = point_cloud.min(dim=0)
    max_coords, _ = point_cloud.max(dim=0)
    original_scale = max_coords - min_coords
    original_translation = min_coords

    # Normalize the point cloud to [0, 1] range 
    normalized_point_cloud = (point_cloud - original_translation) / original_scale

    resolution = 20
    went_under=False

    # Searching for the best resolution that yields a greater number of vertices than the original point cloud while minimizing total vertices
    # Resolution = 20 chosen empirically as a good starting point

    while True:
        voxel_grid = kaolin.ops.conversions.pointclouds_to_voxelgrids(
            normalized_point_cloud.unsqueeze(0), resolution=resolution
        ).cuda()

        # Convert voxel grid to triangle mesh
        triangle_mesh = kaolin.ops.conversions.voxelgrids_to_trianglemeshes(
            voxel_grid, iso_value=0.95
        )
        if len(triangle_mesh[0][0]) <  len(point_cloud):
            went_under=True
            resolution+=1
            
            continue
        elif( went_under and len(triangle_mesh[0][0]) >=  len(point_cloud)):
            print("Choosen res",resolution)
            break
        
        resolution-=1


    # Extract vertices and faces from the triangle mesh
    verts, faces = triangle_mesh
    verts = verts[0].cpu()  
    faces = faces[0].cpu() 

   
    verts = verts / resolution  # Normalize vertices 

    #De-normalize vertices back to the original point cloud coordinates
    verts = verts * original_scale + original_translation




    if verts.numel() == 0 or faces.numel() == 0:
        raise ValueError("Vertices or faces are empty. Cannot create a mesh.")

  

    mesh = trimesh.Trimesh(vertices=verts.cpu().numpy(), faces=faces.cpu().numpy())
    
    # Fix alignment issue
    verts[:, 1] -= 0.04
    verts[:, 0] -= 0.01
    verts[:,2]-=0.017

    # Smooth the mesh
    if smooth:
        mesh = trimesh.smoothing.filter_laplacian(mesh, lamb=0.2, iterations=8, 
                                    implicit_time_integration=False, 
                                    volume_constraint=True, 
                                    laplacian_operator=None)


    # Export to OBJ file
    mesh.export(mesh_path)
    

#### Define the model

In [9]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision

from itertools import permutations, product
from Normalization.MeshNormalizer import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh



class NeuralHighlighter(nn.Module):
    def __init__(self, depth=2, width=256, out_dim=2,input_dim=3):
        super(NeuralHighlighter, self).__init__()
        self.depth = depth
        self.width = width
        self.out_dim = out_dim
      

        # Core model
        self.model = nn.Sequential(
            nn.Linear(input_dim, width),
            nn.ReLU(),
            nn.LayerNorm(width),
        )

        # Replicate the core model depth times
        for _ in range(depth - 1):
            self.model.append(nn.Linear(width, width))
            self.model.append(nn.ReLU())
            self.model.append(nn.LayerNorm(width))

        # Final layers
        self.model.append(nn.Linear(width, out_dim))
        self.model.append(nn.Softmax(dim=1))

    def forward(self, x):
       
        # Pass through the model
        return self.model(x)

def get_clip_model(clipmodel):
    model, preprocess = clip.load(clipmodel)
    return model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')


def clip_loss(embedding,images,clip_model,augmentations,augmentation_number):
    loss = 0.0
    encoded_text = clip_model.encode_text(embedding)
    for _ in range(augmentation_number):
        aug_img = augmentations(images)
        encoded_imgs = clip_model.encode_image(aug_img)
        loss -= torch.mean(torch.cosine_similarity(encoded_imgs, encoded_text))

    return loss/augmentation_number
    


def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


#### Core loop

In [None]:

import random
seed = 1
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


render_res = 224
# learning_rate = 0.0001
learning_rate = 0.0008
n_iter = 1800 
# obj_path = '/content/Affordance_Highlighting_Project_2024/data/horse.obj'
n_augs = 5
output_dir = './output/'
# clip_version = 'ViT-L/14'
# clip_version = 'RN50x4'
# clip_version = 'RN50x16'
# clip_version = 'RN50x16'
clip_version = 'ViT-B/32'
object_number = 1000
object_name="vase"
mesh_path=f"data/{object_name}.obj"
n_views = 5



Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(mesh_path))

render = Renderer(dim=(render_res, render_res))


# Load dataset
dataset = load_dataset("data_bench/full_shape_train_data.pkl")


create_mesh(dataset[object_number]["data_info"]["coordinate"],mesh_path)

mesh = Mesh(obj_path=mesh_path)
MeshNormalizer(mesh)()


# Initialize variables
bg = torch.tensor((1., 1., 1.)).to(device)
log_dir = output_dir



# MLP Settings
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]

colors = torch.tensor(full_colors).to(device)


# --- Prompt ---
# encode prompt with CLIP
clip_model,preprocess = get_clip_model(clip_version)
# print(model)
prompt = 'A 3D render of a gray vase with highlighted hat'
tokenized_text = clip.tokenize([prompt]).to(device) 

vertices = copy.deepcopy(mesh.vertices)


losses = []

#normalizer for image of clip
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],std=[0.26862954, 0.26130258, 0.27577711]) #from https://github.com/openai/CLIP/issues/20


#Add augmentation
augmentations = transforms.Compose([
    transforms.RandomResizedCrop(render_res, scale=(0.5, 1.0)),
    transforms.RandomPerspective(p=0.5,distortion_scale=0.5,fill=1),
    normalize
])
 
# Optimization loop
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                            lighting=True,
                                                            background=bg)
                                                            
    # Calculate CLIP Loss
    loss = clip_loss(tokenized_text,rendered_images,clip_model,augmentations,augmentation_number=n_augs)
    loss.backward(retain_graph=True)

    optim.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")



# save results
save_final_results(log_dir, "3d-render",mesh, mlp, vertices, colors, render, bg)

# Save prompts
with open(os.path.join(output_dir, prompt), "w") as f:
    f.write('')