<a href="https://colab.research.google.com/github/LeograndeCode/Neural-Highlighting-of-Affordance-Regions/blob/Parte-3/Notebook3D_AffordanceNet_v5.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
from google.colab import drive
drive.mount('/content/drive') #replace with drive.mount('/content/drive/', force_remount=True) if the drive has changed since last mount in order to force the remount
%cd /content/drive/MyDrive/Neural-Highlighting-of-Affordance-Regions/


Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
/content/drive/MyDrive/Neural-Highlighting-of-Affordance-Regions


In [2]:
!apt-get update
!apt-get install -y xvfb ffmpeg libsm6 libxext6
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin==0.17.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html
!pip install open3d pyvirtualdisplay

0% [Working]            Get:1 https://cloud.r-project.org/bin/linux/ubuntu jammy-cran40/ InRelease [3,626 B]
Hit:2 https://developer.download.nvidia.com/compute/cuda/repos/ubuntu2204/x86_64  InRelease
Get:3 http://security.ubuntu.com/ubuntu jammy-security InRelease [129 kB]
Hit:4 http://archive.ubuntu.com/ubuntu jammy InRelease
Hit:5 https://ppa.launchpadcontent.net/deadsnakes/ppa/ubuntu jammy InRelease
Get:6 http://archive.ubuntu.com/ubuntu jammy-updates InRelease [128 kB]
Hit:7 https://ppa.launchpadcontent.net/graphics-drivers/ppa/ubuntu jammy InRelease
Hit:8 https://ppa.launchpadcontent.net/ubuntugis/ppa/ubuntu jammy InRelease
Hit:9 https://r2u.stat.illinois.edu/ubuntu jammy InRelease
Hit:10 http://archive.ubuntu.com/ubuntu jammy-backports InRelease
Fetched 261 kB in 2s (108 kB/s)
Reading package lists... Done
W: Skipping acquire of configured file 'main/source/Sources' as repository 'https://r2u.stat.illinois.edu/ubuntu jammy InRelease' does not seem to provide it (sources.list 

###AffordanceNet Class

AffordanceNet Class modified with capability of choosing only some objects, in our case household objects with hand-object affordances



In [3]:
import os
from os.path import join as opj
import numpy as np
from torch.utils.data import Dataset
import h5py
import json
import pickle as pkl

def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc, centroid, m


class AffordNetDataset(Dataset):
    def __init__(self, data_dir):
        super().__init__()
        self.data_dir = data_dir
        self.semantic_class = 'Knife'
        self.label_name = 'grasp'
        self.load_data()
        return

    def load_data(self):
        self.all_data = []

        # Open the dataset file
        with open('full_shape_train_data.pkl', 'rb') as f:
            data = pkl.load(f)


        # Check if the data is a list
        if isinstance(data, list):
            # Filter the point clouds based on the 'semantic class' attribute
            point_clouds = [
                pc for pc in data
                if pc.get('semantic class') == self.semantic_class
            ]
        print(f"Number of point clouds with 'semantic class' equal to {self.semantic_class}: {len(point_clouds)}")

        self.all_data = point_clouds

    def __getitem__(self, index):
        data_dict = self.all_data[index]

        coordinates = np.array(data_dict['full_shape']['coordinate'])
        label = np.array(data_dict['full_shape']['label'][self.label_name])

        data, _, _ = pc_normalize(coordinates)

        return data, label

    def __len__(self):
        return len(self.all_data)


### Model

In [4]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as F

from itertools import permutations, product
from Normalization import MeshNormalizer
from render import Renderer
from mesh import Mesh
from pathlib import Path
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh
import open3d as o3d
from pyvirtualdisplay import Display

width = 256
depth = 6      #default is 4
out_dim = 2
input_dim = 3
n_augs = 1      #default is 1

class NeuralHighlighter(nn.Module):
    def __init__(self):
        super(NeuralHighlighter, self).__init__()
        input_size = 3 #Dimension of the vertex
        output_size = 2 #Dimension of the output
                        #for the standard highlighter task there are only 2 classes: target region and not target region.
                        #we use the element of the output vector corresponding to the probability of belonging to the target
                        #region as the highlight probability described in the main paper.
        layers = []

        #See Appendix B (page 13)
        #first linear layer followed by ReLU and LayerNorm
        layers.append(nn.Linear(input_dim, width))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm([width]))
        #other [depth] linear layers followed by ReLU and LayerNorm
        # -> changing the depth hyperparameter results in a deeper/shallower net
        # -> total depth (in terms of modules[Linear+ReLU+LayerNorm]) = [depth] + 1
        for i in range(depth):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        #last linear layer followed by softmax in order to output probability-like values
        layers.append(nn.Linear(width, out_dim))
        layers.append(nn.Softmax(dim=1))

        self.mlp = nn.ModuleList(layers)
        self.model = self.mlp
        print(self.mlp)

    def forward(self, x):
        for layer in self.model:
            x = layer(x)
        return x

