In [1]:
from google.colab import drive
drive.mount('/content/drive') #replace with drive.mount('/content/drive/', force_remount=True) if the drive has changed since last mount in order to force the remount
%cd /content/drive/MyDrive/AML_V2/

Mounted at /content/drive
/content/drive/MyDrive/AML_V2


In [2]:
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin==0.17.0 -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.1_cu121.html

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-2jfnjmbz
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-2jfnjmbz
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting ftfy (from clip==1.0)
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m4.7 MB/s[0m eta [36m0:00:00[0m
[?25hBuilding wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369489 sha256=dd00da2f6ba0af62a652ac4ffebf7fbbc383789cda6bb9d2c14d127841f93a0d
  Stored in directory: /tmp/pip-ephem-wheel-cache-u67w3eok/wheels/da/2b/4c/d6691fa9597aac8bb

In [3]:
#Avaiable clip modules
import clip
clip.available_models()


['RN50',
 'RN101',
 'RN50x4',
 'RN50x16',
 'RN50x64',
 'ViT-B/32',
 'ViT-B/16',
 'ViT-L/14',
 'ViT-L/14@336px']

In [4]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision
from torchvision import transforms
import torchvision.transforms.functional as F

from itertools import permutations, product
from Normalization import MeshNormalizer
from render import Renderer
from mesh import Mesh
from pathlib import Path
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh

width = 256
depth = 4
out_dim = 2
input_dim = 3
n_augs = 2      #default is 1

class NeuralHighlighter(nn.Module):
    def __init__(self):
        super(NeuralHighlighter, self).__init__()
        input_size = 3 #Dimension of the vertex
        output_size = 2 #Dimension of the output
                        #for the standard highlighter task there are only 2 classes: target region and not target region.
                        #we use the element of the output vector corresponding to the probability of belonging to the target
                        #region as the highlight probability described in the main paper.
        layers = []

        #See Appendix B (page 13)
        #first linear layer followed by ReLU and LayerNorm
        layers.append(nn.Linear(input_dim, width))
        layers.append(nn.ReLU())
        layers.append(nn.LayerNorm([width]))
        #other [depth] linear layers followed by ReLU and LayerNorm
        # -> changing the depth hyperparameter results in a deeper/shallower net
        # -> total depth (in terms of modules[Linear+ReLU+LayerNorm]) = [depth] + 1
        for i in range(depth):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        #last linear layer followed by softmax in order to output probability-like values
        layers.append(nn.Linear(width, out_dim))
        layers.append(nn.Softmax(dim=1))

        self.mlp = nn.ModuleList(layers)
        self.model = self.mlp
        print(self.mlp)

    def forward(self, x):
        for layer in self.model:
            x = layer(x)
        return x

def get_clip_model(clipmodel):
    model, preprocess = clip.load(clipmodel, device=device)
    return model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')

def clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model):
    if n_augs == 0:
        clip_image = clip_transform(rendered_images)
        encoded_renders = clip_model.encode_image(clip_image)
        encoded_renders = encoded_renders / encoded_renders.norm(dim=1, keepdim=True)
        if encoded_text.shape[0] > 1:
            loss = torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                                torch.mean(encoded_text, dim=0), dim=0)
        else:
            loss = torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                                encoded_text)

    elif n_augs > 0:
        loss = 1.0 #original 0.0
        for _ in range(n_augs):
            augmented_image = augment_transform(rendered_images)
            encoded_renders = clip_model.encode_image(augmented_image)
            if encoded_text.shape[0] > 1:
                loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0),
                                                    torch.mean(encoded_text, dim=0), dim=0)
            else:
                loss -= torch.cosine_similarity(torch.mean(encoded_renders, dim=0, keepdim=True),
                                                    encoded_text)
    return loss


def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


Warp 1.5.1 initialized:
   CUDA Toolkit 12.6, Driver 12.2
   Devices:
     "cpu"      : "x86_64"
     "cuda:0"   : "Tesla T4" (15 GiB, sm_75, mempool enabled)
   Kernel cache:
     /root/.cache/warp/1.5.1


