In [None]:
import os
import sys
import torch
import random
import imageio
import numpy as np
import matplotlib.pyplot as plt
from pathlib import Path

parent_dir = os.path.dirname(os.path.dirname(os.path.realpath("__file__")))
sys.path.insert(0, parent_dir)

from face_lib.datasets import IJBDataset, IJBCTest
from face_lib.utils import cfg
from face_lib.evaluation.feature_extractors import (
    extract_features_head,
    extract_features_gan,
    extract_features_scale,
)
from face_lib import models as mlib, utils
from face_lib.utils.imageprocessing import preprocess
from face_lib.evaluation.distance_uncertainty_funcs import harmonic_mean

In [None]:
device_id = 0

uncertainty_type = "head"
# uncertainty_type = "scale"

config_path = "../configs/models/iresnet_ms1m_pfe_normalized.yaml"
# config_path = "../configs/scale/02_sigm_mul_coef_selection/32.yaml"

checkpoint_path = "/gpfs/data/gpfs0/k.fedyanin/space/models/pfe/normalized_pfe/sota.pth"
# checkpoint_path = "/gpfs/data/gpfs0/k.fedyanin/space/models/scale/02_sigm_mul_selection/32/checkpoint.pth"

dataset_path = "/gpfs/gpfs0/k.fedyanin/space/IJB/aligned_data_for_fusion/small"
protocol_path = "/gpfs/gpfs0/k.fedyanin/space/IJB/IJB-C/protocols/archive"
# discriminator_path = "/gpfs/data/gpfs0/k.fedyanin/space/GAN/stylegan.pth"
discriminator_path = None
batch_size = 4

In [None]:
device = torch.device("cuda:" + str(device_id))

model_args = cfg.load_config(config_path)
if uncertainty_type == "head":
    backbone = mlib.model_dict[model_args.backbone["name"]](
        **utils.pop_element(model_args.backbone, "name")
    )
    head = mlib.heads[model_args.head.name](
        **utils.pop_element(model_args.head, "name")
    )
elif uncertainty_type == "scale":
    backbone = mlib.model_dict[model_args.backbone["name"]](
        **utils.pop_element(model_args.backbone, "name")
    )
    head = mlib.scale_predictors[model_args.scale_predictor.name](
        **utils.pop_element(model_args.scale_predictor, "name")
    )
else:
    raise RuntimeError("Choose the right uncertainty_type")

checkpoint = torch.load(checkpoint_path, map_location=device)
backbone.load_state_dict(checkpoint["backbone"])

if uncertainty_type == "head":
    head.load_state_dict(checkpoint["head"])
elif uncertainty_type == "scale":
    head.load_state_dict(checkpoint["scale_predictor"])

backbone, head = backbone.eval().to(device), head.eval().to(device)

discriminator = None
if discriminator_path:
    discriminator = mlib.StyleGanDiscriminator()
    discriminator.load_state_dict(torch.load(discriminator_path)["d"])
    discriminator.eval().to(device)

In [None]:
testset = IJBDataset(dataset_path)
tester = IJBCTest(testset["abspath"].values)
tester.init_proto(protocol_path)

In [None]:
proc_func = lambda images: preprocess(images, [112, 112], is_training=False)

In [None]:
if uncertainty_type == "head":
    with torch.no_grad():
        mu, sigma_sq = extract_features_head(
            backbone,
            head,
            tester.image_paths,
            batch_size,
            proc_func=proc_func,
            verbose=False,
            device=device,
        )
elif uncertainty_type == "gan":
    with torch.no_grad():
        mu, sigma_sq = extract_features_gan(
            backbone,
            discriminator,
            tester.image_paths,
            batch_size,
            proc_func=proc_func,
            verbose=False,
            device=device,
        )
elif uncertainty_type == "scale":
    with torch.no_grad():
        mu, sigma_sq = extract_features_scale(
            backbone,
            head,
            tester.image_paths,
            batch_size,
            proc_func=proc_func,
            verbose=False,
            device=device,
        )

In [None]:
uncertainties = harmonic_mean(sigma_sq)
# uncertainties = np.arange(len(tester.image_paths))
# uncertainties = sigma_sq[:, 0]

In [None]:
indices = np.argsort(uncertainties)
new_mus = mu[indices]
new_sigmas = sigma_sq[indices]
paths = tester.image_paths[indices]

In [None]:
def show_picture(path, ax):
    #     print(f"Path : {path}")
    pic = imageio.imread(path)
    ax.imshow(pic)


def show_pics_quantiles(sorted_paths, n_groups=10, n_pics=5, save_path=None):
    fig, axes = plt.subplots(n_groups, n_pics, figsize=(30, 30))

    for quantile_idx in range(n_groups):
        left_idx = int(quantile_idx / n_groups * len(sorted_paths))
        right_idx = int((quantile_idx + 1) / n_groups * len(sorted_paths))

        picture_paths = random.sample(list(sorted_paths[left_idx:right_idx]), k=n_pics)
        #         print(picture_paths, left_idx, right_idx)
        for pic_path, ax in zip(picture_paths, axes[quantile_idx]):
            #             print (pic_path.shape, pic_path[0])
            show_picture(pic_path, ax)

    names = [
        str(quantile_idx / n_groups * 100)
        + "-"
        + str((quantile_idx + 1) / n_groups * 100)
        + "%"
        for quantile_idx in range(n_groups)
    ]
    pad = 20
    for ax, row in zip(axes[:, 0], names):
        ax.annotate(
            row,
            xy=(0, 0.5),
            xytext=(-ax.yaxis.labelpad - pad, 0),
            xycoords=ax.yaxis.label,
            textcoords="offset points",
            size="large",
            ha="right",
            va="center",
        )

    if save_path:
        plt.savefig(save_path, dpi=400)
    plt.show()

In [None]:
for idx in range(5):
    show_pics_quantiles(
        paths,
        n_groups=10,
        n_pics=10,
        save_path=f"/beegfs/home/r.kail/faces/figures/14_sorted_faces/pfe/{idx}.pdf",
    )