<a href="https://colab.research.google.com/github/LeograndeCode/Neural-Highlighting-of-Affordance-Regions/blob/Parte-3/Notebook3D_AffordanceNet_v1.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
from google.colab import drive
drive.mount('/content/drive') #replace with drive.mount('/content/drive/', force_remount=True) if the drive has changed since last mount in order to force the remount
%cd /content/drive/MyDrive/Neural-Highlighting-of-Affordance-Regions/


In [None]:
!apt-get update
!apt-get install -y xvfb ffmpeg libsm6 libxext6
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin==0.17.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html
!pip install open3d pyvirtualdisplay

###AffordanceNet Class

AffordanceNet Class modified with capability of choosing only some objects, in our case household objects with hand-object affordances



In [None]:
import os
from os.path import join as opj
import numpy as np
from torch.utils.data import Dataset
import h5py
import json
from utils.provider import rotate_point_cloud_SO3, rotate_point_cloud_y
import pickle as pkl


# List of objects you're interested in
valid_objects = ["Bowl", "Bottle", "Door", "Faucet", "Knife", "Mug", "Scissors", "Vase"]

# List of affordances you're interested in
valid_affordances = ["grasp", "push", "wrap", "pour", "wrap-grasp", "open", "pull"]

def pc_normalize(pc):
    centroid = np.mean(pc, axis=0)
    pc = pc - centroid
    m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
    pc = pc / m
    return pc, centroid, m


def semi_points_transform(points):
    spatialExtent = np.max(points, axis=0) - np.min(points, axis=0)
    eps = 2e-3*spatialExtent[np.newaxis, :]
    jitter = eps*np.random.randn(points.shape[0], points.shape[1])
    points_ = points + jitter
    return points_


class AffordNetDataset(Dataset):
    def __init__(self, data_dir, split, partial=False, rotate='None', semi=False):
        super().__init__()
        self.data_dir = data_dir
        self.split = split

        self.partial = partial
        self.rotate = rotate
        self.semi = semi

        self.load_data()

        self.affordance = self.all_data[0]["affordance"]

        return

    def load_data(self):
        self.all_data = []

        # Open the dataset file (depending on partial or not)
        if self.semi:
            with open(opj(self.data_dir, 'semi_label_1.pkl'), 'rb') as f:
                temp_data = pkl.load(f)
        else:
            if self.partial:
                with open(opj(self.data_dir, 'partial_%s_data.pkl' % self.split), 'rb') as f:
                    temp_data = pkl.load(f)
            elif self.rotate != "None" and self.split != 'train':
                with open(opj(self.data_dir, 'rotate_%s_data.pkl' % self.split), 'rb') as f:
                    temp_data_rotate = pkl.load(f)
                with open(opj(self.data_dir, 'full_shape_%s_data.pkl' % self.split), 'rb') as f:
                    temp_data = pkl.load(f)
            else:
                with open(opj(self.data_dir, 'full_shape_%s_data.pkl' % self.split), 'rb') as f:
                    temp_data = pkl.load(f)

        for index, info in enumerate(temp_data):
            # Only load data for valid objects
            if info["semantic class"] not in valid_objects:
                continue

            if self.partial:
                partial_info = info["partial"]
                for view, data_info in partial_info.items():
                    temp_info = {}
                    temp_info["shape_id"] = info["shape_id"]
                    temp_info["semantic class"] = info["semantic class"]
                    temp_info["affordance"] = info["affordance"]
                    temp_info["view_id"] = view
                    temp_info["data_info"] = data_info

                    # Filter out unwanted affordances
                    temp_info["affordance"] = {key: value for key, value in temp_info["affordance"].items() if key in valid_affordances}

                    self.all_data.append(temp_info)
            elif self.split != 'train' and self.rotate != 'None':
                rotate_info = temp_data_rotate[index]["rotate"][self.rotate]
                full_shape_info = info["full_shape"]
                for r, r_data in rotate_info.items():
                    temp_info = {}
                    temp_info["shape_id"] = info["shape_id"]
                    temp_info["semantic class"] = info["semantic class"]
                    temp_info["affordance"] = info["affordance"]
                    temp_info["data_info"] = full_shape_info
                    temp_info["rotate_matrix"] = r_data.astype(np.float32)

                    # Filter out unwanted affordances
                    temp_info["affordance"] = {key: value for key, value in temp_info["affordance"].items() if key in valid_affordances}

                    self.all_data.append(temp_info)
            else:
                temp_info = {}
                temp_info["shape_id"] = info["shape_id"]
                temp_info["semantic class"] = info["semantic class"]
                temp_info["affordance"] = info["affordance"]
                temp_info["data_info"] = info["full_shape"]

                # Filter out unwanted affordances
                temp_info["affordance"] = {key: value for key, value in temp_info["affordance"].items() if key in valid_affordances}

                self.all_data.append(temp_info)

    def __getitem__(self, index):
        data_dict = self.all_data[index]
        modelid = data_dict["shape_id"]
        modelcat = data_dict["semantic class"]

        data_info = data_dict["data_info"]
        model_data = data_info["coordinate"].astype(np.float32)
        labels = data_info["label"]
        for aff in self.affordance:
            temp = labels[aff].astype(np.float32).reshape(-1, 1)
            model_data = np.concatenate((model_data, temp), axis=1)

        datas = model_data[:, :3]
        targets = model_data[:, 3:]

        if self.rotate != 'None':
            if self.split == 'train':
                if self.rotate == 'so3':
                    datas = rotate_point_cloud_SO3(
                        datas[np.newaxis, :, :]).squeeze()
                elif self.rotate == 'z':
                    datas = rotate_point_cloud_y(
                        datas[np.newaxis, :, :]).squeeze()
            else:
                r_matrix = data_dict["rotate_matrix"]
                datas = (np.matmul(r_matrix, datas.T)).T

        datas, _, _ = pc_normalize(datas)

        return datas, datas, targets, modelid, modelcat

    def __len__(self):
        return len(self.all_data)


