In [34]:
from os.path import join
import os

import numpy as np
import pandas as pd
import pathlib
import torch
from pytorch_fid.fid_score import compute_statistics_of_path, calculate_frechet_distance, \
    calculate_activation_statistics
from pytorch_fid.inception import InceptionV3
import matplotlib.pyplot as plt
import seaborn as sns
from config import CYCLE_GAN_DIR_RESULTS, CALTECH_GRAY_DATASET_OUT, \
    CALTECH_NIR_DATASET_OUT

sns.set_theme()

plt.rcParams["figure.figsize"] = (16, 9)

In [35]:
num_avail_cpus = len(os.sched_getaffinity(0))
num_workers = min(num_avail_cpus, 8)
batch_size = 50

device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')

dims = 2048
block_idx = InceptionV3.BLOCK_INDEX_BY_DIM[dims]
model = InceptionV3([block_idx]).to(device)



In [36]:
def calculate_fid_for_file_list(result_files, ground_truth_stat):
    m1, s1 = ground_truth_stat
    m2, s2 = calculate_activation_statistics(sorted(result_files), model, batch_size=batch_size, device=device,
                                             num_workers=num_workers)
    return calculate_frechet_distance(m1, s1, m2, s2)

In [37]:
test_b_dir = join(CALTECH_GRAY_DATASET_OUT, "testA")
inception_values_files = join(test_b_dir, "inception-values.npz")
if os.path.exists(inception_values_files):
    m1, s1 = compute_statistics_of_path(inception_values_files, model, dims=dims, batch_size=batch_size, device=device,
                                        num_workers=num_workers)
else:
    m1, s1 = compute_statistics_of_path(test_b_dir, model, dims=dims, batch_size=batch_size, device=device,
                                        num_workers=num_workers)
    np.savez(inception_values_files, mu=m1, sigma=s1)



ValueError: batch_size should be a positive integer value, but got batch_size=0

In [None]:
NETWORK_NAME_MAP = {
    "nir_cyclegan_unet_ralsgan_sampling": "Verbessertes Sampling",
    "nir_cyclegan_unet_ralsgan_sampling_ssim": "SSIM",
    "nir_cyclegan_unet_ralsgan_sampling_ssim_ttur": "TTUR",
    "nir_cyclegan_unet_ralsgan_sampling_ssim_ttur_2_cyc": "Verbessertes Cycle Consistency",
    "nir_cyclegan_unet_ralsgan_sampling_ssim_ttur_2_cyc_spectral_normalization": "Spectral Normalization",
    "nir_cyclegan_unet_ralsgan_sampling_ssim_ttur_2_cyc_spectral_normalization_reduced_cycle": "Spectral Normalization (geringeres $\lambda$)",
    "cut": "CUT",
    "gray_cut": "CUT (Grauwert)",
    "nir_cyclegan_unet_ralsgan_sampling_ssim_ttur_2_cyc_spectral_normalization_reduced_cycle_detach": "Detach Fix"
}

comparable_networks = [
    "CUT",
    "CUT (Grauwert)"
]


def map_epoch_name(name):
    name = name.removeprefix("test_")
    if name == "latest":
        return 200
    return int(name)


def map_network_name(name):
    if name in NETWORK_NAME_MAP.keys():
        return NETWORK_NAME_MAP[name]
    return name


results = []

for network_name in os.listdir(CYCLE_GAN_DIR_RESULTS):
    if map_network_name(network_name) not in comparable_networks:
        continue

    network_dir = join(CYCLE_GAN_DIR_RESULTS, network_name)
    for epoch_result in os.listdir(network_dir):
        epoch_result_dir = pathlib.Path(join(network_dir, epoch_result, "images"))

        if "cut" in network_name:
            glob = epoch_result_dir.glob("fake_B/*.png")
        else:
            glob = epoch_result_dir.glob("*_fake.png")
        fid = calculate_fid_for_file_list(glob, (m1, s1))
        results.append([map_network_name(network_name), map_epoch_name(epoch_result), fid])

df = pd.DataFrame(results, columns=["Network", "Epoch", "FID"])
df

In [None]:
df.groupby(by=["Network", "Epoch"]).min()

In [None]:
df = df.sort_values(by=["Network", "Epoch"])
df

In [None]:
epoch_per_network_fid = df[df["Network"].isin(comparable_networks)].groupby(by=["Epoch", "Network"]).min()
epoch_per_network_fid = epoch_per_network_fid.sort_index()
epoch_per_network_fid: pd.DataFrame = epoch_per_network_fid.reset_index()
epoch_per_network_fid

In [None]:
epoch_per_network_fid = epoch_per_network_fid.rename(columns={"Network": "Netzwerk", "Epoch": "Epoche"})
plt.title("FID der Test Ergebnisse pro Epoche")
sns.barplot(x="Epoche", y="FID", hue="Netzwerk", data=epoch_per_network_fid, hue_order=comparable_networks)

In [None]:
x = np.linspace(0, 200, num=400)
y = np.minimum(-1 * x + 200, 100)

ax: plt.Axes = sns.lineplot(x=x, y=y)
ax.set_xlabel("Epochen")
ax.set_ylabel("$\lambda$")