In [6]:
from torch.optim.lr_scheduler import MultiStepLR

def seed_everything(seed=42):
  random.seed(seed)
  os.environ['PYTHONHASHSEED'] = str(seed)
  np.random.seed(seed)
  torch.manual_seed(seed)
  torch.backends.cudnn.deterministic = True
  torch.backends.cudnn.benchmark = False

def get_truly_random_seed_through_os():
    RAND_SIZE = 4
    random_data = os.urandom(
        RAND_SIZE
    )  # Return a string of size random bytes suitable for cryptographic use.
    random_seed = int.from_bytes(random_data, byteorder="big")
    return random_seed
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
seed = get_truly_random_seed_through_os()
seed_everything(seed)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

render_res = 224
learning_rate = 0.0001
n_iter = 2500
res = 224
obj_path = 'data/horse.obj'
#output_dir = './output/'
clip_model_name = 'ViT-L/14'

device = "cuda" if torch.cuda.is_available() else "cpu"

#Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(obj_path))

render = Renderer(dim=(render_res, render_res))
mesh = Mesh(obj_path)
MeshNormalizer(mesh)()

# Initialize variables
background = torch.tensor((1., 1., 1.)).to(device)

#log_dir = output_dir

# CLIP and Augmentation Transforms
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))

clip_transform = transforms.Compose([
        transforms.Resize((res, res)),
        clip_normalizer
])

augment_transform = transforms.Compose([
        transforms.RandomResizedCrop(res, scale=(1, 1)),
        transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
        clip_normalizer
])

# MLP Settings
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

#introducing learning rate decay
#with the prompt horse/saddle the loss plateaus
#scheduler = StepLR(optim, step_size=300, gamma=0.1)

#scheduler = MultiStepLR(optim, milestones=[300, 1800], gamma=0.1)  # Decay a epoch 300 e 1800


# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)

name = 'horse_d_{}_augs_{}'.format(depth, n_augs)

# --- Prompt ---
# encode prompt with CLIP
clip_model, preprocess = get_clip_model(clip_model_name)
#prompts = ['A 3D render of a gray horse with highlighted hat',
#           'A 3D render of a gray horse with highlighted shoes',
#           'A 3D render of a gray horse with highlighted saddle']
prompts = ['A 3D render of a gray horse with highlighted horseback']


for i, prompt in enumerate(prompts):

  output_dir = './output_{}_{}/'.format(name, i)
  Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)
  log_dir = output_dir

  #here we compute the text encoding only once
  #if we put it inside the loss, we repeat n_iter times the same computation
  with torch.no_grad():
    text_input = clip.tokenize([prompt]).to(device)
    encoded_text = clip_model.encode_text(text_input)
    encoded_text = encoded_text / encoded_text.norm(dim=1, keepdim=True)

  vertices = copy.deepcopy(mesh.vertices)
  n_views = 5

  losses = []


  # Optimization loop
  for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                            lighting=True,
                                                            background=background)

    # Calculate CLIP Loss
    loss = clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model)

    #loss = clip_loss_custom(encoded_text, rendered_images, clip_model, preprocess)
    loss.backward(retain_graph=True)

    optim.step()

    #LR decay
    #scheduler.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")


  # save results
  save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background)

  # Save prompts
  with open(os.path.join(output_dir, 'prompt.txt'), "w") as f:
    f.write(prompt)
    f.write("\n")
    f.write("initial learning rate:")
    f.write(str(learning_rate))
    f.write("\n")
    f.write("n_iter:")
    f.write(str(n_iter))
    f.write("\n")
    f.write("n_augs:")
    f.write(str(n_augs))
    f.write("\n")
    f.write("clip_model:")
    f.write(clip_model_name)
    f.write("\n")
    f.write("depth:")
    f.write(str(depth))

