In [2]:
import os
import json
import numpy as np


In [None]:
import tabulate
from collections import OrderedDict

experiments = {
    "plucker": "release-2gpus-b8-s1-80k-plucker-none",
    "gta": "release-2gpus-b8-s1-80k-none-gta",
    "prope": "release-2gpus-b8-s1-80k-none-prope",
}

variations = {
    "zoom1x": "eval-zoom1x",
    "zoom3x": "eval-zoom3x",
    "zoom5x": "eval-zoom5x",
}

psnrs = {k: [] for k in experiments.keys()}
ssims = {k: [] for k in experiments.keys()}
lpipses = {k: [] for k in experiments.keys()}

for exp_name, exp_dir in experiments.items():
    for variation_name, variation_dir in variations.items():
        fp = os.path.join("../results", exp_dir, variation_dir, "metrics.json")
        assert os.path.exists(fp)
        with open(fp, "r") as f:
            metrics = json.load(f)
            psnrs[exp_name].append(metrics["psnr"])
            ssims[exp_name].append(metrics["ssim"])
            lpipses[exp_name].append(metrics["lpips"])

data = OrderedDict()
data["PSNR"] = list(variations.keys())
data.update({k: v for k, v in psnrs.items()})
print(tabulate.tabulate(data, headers="keys", tablefmt="grid"))

In [None]:
import pandas as pd
import matplotlib.pyplot as plt

plt.figure(figsize=(20, 4))
plt.rcParams.update({'font.size': 16})
plt.rcParams['font.family'] = 'sans-serif'

zoom_factors = list(variations.keys())

for imetric, metric in enumerate(["psnr", "ssim", "lpips"]):
    if metric == "psnr":
        data = psnrs
    elif metric == "ssim":
        data = ssims
    elif metric == "lpips":
        data = lpipses
    df = pd.DataFrame(data, index=zoom_factors)

    plt.subplot(1, 3, imetric + 1)
    colors = plt.cm.brg(np.linspace(0, 0.85, len(df.columns)))  # not bad
    colors = colors[::-1]
    for i, col in enumerate(df.columns):
        plt.plot(df.index, df[col], marker="o", label=col, color=colors[i % len(colors)])    
    plt.ylabel(metric.upper() + "↓" if metric == "lpips" else metric.upper() + "↑")
    plt.xticks(zoom_factors)
    plt.grid(True, axis="both", alpha=0.3)
    plt.gca().spines["top"].set_visible(False)
    plt.gca().spines["right"].set_visible(False)

plt.tight_layout()
plt.legend(bbox_to_anchor=(-0.7, +1.20), loc="upper center", ncol=len(df.columns))
plt.show()
# plt.savefig("../figures/analysis_focal.png")