In [1]:
%cd ..

/home/jovyan/HyperDomainNet


In [2]:
import clip
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np

from tqdm.auto import tqdm
from omegaconf import OmegaConf
from core.utils.class_registry import ClassRegistry
from core.utils.common import load_clip, mixing_noise
from core.utils.text_templates import imagenet_templates, imagenet_templates_small
from core.utils.example_utils import Inferencer
from core.utils.loss_utils import get_tril_elements_mask

In [3]:
class Evaluator:
    def __init__(self, visual_encoder, device, bs=12, data_size=500):
        self.device = device
        self.batch_size = bs
        self.data_size = data_size
        self.clip_models = {
            visual_encoder: load_clip(visual_encoder, device)
        }

    @torch.no_grad()
    def _encode_text(
        self, clip_model: nn.Module, text: str, templates=imagenet_templates
    ):
        tokens = clip.tokenize(t.format(text) for t in templates).to(self.device)
        text_features = clip_model.encode_text(tokens).detach()
        text_features /= text_features.norm(dim=-1, keepdim=True)
        return text_features

    @torch.no_grad()
    def _encode_image(self, clip_model: nn.Module, preprocess, imgs: torch.Tensor):
        images = preprocess(imgs).to(self.device)
        image_features = clip_model.encode_image(images).detach()
        image_features /= image_features.clone().norm(dim=-1, keepdim=True)
        return image_features

    def _mean_cosine_sim(self, imgs_encoded: torch.Tensor, mean_vector: torch.Tensor):
        return (imgs_encoded.unsqueeze(1) * mean_vector).sum(dim=-1).mean().item()

    def _std_cosine_sim(self, imgs_encoded: torch.Tensor, mean_vector: torch.Tensor):
        return nn.CosineSimilarity()(imgs_encoded, mean_vector).std().item()

    def _diversity_from_embeddings_pairwise_cosines(self, imgs_encoded: torch.Tensor):
        data = (imgs_encoded @ imgs_encoded.T).cpu().numpy()
        mask = get_tril_elements_mask(data.shape[0])
        return np.mean(1 - data[mask])

    @torch.no_grad()
    def _generate_data(
        self,
        clip_model,
        preprocess,
        model
    ):
        answer = []

        for idx in tqdm(range(self.data_size // self.batch_size)):
            sample_z = mixing_noise(
                self.batch_size,
                512,
                0,
                self.device
            )

            _, trg_imgs = model(sample_z)
            trg_imgs = trg_imgs.detach()
            
            image_features = self._encode_image(clip_model, preprocess, trg_imgs)
            answer.append(image_features)

        return torch.cat(answer, dim=0)

    def get_metrics(
        self, model, text_description
    ):

        model.eval()
        metrics = {}
        
        for key, (clip_model, preprocess) in self.clip_models.items():
            domain_mean_vector = self._encode_text(clip_model, text_description).unsqueeze(0)
            imgs_encoded = self._generate_data(
                clip_model,
                preprocess,
                model
            )

            key_quality = f"quality/{text_description}/{key.replace('/', '-')}"
            key_diversity = f"diversity/{text_description}/{key.replace('/', '-')}"

            metrics[key_quality] = self._mean_cosine_sim(imgs_encoded, domain_mean_vector)
            metrics[key_diversity] = self._diversity_from_embeddings_pairwise_cosines(imgs_encoded)

        return metrics

In [5]:
device = 'cuda:0'

ckpt_path = 'td_checkpoints/td_anime.pt'
ckpt = torch.load(ckpt_path, map_location='cpu')

model = Inferencer(ckpt, device)

In [6]:
evaluator = Evaluator('ViT-B/16', device)

In [7]:
metrics = evaluator.get_metrics(model, 'Anime Painting')

  0%|          | 0/41 [00:00<?, ?it/s]

In [8]:
metrics

{'quality/Anime Painting/ViT-B-16': 0.2890625,
 'diversity/Anime Painting/ViT-B-16': 0.2585}