In [None]:
from pathlib import Path
import mediapy
import os
from tqdm import tqdm

import numpy as np
import torch

from p3d.losses import calc_l2_losses, calc_lpips_losses
from transformers import AutoImageProcessor, AutoModel, CLIPProcessor, CLIPModel


processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
dino = AutoModel.from_pretrained('facebook/dinov2-base')

CLIP = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
name = "shapegen_00"

ref_dir = Path(f"data/{name}/renders")
test_dir = Path(f"data/{name}/rand_renders")

renders = list(ref_dir.glob("*.png"))
renders.sort()
tests = list(test_dir.glob("*.png"))
tests.sort()

In [None]:
l2_dir = Path(f"data/{name}/l2")
lpips_dir = Path(f"data/{name}/lpips")
clip_dir = Path(f"data/{name}/clip")
dino_dir = Path(f"data/{name}/dino")

l2_dir.mkdir(parents=True, exist_ok=True)
lpips_dir.mkdir(parents=True, exist_ok=True)
clip_dir.mkdir(parents=True, exist_ok=True)
dino_dir.mkdir(parents=True, exist_ok=True)

In [None]:
rendered_images = torch.from_numpy(np.stack([mediapy.read_image(render) for render in renders]))[...,:3] / 255

with torch.no_grad():
    print("processing clip features")
    render_features_clip = CLIP.get_image_features(**clip_processor(images=rendered_images, return_tensors="pt", do_rescale=False)).detach()
    print("processing dino features")
    render_features_dino = dino(**processor(images=rendered_images, return_tensors="pt", do_rescale=False)).pooler_output.detach()


    for i, test in tqdm(enumerate(tests)):
        test_image = torch.from_numpy(mediapy.read_image(test))[...,:3] / 255

        test_features_clip = CLIP.get_image_features(**clip_processor(images=test_image, return_tensors="pt", do_rescale=False)).detach()
        test_features_dino = dino(**processor(images=test_image, return_tensors="pt", do_rescale=False)).pooler_output.detach()
        
        l2_loss = calc_l2_losses(test_image, rendered_images)
        lpips_loss = calc_lpips_losses(test_image, rendered_images).flatten()
        clip_loss = 1 - torch.nn.functional.cosine_similarity(test_features_clip, render_features_clip, dim=1).clamp(0, 1)
        dino_loss = 1 - torch.nn.functional.cosine_similarity(test_features_dino, render_features_dino, dim=1).clamp(0, 1)

        torch.save(l2_loss, l2_dir / f"{i:05d}.pt")
        torch.save(lpips_loss, lpips_dir / f"{i:05d}.pt")
        torch.save(clip_loss, clip_dir / f"{i:05d}.pt")
        torch.save(dino_loss, dino_dir / f"{i:05d}.pt")