In [None]:
from os.path import join

import matplotlib.pyplot as plt
import os
import pandas as pd
import seaborn as sns
import tqdm

from config import CYCLE_GAN_DIR_RESULTS, SERENGETI_NIR_INCANDESCENT_DATASET_OUT, CALTECH_NIR_DATASET_OUT, \
    CALTECH_GRAY_DATASET_OUT
from evaluation.fid import calculate_fid

sns.set_theme()

pd.set_option('display.max_rows', 500)
plt.rcParams["figure.figsize"] = (16, 9)

In [None]:
CYCLE_GAN_RESULT_MATCHER = "*_fake.png"
CUT_RESULT_MATCHER = "fake_B/*.png"
DATASET_TESTB_MATCHER = "*.jpg"

NETWORK_NAME_MAP = {
    "CycleGAN (Incandescent)": ("cycle_gan_nir_incadescent_caltech", CALTECH_NIR_DATASET_OUT),
    "CycleGAN (Random)": (
    "nir_cyclegan_unet_ralsgan_sampling_ssim_ttur_2_cyc_spectral_normalization_reduced_cycle_detach",
    CALTECH_NIR_DATASET_OUT),
    "CycleGAN (Grayscale)": ("gray_cycle_gan", CALTECH_GRAY_DATASET_OUT),
    "CUT (Grayscale)": ("gray_cut", CALTECH_GRAY_DATASET_OUT),
    "CycleGAN (3e-5,9e-4)": ("cycle_gan_serengeti_inc_0_00003_0_0009", SERENGETI_NIR_INCANDESCENT_DATASET_OUT),
    "CycleGAN (3e-5,9e-5)": ("cycle_gan_serengeti_inc_0_00003_0_00009", SERENGETI_NIR_INCANDESCENT_DATASET_OUT),
    "CycleGAN (1.5e-5,4.5e-5)": ("cycle_gan_serengeti_inc_0_000015_0_000045", SERENGETI_NIR_INCANDESCENT_DATASET_OUT),
    "CycleGAN (3e-6,9e-5)": ("cycle_gan_serengeti_inc_0_000003_0_00009", SERENGETI_NIR_INCANDESCENT_DATASET_OUT),
    "CycleGAN (3e-6,9e-6)": ("cycle_gan_serengeti_inc_0_000003_0_000009", SERENGETI_NIR_INCANDESCENT_DATASET_OUT),
    "CUT (2e-3)": ("cut_serengeti_inc_0_0002", SERENGETI_NIR_INCANDESCENT_DATASET_OUT),
    "CUT (2e-4)": ("cut_serengeti_inc_0_00002", SERENGETI_NIR_INCANDESCENT_DATASET_OUT),
    "CUT (2e-5)": ("cut_serengeti_inc_0_000002", SERENGETI_NIR_INCANDESCENT_DATASET_OUT),
    "CUT (2e-6)": ("cut_serengeti_inc_0_0000002", SERENGETI_NIR_INCANDESCENT_DATASET_OUT)
}


def map_epoch_name(name):
    name = name.removeprefix("test_")
    return int(name)


def map_network_name(name):
    if name in NETWORK_NAME_MAP.keys():
        return NETWORK_NAME_MAP[name]
    return name


results = []


def get_matcher(network_name):
    if "cut" in network_name:
        return CUT_RESULT_MATCHER
    else:
        return CYCLE_GAN_RESULT_MATCHER


for network_name, (network_dir_name, dataset_path) in tqdm.tqdm(NETWORK_NAME_MAP.items(), "Processing networks"):
    network_dir = join(CYCLE_GAN_DIR_RESULTS, network_dir_name)
    dataset_b_dir = join(dataset_path, "testB")
    for epoch_result in tqdm.tqdm(os.listdir(network_dir), f"Processing epochs of network \"{network_name}\""):
        epoch_result_dir = join(network_dir, epoch_result, "images")
        matcher = get_matcher(network_dir_name)

        fid = calculate_fid(dataset_b_dir, epoch_result_dir, DATASET_TESTB_MATCHER, matcher)

        results.append([network_name, map_epoch_name(epoch_result), fid])

df = pd.DataFrame(results, columns=["network", "epoch", "FID"])
df

In [None]:
df.groupby(by=["network", "epoch"]).min()

In [None]:
df = df.sort_values(by=["network", "epoch"])
df

In [None]:
def plot_fid(df, networks):
    epoch_per_network_fid = df[df["network"].isin(networks)].groupby(by=["epoch", "network"]).min()
    epoch_per_network_fid = epoch_per_network_fid.sort_index()
    epoch_per_network_fid: pd.DataFrame = epoch_per_network_fid.reset_index()
    display(epoch_per_network_fid.sort_values("FID"))
    plt.title("FID per network per epoch")
    sns.barplot(x="epoch", y="FID", hue="network", data=epoch_per_network_fid, hue_order=networks)

In [None]:
plot_fid(df, [
    "CycleGAN (Random)", "CycleGAN (Incandescent)"
])

In [None]:
plot_fid(df, [
    "CUT (night $\\rightarrow$ day)",
    "CycleGAN (night $\\rightarrow$ day)"
])

In [None]:
plot_fid(df, [
    "CUT (Grayscale)",
    "CycleGAN (Grayscale)"
])

In [None]:
plot_fid(df, [
    "CUT (2e-3)", "CUT (2e-4)", "CUT (2e-5)", "CUT (2e-6)"
])

In [None]:
plot_fid(df, [
    "CycleGAN (3e-5,9e-4)",
    "CycleGAN (3e-5,9e-5)",
    "CycleGAN (1.5e-5,4.5e-5)",
    "CycleGAN (3e-6,9e-5)",
    "CycleGAN (3e-6,9e-6)"
])

In [None]:
plot_fid(df, [
    "CycleGAN (1.5e-5,4.5e-5)",
    "CUT (2e-5)"
])