In [1]:
from google.colab import drive
drive.mount('/content/drive') #replace with drive.mount('/content/drive/', force_remount=True) if the drive has changed since last mount in order to force the remount
%cd /content/drive/MyDrive/Neural-Highlighting-of-Affordance-Regions/

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Neural-Highlighting-of-Affordance-Regions


In [2]:
!apt-get update
!apt-get install -y xvfb ffmpeg libsm6 libxext6
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin==0.17.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html
!pip install open3d pyvirtualdisplay

Get:1 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease [1,581 B]
Get:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  Packages [1,199 kB]
Get:3 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Get:5 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Get:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Hit:7 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Get:8 https://r2u.stat.illinois.edu/ubuntu jammy InRelease [6,555 B]
Hit:9 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Get:10 http://security.ubuntu.com/ubuntu jammy-security/main amd64 Packages [2,560 kB]
Hit:11 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Get:12 https://r2u.stat.illinois.edu/ubuntu jammy/main amd64 Packages [2,639 kB]
Get:13 http://archive.ubuntu.com/ubu

**PointCLIP**: config and dependencies


In [3]:
#stuff for importing PointCLIP
%cd PointCLIP
!pip install -r requirements.txt
%cd Dassl3D/
!python setup.py develop
%cd ..
%cd ..

/content/drive/MyDrive/Neural-Highlighting-of-Affordance-Regions/PointCLIP
Collecting flake8==3.7.9 (from -r requirements.txt (line 1))
  Downloading flake8-3.7.9-py2.py3-none-any.whl.metadata (3.6 kB)
Collecting yapf==0.29.0 (from -r requirements.txt (line 2))
  Downloading yapf-0.29.0-py2.py3-none-any.whl.metadata (30 kB)
Collecting isort==4.3.21 (from -r requirements.txt (line 3))
  Downloading isort-4.3.21-py2.py3-none-any.whl.metadata (19 kB)
Collecting yacs (from -r requirements.txt (line 4))
  Downloading yacs-0.1.8-py3-none-any.whl.metadata (639 bytes)
Collecting tb-nightly (from -r requirements.txt (line 6))
  Downloading tb_nightly-2.19.0a20250113-py3-none-any.whl.metadata (1.8 kB)
Collecting entrypoints<0.4.0,>=0.3.0 (from flake8==3.7.9->-r requirements.txt (line 1))
  Downloading entrypoints-0.3-py2.py3-none-any.whl.metadata (1.4 kB)
Collecting pyflakes<2.2.0,>=2.1.0 (from flake8==3.7.9->-r requirements.txt (line 1))
  Downloading pyflakes-2.1.1-py2.py3-none-any.whl.metadat

Here we import from PointCLIP the customized PointCLIP_ZS class, which will be used to perform model inference.

In [2]:
import sys
sys.path.append('/content/drive/MyDrive/Neural-Highlighting-of-Affordance-Regions/PointCLIP/trainers')
from zeroshot import PointCLIP_ZS

#Available clip models
import clip
clip.available_models()


['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [3]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as F

from itertools import permutations, product
from Normalization import MeshNormalizer
from render import Renderer
from mesh import Mesh
from pathlib import Path
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh
import open3d as o3d
from pyvirtualdisplay import Display

width = 256
depth = 8       #default is 4
out_dim = 2
input_dim = 3
n_augs = 1      #default is 1

class NeuralHighlighter(nn.Module):
    def __init__(self):
        super(NeuralHighlighter, self).__init__()
        input_size = 3 #Dimension of the vertex
        output_size = 2 #Dimension of the output
                        #for the standard highlighter task there are only 2 classes: target region and not target region.
                        #we use the element of the output vector corresponding to the probability of belonging to the target
                        #region as the highlight probability described in the main paper.
        layers = []

        #See Appendix B (page 13)
        #first linear layer followed by ReLU and LayerNorm
        layers.append(nn.Linear(input_dim, width))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm([width]))
        #other [depth] linear layers followed by ReLU and LayerNorm
        # -> changing the depth hyperparameter results in a deeper/shallower net
        # -> total depth (in terms of modules[Linear+ReLU+LayerNorm]) = [depth] + 1
        for i in range(depth):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        #last linear layer followed by softmax in order to output probability-like values
        layers.append(nn.Linear(width, out_dim))
        layers.append(nn.Softmax(dim=1))

        self.mlp = nn.ModuleList(layers)
        self.model = self.mlp
        print(self.mlp)

    def forward(self, x):
        for layer in self.model:
            x = layer(x)
        return x

def get_clip_model(clipmodel):
    model, preprocess = clip.load(clipmodel, device=device)
    return model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')