class AffordNetDataset_Unlabel(Dataset):
    def __init__(self, data_dir):
        super().__init__()
        self.data_dir = data_dir
        self.load_data()
        self.affordance = self.all_data[0]["affordance"]
        return

    def load_data(self):
        self.all_data = []
        with open(opj(self.data_dir, 'semi_unlabel_1.pkl'), 'rb') as f:
            temp_data = pkl.load(f)

        for info in temp_data:
            # Only load data for valid objects
            if info["semantic class"] not in valid_objects:
                continue

            temp_info = {}
            temp_info["shape_id"] = info["shape_id"]
            temp_info["semantic class"] = info["semantic class"]
            temp_info["affordance"] = info["affordance"]
            temp_info["data_info"] = info["full_shape"]

            # Filter out unwanted affordances
            temp_info["affordance"] = {key: value for key, value in temp_info["affordance"].items() if key in valid_affordances}

            self.all_data.append(temp_info)

    def __getitem__(self, index):
        data_dict = self.all_data[index]
        modelid = data_dict["shape_id"]
        modelcat = data_dict["semantic class"]

        data_info = data_dict["data_info"]
        datas = data_info["coordinate"].astype(np.float32)

        datas, _, _ = pc_normalize(datas)

        return datas, datas, modelid, modelcat

    def __len__(self):
        return len(self.all_data)


### Model

In [None]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as F

from itertools import permutations, product
from Normalization import MeshNormalizer
from render import Renderer
from mesh import Mesh
from pathlib import Path
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh
import open3d as o3d
from pyvirtualdisplay import Display

width = 256
depth = 4       #default is 4
out_dim = 2
input_dim = 3
n_augs = 1      #default is 1

class NeuralHighlighter(nn.Module):
    def __init__(self):
        super(NeuralHighlighter, self).__init__()
        input_size = 3 #Dimension of the vertex
        output_size = 2 #Dimension of the output
                        #for the standard highlighter task there are only 2 classes: target region and not target region.
                        #we use the element of the output vector corresponding to the probability of belonging to the target
                        #region as the highlight probability described in the main paper.
        layers = []

        #See Appendix B (page 13)
        #first linear layer followed by ReLU and LayerNorm
        layers.append(nn.Linear(input_dim, width))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm([width]))
        #other [depth] linear layers followed by ReLU and LayerNorm
        # -> changing the depth hyperparameter results in a deeper/shallower net
        # -> total depth (in terms of modules[Linear+ReLU+LayerNorm]) = [depth] + 1
        for i in range(depth):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        #last linear layer followed by softmax in order to output probability-like values
        layers.append(nn.Linear(width, out_dim))
        layers.append(nn.Softmax(dim=1))

        self.mlp = nn.ModuleList(layers)
        self.model = self.mlp
        print(self.mlp)

    def forward(self, x):
        for layer in self.model:
            x = layer(x)
        return x

