In [None]:
%load_ext autoreload
%autoreload 2

import os
import sys
# sys.path.append(os.getcwd() + "/submodules/GenerativeModels")
# from scripts.sampling.simple_video_sample import sample

import torch
import torch.nn.functional as F
from PIL import Image
import numpy as np
import mediapy
from pathlib import Path


from p3d.models.openlrm import OpenLRM

In [None]:
def load_image(image_path, source_size=266):
    image = torch.from_numpy(np.array(Image.open(image_path)))
    image = image.permute(2, 0, 1).unsqueeze(0) / 255.0
    if image.shape[1] == 4:  # RGBA
        image = image[:, :3, ...] * image[:, 3:, ...] + (1 - image[:, 3:, ...])

    height, width = image.shape[2:]
    height_pad = max(0, (width-height) //2)
    width_pad = max(0, (height-width) //2)

    image = F.pad(image, (width_pad, width_pad, height_pad, height_pad), "constant", value=1)
    image = torch.nn.functional.interpolate(image, size=(source_size, source_size), mode='bicubic', align_corners=True)
    image = torch.clamp(image, 0, 1)
    return image

def generate_camera_coords(step=60):
    thetas = np.arange(0, 180 + step, step)
    psis = np.arange(0, 360, step) # roll

    ttl_phis = 0

    camera_coords = []
    for i, theta_ in enumerate(thetas):
        num_phis = round(360 / step * np.sin(i * step * np.pi / 180))
        num_phis = max(num_phis, 2)

        phis = np.linspace(0, 360, num_phis)[:-1]
        images = []
        for j, phi_ in enumerate(phis):
            for k, psi_ in enumerate(psis):
                if theta_ == 0 or theta_ == 180:
                    if phi_  > 0 or psi_ > 0:
                        break
                ttl_phis += 1
                camera_coords.append([theta_, phi_, psi_])

    camera_coords = np.array(camera_coords)
    return camera_coords


In [None]:
trial_name = "familiar_high_screen19"
trial_dir = Path(f"data/barense/{trial_name}")
trial_paths = list(trial_dir.glob("*.png"))
trial_paths.sort()
trial_imgs = [load_image(trial_path) for trial_path in trial_paths]
mediapy.show_image(trial_imgs[0][0].permute(1, 2, 0))

In [None]:
# input_path = str(trial_paths[0])
# version = "sv3d_p"
# elevations_deg = 10.0
# sample(input_path=input_path, version=version, elevations_deg=elevations_deg)

In [None]:
olrm = OpenLRM(device="cuda:0")

In [None]:
from p3d.losses import calc_l2_losses, calc_lpips_losses
from transformers import AutoImageProcessor, AutoModel, CLIPProcessor, CLIPModel


processor = AutoImageProcessor.from_pretrained('facebook/dinov2-base')
dino = AutoModel.from_pretrained('facebook/dinov2-base')

CLIP = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
camera_coords = generate_camera_coords(90)

ref_images = []
brute_force_samples = []

for trial_img in trial_imgs:
    olrm.process_image(trial_img)
    
    ref_images.append(olrm.generate_images(np.array([[90, 270, 0]]))[0])

    images = olrm.generate_images(camera_coords)
    brute_force_samples.append(images)

ref_images = torch.stack(ref_images)
brute_force_samples = torch.stack(brute_force_samples)

In [None]:
with torch.no_grad():
    print("Generating CLIP Features")
    all_clip_features = []
    for row in brute_force_samples:
        clip_features = CLIP.get_image_features(**clip_processor(images=row, return_tensors="pt", do_rescale=False)).detach()
        all_clip_features.append(clip_features)
    all_clip_features = torch.stack(all_clip_features)
    
    print("Generating Dino Features")
    all_dino_features = []
    for row in brute_force_samples:
        dino_features = dino(**processor(images=row, return_tensors="pt", do_rescale=False)).pooler_output.detach()
        all_dino_features.append(dino_features)
    all_dino_features = torch.stack(all_dino_features)

In [None]:
l2_image_matches = torch.zeros((4, 4) + ref_images[0].shape)
l2_coord_matches = [[None for _ in range(4)] for _ in range(4)]
l2_scores = [[None for _ in range(4)] for _ in range(4)]

lpips_image_matches = torch.zeros((4, 4) + ref_images[0].shape)
lpips_coord_matches = [[None for _ in range(4)] for _ in range(4)]
lpips_scores = [[None for _ in range(4)] for _ in range(4)]

clip_image_matches = torch.zeros((4, 4) + ref_images[0].shape)
clip_coord_matches = [[None for _ in range(4)] for _ in range(4)]
clip_scores = [[None for _ in range(4)] for _ in range(4)]

dino_image_matches = torch.zeros((4, 4) + ref_images[0].shape)
dino_coord_matches = [[None for _ in range(4)] for _ in range(4)]
dino_scores = [[None for _ in range(4)] for _ in range(4)]

with torch.no_grad():
    for i in range(4):
        ref_image = ref_images[i]
        test_features_clip = CLIP.get_image_features(**clip_processor(images=ref_image, return_tensors="pt", do_rescale=False)).detach()
        test_features_dino = dino(**processor(images=ref_image, return_tensors="pt", do_rescale=False)).pooler_output.detach()
        for j in range(4):
            if i == j:
                l2_image_matches[i][j] = ref_image
                lpips_image_matches[i][j] = ref_image
                clip_image_matches[i][j] = ref_image
                dino_image_matches[i][j] = ref_image

                l2_coord_matches[i][j] = np.array([90, 270, 0])
                lpips_coord_matches[i][j] = np.array([90, 270, 0])
                clip_coord_matches[i][j] = np.array([90, 270, 0])
                dino_coord_matches[i][j] = np.array([90, 270, 0])

                l2_scores[i][j] = 0.0
                lpips_scores[i][j] = 0.0
                clip_scores[i][j] = 0.0
                dino_scores[i][j] = 0.0
                continue

            image_row = brute_force_samples[j]
            clip_row = all_clip_features[j]
            dino_row = all_dino_features[j]

            l2_loss = calc_l2_losses(ref_image, image_row)
            l2_image_matches[i][j] = image_row[l2_loss.argmin()]
            l2_coord_matches[i][j] = camera_coords[l2_loss.argmin()]
            l2_scores[i][j] = l2_loss.min().item()

            lpips_loss = calc_lpips_losses(ref_image, image_row).flatten()
            lpips_image_matches[i][j] = image_row[lpips_loss.argmin()]
            lpips_coord_matches[i][j] = camera_coords[lpips_loss.argmin()]
            lpips_scores[i][j] = lpips_loss.min().item()

            clip_loss = 1 - torch.nn.functional.cosine_similarity(test_features_clip, clip_row, dim=1).clamp(0, 1)
            clip_image_matches[i][j] = image_row[clip_loss.argmin()]
            clip_coord_matches[i][j] = camera_coords[clip_loss.argmin()]
            clip_scores[i][j] = clip_loss.min().item()

            dino_loss = 1 - torch.nn.functional.cosine_similarity(test_features_dino, dino_row, dim=1).clamp(0, 1)
            dino_image_matches[i][j] = image_row[dino_loss.argmin()]
            dino_coord_matches[i][j] = camera_coords[dino_loss.argmin()]
            dino_scores[i][j] = dino_loss.min().item()

In [None]:
l2_coord_matches = np.stack(l2_coord_matches)
lpips_coord_matches = np.stack(lpips_coord_matches)
clip_coord_matches = np.stack(clip_coord_matches)
dino_coord_matches = np.stack(dino_coord_matches)

l2_scores = np.stack(l2_scores)
lpips_scores = np.stack(lpips_scores)
clip_scores = np.stack(clip_scores)
dino_scores = np.stack(dino_scores)

image_matches = [l2_image_matches, lpips_image_matches, clip_image_matches, dino_image_matches]
coord_matches = [l2_coord_matches, lpips_coord_matches, clip_coord_matches, dino_coord_matches]
loss_matches = [l2_scores, lpips_scores, clip_scores, dino_scores]

In [None]:
output_path = Path(f"data/openlrm/{trial_name}")

l2_path = output_path / "l2"
lpips_path = output_path / "lpips"
clip_path = output_path / "clip"
dino_path = output_path / "dino"

l2_path.mkdir(exist_ok=True, parents=True)
lpips_path.mkdir(exist_ok=True, parents=True)
clip_path.mkdir(exist_ok=True, parents=True)
dino_path.mkdir(exist_ok=True, parents=True)

for l, path in enumerate([l2_path, lpips_path, clip_path, dino_path]):
    images_path = path / "images"

    images_path.mkdir(exist_ok=True, parents=True)
    with open(path / 'coords.npy', 'wb') as f:
        np.save(f, coord_matches[l])
    with open(path / 'losses.npy', 'wb') as f:
        np.save(f, loss_matches[l])

    for i in range(4):
        for j in range(4):
            mediapy.write_image(images_path / f"viewpoint{i}_model{j}.png", image_matches[l][i][j])

originals_path = output_path / "original_viewpoints"
originals_path.mkdir(exist_ok=True, parents=True)
for i, ref_image in enumerate(ref_images):
    mediapy.write_image(originals_path / f"image{i}.png", ref_image)



originals_path = output_path / "original_viewpoints"
originals_path.mkdir(exist_ok=True, parents=True)

originals_path = output_path / "original_viewpoints"
originals_path.mkdir(exist_ok=True, parents=True)
for i, ref_image in enumerate(ref_images):
    mediapy.write_image(originals_path / f"image{i}.png", ref_image)

In [None]:
for i in range(4):
    mediapy.show_images(dino_image_matches[i], height=200)