In [None]:

import torch
#get torch version
print(torch.__version__)
#check if CUDA is available
print(torch.cuda.is_available())
#get cuda version
print(torch.version.cuda)

#define device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#### Define the model

In [None]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision

from itertools import permutations, product
from Normalization.MeshNormalizer import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh



class NeuralHighlighter(nn.Module):
    def __init__(self, depth=5, width=256, out_dim=2,input_dim=3):
        super(NeuralHighlighter, self).__init__()
        self.depth = depth
        self.width = width
        self.out_dim = out_dim
      

        # Core model
        self.model = nn.Sequential(
            nn.Linear(input_dim, width),
            nn.ReLU(),
            nn.LayerNorm(width),
        )

        # Replicate the core model depth times
        for _ in range(depth - 1):
            self.model.append(nn.Linear(width, width))
            self.model.append(nn.ReLU())
            self.model.append(nn.LayerNorm(width))

        # Final layers
        self.model.append(nn.Linear(width, out_dim))
        self.model.append(nn.Softmax(dim=1))

    def forward(self, x):
       
        # Pass through the model
        return self.model(x)

def get_clip_model(clipmodel):
    model, preprocess = clip.load(clipmodel)
    return model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=1,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=background)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')


def clip_loss(embedding,images,clip_model,augmentations,augmentation_number):
    loss = 0.0
    encoded_text = clip_model.encode_text(embedding)
    for _ in range(augmentation_number):
        aug_img = augmentations(images)
        encoded_imgs = clip_model.encode_image(aug_img)
        loss -= torch.mean(torch.cosine_similarity(encoded_imgs, encoded_text))

    return loss/augmentation_number
    


def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


#### Core loop

In [None]:

import random
seed = 0
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True


render_res = 224
learning_rate = 0.0008
n_iter = 1800
n_augs = 3
output_dir = './output/'
# clip_version = 'ViT-L/14'
# clip_version = 'RN50x4'
# clip_version = 'RN50x16'
# clip_version = 'RN50x16'
clip_version = 'ViT-B/32'
object_name="horse"
mesh_path=f"data/{object_name}.obj"
n_views = 5

Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(mesh_path))

render = Renderer(dim=(render_res, render_res))


mesh = Mesh(obj_path=mesh_path)
MeshNormalizer(mesh)()


# Initialize variables
bg = torch.tensor((1., 1., 1.)).to(device)
log_dir = output_dir



# MLP Settings
mlp = NeuralHighlighter().to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 0, 0], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)


# --- Prompt ---
# encode prompt with CLIP
clip_model,preprocess = get_clip_model(clip_version)
# print(model)
prompt = 'A 3D render of a gray horse with red saddle'
tokenized_text = clip.tokenize([prompt]).to(device) 

vertices = copy.deepcopy(mesh.vertices)


losses = []

#normalizer for image of clip
normalize = transforms.Normalize(mean=[0.48145466, 0.4578275, 0.40821073],std=[0.26862954, 0.26130258, 0.27577711]) #from https://github.com/openai/CLIP/issues/20


#Add augmentation
augmentations = transforms.Compose([
    transforms.RandomResizedCrop(render_res, scale=(0.5, 1.0)),
    transforms.RandomPerspective(p=0.5,distortion_scale=0.5,fill=1),
    normalize
])
 
# Optimization loop
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                            show=False,
                                                            center_azim=0,
                                                            center_elev=0,
                                                            std=1,
                                                            return_views=True,
                                                            lighting=True,
                                                            background=bg)
                                                            
    # Calculate CLIP Loss
    loss = clip_loss(tokenized_text,rendered_images,clip_model,augmentations,augmentation_number=n_augs)
    loss.backward(retain_graph=True)

    optim.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")



# save results
save_final_results(log_dir, "3d-render",mesh, mlp, vertices, colors, render, bg)

# Save prompts
with open(os.path.join(output_dir, prompt), "w") as f:
    f.write('')