def get_clip_model(clipmodel):
    model, preprocess = clip.load(clipmodel, device=device)
    return model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')

#TODO: fix the generation of the point cloud subsequently
#      now the point cloud generation is possible only by executing the PC_rendering.ipynb

def save_point_cloud_results(log_dir, name):
        #now i load the highlighted mesh and transpose it back to the point cloud
        display = Display(visible=0, size=(1400, 900))
        display.start()
        mesh_o3d = o3d.io.read_triangle_mesh(os.path.join(log_dir, f"{name}.ply"))

        if not mesh_o3d.has_vertex_normals():
          mesh_o3d.compute_vertex_normals()

        point_cloud = mesh_o3d.sample_points_poisson_disk(number_of_points=2048)

        width_final_render, height_final_render = 1400, 900
        render_final_pc = o3d.visualization.rendering.OffscreenRenderer(width_final_render, height_final_render)
        material = o3d.visualization.rendering.MaterialRecord()
        material.shader = "defaultUnlit"
        render_final_pc.scene.add_geometry("point_cloud", point_cloud, material)

        zoom_out_factor = 0.5
        bounding_box = point_cloud.get_axis_aligned_bounding_box()
        center = bounding_box.get_center()
        extent = bounding_box.get_extent()
        render_final_pc.scene.camera.look_at(center, center + [0, 0, 1], [0, 1, 0])
        render_final_pc.scene.camera.set_projection(60 / zoom_out_factor, width_final_render / height_final_render, 0.1, 100.0,
                                      o3d.visualization.rendering.Camera.FovType.Horizontal)

        pc_img = render_final_pc.render_to_image()
        output_file = os.path.join(log_dir, f"{name}_final_render.jpg")
        o3d.io.write_image(output_file, pc_img)
        display.stop()


def clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model):
    if n_augs == 0:
        clip_image = clip_transform(rendered_images)
        encoded_renders = clip_model.encode_image(clip_image)
        encoded_renders = encoded_renders / encoded_renders.norm(dim=1, keepdim=True)
        if encoded_text.shape[0] > 1:
            loss = torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                                torch.mean(encoded_text, dim=0), dim=0)
        else:
            loss = torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                                encoded_text)

    elif n_augs > 0:
        loss = 1.0 #original 0.0
        for _ in range(n_augs):
            augmented_image = augment_transform(rendered_images)
            encoded_renders = clip_model.encode_image(augmented_image)
            if encoded_text.shape[0] > 1:
                loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                                    torch.mean(encoded_text, dim=0), dim=0)
            else:
                loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                                    encoded_text)
    return loss

def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))

Warp 1.5.1 initialized:
   CUDA Toolkit 12.6, Driver 12.2
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "Tesla T4" (15 GiB, sm_75, mempool enabled)
   Kernel cache:
     /root/.cache/warp/1.5.1


#Dataset

In [5]:
import open3d as o3d
import numpy as np
from sklearn.model_selection import train_test_split
from torch.utils.data import Subset, DataLoader
import torch  # If dataset uses PyTorch tensors

# Assuming AffordNetDataset is correctly initialized
data_dir = "."
dataset = AffordNetDataset(data_dir=data_dir)
print(len(dataset))
# Split dataset into test and validation indices
indexes = list(range(len(dataset)))
val_indexes, test_indexes = train_test_split(indexes, test_size=0.9, shuffle=True)