#TODO: fix the generation of the point cloud subsequently
#      now the point cloud generation is possible only by executing the PC_rendering.ipynb

def save_point_cloud_results(vertices, log_dir, name, mlp):
    '''
        #now i load the highlighted mesh and transpose it back to the point cloud
        display = Display(visible=0, size=(1400, 900))
        display.start()
        mesh_o3d = o3d.io.read_triangle_mesh(os.path.join(log_dir, f"{name}.ply"))

        if not mesh_o3d.has_vertex_normals():
          mesh_o3d.compute_vertex_normals()

        point_cloud = mesh_o3d.sample_points_poisson_disk(number_of_points=5000)

        output_ply_file = os.path.join(log_dir, f"{name}_point_cloud.ply")
        o3d.io.write_point_cloud(output_ply_file, point_cloud)
    '''
    mlp.eval()
    with torch.no_grad():
      probs = mlp(vertices)

      highlight = torch.tensor([204, 255, 0]).to(device)
      gray = torch.tensor([180, 180, 180]).to(device)
      colors = torch.stack((highlight/255, gray/255)).to(device)

      point_colors = assign_colors(probs, colors, device)
      point_cloud = o3d.geometry.PointCloud()
      point_cloud.points = o3d.utility.Vector3dVector(vertices)
      point_cloud.colors = o3d.utility.Vector3dVector(colors)

      o3d.io.write_point_cloud(f"{name}_colored_point_cloud.ply", point_cloud)

      '''
        width_final_render, height_final_render = 1400, 900
        render_final_pc = o3d.visualization.rendering.OffscreenRenderer(width_final_render, height_final_render)
        material = o3d.visualization.rendering.MaterialRecord()
        material.shader = "defaultUnlit"
        render_final_pc.scene.add_geometry("point_cloud", point_cloud, material)

        zoom_out_factor = 0.5
        bounding_box = point_cloud.get_axis_aligned_bounding_box()
        center = bounding_box.get_center()
        extent = bounding_box.get_extent()
        render_final_pc.scene.camera.look_at(center, center + [0, 0, 1], [0, 1, 0])
        render_final_pc.scene.camera.set_projection(60 / zoom_out_factor, width_final_render / height_final_render, 0.1, 100.0,
                                      o3d.visualization.rendering.Camera.FovType.Horizontal)

        pc_img = render_final_pc.render_to_image()
        output_file = os.path.join(log_dir, f"{name}_final_render.jpg")
        o3d.io.write_image(output_file, pc_img)

        display.stop()
      '''

def clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model):
    if n_augs == 0:
        clip_image = clip_transform(rendered_images)
        encoded_renders = clip_model.encode_image(clip_image)
        encoded_renders = encoded_renders / encoded_renders.norm(dim=1, keepdim=True)
        if encoded_text.shape[0] > 1:
            loss = torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                                torch.mean(encoded_text, dim=0), dim=0)
        else:
            loss = torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                                encoded_text)

    elif n_augs > 0:
        loss = 0.0
        for _ in range(n_augs):
            augmented_image = augment_transform(rendered_images)
            encoded_renders = clip_model.encode_image(augmented_image)
            if encoded_text.shape[0] > 1:
                loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                                    torch.mean(encoded_text, dim=0), dim=0)
            else:
                loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                                    encoded_text)
    return loss

def pointClip_loss(encoded_text, point_cloud, point_cloud_colors, pointClipZS):

    #If you want to use pointclip you should change the pipeline for computing the loss,
    #instead of differentiable rendering + clip you pass the colored pointcloud straight to pointclip
    #and compute the similarity with the prompt from the output of pointclip.
    #Basically you are changing the pipeline for computing the loss but the input output of your model are the same

    #HP n.1
    # >>> so, does it mean that we have to compute the image encoding with PointCLIP, leaving the rest unchanged?

    #TODO

    #HP n.2
    # >>> we pass the colored point cloud to PointCLIP and we expect as output the value of the loss
    # (dot product between text encoding and renders)
    # problem: the output is not a single value...!? it depends on the textual prompt...
    # note that the textual prompt is now *hardcoded* into PointCLIP/trainers/zeroshot.py
    # also the model channel


    logits = pointClipZS.model_inference(point_cloud, point_cloud_colors)
    #logits_img = pointClipZS.model_inference_img(point_cloud, point_cloud_colors)

    #print(logits.shape) -> depends on the text prompt, we obtain a list of values.
    #print(logits)       -> scores (%) ? here we compute the mean.
    loss = torch.sigmoid(torch.tensor(logits, device=device, requires_grad = True))
    #loss = -logits

    return loss