ModuleList(
  (0): Linear(in_features=3, out_features=256, bias=True)
  (1): ReLU()
  (2): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (3): Linear(in_features=256, out_features=256, bias=True)
  (4): ReLU()
  (5): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (6): Linear(in_features=256, out_features=256, bias=True)
  (7): ReLU()
  (8): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (9): Linear(in_features=256, out_features=256, bias=True)
  (10): ReLU()
  (11): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (12): Linear(in_features=256, out_features=256, bias=True)
  (13): ReLU()
  (14): LayerNorm((256,), eps=1e-05, elementwise_affine=True)
  (15): Linear(in_features=256, out_features=2, bias=True)
  (16): Softmax(dim=1)
)


  0%|          | 1/2500 [00:00<26:55,  1.55it/s]

Last 100 CLIP score: 0.4755859375


  4%|▍         | 101/2500 [00:56<22:58,  1.74it/s]

Last 100 CLIP score: 0.3980859375


  8%|▊         | 201/2500 [01:51<21:42,  1.77it/s]

Last 100 CLIP score: 0.392275390625


 12%|█▏        | 301/2500 [02:46<20:27,  1.79it/s]

Last 100 CLIP score: 0.39135009765625


 16%|█▌        | 401/2500 [03:41<19:33,  1.79it/s]

Last 100 CLIP score: 0.38627197265625


 20%|██        | 501/2500 [04:37<18:35,  1.79it/s]

Last 100 CLIP score: 0.3881787109375


 24%|██▍       | 601/2500 [05:32<17:39,  1.79it/s]

Last 100 CLIP score: 0.38968994140625


 28%|██▊       | 701/2500 [06:27<16:47,  1.79it/s]

Last 100 CLIP score: 0.3852490234375


 32%|███▏      | 801/2500 [07:22<15:45,  1.80it/s]

Last 100 CLIP score: 0.39295166015625


 36%|███▌      | 901/2500 [08:18<14:50,  1.80it/s]

Last 100 CLIP score: 0.38440185546875


 40%|████      | 1001/2500 [09:13<13:54,  1.80it/s]

Last 100 CLIP score: 0.38598388671875


 44%|████▍     | 1101/2500 [10:08<12:56,  1.80it/s]

Last 100 CLIP score: 0.39181884765625


 48%|████▊     | 1201/2500 [11:03<12:04,  1.79it/s]

Last 100 CLIP score: 0.38287353515625


 52%|█████▏    | 1301/2500 [11:59<11:07,  1.80it/s]

Last 100 CLIP score: 0.39218994140625


 56%|█████▌    | 1401/2500 [12:54<10:16,  1.78it/s]

Last 100 CLIP score: 0.3810302734375


 60%|██████    | 1501/2500 [13:49<09:16,  1.79it/s]

Last 100 CLIP score: 0.3913916015625


 64%|██████▍   | 1601/2500 [14:44<08:22,  1.79it/s]

Last 100 CLIP score: 0.3830615234375


 68%|██████▊   | 1701/2500 [15:40<07:29,  1.78it/s]

Last 100 CLIP score: 0.38697265625


 72%|███████▏  | 1801/2500 [16:35<06:33,  1.78it/s]

Last 100 CLIP score: 0.3880859375


 76%|███████▌  | 1901/2500 [17:30<05:37,  1.77it/s]

Last 100 CLIP score: 0.3767333984375


 80%|████████  | 2001/2500 [18:25<04:40,  1.78it/s]

Last 100 CLIP score: 0.3878662109375


 84%|████████▍ | 2101/2500 [19:21<03:44,  1.78it/s]

Last 100 CLIP score: 0.38941650390625


 88%|████████▊ | 2201/2500 [20:16<02:46,  1.79it/s]

Last 100 CLIP score: 0.3910302734375


 92%|█████████▏| 2301/2500 [21:11<01:51,  1.79it/s]

Last 100 CLIP score: 0.3909521484375


 96%|█████████▌| 2401/2500 [22:06<00:55,  1.79it/s]

Last 100 CLIP score: 0.38823486328125


100%|██████████| 2500/2500 [23:01<00:00,  1.81it/s]