# Create Subsets
val_dataset = Subset(dataset, val_indexes)
test_dataset = Subset(dataset, test_indexes)

# DataLoaders
batch_size = 1  # Load one point cloud at a time for visualization
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers=4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers=4)

# Check dataset sizes
print(f"Valid Dataset Size: {len(val_dataset)}")
print(f"Test Dataset Size: {len(test_dataset)}")


# Access a sample
point_cloud, affordance_labels = dataset[0]

print("Point Cloud Shape:", point_cloud.shape)
print("Affordance Labels Shape:", affordance_labels.shape)



Number of point clouds with 'semantic class' equal to Knife: 225
225
Valid Dataset Size: 22
Test Dataset Size: 203
Point Cloud Shape: (2048, 3)
Affordance Labels Shape: (2048, 1)




# Training

In [7]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

render_res = 224
learning_rate = 0.0001
n_iter = 2500
res = 224
obj_path = 'data/candle.obj'
#output_dir = './output/'
clip_model_name = 'ViT-B/32'

device = "cuda" if torch.cuda.is_available() else "cpu"

#Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

#objbase, extension = os.path.splitext(os.path.basename(obj_path))

render = Renderer(dim=(render_res, render_res))

#---------- WARNING---------
#---------- now i read the mesh from the obj of the reconstructed mesh
#mesh = Mesh(obj_path)
#MeshNormalizer(mesh)()
#---------- WARNING---------

#------------ MESH TO POINT CLOUD INIT---------------
from pyvirtualdisplay import Display
display = Display(visible=0, size=(1400, 900))
display.start()

#here we retrieve first the point cloud from the mesh (only for test purpose)
mesh_o3d = o3d.io.read_triangle_mesh(obj_path)
mesh_o3d.compute_vertex_normals()
pcd = mesh_o3d.sample_points_poisson_disk(2048)
o3d.io.write_point_cloud("candle.pcd", pcd)
display.stop()
#------------ MESH TO POINT CLOUD END---------------

# Load the point cloud using Open3D
pcd = o3d.io.read_point_cloud("candle.pcd")

# Estimate normals for the point cloud
pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))
#------------ POINT CLOUD TO MESH INIT---------------
radii = [0.005, 0.01, 0.02, 0.04]
rec_mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, o3d.utility.DoubleVector(radii))

#rec_mesh.vertices = o3d.utility.Vector3dVector(np.asarray(rec_mesh.vertices))
#rec_mesh.triangles = o3d.utility.Vector3iVector(np.asarray(rec_mesh.triangles))

# Optionally, check if the reconstructed mesh is valid
if rec_mesh.is_empty():
    print("Mesh reconstruction failed.")
else:
    print("Mesh reconstruction successful!")



#Export the reconstructed mesh to an obj file (allows to reuse helper functions)
output_mesh_file = "mug.obj"
o3d.io.write_triangle_mesh(output_mesh_file, rec_mesh)

#Import the mesh from the exported obj file
mesh = Mesh(output_mesh_file)
MeshNormalizer(mesh)()

#------------ POINT CLOUD TO MESH END---------------

#then we approximate a mesh so we can still use the previously defined helper functions
#also the loss minimization should converge better


# Initialize variables
background = torch.tensor((1., 1., 1.)).to(device)

#log_dir = output_dir

# CLIP and Augmentation Transforms
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

clip_transform = transforms.Compose([
        transforms.Resize((res, res)),
        clip_normalizer
])

augment_transform = transforms.Compose([
        transforms.RandomResizedCrop(res, scale=(1, 1)),
        transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
        clip_normalizer
])

# MLP Settings
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

#introducing learning rate decay
#with the prompt horse/saddle the loss plateaus
#scheduler = StepLR(optim, step_size=300, gamma=0.1)

#scheduler = MultiStepLR(optim, milestones=[300, 1800], gamma=0.1)  # Decay a epoch 300 e 1800


# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)

name = 'candlePC_d_{}_augs_{}'.format(depth, n_augs)