def get_clip_model(clipmodel):
    model, preprocess = clip.load(clipmodel, device=device)
    return model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')

#TODO: fix the generation of the point cloud subsequently
#      now the point cloud generation is possible only by executing the PC_rendering.ipynb

def save_point_cloud_results(log_dir, name):
        #now i load the highlighted mesh and transpose it back to the point cloud
        display = Display(visible=0, size=(1400, 900))
        display.start()
        mesh_o3d = o3d.io.read_triangle_mesh(os.path.join(log_dir, f"{name}.ply"))

        if not mesh_o3d.has_vertex_normals():
          mesh_o3d.compute_vertex_normals()

        point_cloud = mesh_o3d.sample_points_poisson_disk(number_of_points=10000)

        width_final_render, height_final_render = 1400, 900
        render_final_pc = o3d.visualization.rendering.OffscreenRenderer(width_final_render, height_final_render)
        material = o3d.visualization.rendering.MaterialRecord()
        material.shader = "defaultUnlit"
        render_final_pc.scene.add_geometry("point_cloud", point_cloud, material)

        zoom_out_factor = 0.5
        bounding_box = point_cloud.get_axis_aligned_bounding_box()
        center = bounding_box.get_center()
        extent = bounding_box.get_extent()
        render_final_pc.scene.camera.look_at(center, center + [0, 0, 1], [0, 1, 0])
        render_final_pc.scene.camera.set_projection(60 / zoom_out_factor, width_final_render / height_final_render, 0.1, 100.0,
                                      o3d.visualization.rendering.Camera.FovType.Horizontal)

        pc_img = render_final_pc.render_to_image()
        output_file = os.path.join(log_dir, f"{name}_final_render.jpg")
        o3d.io.write_image(output_file, pc_img)
        display.stop()


def clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model):
    if n_augs == 0:
        clip_image = clip_transform(rendered_images)
        encoded_renders = clip_model.encode_image(clip_image)
        encoded_renders = encoded_renders / encoded_renders.norm(dim=1, keepdim=True)
        if encoded_text.shape[0] > 1:
            loss = torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                                torch.mean(encoded_text, dim=0), dim=0)
        else:
            loss = torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                                encoded_text)

    elif n_augs > 0:
        loss = 0.0 #original 0.0
        for _ in range(n_augs):
            augmented_image = augment_transform(rendered_images)
            encoded_renders = clip_model.encode_image(augmented_image)
            if encoded_text.shape[0] > 1:
                loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                                    torch.mean(encoded_text, dim=0), dim=0)
            else:
                loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                                    encoded_text)
    return loss

def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))

TO-DO Make sure it's correct


In [None]:
from sklearn.model_selection import train_test_split
for torch.utils.data import Subset, DataLoader

def split_dataset(dataset, train_size=0.7, val_size=0.1, test_size=0.2):
    # First split: 80% for training and 20% for temp (which will be further split into validation and test)
    train_data, temp_data = train_test_split(dataset, test_size=(val_size + test_size), random_state=42)

    # Second split: divide the 20% temp_data into 50% validation and 50% test
    val_size_adjusted = val_size / (val_size + test_size)
    val_data, test_data = train_test_split(temp_data, test_size=(1 - val_size_adjusted), random_state=42)

    return train_data, val_data, test_data

# Assuming AffordNetDataset class is initialized correctly and accepts a `split` argument
data_dir = "./dataset"
dataset = AffordNetDataset(data_dir=data_dir, split="train", partial=False, rotate="None", semi=False)

