In [None]:
!pip install git+https://github.com/openai/CLIP.git
!pip install kaolin -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.5.0_cu111.html

In [None]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision

from itertools import permutations, product
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh


def get_clip_model(clipmodel="ViT-L/14"):
    """
    Loads and configures a CLIP model for text-guided 3D highlighting.

    Args:
        clipmodel (str): Name of the CLIP model to use. Default is "ViT-L/14".
                        Other options include "ViT-L/14@336px", "RN50x4", "RN50x16", "RN50x64"

    Returns:
        tuple: (clip_model, preprocess, resolution)
            - clip_model: The loaded CLIP model
            - preprocess: CLIP's preprocessing transform
            - resolution: The appropriate resolution for the model
    """
    import clip
    from utils import device

    # Load the CLIP model and move to appropriate device
    clip_model, preprocess = clip.load(clipmodel, device=device, jit=False)

    # Determine the appropriate resolution for the model
    resolution = 224  # Default resolution
    if clipmodel == "ViT-L/14@336px":
        resolution = 336
    elif clipmodel == "RN50x4":
        resolution = 288
    elif clipmodel == "RN50x16":
        resolution = 384
    elif clipmodel == "RN50x64":
        resolution = 448

    # Freeze the model parameters
    for param in clip_model.parameters():
        param.requires_grad = False

    # Set model to evaluation mode
    clip_model.eval()

    return clip_model, preprocess, resolution

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')
        

def clip_loss(rendered_images, encoded_text, clip_transform, augment_transform, clip_model, n_augs=5, clipavg="view"):
    """
    Calculates the CLIP-based loss between rendered images and text description.

    The loss measures how well the highlighted regions match the text description
    by comparing their CLIP embeddings. Lower loss means better alignment.

    Args:
        rendered_images: Tensor of rendered views of the mesh
        encoded_text: CLIP embedding of the target text description
        clip_transform: Basic CLIP preprocessing transform
        augment_transform: Transform for data augmentation
        clip_model: The CLIP model for computing embeddings
        n_augs: Number of augmentations to apply (default: 5)
        clipavg: Method for averaging CLIP scores ("view" or "embedding")

    Returns:
        torch.Tensor: The computed loss value
    """
    # If no augmentations requested, just use basic transform
    if n_augs == 0:
        # Apply CLIP's preprocessing transform
        clip_image = clip_transform(rendered_images)

        # Get image embeddings from CLIP
        encoded_renders = clip_model.encode_image(clip_image)

        # Normalize embeddings to lie on unit sphere
        encoded_renders = encoded_renders / encoded_renders.norm(dim=1, keepdim=True)

        # Average across views or compare each view individually
        if clipavg == "view":
            # Handle both single and multiple text embeddings
            if encoded_text.shape[0] > 1:
                # Multiple text embeddings: average both image and text embeddings
                loss = torch.cosine_similarity(
                    torch.mean(encoded_renders, dim=0),
                    torch.mean(encoded_text, dim=0),
                    dim=0
                )
            else:
                # Single text embedding: just average image embeddings
                loss = torch.cosine_similarity(
                    torch.mean(encoded_renders, dim=0, keepdim=True),
                    encoded_text
                )
        else:
            # Compare each view individually and average the similarities
            loss = torch.mean(torch.cosine_similarity(encoded_renders, encoded_text))

    # If augmentations requested, apply them and average results
    else:
        loss = 0.0
        # Run multiple augmentations and average their losses
        for _ in range(n_augs):
            # Apply random augmentation transforms
            augmented_image = augment_transform(rendered_images)

            # Get embeddings for augmented images
            encoded_renders = clip_model.encode_image(augmented_image)
            encoded_renders = encoded_renders / encoded_renders.norm(dim=1, keepdim=True)

            # Calculate loss based on averaging method
            if clipavg == "view":
                if encoded_text.shape[0] > 1:
                    loss -= torch.cosine_similarity(
                        torch.mean(encoded_renders, dim=0),
                        torch.mean(encoded_text, dim=0),
                        dim=0
                    )
                else:
                    loss -= torch.cosine_similarity(
                        torch.mean(encoded_renders, dim=0, keepdim=True),
                        encoded_text
                    )
            else:
                loss -= torch.mean(torch.cosine_similarity(encoded_renders, encoded_text))

    return loss
    
def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


In [None]:
from torchvision import transforms

def setup_clip_transforms(resolution=224):
    """
    Creates the transformation pipelines needed for CLIP processing.

    Args:
        resolution (int): The target resolution for the images (depends on CLIP model)

    Returns:
        tuple: (clip_transform, augment_transform)
    """
    # CLIP's normalization values
    clip_mean = (0.48145466, 0.4578275, 0.40821073)
    clip_std = (0.26862954, 0.26130258, 0.27577711)

    # Basic CLIP transform - just resize and normalize
    clip_transform = transforms.Compose([
        transforms.Resize((resolution, resolution)),
        transforms.Normalize(clip_mean, clip_std)
    ])

    # Augmentation transform - adds random perturbations
    augment_transform = transforms.Compose([
        transforms.RandomResizedCrop(resolution, scale=(1, 1)),
        transforms.RandomPerspective(fill=1, p=0.8, distortion_scale=0.5),
        transforms.Normalize(clip_mean, clip_std)
    ])

    return clip_transform, augment_transform

In [None]:
from neural_highlighter import NeuralHighlighter

# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
# Set a consistent seed for reproducibility
seed = 0  # You can use any integer value
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


render_res = 224
learning_rate = 0.0001
n_iter = 2500
res = 224
obj_path = 'data/horse.obj'
n_augs = 5
output_dir = './output/'
clip_model_name = 'ViT-L/14'

Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(obj_path))


mesh = Mesh(obj_path)
MeshNormalizer(mesh)()

# Initialize variables
background = torch.tensor((1., 1., 1.)).to(device)

log_dir = output_dir


# MLP Settings
mlp =  NeuralHighlighter(
    depth=5,           # Number of hidden layers
    width=256,         # Width of each layer
    out_dim=2,         # Binary classification (highlight/no-highlight)
    input_dim=3,       # 3D coordinates (x,y,z)
    positional_encoding=False  # As recommended in the paper
).to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)


# --- Prompt ---
# encode prompt with CLIP
clip_model, preprocess, resolution = get_clip_model(clipmodel=clip_model_name)
# Create your text prompt
object_name = "horse"
highlight_region = "shoes"
prompt = "A 3D render of a gray {} with highlighted {}".format(object_name, highlight_region)

# Encode the text using CLIP
with torch.no_grad():
    text_tokens = clip.tokenize([prompt]).to(device)
    text_features = clip_model.encode_text(text_tokens)
    text_features = text_features / text_features.norm(dim=1, keepdim=True)

render = Renderer(dim=(resolution, resolution))
vertices = copy.deepcopy(mesh.vertices)
n_views = 5

# Set up the transforms
clip_transform, augment_transform = setup_clip_transforms(resolution)

losses = []

# Optimization loop
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                            lighting=True,
                                                            background=background)

    # Calculate CLIP Loss
    loss_func = loss = clip_loss(rendered_images,
                text_features,  # This was called encoded_text in the function
                clip_transform,
                augment_transform,
                clip_model,
                n_augs=5)
    loss_func.backward(retain_graph=True)

    optim.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss_func.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")


# save results
save_final_results(log_dir, mesh, mlp, vertices, colors, render, background)

# Save prompts
with open(os.path.join(dir, prompt), "w") as f:
    f.write('')