# --- Prompt ---
# encode prompt with CLIP
clip_model, preprocess = get_clip_model(clip_model_name)
#prompts = ['A 3D render of a gray horse with highlighted hat',
#           'A 3D render of a gray horse with highlighted shoes',
#           'A 3D render of a gray horse with highlighted saddle']
prompts = ['Identify the regions of the candle where a human can interact by grasping, wrapping, or pushing. Consider areas where the candle can be held for lighting or moved, where it can be wrapped in decorative materials, and areas that are pushed to adjust its position or extinguish the flame.']
#prompts = ['Highlight regions on the grey candle for grasping (base/sides), wrapping (body/base), pushing (sides/top), pouring (top/sides), containing (holders/containers), cutting (grooves), and stabbing (top). Focus on areas of human interaction.']

for i, prompt in enumerate(prompts):

  output_dir = './output_{}_{}/'.format(name, i)
  Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)
  log_dir = output_dir

  #here we compute the text encoding only once
  #if we put it inside the loss, we repeat n_iter times the same computation
  with torch.no_grad():
    text_input = clip.tokenize([prompt]).to(device)
    encoded_text = clip_model.encode_text(text_input)
    encoded_text = encoded_text / encoded_text.norm(dim=1, keepdim=True)

  vertices = copy.deepcopy(mesh.vertices)
  #vertices = torch.tensor(np.asarray(rec_mesh.vertices), dtype=torch.float32, device=device) # Convert vertices to a PyTorch tensor
  n_views = 5

  losses = []


  # Optimization loop
  for i in tqdm(range(n_iter)):
    optim.zero_grad()



    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                            lighting=True,
                                                            background=background)

    # Calculate CLIP Loss
    loss = clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model)

    print(loss.shape)

    loss.backward(retain_graph=True)

    optim.step()

    #LR decay
    #scheduler.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")


  # save results
  save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background)

  # save point cloud results
  save_point_cloud_results(log_dir, name)



  # Save prompts
  with open(os.path.join(output_dir, 'prompt.txt'), "w") as f:
    f.write(prompt)
    f.write("\n")
    f.write("initial learning rate:")
    f.write(str(learning_rate))
    f.write("\n")
    f.write("n_iter:")
    f.write(str(n_iter))
    f.write("\n")
    f.write("n_augs:")
    f.write(str(n_augs))
    f.write("\n")
    f.write("n_views:")
    f.write(str(n_views))
    f.write("\n")
    f.write("clip_model:")
    f.write(clip_model_name)
    f.write("\n")
    f.write("depth:")
    f.write(str(depth))