def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


Warp 1.5.1 initialized:
   CUDA Toolkit 12.6, Driver 12.2
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "Tesla T4" (15 GiB, sm_75, mempool enabled)
   Kernel cache:
     /root/.cache/warp/1.5.1


In [4]:
from torch.optim.lr_scheduler import MultiStepLR
from utils import assign_colors

# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

render_res = 224
learning_rate = 0.01
n_iter = 2500
res = 224
obj_path = 'data/dog.obj'
#output_dir = './output/'
clip_model_name = 'ViT-B/16'

device = "cuda" if torch.cuda.is_available() else "cpu"

#Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(obj_path))

render = Renderer(dim=(render_res, render_res))

#---------- WARNING---------
#---------- now i read the mesh from the obj of the reconstructed mesh
#mesh = Mesh(obj_path)
#MeshNormalizer(mesh)()
#---------- WARNING---------

width_render, height_render = 1400, 900
zoom_out_factor = 0.4

#------------ MESH TO POINT CLOUD INIT---------------
from pyvirtualdisplay import Display
display = Display(visible=0, size=(width_render, height_render))
display.start()

#------- here we retrieve first the point cloud from the mesh (only for test purpose)
mesh_o3d = o3d.io.read_triangle_mesh(obj_path)
mesh_o3d.compute_vertex_normals()
pcd = mesh_o3d.sample_points_poisson_disk(10000)
pcd_points = np.asarray(pcd.points)
print("pcd_points shape")
print(pcd_points.shape)
#-------

#-------here we retrieve the point cloud from the ply file (extracted from AffordanceNet pkls)
#pcd = o3d.io.read_point_cloud('./pointClouds/validation/bottle/Bottle_point_cloud_0.ply')

render_pc = o3d.visualization.rendering.OffscreenRenderer(width_render, height_render)
material = o3d.visualization.rendering.MaterialRecord()
material.shader = "defaultUnlit"
render_pc.scene.add_geometry("point_cloud", pcd, material)

bounding_box = pcd.get_axis_aligned_bounding_box()
center = bounding_box.get_center()
extent = bounding_box.get_extent()
render_pc.scene.camera.look_at(center, center + [0, 0, 1], [0, 1, 0])
render_pc.scene.camera.set_projection(60 / zoom_out_factor, width_render / height_render, 0.1, 100.0,
                                      o3d.visualization.rendering.Camera.FovType.Horizontal)

img = render_pc.render_to_image()
output_file = "point_cloud_render.jpg"
o3d.io.write_image(output_file, img)

#------------ MESH TO POINT CLOUD END---------------

#then we approximate a mesh so we can still use the previously defined helper functions
#also the loss minimization should converge better

#------------ POINT CLOUD TO MESH INIT---------------

#--- TO USE WITH AFFORDANCENET PCs
#pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.5, max_nn=25))
#radii = [0.05, 0.1, 0.2]

#rec_mesh, densities = o3d.geometry.TriangleMesh.create_from_point_cloud_poisson(pcd, depth=9, scale=1.1, linear_fit=False)
#rec_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, o3d.utility.DoubleVector(radii))

#---

radii = [0.005, 0.01, 0.02, 0.04]
rec_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, o3d.utility.DoubleVector(radii))

# Optionally, check if the reconstructed mesh is valid
if rec_mesh.is_empty():
    print("Mesh reconstruction failed.")
else:
    print("Mesh reconstruction successful!")

# Render the reconstructed mesh
render_mesh = o3d.visualization.rendering.OffscreenRenderer(width_render, height_render)
material_mesh = o3d.visualization.rendering.MaterialRecord()
material_mesh.shader = "defaultUnlit"
render_mesh.scene.add_geometry("reconstructed_mesh", rec_mesh, material_mesh)

# Set zoom-out factor and camera for reconstructed mesh
bounding_box = pcd.get_axis_aligned_bounding_box()
center = bounding_box.get_center()
render_mesh.scene.camera.look_at(center, center + [0, 0, 1], [0, 1, 0])
render_mesh.scene.camera.set_projection(60 / zoom_out_factor, width_render / height_render, 0.1, 100.0,
                                        o3d.visualization.rendering.Camera.FovType.Horizontal)

# Render and save the reconstructed mesh image
img_mesh = render_mesh.render_to_image()
output_file_mesh = "reconstructed_mesh_render.jpg"
o3d.io.write_image(output_file_mesh, img_mesh)

#Export the reconstructed mesh to an obj file (allows to reuse helper functions)
output_mesh_file = "reconstructed_mesh.obj"
o3d.io.write_triangle_mesh(output_mesh_file, rec_mesh)

