In [None]:
# GPU: CUDA 11.8, PyTorch 2.0.1 su Kaggle
!pip install --upgrade pip
!pip install git+https://github.com/openai/CLIP.git

!pip uninstall -y kaolin


!pip install torch==2.0.1+cu118 torchvision==0.15.2+cu118 torchaudio==2.0.2+cu118 \
  -f https://download.pytorch.org/whl/cu118/torch_stable.html


!pip install kaolin==0.17.0 \
  -f https://nvidia-kaolin.s3.us-east-2.amazonaws.com/torch-2.0.1_cu118.html

!pip install tqdm pillow
!rm -rf /kaggle/working/Affordance_Highlighting_Project_2024
!rm -rf /kaggle/working/output
!git clone https://github.com/MirkoDiMa/Affordance_Highlighting_Project_2024.git
%cd Affordance_Highlighting_Project_2024
import sys

sys.path.append('/kaggle/working/Affordance_Highlighting_Project_2024')

In [None]:
# ─── EXPERIMENT CONFIG ──────────────────────────────────────────────────────────
exp_config = {
    # Prompt
    "prompt": "A 3D render of a gray candle with highlighted hat",

    # Seed 
    "seed": 45,

    # Data & path
    "obj_path":        "data/candle.obj",
    "output_dir":      "/kaggle/working/output",

    # CLIP
    "clip_model_name": "ViT-L/14",

    # MLP
    "mlp_input_dim":   3,
    "mlp_hidden_dim":  256,
    "mlp_num_layers":  6,
    "mlp_out_dim":     2,
    "positional_encoding": False,
    "sigma":           5.0,

    # Training
    "render_res":      224,
    "n_views":         6,
    "learning_rate":   1e-4,
    "n_iter":          2500,
    "n_augs":          5,
    "clipavg":         "view",

    # Augmentation
    "aug_type":        "RandomPerspective",
    "aug_params": {
        "distortion_scale": 0.5,
        "p":               0.8,
    },
    #extension
    "bg_mode":           "noise",     # "none" | "solid" | "noise" | "image" | "mixed"
    "bg_prob":           0.75,        # probability of applying a background per batch of views
    "bg_key_color":      [1.0, 1.0, 1.0],  # colour of rendering used as "green screen"
    "bg_key_tol":        0.02,        # key color matching tolerance (0..1)
    "bg_dir":            "data/backgrounds",  # image folder path (if bg_mode="image" o "mixed")
    "bg_min_resize":     256,         # min size for random resize of background images
    "bg_eval_color":     [1.0, 1.0, 1.0],    # background used for final renders

}


In [None]:
import clip
import copy
import json
import kaolin as kal
import kaolin.ops.mesh
import numpy as np
import os
import random
import torch
import torch.nn as nn
import torchvision
import time

from itertools import permutations, product
from Normalization import MeshNormalizer
from mesh import Mesh
from pathlib import Path
from render import Renderer
from tqdm import tqdm
from torch.autograd import grad
from torchvision import transforms
from utils import device, color_mesh
from utils import FourierFeatureTransform