Mesh reconstruction successful!
ModuleList(
  (0): Linear(in_features=3, out_features=256, bias=True)
  (1): ReLU()
  (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (3): Linear(in_features=256, out_features=256, bias=True)
  (4): ReLU()
  (5): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (6): Linear(in_features=256, out_features=256, bias=True)
  (7): ReLU()
  (8): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (9): Linear(in_features=256, out_features=256, bias=True)
  (10): ReLU()
  (11): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (12): Linear(in_features=256, out_features=256, bias=True)
  (13): ReLU()
  (14): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (15): Linear(in_features=256, out_features=256, bias=True)
  (16): ReLU()
  (17): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (18): Linear(in_features=256, out_features=256, bias=True)
  (19): ReLU()
  (20): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  

  0%|          | 3/2500 [00:00<03:25, 12.16it/s]

torch.Size([1])
Last 100 CLIP score: 0.751953125
torch.Size([1])
torch.Size([1])


  0%|          | 5/2500 [00:00<03:08, 13.26it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  0%|          | 9/2500 [00:00<02:45, 15.06it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  1%|          | 15/2500 [00:01<02:35, 15.98it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  1%|          | 19/2500 [00:01<02:30, 16.54it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  1%|          | 21/2500 [00:01<02:29, 16.57it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  1%|          | 25/2500 [00:01<02:31, 16.35it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  1%|          | 29/2500 [00:01<02:32, 16.21it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  1%|▏         | 35/2500 [00:02<02:27, 16.69it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  2%|▏         | 39/2500 [00:02<02:25, 16.91it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  2%|▏         | 43/2500 [00:02<02:25, 16.91it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  2%|▏         | 45/2500 [00:02<02:30, 16.33it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  2%|▏         | 51/2500 [00:03<02:23, 17.02it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  2%|▏         | 55/2500 [00:03<02:24, 16.92it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  2%|▏         | 57/2500 [00:03<02:29, 16.35it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  2%|▏         | 61/2500 [00:03<02:30, 16.24it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  3%|▎         | 67/2500 [00:04<02:25, 16.70it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  3%|▎         | 71/2500 [00:04<02:24, 16.85it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  3%|▎         | 73/2500 [00:04<02:22, 16.99it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  3%|▎         | 77/2500 [00:04<02:33, 15.82it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  3%|▎         | 83/2500 [00:05<02:24, 16.71it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  3%|▎         | 87/2500 [00:05<02:23, 16.84it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  4%|▎         | 89/2500 [00:05<02:22, 16.90it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  4%|▎         | 93/2500 [00:05<02:27, 16.27it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  4%|▍         | 99/2500 [00:06<02:23, 16.71it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  4%|▍         | 101/2500 [00:06<02:29, 16.06it/s]

torch.Size([1])
torch.Size([1])
Last 100 CLIP score: 0.74421875
torch.Size([1])
torch.Size([1])


  4%|▍         | 105/2500 [00:06<02:24, 16.53it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  4%|▍         | 109/2500 [00:06<02:28, 16.15it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  5%|▍         | 115/2500 [00:07<02:24, 16.48it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  5%|▍         | 119/2500 [00:07<02:22, 16.69it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  5%|▍         | 123/2500 [00:07<02:19, 17.02it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  5%|▌         | 125/2500 [00:07<02:23, 16.55it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  5%|▌         | 131/2500 [00:08<02:18, 17.06it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  5%|▌         | 135/2500 [00:08<02:18, 17.05it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  5%|▌         | 137/2500 [00:08<02:19, 16.90it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  6%|▌         | 141/2500 [00:08<02:19, 16.86it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  6%|▌         | 147/2500 [00:08<02:23, 16.43it/s]

torch.Size([1])
torch.Size([1])
torch.Size([1])
torch.Size([1])


  6%|▌         | 147/2500 [00:08<02:23, 16.34it/s]


KeyboardInterrupt: 

# Validation

In [15]:
import os
from PIL import Image
import numpy as np
import torch
from tqdm import tqdm

def evaluate(net, dataloader, device):

    net.eval()
    with torch.no_grad():
        all_ious = []
        for i, (points, affordance_labels) in enumerate(tqdm(dataloader)):

            # -------------- POINT CLOUD TO MESH ---------------
            points = points.squeeze(0).to(device)  # Ensure points are on the correct device
            affordance_labels = affordance_labels.squeeze(0).to(device)  # Ensure labels are on the correct device

            # Create Open3D point cloud and mesh
            pcd = o3d.geometry.PointCloud()
            pcd.points = o3d.utility.Vector3dVector(points.cpu().numpy())  # Transfer points to CPU for Open3D

            # Estimate normals for the point cloud
            pcd.estimate_normals(search_param=o3d.geometry.KDTreeSearchParamHybrid(radius=0.1, max_nn=30))

            radii = [0.005, 0.01, 0.02, 0.04]
            mesh = o3d.geometry.TriangleMesh.create_from_point_cloud_ball_pivoting(pcd, o3d.utility.DoubleVector(radii))

            # Export mesh to an obj file
            output_mesh_file = "mug.obj"
            o3d.io.write_triangle_mesh(output_mesh_file, mesh)

            mesh = Mesh(output_mesh_file)
            MeshNormalizer(mesh)  # Apply MeshNormalizer on CPU if necessary
            #--------------- END POINT CLOUD TO MESH -------------


            # Predict affordance
            vertices = copy.deepcopy(mesh.vertices)

            pred_class = net(vertices)  # Get predictions


            # Compute IoU

            pred_class = pred_class.cpu().numpy()

            # Keep the first column (index 0) only, resulting in shape (2048, 1)
            score = pred_class[:, 0].reshape(-1, 1)


            target_score = affordance_labels.cpu().numpy()

            gt_th = 0.01
            pred_th = 0.5
            ious = 0

            t_mask = (target_score > gt_th).astype(int)
            p_mask = (score > pred_th).astype(int)
            intersection = np.sum(t_mask & p_mask)
            union = np.sum(p_mask | t_mask)
            iou = 1. * intersection/union
            if union == 0:
                all_ious.append(0.0)
            else:
                all_ious.append(1. * intersection/union)

            print(f"iou: {iou}")

        all_ious = np.array(all_ious)
        mIOU = np.nanmean(all_ious)
        print(f"Mean Intersection over Union (mIOU): {mIOU:.4f}")

        return mIOU



print("Starting evaluation...")
mIOU = evaluate(mlp, val_loader, device)
print(f"Mean Intersection over Union (mIOU): {mIOU:.4f}")




Starting evaluation...


  0%|          | 0/22 [00:00<?, ?it/s]



  5%|▍         | 1/22 [00:01<00:23,  1.11s/it]

iou: 0.0


  9%|▉         | 2/22 [00:01<00:15,  1.30it/s]

iou: 0.0


 14%|█▎        | 3/22 [00:02<00:11,  1.64it/s]

iou: 0.0


 18%|█▊        | 4/22 [00:02<00:09,  1.91it/s]

iou: 0.0


 23%|██▎       | 5/22 [00:02<00:08,  2.03it/s]

iou: 0.0


 27%|██▋       | 6/22 [00:03<00:07,  2.01it/s]

iou: 0.0


 32%|███▏      | 7/22 [00:03<00:07,  2.11it/s]

iou: 0.0


 36%|███▋      | 8/22 [00:04<00:06,  2.01it/s]

iou: 0.0


 41%|████      | 9/22 [00:04<00:06,  2.10it/s]

iou: 0.0


 45%|████▌     | 10/22 [00:05<00:05,  2.10it/s]

iou: 0.0


 50%|█████     | 11/22 [00:05<00:05,  1.95it/s]

iou: 0.0


 55%|█████▍    | 12/22 [00:06<00:04,  2.01it/s]

iou: 0.0


 59%|█████▉    | 13/22 [00:06<00:04,  2.15it/s]

iou: 0.0


 64%|██████▎   | 14/22 [00:07<00:03,  2.10it/s]

iou: 0.0


 68%|██████▊   | 15/22 [00:07<00:03,  2.02it/s]

iou: 0.0


 73%|███████▎  | 16/22 [00:08<00:03,  1.96it/s]

iou: 0.0


 77%|███████▋  | 17/22 [00:08<00:02,  2.01it/s]

iou: 0.0


 82%|████████▏ | 18/22 [00:09<00:02,  2.00it/s]

iou: 0.0


 86%|████████▋ | 19/22 [00:09<00:01,  1.79it/s]

iou: 0.0


 91%|█████████ | 20/22 [00:10<00:01,  1.84it/s]

iou: 0.0


 95%|█████████▌| 21/22 [00:11<00:00,  1.79it/s]

iou: 0.0


100%|██████████| 22/22 [00:11<00:00,  1.88it/s]

iou: 0.0


100%|██████████| 22/22 [00:11<00:00,  1.86it/s]

Mean Intersection over Union (mIOU): 0.0000
Mean Intersection over Union (mIOU): 0.0000





In [16]:
torch.cuda.empty_cache()

# Test


In [12]:

print("Starting test...")
mIOU = evaluate(mlp, test_loader, device)
print(f"Mean Intersection over Union (mIOU): {mIOU:.4f}")


Starting test...


  0%|          | 0/45 [00:00<?, ?it/s]



  2%|▏         | 1/45 [00:00<00:29,  1.51it/s]

iou: 0.0


  4%|▍         | 2/45 [00:01<00:21,  2.00it/s]

iou: 0.0


  7%|▋         | 3/45 [00:01<00:18,  2.30it/s]

iou: 0.0


  9%|▉         | 4/45 [00:01<00:17,  2.36it/s]

iou: 0.0


 11%|█         | 5/45 [00:02<00:16,  2.42it/s]

iou: 0.0


 13%|█▎        | 6/45 [00:02<00:15,  2.47it/s]

iou: 0.0


 16%|█▌        | 7/45 [00:03<00:16,  2.35it/s]

iou: 0.0


 18%|█▊        | 8/45 [00:03<00:15,  2.44it/s]

iou: 0.0


 20%|██        | 9/45 [00:03<00:13,  2.60it/s]

iou: 0.0


 22%|██▏       | 10/45 [00:04<00:13,  2.63it/s]

iou: 0.0


 24%|██▍       | 11/45 [00:04<00:13,  2.46it/s]

iou: 0.0


 27%|██▋       | 12/45 [00:05<00:13,  2.42it/s]

iou: 0.0


 29%|██▉       | 13/45 [00:05<00:13,  2.45it/s]

iou: 0.0


 31%|███       | 14/45 [00:05<00:12,  2.47it/s]

iou: 0.0


 33%|███▎      | 15/45 [00:06<00:12,  2.35it/s]

iou: 0.0033482142857142855


 36%|███▌      | 16/45 [00:06<00:12,  2.37it/s]

iou: 0.0


  iou = 1. * intersection/union
 38%|███▊      | 17/45 [00:07<00:12,  2.33it/s]

iou: nan


 40%|████      | 18/45 [00:07<00:11,  2.27it/s]

iou: 0.0


 42%|████▏     | 19/45 [00:07<00:10,  2.39it/s]

iou: nan


 44%|████▍     | 20/45 [00:08<00:10,  2.41it/s]

iou: 0.0


 47%|████▋     | 21/45 [00:08<00:10,  2.36it/s]

iou: 0.0


 49%|████▉     | 22/45 [00:09<00:09,  2.47it/s]

iou: nan


 51%|█████     | 23/45 [00:09<00:09,  2.44it/s]

iou: 0.0


 53%|█████▎    | 24/45 [00:10<00:08,  2.44it/s]

iou: 0.0


 56%|█████▌    | 25/45 [00:10<00:07,  2.64it/s]

iou: 0.0


 58%|█████▊    | 26/45 [00:10<00:08,  2.18it/s]

iou: 0.0


 60%|██████    | 27/45 [00:11<00:08,  2.25it/s]

iou: 0.0


 62%|██████▏   | 28/45 [00:11<00:07,  2.40it/s]

iou: 0.0


 64%|██████▍   | 29/45 [00:12<00:06,  2.32it/s]

iou: 0.0


 67%|██████▋   | 30/45 [00:12<00:05,  2.60it/s]

iou: 0.0


 69%|██████▉   | 31/45 [00:12<00:05,  2.72it/s]

iou: 0.0


 71%|███████   | 32/45 [00:13<00:05,  2.51it/s]

iou: 0.0


 73%|███████▎  | 33/45 [00:13<00:04,  2.70it/s]

iou: 0.0


 76%|███████▌  | 34/45 [00:14<00:04,  2.55it/s]

iou: 0.0


 78%|███████▊  | 35/45 [00:14<00:03,  2.69it/s]

iou: 0.0


 80%|████████  | 36/45 [00:14<00:03,  2.54it/s]

iou: 0.0


 82%|████████▏ | 37/45 [00:15<00:03,  2.52it/s]

iou: 0.0


 84%|████████▍ | 38/45 [00:15<00:02,  2.63it/s]

iou: 0.0


 87%|████████▋ | 39/45 [00:15<00:02,  2.52it/s]

iou: 0.0


 89%|████████▉ | 40/45 [00:16<00:02,  2.01it/s]

iou: 0.0


 91%|█████████ | 41/45 [00:17<00:02,  2.00it/s]

iou: 0.0


 93%|█████████▎| 42/45 [00:17<00:01,  2.13it/s]

iou: 0.0


 96%|█████████▌| 43/45 [00:17<00:00,  2.32it/s]

iou: 0.0


 98%|█████████▊| 44/45 [00:18<00:00,  2.11it/s]

iou: 0.0


100%|██████████| 45/45 [00:18<00:00,  2.24it/s]

iou: nan


100%|██████████| 45/45 [00:19<00:00,  2.35it/s]

Mean Intersection over Union (mIOU): 0.0001
Mean Intersection over Union (mIOU): 0.0001