#Import the mesh from the exported obj file
mesh = Mesh(output_mesh_file)
MeshNormalizer(mesh)()

display.stop()
#------------ POINT CLOUD TO MESH END---------------

# Initialize variables
background = torch.tensor((1., 1., 1.)).to(device)

#log_dir = output_dir

# CLIP and Augmentation Transforms
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

clip_transform = transforms.Compose([
        transforms.Resize((res, res)),
        clip_normalizer
])

augment_transform = transforms.Compose([
        transforms.RandomResizedCrop(res, scale=(1, 1)),
        transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
        clip_normalizer
])

# MLP Settings
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

#introducing learning rate decay
#with the prompt horse/saddle the loss plateaus
#scheduler = StepLR(optim, step_size=300, gamma=0.1)

#scheduler = MultiStepLR(optim, milestones=[300, 1800], gamma=0.1)  # Decay at epoch 300 and 1800


# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)

name = 'dogPointCLIP_d_{}_augs_{}'.format(depth, n_augs)
depth_maps_views = 6

# --- Prompt ---
# encode prompt with CLIP
clip_model, preprocess = get_clip_model(clip_model_name)
prompts = ['A 3D picture of a gray and white dog with highlighted hat'] #WARNING: with PointCLIP the prompt is hardcoded


for i, prompt in enumerate(prompts):

  pointClipZS = PointCLIP_ZS()
  output_dir = './output_{}_{}/'.format(name, i)
  Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)
  log_dir = output_dir

  #here we compute the text encoding only once
  #if we put it inside the loss, we repeat n_iter times the same computation
  with torch.no_grad():
    text_input = clip.tokenize([prompt]).to(device)
    print(text_input.shape)
    encoded_text = clip_model.encode_text(text_input)
    encoded_text = encoded_text / encoded_text.norm(dim=1, keepdim=True)

  vertices = copy.deepcopy(mesh.vertices)
  #vertices = torch.tensor(np.asarray(rec_mesh.vertices), dtype=torch.float32, device=device) # Convert vertices to a PyTorch tensor
  point_cloud_points = torch.tensor(pcd_points, dtype=torch.float32, device=device) # Convert point cloud points to a PyTorch tensor
  n_views = 5

  losses = []

  # Optimization loop
  for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)
    #print("pred_class shape")
    #print(pred_class.shape)
    print(pred_class)

    #point cloud coloring
    #point_colors = assign_colors(pred_class, colors, device)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)       # EDIT: utils.py/color_mesh edited in order to return colors used later for mesh export


    max_idx = torch.argmax(pred_class, 1, keepdim=True)
    highlight = torch.tensor([204, 255, 0]).to(device)
    gray = torch.tensor([180, 180, 180]).to(device)
    final_color = torch.zeros(vertices.shape[0], 3).to(device)
    final_color = torch.where(max_idx==0, highlight, gray)
    mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)

    mesh_o3d = o3d.io.read_triangle_mesh(os.path.join(log_dir, f"{name}.ply"))
    mesh_o3d.compute_vertex_normals()
    pcd_train = mesh_o3d.sample_points_poisson_disk(5000)

    points_train = torch.tensor(np.asarray(pcd_train.points), dtype=torch.float32).unsqueeze(0).cuda()  # Shape: [1, N, 3]
    colors_train = torch.tensor(np.asarray(pcd_train.colors), dtype=torch.float32).unsqueeze(0).cuda()  # Shape: [1, N, 3]

    '''
    mesh_dir = os.path.join(log_dir, 'meshes')
    os.makedirs(mesh_dir, exist_ok=True)
    mesh_path = os.path.join(mesh_dir, f"mesh_{i}.ply")

    #here I take the colored mesh as input and I export it to a ply file
    mesh.export(mesh_path, extension="ply", color=colors)

    #then I obtain the point cloud from the exported mesh (check if we can lower the number of sampled points)
    mesh_o3d = o3d.io.read_triangle_mesh(mesh_path)
    mesh_o3d.compute_vertex_normals()
    pcd = mesh_o3d.sample_points_poisson_disk(10000)

    #----- now we generate the 2D depth maps which will be the inputs to the new loss function
    #first we extract points and colors
    points = torch.tensor(np.asarray(pcd.points), dtype=torch.float32).unsqueeze(0).cuda()  # Shape: [1, N, 3]
    colors = torch.tensor(np.asarray(pcd.colors), dtype=torch.float32).unsqueeze(0).cuda()  # Shape: [1, N, 3]

    #depth_maps = pc_views.get_img(points)
    color_tensors = []

    for view_idx in range(depth_maps_views):
      depth_map, color_tensor = points2depth_with_color(points, colors, render_res, render_res)
      color_tensors.append(color_tensor)

    #depth_map_tensor = torch.stack(depth_maps)  # [num_views, B, RESOLUTION, RESOLUTION]  -> gray scale maps
    color_map_tensor = torch.stack(color_tensors)  # [num_views, B, RESOLUTION, RESOLUTION, 3] -> rgb maps

    color_tensor_final = color_map_tensor.squeeze(1) #collapse batch dimension (=1) then make of size [n_views, C, H, W]

    #color_tensor_final = color_tensor_final.permute(0, 2, 3, 1) #[n_views, H, W, C]
    #print("Color tensor final permuted shape")
    #print(color_tensor_final.shape)

    '''

    '''
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                             lighting=True,
                                                            background=background)


    # Calculate CLIP Loss
    loss = clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model)
    '''

    #Calculate PointCLIP Loss
    loss = pointClip_loss(encoded_text, points_train, colors_train, pointClipZS)

    loss.backward(retain_graph=True)

    optim.step()

    #LR decay
    #scheduler.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        #save_renders(log_dir, i, rendered_images)
        #save_renders(log_dir, i, color_tensor_final)

        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")


  # save results
  #save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background)

  # save point cloud results
  save_point_cloud_results(vertices, log_dir, name, mlp)



  # Save prompts
  with open(os.path.join(output_dir, 'prompt.txt'), "w") as f:
    f.write(prompt)
    f.write("\n")
    f.write("initial learning rate:")
    f.write(str(learning_rate))
    f.write("\n")
    f.write("n_iter:")
    f.write(str(n_iter))
    f.write("\n")
    f.write("n_augs:")
    f.write(str(n_augs))
    f.write("\n")
    f.write("n_views:")
    f.write(str(n_views))
    f.write("\n")
    f.write("clip_model:")
    f.write(clip_model_name)
    f.write("\n")
    f.write("depth:")
    f.write(str(depth))