class NeuralHighlighter(nn.Module):
    def __init__(self, depth, width, out_dim, input_dim=3, positional_encoding=False, sigma=5.0):
        super(NeuralHighlighter, self).__init__()
        layers = []
        if positional_encoding:
            layers.append(FourierFeatureTransform(input_dim, width, sigma))
            layers.append(nn.Linear(width * 2 + input_dim, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        else:
            layers.append(nn.Linear(input_dim, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        for i in range(depth):
            layers.append(nn.Linear(width, width))
            layers.append(nn.ReLU())
            layers.append(nn.LayerNorm([width]))
        layers.append(nn.Linear(width, out_dim))
        layers.append(nn.Softmax(dim=1))

        self.mlp = nn.ModuleList(layers)
        print(self.mlp)
    
    def forward(self, x):
        for layer in self.mlp:
            x = layer(x)
        return x

def get_clip_model(clipmodel):
    model, preprocess = clip.load(clipmodel, device=device, jit=False)
    return model, preprocess

# ================== HELPER FUNCTIONS =============================
def save_final_results(log_dir, name, mesh, mlp, vertices, colors, render, background):
    mlp.eval()
    with torch.no_grad():
        probs = mlp(vertices)
        max_idx = torch.argmax(probs, 1, keepdim=True)
        # for renders
        one_hot = torch.zeros(probs.shape).to(device)
        one_hot = one_hot.scatter_(1, max_idx, 1)
        sampled_mesh = mesh

        highlight = torch.tensor([204, 255, 0]).to(device)
        gray = torch.tensor([180, 180, 180]).to(device)
        colors = torch.stack((highlight/255, gray/255)).to(device)
        color_mesh(one_hot, sampled_mesh, colors)
        eval_bg = torch.tensor(exp_config.get("bg_eval_color", [1.,1.,1.]), device=device).float()
        rendered_images, _, _ = render.render_views(sampled_mesh, num_views=5,
                                                                        show=False,
                                                                        center_azim=0,
                                                                        center_elev=0,
                                                                        std=4,
                                                                        return_views=True,
                                                                        lighting=True,
                                                                        background=eval_bg)
        # for mesh
        final_color = torch.zeros(vertices.shape[0], 3).to(device)
        final_color = torch.where(max_idx==0, highlight, gray)
        mesh.export(os.path.join(log_dir, f"{name}.ply"), extension="ply", color=final_color)
        save_renders(log_dir, 0, rendered_images, name='final_render.jpg')
def save_exp_config(config, output_dir):
    import json, csv, os
    # JSON
    with open(os.path.join(output_dir, 'experiment_config.json'), 'w') as f:
        json.dump(config, f, indent=2)
    # CSV
    csv_path = os.path.join(output_dir, 'experiments_summary.csv')
    write_header = not os.path.exists(csv_path)
    with open(csv_path, 'a', newline='') as f:
        writer = csv.DictWriter(f, fieldnames=config.keys())
        if write_header: writer.writeheader()
        writer.writerow(config)

def clip_loss(rendered_images: torch.Tensor,
              text_embedding: torch.Tensor,
              clip_model: nn.Module,
              clip_transform: transforms.Compose,
              augment_transform: transforms.Compose,
              n_augs: int,
              clipavg: str = "view") -> torch.Tensor:
    """
    Replichiamo esattamente la loss del codice ufficiale:
    - n_augs==0: un solo forward con clip_transform
    - n_augs>0: summation di n_augs forward con augment_transform
    - clipavg="view": media sulle viste prima di cosine‐similarity
    - clipavg!="view": media sulle coppie vista‐testo
    """
    # step without augmentations
    if n_augs == 0:
        # 1)  resize+normalize
        clip_imgs = clip_transform(rendered_images)            # (V,3,H,W)
        # 2) encode CLIP
        enc = clip_model.encode_image(clip_imgs)               # (V,D)
        enc = enc / enc.norm(dim=1, keepdim=True)

        # 3) test normalized
        txt = text_embedding / text_embedding.norm(dim=1, keepdim=True)

        # 4) compute loss
        if clipavg == "view":
            if txt.shape[0] > 1:
                # view mean then cosine (similarity of means)
                loss = -torch.cosine_similarity(enc.mean(0),
                                                txt.mean(0), dim=0)
            else:
                loss = -torch.cosine_similarity(enc.mean(0, keepdim=True),
                                                txt, dim=1)
        else:
            loss = -torch.mean(torch.cosine_similarity(enc, txt, dim=1))

    # step with augmentations
    else:
        loss = 0.0
        for _ in range(n_augs):
            # 1) augment + normalize
            aug = augment_transform(rendered_images)            # (V,3,H,W)
            # 2) encode
            enc_a = clip_model.encode_image(aug)
            enc_a = enc_a / enc_a.norm(dim=1, keepdim=True)
            # 3) text normalized
            txt = text_embedding / text_embedding.norm(dim=1, keepdim=True)
            # 4) compute loss
            if clipavg == "view":
                if txt.shape[0] > 1:
                    loss -= torch.cosine_similarity(enc_a.mean(0),
                                                    txt.mean(0), dim=0)
                else:
                    loss -= torch.cosine_similarity(enc_a.mean(0, keepdim=True),
                                                    txt, dim=1)
            else:
                loss -= torch.mean(torch.cosine_similarity(enc_a, txt, dim=1))

    return loss


    
def save_renders(dir, i, rendered_images, name=None):
    if name is not None:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, name))
    else:
        torchvision.utils.save_image(rendered_images, os.path.join(dir, 'renders/iter_{}.jpg'.format(i)))


In [None]:
from PIL import Image, ImageFilter
import math
import glob

# --------------------------------------------------------------------------- #
# Background sampler + matting/compositing
# --------------------------------------------------------------------------- #

def _to_4d(x):
    # garantisce shape (B,3,H,W)
    if x.dim() == 3:
        x = x.unsqueeze(0)
    return x

def load_bg_pool(bg_dir):
    paths = []
    if os.path.isdir(bg_dir):
        exts = ("*.jpg", "*.jpeg", "*.png", "*.bmp", "*.webp")
        for e in exts:
            paths.extend(glob.glob(os.path.join(bg_dir, e)))
    return paths

_BG_POOL = load_bg_pool(exp_config["bg_dir"])

def _rand_solid(res):
    c = torch.rand(3, device=device).view(3,1,1)
    return c.expand(3, res, res)

def _rand_noise(res):
    return torch.rand(3, res, res, device=device)

def _rand_image(res):
    if len(_BG_POOL) == 0:
        # fallback
        return _rand_noise(res)
    p = random.choice(_BG_POOL)
    img = Image.open(p).convert("RGB")
    # ridimensiona lungo il lato minore
    s = exp_config["bg_min_resize"]
    scale = s / min(img.size)
    img = img.resize((int(img.size[0]*scale), int(img.size[1]*scale)))
    # random crop a (res,res)
    if img.size[0] == res and img.size[1] == res:
        crop = img
    else:
        if img.size[0] < res or img.size[1] < res:
            img = img.resize((max(res,img.size[0]), max(res,img.size[1])))
        x0 = random.randint(0, img.size[0]-res)
        y0 = random.randint(0, img.size[1]-res)
        crop = img.crop((x0, y0, x0+res, y0+res))
    bg = torch.from_numpy(np.array(crop)).float().permute(2,0,1) / 255.0
    return bg.to(device)

def sample_background_batch(batch, res, mode):
    # mode: "solid" | "noise" | "image" | "mixed"
    outs = []
    for _ in range(batch):
        m = mode
        if mode == "mixed":
            m = random.choice(["solid","noise","image"])
        if m == "solid":
            outs.append(_rand_solid(res))
        elif m == "noise":
            outs.append(_rand_noise(res))
        else:
            outs.append(_rand_image(res))
    return torch.stack(outs, dim=0)  # (B,3,H,W)

def composite_background(rendered, bg, key_rgb=(1.,1.,1.), tol=0.02, blur_px=0):
    """
    rendered: (B,3,H,W), su colore uniforme key_rgb
    bg:       (B,3,H,W)
    easy Matting: change  the sane pixel of key color with the background.
    """
    rendered = _to_4d(rendered)
    B, C, H, W = rendered.shape

    key = torch.tensor(key_rgb, device=rendered.device).view(1,3,1,1)
    
    dist = (rendered - key).abs().amax(dim=1, keepdim=True)  # (B,1,H,W)
    mask_bg = (dist < tol).float()        # 1 = background; 0 = object

    
    if blur_px > 0:
        import torch.nn.functional as F
        
        k = max(1, int(blur_px))
        x = torch.arange(-k, k+1, device=rendered.device).float()
        w = torch.exp(-(x**2)/(2*(k/2.0+1e-6)**2))
        w = (w / w.sum()).view(1,1,-1,1)
        mask_bg = F.conv2d(mask_bg, w, padding=(k,0), groups=mask_bg.size(1))
        wT = w.permute(0,1,3,2)
        mask_bg = F.conv2d(mask_bg, wT, padding=(0,k), groups=mask_bg.size(1))
        mask_bg = mask_bg.clamp(0,1)

    # ensure bg shape
    if bg.dim() == 3: bg = bg.unsqueeze(0)
    if bg.shape[-2:] != (H,W):
        bg = torch.nn.functional.interpolate(bg, size=(H,W), mode="bilinear", align_corners=False)

    comp = rendered*(1.0 - mask_bg) + bg*mask_bg
    return comp


In [None]:
# Constrain most sources of randomness
# (some torch backwards functions within CLIP are non-determinstic)
seed=exp_config["seed"]
torch.manual_seed(seed)
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

render_res = exp_config["render_res"]
learning_rate = exp_config["learning_rate"]
n_iter = exp_config["n_iter"]
res = exp_config["render_res"]
obj_path = exp_config["obj_path"]
n_augs = exp_config["n_augs"]
output_dir = exp_config["output_dir"]
clip_model = exp_config["clip_model_name"]

clip_model, preprocess = get_clip_model(clip_model)

Path(os.path.join(output_dir, 'renders')).mkdir(parents=True, exist_ok=True)

objbase, extension = os.path.splitext(os.path.basename(obj_path))

render = Renderer(dim=(render_res, render_res))
mesh = Mesh(obj_path)
MeshNormalizer(mesh)()

# Initialize variables
background = torch.tensor((1., 1., 1.)).to(device)

log_dir = output_dir

# CLIP and Augmentation Transforms
clip_normalizer = transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
clip_transform = transforms.Compose([
    transforms.Resize((res, res), antialias=False),
    clip_normalizer
])
augment_transform = transforms.Compose([
    transforms.RandomResizedCrop(res, scale=(1, 1), antialias=False),
    transforms.RandomPerspective(fill=1,
                                 distortion_scale=exp_config["aug_params"]["distortion_scale"],
                                 p=exp_config["aug_params"]["p"]),
    clip_normalizer
])

# MLP Settings
mlp = NeuralHighlighter(depth=exp_config["mlp_num_layers"],
    width=exp_config["mlp_hidden_dim"],
    out_dim=exp_config["mlp_out_dim"],
    input_dim=exp_config["mlp_input_dim"],
    positional_encoding=exp_config["positional_encoding"],
    sigma=exp_config["sigma"]).to(device)
optim = torch.optim.Adam(mlp.parameters(), learning_rate)

# list of possible colors
rgb_to_color = {(204/255, 1., 0.): "highlighter", (180/255, 180/255, 180/255): "gray"}
color_to_rgb = {"highlighter": [204/255, 1., 0.], "gray": [180/255, 180/255, 180/255]}
full_colors = [[204/255, 1., 0.], [180/255, 180/255, 180/255]]
colors = torch.tensor(full_colors).to(device)


# --- Prompt ---
# encode prompt with CLIP
prompt = exp_config["prompt"]

with torch.no_grad():
    prompt_token = clip.tokenize([prompt]).to(device)
    encoded_text = clip_model.encode_text(prompt_token)
    encoded_text = encoded_text / encoded_text.norm(dim=1, keepdim=True)

vertices = copy.deepcopy(mesh.vertices)
n_views = exp_config["n_views"]

best_loss = float('inf')
best_iter = -1
best_state = None

losses = []
start_time = time.time()
# Optimization loop
for i in tqdm(range(n_iter)):
    optim.zero_grad()

    # predict highlight probabilities
    pred_class = mlp(vertices)

    # color and render mesh
    sampled_mesh = mesh
    color_mesh(pred_class, sampled_mesh, colors)
    
    # --- Compositing background  ---
    
    background = torch.tensor(exp_config["bg_key_color"], device=device).float()
    rendered_images, elev, azim = render.render_views(sampled_mesh, num_views=n_views,
                                                      show=False,
                                                      center_azim=0,
                                                      center_elev=0,
                                                      std=4,
                                                      return_views=True,
                                                      lighting=True,
                                                      background=background)
    
    
    if exp_config["bg_mode"] != "none" and random.random() < exp_config["bg_prob"]:
        B, C, H, W = rendered_images.shape
        bgs = sample_background_batch(B, res=H, mode=exp_config["bg_mode"])  # (B,3,H,W)
        rendered_images = composite_background(
            rendered_images, bgs,
            key_rgb=exp_config["bg_key_color"],
            tol=exp_config["bg_key_tol"],
            blur_px=1
        )


    # Calculate CLIP Loss
    loss = clip_loss(rendered_images,
        encoded_text,
        clip_model,
        clip_transform,
        augment_transform,
        n_augs,
        clipavg = exp_config["clipavg"])

    loss.backward()

    optim.step()

    # update variables + record loss
    with torch.no_grad():
        losses.append(loss.item())
    # tracking of the best
    if loss.item() < best_loss:
        best_loss  = loss.item()
        best_iter  = i
        best_state = copy.deepcopy(mlp.state_dict())
        

    # report results
    if i % 100 == 0:
        print("Last 100 CLIP score: {}".format(np.mean(losses[-100:])))
        save_renders(log_dir, i, rendered_images)
        with open(os.path.join(log_dir, "training_info.txt"), "a") as f:
            f.write(f"For iteration {i}... Prompt: {prompt}, Last 100 avg CLIP score: {np.mean(losses[-100:])}, CLIP score {losses[-1]}\n")

# metrics
final_loss       = losses[-1]
exp_config["final_clip_score"]       = -final_loss
exp_config["avg_clip_score_last100"] = -float(np.mean(losses[-100:]))
exp_config["runtime_seconds"]        = time.time() - start_time

mlp.load_state_dict(best_state)
exp_config["best_iter"] = best_iter
exp_config["best_clip_score"] = -best_loss
# save config + summary
save_exp_config(exp_config, output_dir)
# save results
save_final_results(log_dir, f"{objbase}_best_iter{best_iter}", mesh, mlp, vertices, colors, render, background)

# Save prompts
with open(os.path.join(log_dir, "prompt.txt"), "w") as f:
    f.write(prompt)