# Split the dataset into train (70%), validation (10%), and test (20%)
train_dataset, val_dataset, test_dataset = split_dataset(dataset, train_size=0.7, val_size=0.1, test_size=0.2)

# Now define the DataLoaders for each split
batch_size = 32

train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, num_workers = 4)
val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, num_workers = 4)
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, num_workers = 4)

# Check dataset sizes
print('Train Dataset: {}'.format(len(train_dataset)))
print('Valid Dataset: {}'.format(len(val_dataset)))
print('Test Dataset: {}'.format(len(test_dataset)))


 TO-DO Upsample Pointclouds to match the 10000 points of the 3d Highlighter and make sure they are both normalized

## TRAIN


TO-DO Modify Training

In [None]:
from torch.optim.lr_scheduler import MultiStepLR
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
seed = 0
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

render_res = 224
learning_rate = 0.0001
n_iter = 2500
res = 224
obj_path = 'data/dog.obj'
#output_dir = './output/'
clip_model_name = 'ViT-B/32'

device = "cuda" if torch.cuda.is_available() else "cpu"

#Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(obj_path))

render = Renderer(dim=(render_res, render_res))

#then we approximate a mesh so we can still use the previously defined helper functions
#also the loss minimization should converge better


# Initialize variables
background = torch.tensor((1., 1., 1.)).to(device)

#log_dir = output_dir

# CLIP and Augmentation Transforms
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

clip_transform = transforms.Compose([
        transforms.Resize((res, res)),
        clip_normalizer
])

augment_transform = transforms.Compose([
        transforms.RandomResizedCrop(res, scale=(1, 1)),
        transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
        clip_normalizer
])

# MLP Settings
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

#introducing learning rate decay
#with the prompt horse/saddle the loss plateaus
#scheduler = StepLR(optim, step_size=300, gamma=0.1)

#scheduler = MultiStepLR(optim, milestones=[300, 1800], gamma=0.1)  # Decay a epoch 300 e 1800


# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)

name = 'dogPC_d_{}_augs_{}'.format(depth, n_augs)

# --- Prompt ---
# encode prompt with CLIP
clip_model, preprocess = get_clip_model(clip_model_name)
#prompts = ['A 3D render of a gray horse with highlighted hat',
#           'A 3D render of a gray horse with highlighted shoes',
#           'A 3D render of a gray horse with highlighted saddle']
prompts = ['A three-dimensional picture of a gray dog with highlighted belt']


for i, prompt in enumerate(prompts):

  output_dir = './output_{}_{}/'.format(name, i)
  Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)
  log_dir = output_dir

  #here we compute the text encoding only once
  #if we put it inside the loss, we repeat n_iter times the same computation
  with torch.no_grad():
    text_input = clip.tokenize([prompt]).to(device)
    encoded_text = clip_model.encode_text(text_input)
    encoded_text = encoded_text / encoded_text.norm(dim=1, keepdim=True)

  vertices = copy.deepcopy(mesh.vertices)
  #vertices = torch.tensor(np.asarray(rec_mesh.vertices), dtype=torch.float32, device=device) # Convert vertices to a PyTorch tensor
  n_views = 5

  losses = []


  # Optimization loop
  for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                            lighting=True,
                                                            background=background)

    # Calculate CLIP Loss
    loss = clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model)

    #loss = clip_loss_custom(encoded_text, rendered_images, clip_model, preprocess)
    loss.backward(retain_graph=True)

    optim.step()

    #LR decay
    #scheduler.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")


  # save results
  save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background)

  # save point cloud results
  save_point_cloud_results(log_dir, name)



  # Save prompts
  with open(os.path.join(output_dir, 'prompt.txt'), "w") as f:
    f.write(prompt)
    f.write("\n")
    f.write("initial learning rate:")
    f.write(str(learning_rate))
    f.write("\n")
    f.write("n_iter:")
    f.write(str(n_iter))
    f.write("\n")
    f.write("n_augs:")
    f.write(str(n_augs))
    f.write("\n")
    f.write("n_views:")
    f.write(str(n_views))
    f.write("\n")
    f.write("clip_model:")
    f.write(clip_model_name)
    f.write("\n")
    f.write("depth:")
    f.write(str(depth))

### Validation

### Test