pcd_points shape
(10000, 3)
[Open3D INFO] EGL headless mode enabled.
Mesh reconstruction successful!
ModuleList(
  (0): Linear(in_features=3, out_features=256, bias=True)
  (1): ReLU()
  (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (3): Linear(in_features=256, out_features=256, bias=True)
  (4): ReLU()
  (5): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (6): Linear(in_features=256, out_features=256, bias=True)
  (7): ReLU()
  (8): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (9): Linear(in_features=256, out_features=256, bias=True)
  (10): ReLU()
  (11): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (12): Linear(in_features=256, out_features=256, bias=True)
  (13): ReLU()
  (14): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (15): Linear(in_features=256, out_features=256, bias=True)
  (16): ReLU()
  (17): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (18): Linear(in_features=256, out_features=256, bias=True)
  (19): Re

  0%|          | 0/2500 [00:00<?, ?it/s]

tensor([[0.3683, 0.6317],
        [0.1028, 0.8972],
        [0.3529, 0.6471],
        ...,
        [0.2518, 0.7482],
        [0.2408, 0.7592],
        [0.5064, 0.4936]], device='cuda:0', grad_fn=<SoftmaxBackward0>)


  img = torch.nn.functional.upsample(img, size=(224, 224), mode='bilinear', align_corners=True)
  0%|          | 1/2500 [00:03<2:09:31,  3.11s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 2/2500 [00:06<2:04:55,  3.00s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 3/2500 [00:08<2:01:55,  2.93s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 4/2500 [00:12<2:06:10,  3.03s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 5/2500 [00:14<2:03:17,  2.96s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 6/2500 [00:17<2:02:47,  2.95s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 7/2500 [00:20<2:01:24,  2.92s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 8/2500 [00:23<2:00:26,  2.90s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 9/2500 [00:26<1:59:37,  2.88s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 10/2500 [00:29<1:59:54,  2.89s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 11/2500 [00:32<1:59:32,  2.88s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  0%|          | 12/2500 [00:35<1:59:21,  2.88s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  1%|          | 13/2500 [00:37<1:58:58,  2.87s/it]

torch.Size([6, 128, 128, 3])
tensor([[[[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         ...,

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.]],

         [[0., 0., 0.],
          [0., 0., 0.],
          [0., 0., 0.],
          ...,
          [0., 0., 0.],
          [0., 0., 0.],
          [0., 0

  1%|          | 13/2500 [00:39<2:06:11,  3.04s/it]


KeyboardInterrupt: 

In [None]:
torch.cuda.empty_cache()