In [None]:
from config import CALTECH_GRAY_DATASET_OUT, CALTECH_NIR_DATASET_OUT
%flow mode reactive

import pathlib
from os.path import join
from pathlib import PosixPath, Path
from typing import List

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
from PIL import Image

from kornia.filters import sobel
from nircoloring.evaluation.fid import calculate_fid_from_images, NumpyDataset

sns.set_theme()

In [None]:
def is_image_file(file: PosixPath):
    if not file.is_file():
        return False

    lowercase_filename = file.name.lower()
    return lowercase_filename.endswith(".jpg") or lowercase_filename.endswith(".png")


def list_image_files_from_directory(directory, matcher="*") -> list[str]:
    path = pathlib.Path(directory)
    all_filenames = path.glob(matcher)
    all_image_filenames = filter(is_image_file, all_filenames)
    return sorted(list(all_image_filenames))


def load_images(all_image_filenames, load_size) -> list[np.array]:
    def read_image(filename):
        image = Image.open(filename)
        image.thumbnail(load_size, Image.Resampling.LANCZOS)
        image = np.asarray(image)
        return image

    all_images = map(read_image, all_image_filenames)
    return list(all_images)


def load_npz_images(file, load_size) -> list[np.array]:
    images = []
    for image in np.load(file)["arr_0"]:
        pil_image = Image.fromarray(image)
        pil_image.thumbnail(load_size, Image.Resampling.LANCZOS)
        resized_image = np.array(pil_image)

        images.append(resized_image)

    return images


class Result:
    def __init__(self, directory, title, load_size=(64, 64), filename_matcher="*"):
        self.title = title
        self.directory = directory
        self.image_filenames = list_image_files_from_directory(directory, filename_matcher)
        self.images = None
        self.fid = None
        self.load_size = load_size

    def load_images(self) -> List[np.array]:
        if self.images is not None:
            return self.images

        self.images = load_images(self.image_filenames, self.load_size)

        return self.images


class EvaluationResult(Result):
    def __init__(self, directory, title, fid_reference: Result, **args):
        super().__init__(directory, title, **args)
        self.fid_reference = fid_reference
        self.fid = None

    def load_fid(self):
        if self.fid is not None:
            return self.fid

        self.fid = calculate_fid_from_images(
            self.directory, self.load_images(),
            self.fid_reference.directory, self.fid_reference.load_images()
        )

        return self.fid


class NpzEvaluationResult(EvaluationResult):
    def __init__(self, file, **args):
        parent_dir = Path(file).parent.absolute()
        self.file = file
        super().__init__(parent_dir, **args)

    def load_images(self) -> List[np.array]:
        if self.images is not None:
            return self.images

        self.images = load_npz_images(self.file, self.load_size)

        return self.images

In [None]:
def plot_grid(results_to_plot: list[Result], columns, rows=4, column_titles=None):
    image_count_to_plot = min(columns * rows, len(results_to_plot[0].load_images()))

    column_size = len(results_to_plot)

    image_columns = column_size * columns

    fig, axes = plt.subplots(nrows=rows, ncols=image_columns, figsize=(image_columns * 2, rows * 2))

    axes_matrix = axes.reshape((rows * columns, column_size)).T

    for result, column_axis in zip(results_to_plot, axes_matrix):
        for image, ax in zip(result.load_images()[:image_count_to_plot], column_axis):
            if len(image.shape) == 2:
                ax.imshow(image, cmap='gray')
            else:
                ax.imshow(image)

            ax.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])

    if column_titles is None:
        column_titles = [result.title for result in results_to_plot]

    assert len(column_titles) == column_size

    for title, ax in zip(column_titles * columns, axes[0]):
        ax.set_title(title)

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

In [None]:
def plot_diff_heatmap(first: Result, second: Result, columns, rows=4, column_titles=None, cbar_with = 0.08, intensity=False, title=None):
    image_count_to_plot = min(columns * rows, len(first.load_images()))
    column_size = 3

    image_columns = column_size * columns

    fig, axes = plt.subplots(nrows=rows, ncols=image_columns + 1, figsize=(image_columns * 2 + cbar_with, rows * 2 + 0.5), gridspec_kw={'width_ratios': [1.] * image_columns + [cbar_with]})

    fig.suptitle(title)
    cbar_axes = axes[:, -1]

    axes_matrix = axes[: , :image_columns].reshape((rows * columns, column_size)).T

    for image, ax in zip(first.load_images()[:image_count_to_plot], axes_matrix[0]):
        if len(image.shape) == 2:
            ax.imshow(image, cmap='gray')
        else:
            ax.imshow(image)

        ax.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])

    for image, ax in zip(second.load_images()[:image_count_to_plot], axes_matrix[1]):
        if len(image.shape) == 2:
            ax.imshow(image, cmap='gray')
        else:
            ax.imshow(image)

        ax.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])


    first_images = np.array(first.load_images()[:image_count_to_plot])
    second_images = np.array(second.load_images()[:image_count_to_plot])

    if intensity:
        first_images = np.average(first_images, axis=3)
        second_images = np.average(second_images, axis=3)

    images_diff = first_images - second_images
    images_diff = np.abs(images_diff)

    if not intensity:
        images_diff = np.sum(images_diff, axis=3)

    vmin = np.min(images_diff)
    vmax = np.percentile(images_diff, 99.9) #np.max(images_diff)

    for i, (diff, ax) in enumerate(zip(images_diff, axes_matrix[2])):
        is_last_column = (i % columns == columns - 1)
        current_cbar = cbar_axes[int(i / columns)]

        sns.heatmap(diff, ax=ax, vmin=vmin, vmax=vmax, cbar=is_last_column, cbar_ax=current_cbar, square=True)
        ax.axis('off')
        ax.set_xticklabels([])
        ax.set_yticklabels([])

    if column_titles is None:
        column_titles = [first.title, second.title, "Diff"]

    assert len(column_titles) == column_size

    for title, ax in zip(column_titles * columns, axes[0]):
        ax.set_title(title)

    plt.show()

In [None]:
def result_fid_to_df(results_to_plot: list[EvaluationResult]):
    data = {
        "FID": [result.load_fid() for result in results_to_plot]
    }
    return pd.DataFrame(data, index=[result.title for result in results_to_plot])


def plot_fid_bars(df):
    df.plot.bar()
    plt.xticks(rotation=30, ha='right')

In [None]:
ROOT_DIR = "../../"
EGSDE_RUNS_BASE_PATH = join(ROOT_DIR, "../EGSDE/runs/nir2rgb/")
SERENGETI_NIR_INCANDESCENT_DATASET = join(ROOT_DIR, "cycle-gan/datasets/serengeti-incandescent/")

nirReference = Result(join(SERENGETI_NIR_INCANDESCENT_DATASET, "testA"), "NIR")
nirReference_128x128 = Result(join(SERENGETI_NIR_INCANDESCENT_DATASET, "testA"), "NIR", load_size=(128, 128))
rgbReference = Result(join(SERENGETI_NIR_INCANDESCENT_DATASET, "testB"), "RGB")
rgbReference_128x128 = Result(join(SERENGETI_NIR_INCANDESCENT_DATASET, "testB"), "RGB", load_size=(128, 128))
caltech_gray_rgbReference_128x128 = Result(join(CALTECH_GRAY_DATASET_OUT, "testB"), "RGB", load_size=(128, 128))
caltech_gray_grayReference_128x128 = Result(join(CALTECH_GRAY_DATASET_OUT, "testA"), "RGB", load_size=(128, 128))
caltech_nir_nirReference_128x128 = Result(join(CALTECH_NIR_DATASET_OUT, "testA"), "NIR", load_size=(128, 128))
caltech_nir_rgbReference_128x128 = Result(join(CALTECH_NIR_DATASET_OUT, "testB"), "RGB", load_size=(128, 128))

downN_00 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "downN_00"), "$\mathrm{downN}=0$", fid_reference=rgbReference)
downN_02 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "downN_02"), "$\mathrm{downN}=2$", fid_reference=rgbReference)
downN_32 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "downN_32"), "$\mathrm{downN}=32$", fid_reference=rgbReference)
li_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "li_0"), "$\lambda_i=0$", fid_reference=rgbReference)
ls_0_li_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "ls_0_li_0"), "$\lambda_i=0, \lambda_s=0$",
                             fid_reference=rgbReference)
ls_0_li_500 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "ls_0_li_500"), "$\lambda_i=500, \lambda_s=0$",
                               fid_reference=rgbReference)
initial_random = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "initial_random"),
                                  "$t=500, y \sim \mathcal{N}(0,\mathbf{I})$", fid_reference=rgbReference)
initial_random_li_0_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "initial_random_li_0_ls_0"),
                                            "$t=500$\n$y \sim \mathcal{N}(0,\mathbf{I})$\n$\lambda_i=0, \lambda_s=0$",
                                            fid_reference=rgbReference)
initial_random_li_0_ls_0_t_4000 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "initial_random_li_0_ls_0_t_4000"),
                                                   "$t=4000$\n$y \sim \mathcal{N}(0,\mathbf{I})$\n$\lambda_i=0, \lambda_s=0$",
                                                   fid_reference=rgbReference)
t4000 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "initial_random_li_0_ls_0_t_4000"), "$t=4000$",
                         fid_reference=rgbReference)

highpass_t_2000 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_t_2000_li_02"),
                                   "$t=2000, \lambda_i=2$,\n Highpass",
                                   fid_reference=rgbReference)

highpass_t_2000_li_500 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_t_2000_li_500"),
                                          "$t=2000, \lambda_i=500$,\n Highpass", fid_reference=rgbReference)
highpass_t_2000_li_120 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_t_2000_li_120"),
                                          "$t=2000, \lambda_i=120$,\n Highpass", fid_reference=rgbReference)
highpass_t_2000_li_120_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_t_2000_li_120_ls_0"),
                                               "$t=2000$\n$\lambda_i=120, \lambda_s=0$,\n Highpass",
                                               fid_reference=rgbReference)
highpass_mean_t_2000_li_120_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_mean_t_2000_li_120_ls_0"),
                                                    "$t=2000$\n$\lambda_i=120, \lambda_s=0$,\n Highpass & Mean",
                                                    fid_reference=rgbReference)
highpass_mean_t_2000_li_500_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_mean_t_2000_li_500_ls_0"),
                                                    "$t=2000$\n$\lambda_i=500, \lambda_s=0$,\n Highpass & Mean",
                                                    fid_reference=rgbReference)
mean_t_2000_li_500_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "mean_t_2000_li_500_ls_0"),
                                           "$t=2000$\n$\lambda_i=500, \lambda_s=0$,\n Mean",
                                           fid_reference=rgbReference)

downN_0_li_10_t_2000 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "downN_0_li_10_t_2000"),
                                        "$t=2000, \lambda_i=10$,\n Identity", fid_reference=rgbReference)

highpass_li_120_t_500_0_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_120_t_500_0"),
                                                   "128x128, \n$t=500, \lambda_i=120$, \n0",
                                                   fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_120_t_500_1_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_120_t_500_1"),
                                                   "128x128, \n$t=500, \lambda_i=120$, \n1",
                                                   fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_120_t_600_0_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_120_t_600_0"),
                                                   "128x128, \n$t=600, \lambda_i=120$, \n0",
                                                   fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_120_t_600_1_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_120_t_600_1"),
                                                   "128x128, \n$t=600, \lambda_i=120$, \n1",
                                                   fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_1000_t_500_log_0_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_1000_t_500_log_0"),
    "128x128, \n$t=500, \lambda_i=1000$, \nlog, 0", fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_1000_t_500_log_1_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_1000_t_500_log_1"),
    "128x128, \n$t=500, \lambda_i=1000$, \nlog, 1", fid_reference=rgbReference_128x128, load_size=(128, 128))

mean_li_120_t_500_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_mean_li_120_t_500"),
    "128x128, \n$t=500, \lambda_i=120$, \n Mean", fid_reference=rgbReference_128x128, load_size=(128, 128))

mean_li_120_t_500_128x128_png = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_mean_li_120_t_500_png"),
    "128x128, \n$t=500, \lambda_i=120$, \n Mean", fid_reference=rgbReference_128x128, load_size=(128, 128))

mean_lin_li_120_t_500_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_mean_lin_li_120_t_500"),
    "128x128, \n$t=500, \lambda_i=120$, \n Mean $\downarrow$", fid_reference=rgbReference_128x128, load_size=(128, 128))

mean_highpass_li_170_t_850_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_mean_highpass_li_170_t_850"),
                                                      "128x128, \n$t=850, \lambda_i=170$, \nHighpass $\leftarrow$ Mean",
                                                      fid_reference=rgbReference_128x128, load_size=(128, 128))

t_1000_random_start_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_t_1000_random_start"),
                                               "128x128, Unconditional", fid_reference=rgbReference_128x128,
                                               load_size=(128, 128))

t_1000_random_start_240_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_t_1000_random_start_240"),
                                                   "128x128, Unconditional", fid_reference=rgbReference_128x128,
                                                   load_size=(128, 128))

t_1000_random_start_240_gd_thv_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_t_1000_random_start_240_gd"),
    "128x128, Unconditional, THV", fid_reference=rgbReference_128x128,
    load_size=(128, 128))

t_1000_random_start_240_png_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_t_1000_random_start_240_png"),
    "128x128, Unconditional, PNG", fid_reference=rgbReference_128x128,
    load_size=(128, 128))

t_1000_random_start_240_randlike_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_t_1000_random_start_240_randlike"),
    "128x128, Unconditional\n randlike", fid_reference=rgbReference_128x128,
    load_size=(128, 128))

nir_out = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "nir_out"),
    "NIR Sampled", fid_reference=nirReference_128x128,
    load_size=(128, 128)
)

learn_sigma_out = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "learn_sigma_out"),
    "RGB Sampled (with learn_sigma)", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

sc_default = EvaluationResult(
    "../../../ilvr_adm/output/default",
    "SC Default", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

sc_alternative_matrix = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_matrix_alternative",
    "SC Alternative Matrix", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

sc_alternative_matrix_mask = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_matrix_alternative_mask",
    "SC Alt Matrix\n+ Mask", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

sc_approx_matrix = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_approx_matrix",
    "SC NIR Approx", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

sc_repaint = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_repaint",
    "SC Repaint", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)
sc_repaint_u_10 = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_repaint_u_10",
    "SC Repaint $U=10$", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)
sc_repaint_schedule = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_repaint_schedule",
    "SC Repaint Schedule", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)
sc_repaint_schedule_opt = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_repaint_schedule_opt",
    "SC Repaint Schedule Opt", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)
sc_mask_lin_dec = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_mask_lin_dec",
    "SC Mask Linear Dec", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

sc_two_step_test = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_two_step_test",
    "SC Two Step Test", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

sc_two_step = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_two_step",
    "SC Two Step", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

sc_hsv = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_hsv",
    "SC HSV", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)
sc_lab = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_lab",
    "SC LAB", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)
sc_cycle_gan = EvaluationResult(
    "../../../ilvr_adm/output/colorize_cycle_gan",
    "SC Cycle GAN", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)
sc_cycle_gan_lab = EvaluationResult(
    "../../../ilvr_adm/output/colorize_cycle_gan_lab",
    "SC Cycle GAN", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

caltech_nir_sc_approx_matrix = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_approx_matrix_caltech",
    "Diffusion Caltech 200000", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

caltech_nir_sc_approx_matrix_150000 = EvaluationResult(
    "../../../ilvr_adm/output/colorize_nir_approx_matrix_caltech_150000",
    "Diffusion Caltech 150000", fid_reference=rgbReference_128x128,
    load_size=(128, 128)
)

caltech_gray_128x128_mean_li_120 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "caltech_gray_128x128_mean_li_120"), load_size=(128, 128),
    fid_reference=caltech_gray_rgbReference_128x128, title="Caltech Gray Diffusion")

cycle_gan = EvaluationResult("../../cycle-gan/results/cycle_gan_serengeti_inc_0_000015_0_000045/test_200/images",
                             "CycleGAN", fid_reference=rgbReference, filename_matcher="*_fake.png")

cycle_gan_128x128 = EvaluationResult(
    "../../cycle-gan/results/cycle_gan_serengeti_inc_0_000015_0_000045/test_200/images",
    "CycleGAN", fid_reference=rgbReference_128x128, filename_matcher="*_fake.png", load_size=(128, 128))

caltech_nir_cycle_gan_128x128 = EvaluationResult(
    "../../cycle-gan/results/nir_cyclegan_unet_ralsgan_sampling_ssim_ttur_2_cyc_spectral_normalization_reduced_cycle_detach/test_200/images",
    "CycleGAN", fid_reference=rgbReference_128x128, filename_matcher="*_fake.png", load_size=(128, 128))

caltech_gray_cycle_gan_128x128 = EvaluationResult(
    "../../cycle-gan/results/gray_cycle_gan/test_200/images",
    "Caltech Gray CycleGAN", fid_reference=caltech_gray_rgbReference_128x128, filename_matcher="*_fake.png",
    load_size=(128, 128))

t_1000_random_start_240_gd_128x128 = NpzEvaluationResult(
    "../../../guided-diffusion/samples/128x128/samples_250x128x128x3.npz", load_size=(128, 128),
    fid_reference=rgbReference_128x128, title="128x128, Unconditional, GD")

caltech_nir_gd = NpzEvaluationResult(
    "../../../guided-diffusion/samples/128x128/caltech_samples_250_128x128x3.npz", load_size=(128, 128),
    fid_reference=caltech_nir_rgbReference_128x128, title="Unconditional Caltech GD")

In [None]:
plot_grid([caltech_nir_gd, caltech_nir_rgbReference_128x128], columns=4, column_titles=["Guided Diffusion", "Caltech RGB"])

In [None]:
result_fid_to_df([caltech_nir_gd])

In [None]:
plot_grid([caltech_nir_nirReference_128x128, caltech_nir_cycle_gan_128x128, caltech_nir_sc_approx_matrix, caltech_nir_sc_approx_matrix_150000], columns=4, rows=5, column_titles=["NIR", "CycleGAN", "Diffusion", "Diffusion 15000"])

In [None]:
result_fid_to_df([caltech_nir_cycle_gan_128x128, caltech_nir_sc_approx_matrix, caltech_nir_sc_approx_matrix_150000])

In [None]:
plot_grid([nirReference_128x128, cycle_gan_128x128, sc_cycle_gan, sc_approx_matrix], columns=1,rows=3, column_titles=["NIR", "CycleGAN", "Diffusion\nCycleGAN Input", "Diffusion\nNIR Input"])

In [None]:
result_fid_to_df([cycle_gan_128x128, sc_cycle_gan, sc_cycle_gan_lab, sc_approx_matrix])

In [None]:
plot_diff_heatmap(cycle_gan_128x128, sc_cycle_gan, columns=2, intensity=True, column_titles=["CycleGAN", "Diffusion via CycleGAN", "Diff"], title="Heatmap CycleGAN vs Diffusion via CycleGAN - Intensity")
plot_diff_heatmap(cycle_gan_128x128, sc_cycle_gan, columns=2, intensity=False, column_titles=["CycleGAN", "Diffusion via CycleGAN", "Diff"], title="Heatmap CycleGAN vs Diffusion via CycleGAN - Color")

In [None]:
plot_diff_heatmap(cycle_gan_128x128, sc_two_step_test, columns=2, rows=3, intensity=True, column_titles=["CycleGAN", "Two Step (2)", "Diff"], title="Heatmap CycleGAN vs Diffusion - Intensity")
plot_diff_heatmap(cycle_gan_128x128, sc_two_step_test, columns=2, rows=3, intensity=False, column_titles=["CycleGAN", "Two Step (2)", "Diff"], title="Heatmap CycleGAN vs Diffusion - Color")

In [None]:
plot_diff_heatmap(cycle_gan_128x128, sc_approx_matrix, columns=3, rows=3, intensity=True, column_titles=["CycleGAN", "Best Diffusion", "Diff"], title="Heatmap CycleGAN vs Diffusion - Intensity")
plot_diff_heatmap(cycle_gan_128x128, sc_approx_matrix, columns=3, rows=3, intensity=False, column_titles=["CycleGAN", "Best Diffusion", "Diff"], title="Heatmap CycleGAN vs Diffusion - Color")

In [None]:
plot_diff_heatmap(sc_approx_matrix, sc_two_step, columns=2, rows=3, intensity=False, column_titles=["Best", "Two Step\n(1. Variant)", "Diff"], title="Heatmap Two Step (1. Variant) - Color")


In [None]:
plot_diff_heatmap(sc_approx_matrix, sc_two_step, columns=2, rows=3, intensity=True, column_titles=["Best", "Two Step\n(1. Variant)", "Diff"], title="Heatmap Two Step (1. Variant) - Intensity")

In [None]:
plot_diff_heatmap(sc_approx_matrix, sc_two_step_test, columns=2, rows=3, intensity=False, column_titles=["Best", "Two Step\n(2. Variant)", "Diff"], title="Heatmap Two Step (2. Variant) - Color")

In [None]:
plot_diff_heatmap(sc_approx_matrix, sc_two_step_test, columns=2, rows=3, intensity=True, column_titles=["Best", "Two Step\n(2. Variant)", "Diff"], title="Heatmap Two Step (2. Variant) - Intensity")

In [None]:
plot_grid([nirReference_128x128, sc_approx_matrix, sc_hsv, sc_lab], columns=2, rows=2, column_titles=["NIR", "Best", "HSV", "LAB"])

In [None]:
result_fid_to_df([sc_approx_matrix, sc_hsv, sc_lab])

In [None]:
plot_grid([nirReference_128x128, sc_approx_matrix, sc_two_step, sc_two_step_test], columns=1, column_titles=["NIR", "Best", "1. Variant", "2. Variant"], rows=3)

In [None]:
result_fid_to_df([sc_approx_matrix, sc_two_step, sc_two_step_test])

In [None]:
plot_grid([nirReference_128x128, sc_alternative_matrix_mask, sc_mask_lin_dec], columns=2, rows=3, column_titles=["NIR", "Curent Best", "Mask Weight\nDecrease"])

In [None]:
result_fid_to_df([sc_alternative_matrix_mask, sc_mask_lin_dec])

In [None]:
plot_grid([nirReference_128x128, sc_alternative_matrix_mask, sc_approx_matrix], columns=2, rows=3, column_titles=["NIR", "Current Best", "NIR Approx"])

In [None]:
result_fid_to_df([sc_alternative_matrix_mask, sc_approx_matrix])

In [None]:
plot_grid([nirReference_128x128, sc_alternative_matrix_mask, sc_repaint_u_10, sc_repaint_schedule_opt], columns=2, rows=3, column_titles=["NIR", "Current Best", "Repaint (U)", "Repaint Schedule"])

In [None]:
result_fid_to_df([sc_alternative_matrix_mask, sc_repaint_u_10, sc_repaint_schedule_opt])

In [None]:
plot_grid([nirReference_128x128, sc_alternative_matrix_mask, sc_mask_lin_dec], columns=3, rows=6)

In [None]:
result_fid_to_df([sc_alternative_matrix_mask, sc_mask_lin_dec])

In [None]:
plot_grid([nirReference_128x128, sc_alternative_matrix_mask, sc_repaint_schedule, sc_repaint_schedule_opt, sc_repaint_u_10], columns=2, rows=6)

In [None]:
result_fid_to_df([sc_alternative_matrix_mask, sc_repaint_schedule, sc_repaint_schedule_opt, sc_repaint_u_10])

In [None]:
plot_grid([nirReference_128x128, sc_alternative_matrix_mask, sc_approx_matrix], columns=2, rows=3)

In [None]:
result_fid_to_df([sc_alternative_matrix_mask, sc_approx_matrix])

In [None]:
plot_grid([nirReference_128x128, sc_default, sc_alternative_matrix_mask], columns=2, column_titles=["NIR", "SC", "SC\nSample Last Row"], rows=3)

In [None]:
result_fid_to_df([sc_default, sc_alternative_matrix, sc_alternative_matrix_mask, sc_repaint, sc_repaint_u_10, cycle_gan_128x128])

In [None]:
plot_grid([rgbReference_128x128, learn_sigma_out, t_1000_random_start_240_png_128x128], columns=2)

In [None]:
result_fid_to_df([learn_sigma_out, t_1000_random_start_240_png_128x128])

In [None]:
plot_grid([nirReference_128x128, nir_out], columns=4, rows=7)

In [None]:
result_fid_to_df([nir_out])

In [None]:
plot_grid([nirReference_128x128, sc_default, mean_li_120_t_500_128x128_png, cycle_gan_128x128], columns=2, rows=3)

In [None]:
result_fid_to_df([sc_default, mean_li_120_t_500_128x128_png, cycle_gan_128x128])

In [None]:
plot_grid([caltech_gray_grayReference_128x128, caltech_gray_cycle_gan_128x128, caltech_gray_128x128_mean_li_120], columns=2,
          column_titles=["Gray", "CycleGAN", "Diffusion"], rows=3)

In [None]:
result_fid_to_df([caltech_gray_cycle_gan_128x128, caltech_gray_128x128_mean_li_120])

In [None]:
plot_grid([nirReference, mean_li_120_t_500_128x128_png, cycle_gan_128x128], columns=2, rows=3,
          column_titles=["NIR", "Diffusion", "CycleGAN"])

In [None]:
result_fid_to_df([mean_li_120_t_500_128x128_png, cycle_gan_128x128])

In [None]:
plot_grid([rgbReference_128x128, t_1000_random_start_240_128x128, t_1000_random_start_240_png_128x128],
          column_titles=["Real RGB", "Generated (JPEG)", "Generated (PNG)"], columns=2, rows=3)

In [None]:
result_fid_to_df([t_1000_random_start_240_128x128, t_1000_random_start_240_png_128x128])

In [None]:
plot_grid([rgbReference_128x128, cycle_gan_128x128, t_1000_random_start_240_png_128x128],
          column_titles=["Real RGB", "CycleGAN", "Diffusion 128x128"], columns=2, rows=3)

In [None]:
result_fid_to_df([cycle_gan_128x128, t_1000_random_start_240_png_128x128])

In [None]:
plot_grid([cycle_gan_128x128,
           t_1000_random_start_240_128x128, t_1000_random_start_240_randlike_128x128,
           t_1000_random_start_240_gd_128x128, t_1000_random_start_240_gd_thv_128x128,
           t_1000_random_start_240_png_128x128], columns=3, rows=8,
          column_titles=["CycleGAN", "EGSDE 240", "randlike", "GD 240", "GD 240 THV", "PNG"])

In [None]:
result_fid_to_df([
    cycle_gan_128x128, t_1000_random_start_128x128,
    t_1000_random_start_240_128x128,
    t_1000_random_start_240_randlike_128x128,
    t_1000_random_start_240_gd_128x128,
    t_1000_random_start_240_gd_thv_128x128,
    t_1000_random_start_240_png_128x128
])

In [None]:
plot_grid(
    [nirReference_128x128, highpass_li_120_t_500_0_128x128, mean_li_120_t_500_128x128, mean_lin_li_120_t_500_128x128,
     mean_highpass_li_170_t_850_128x128, cycle_gan_128x128],
    column_titles=["NIR", "Highpass Only", "Mean Only", "Mean $\downarrow$", r"Mean $\rightarrow$ Highpass",
                   "CycleGAN"],
    columns=1,
    rows=3)

In [None]:
result_fid_to_df([highpass_li_120_t_500_0_128x128, mean_li_120_t_500_128x128, mean_lin_li_120_t_500_128x128,
                  mean_highpass_li_170_t_850_128x128])

In [None]:
plot_grid([nirReference_128x128, highpass_li_120_t_500_0_128x128, mean_li_120_t_500_128x128, cycle_gan_128x128],
          columns=2)

In [None]:
result_fid_to_df([cycle_gan_128x128, highpass_li_120_t_500_0_128x128, mean_li_120_t_500_128x128])

In [None]:
plot_grid([nirReference, mean_t_2000_li_500_ls_0, cycle_gan], columns=2)

In [None]:
result_fid_to_df([mean_t_2000_li_500_ls_0, cycle_gan])

In [None]:
plot_grid(
    [nirReference, highpass_t_2000_li_120_ls_0, highpass_mean_t_2000_li_120_ls_0, highpass_mean_t_2000_li_500_ls_0,
     mean_t_2000_li_500_ls_0], columns=1)

In [None]:
result_fid_to_df([highpass_t_2000_li_120_ls_0, highpass_mean_t_2000_li_120_ls_0, highpass_mean_t_2000_li_500_ls_0,
                  mean_t_2000_li_500_ls_0])

In [None]:
plot_grid([nirReference_128x128, rgbReference_128x128, cycle_gan_128x128, t_1000_random_start_128x128,
           t_1000_random_start_240_128x128], columns=2,
          column_titles=["NIR", "Real RGB", "CycleGAN", "Diffusion 90", "Diffusion 240"], rows=20)

In [None]:
result_fid_to_df([cycle_gan_128x128, t_1000_random_start_128x128, t_1000_random_start_240_128x128])

In [None]:
plot_grid([nirReference_128x128, highpass_li_120_t_500_0_128x128, highpass_li_120_t_500_1_128x128,
           highpass_li_120_t_600_0_128x128, highpass_li_120_t_600_1_128x128], columns=1)

In [None]:
result_fid_to_df([highpass_li_120_t_500_0_128x128, highpass_li_120_t_500_1_128x128, highpass_li_120_t_600_0_128x128,
                  highpass_li_120_t_600_1_128x128])

In [None]:
plot_grid([nirReference_128x128, highpass_li_120_t_500_0_128x128, highpass_li_120_t_500_1_128x128,
           highpass_li_1000_t_500_log_0_128x128, highpass_li_1000_t_500_log_1_128x128], columns=1)

In [None]:
result_fid_to_df(
    [highpass_li_120_t_500_0_128x128, highpass_li_120_t_500_1_128x128, highpass_li_1000_t_500_log_0_128x128,
     highpass_li_1000_t_500_log_1_128x128])

In [None]:
plot_grid([nirReference, downN_0_li_10_t_2000, highpass_t_2000_li_120, highpass_t_2000_li_120_ls_0], columns=2, rows=2)

In [None]:
result_fid_to_df([downN_0_li_10_t_2000, highpass_t_2000_li_120, highpass_t_2000_li_120_ls_0])

In [None]:
plot_grid([nirReference, downN_0_li_10_t_2000, highpass_t_2000_li_500, highpass_t_2000_li_120], columns=2, rows=2)

In [None]:
result_fid_to_df([downN_0_li_10_t_2000, highpass_t_2000_li_500, highpass_t_2000_li_120])

In [None]:
plot_grid([nirReference, downN_00, downN_0_li_10_t_2000], columns=3, rows=3,
          column_titles=["NIR", "$t=500, \lambda_i=2$", "$t=2000, \lambda_i=10$"])

In [None]:
result_fid_to_df([downN_00, downN_0_li_10_t_2000])

In [None]:
result_fid_to_df([downN_32, downN_0_li_10_t_2000, highpass_t_2000_li_500, cycle_gan])

In [None]:
plot_grid([nirReference, cycle_gan, highpass_t_2000_li_500, downN_32], 2, rows=16)

In [None]:
result_fid_to_df([cycle_gan, highpass_t_2000_li_500])

In [None]:
plot_grid([nirReference, downN_32], 3)

In [None]:
plot_grid([nirReference, downN_32, ls_0_li_0], 2)

In [None]:
plot_grid([nirReference, downN_32, initial_random], 2)

In [None]:
plot_grid([nirReference, ls_0_li_500], columns=3)

In [None]:
plot_grid([nirReference, downN_00, downN_02, downN_32], columns=2)

In [None]:
plot_grid([nirReference, ls_0_li_0, li_0, downN_32], columns=1)
result_fid_to_df([ls_0_li_0, li_0, downN_32])

In [None]:
plot_grid([initial_random_li_0_ls_0, initial_random_li_0_ls_0_t_4000], columns=4, rows=3, )

In [None]:
result_fid_to_df([initial_random_li_0_ls_0, initial_random_li_0_ls_0_t_4000])

In [None]:
plot_grid([cycle_gan, initial_random_li_0_ls_0_t_4000], columns=3, rows=3, column_titles=["CycleGAN", "Diffusion"])

In [None]:
result_fid_to_df([cycle_gan, initial_random_li_0_ls_0_t_4000])

In [None]:
plot_grid([nirReference, cycle_gan, highpass_t_2000_li_500], columns=2, rows=2,
          column_titles=["NIR", "CycleGAN", "Diffusion"])

In [None]:
result_fid_to_df([cycle_gan, highpass_t_2000_li_500])

In [None]:
plot_grid([nirReference, t4000, initial_random_li_0_ls_0_t_4000], columns=2, rows=2,
          column_titles=["NIR", "$t=4000$\n$y \sim p_{M|0}(y\mid x_0)$\n$\lambda_s=500, \lambda_i=2$",
                         "$t=4000$\n$y \sim \mathcal{N}(0, \mathbf{I})$\n$\lambda_s=0, \lambda_i=0$"])

In [None]:
plot_grid([rgbReference, cycle_gan, initial_random_li_0_ls_0_t_4000], columns=2, rows=3,
          column_titles=["RGB", "CycleGAN", "Diffusion"])

In [None]:
plot_grid([rgbReference, cycle_gan, highpass_t_2000_li_120], columns=2, rows=3,
          column_titles=["RGB", "CycleGAN", "Diffusion"])

In [None]:
import torch
import torchvision.transforms as TF

sim = torch.nn.CosineSimilarity()


def band_pass_loss(x1, x2):
    dataset1 = NumpyDataset(x1.load_images()[:100])
    batch1 = torch.tensor(np.array([i.numpy() for i in dataset1]))
    grad1 = sobel(batch1)

    dataset2 = NumpyDataset(x2.load_images()[:100])
    batch2 = torch.tensor(np.array([i.numpy() for i in dataset2]))
    grad2 = sobel(batch2)

    return sim(grad1, grad2).mean()

In [None]:
band_pass_loss(nirReference, cycle_gan)

In [None]:
dataset1 = NumpyDataset(cycle_gan.load_images()[:100])
batch1 = torch.tensor(np.array([i.numpy() for i in dataset1]))
grad1 = sobel(batch1)

TF.ToPILImage()(grad1.numpy()[0])


In [None]:
band_pass_loss(nirReference, downN_32)

In [None]:
plot_grid([nirReference, cycle_gan, downN_32], columns=3)  #%%


def plot_grid(results_to_plot: list[Result], columns, rows=4, column_titles=None):
    image_count_to_plot = min(columns * rows, len(results_to_plot[0].load_images()))

    column_size = len(results_to_plot)

    image_columns = column_size * columns

    fig, axes = plt.subplots(nrows=rows, ncols=image_columns, figsize=(image_columns * 2, rows * 2))

    axes_matrix = axes.reshape((rows * columns, column_size)).T

    for result, column_axis in zip(results_to_plot, axes_matrix):
        for image, ax in zip(result.load_images()[:image_count_to_plot], column_axis):
            ax.imshow(image)
            ax.axis('off')
            ax.set_xticklabels([])
            ax.set_yticklabels([])

    if column_titles is None:
        column_titles = [result.title for result in results_to_plot]

    assert len(column_titles) == column_size

    for title, ax in zip(column_titles * columns, axes[0]):
        ax.set_title(title)

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()

In [None]:
def result_fid_to_df(results_to_plot: list[EvaluationResult]):
    data = {
        "FID": [result.load_fid() for result in results_to_plot]
    }
    return pd.DataFrame(data, index=[result.title for result in results_to_plot])


def plot_fid_bars(df):
    df.plot.bar()
    plt.xticks(rotation=30, ha='right')

In [None]:
ROOT_DIR = "../../"
EGSDE_RUNS_BASE_PATH = join(ROOT_DIR, "../EGSDE/runs/nir2rgb/")
SERENGETI_NIR_INCANDESCENT_DATASET = join(ROOT_DIR, "cycle-gan/datasets/serengeti-incandescent/")

nirReference = Result(join(SERENGETI_NIR_INCANDESCENT_DATASET, "testA"), "NIR")
nirReference_128x128 = Result(join(SERENGETI_NIR_INCANDESCENT_DATASET, "testA"), "NIR", load_size=(128, 128))
rgbReference = Result(join(SERENGETI_NIR_INCANDESCENT_DATASET, "testB"), "RGB")
rgbReference_128x128 = Result(join(SERENGETI_NIR_INCANDESCENT_DATASET, "testB"), "RGB", load_size=(128, 128))

downN_00 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "downN_00"), "$\mathrm{downN}=0$", fid_reference=rgbReference)
downN_02 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "downN_02"), "$\mathrm{downN}=2$", fid_reference=rgbReference)
downN_32 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "downN_32"), "$\mathrm{downN}=32$", fid_reference=rgbReference)
li_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "li_0"), "$\lambda_i=0$", fid_reference=rgbReference)
ls_0_li_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "ls_0_li_0"), "$\lambda_i=0, \lambda_s=0$",
                             fid_reference=rgbReference)
ls_0_li_500 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "ls_0_li_500"), "$\lambda_i=500, \lambda_s=0$",
                               fid_reference=rgbReference)
initial_random = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "initial_random"),
                                  "$t=500, y \sim \mathcal{N}(0,\mathbf{I})$", fid_reference=rgbReference)
initial_random_li_0_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "initial_random_li_0_ls_0"),
                                            "$t=500$\n$y \sim \mathcal{N}(0,\mathbf{I})$\n$\lambda_i=0, \lambda_s=0$",
                                            fid_reference=rgbReference)
initial_random_li_0_ls_0_t_4000 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "initial_random_li_0_ls_0_t_4000"),
                                                   "$t=4000$\n$y \sim \mathcal{N}(0,\mathbf{I})$\n$\lambda_i=0, \lambda_s=0$",
                                                   fid_reference=rgbReference)
t4000 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "initial_random_li_0_ls_0_t_4000"), "$t=4000$",
                         fid_reference=rgbReference)

highpass_t_2000 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_t_2000_li_02"),
                                   "$t=2000, \lambda_i=2$,\n Highpass",
                                   fid_reference=rgbReference)

highpass_t_2000_li_500 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_t_2000_li_500"),
                                          "$t=2000, \lambda_i=500$,\n Highpass", fid_reference=rgbReference)
highpass_t_2000_li_120 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_t_2000_li_120"),
                                          "$t=2000, \lambda_i=120$,\n Highpass", fid_reference=rgbReference)
highpass_t_2000_li_120_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_t_2000_li_120_ls_0"),
                                               "$t=2000$\n$\lambda_i=120, \lambda_s=0$,\n Highpass",
                                               fid_reference=rgbReference)
highpass_mean_t_2000_li_120_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_mean_t_2000_li_120_ls_0"),
                                                    "$t=2000$\n$\lambda_i=120, \lambda_s=0$,\n Highpass & Mean",
                                                    fid_reference=rgbReference)
highpass_mean_t_2000_li_500_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "highpass_mean_t_2000_li_500_ls_0"),
                                                    "$t=2000$\n$\lambda_i=500, \lambda_s=0$,\n Highpass & Mean",
                                                    fid_reference=rgbReference)
mean_t_2000_li_500_ls_0 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "mean_t_2000_li_500_ls_0"),
                                           "$t=2000$\n$\lambda_i=500, \lambda_s=0$,\n Mean",
                                           fid_reference=rgbReference)

downN_0_li_10_t_2000 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "downN_0_li_10_t_2000"),
                                        "$t=2000, \lambda_i=10$,\n Identity", fid_reference=rgbReference)

highpass_li_120_t_500_0_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_120_t_500_0"),
                                                   "128x128, \n$t=500, \lambda_i=120$, \n0",
                                                   fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_120_t_500_1_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_120_t_500_1"),
                                                   "128x128, \n$t=500, \lambda_i=120$, \n1",
                                                   fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_120_t_600_0_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_120_t_600_0"),
                                                   "128x128, \n$t=600, \lambda_i=120$, \n0",
                                                   fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_120_t_600_1_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_120_t_600_1"),
                                                   "128x128, \n$t=600, \lambda_i=120$, \n1",
                                                   fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_1000_t_500_log_0_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_1000_t_500_log_0"),
    "128x128, \n$t=500, \lambda_i=1000$, \nlog, 0", fid_reference=rgbReference_128x128, load_size=(128, 128))

highpass_li_1000_t_500_log_1_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_highpass_li_1000_t_500_log_1"),
    "128x128, \n$t=500, \lambda_i=1000$, \nlog, 1", fid_reference=rgbReference_128x128, load_size=(128, 128))

mean_li_120_t_500_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_mean_li_120_t_500"),
    "128x128, \n$t=500, \lambda_i=120$, \n Mean", fid_reference=rgbReference_128x128, load_size=(128, 128))

mean_lin_li_120_t_500_128x128 = EvaluationResult(
    join(EGSDE_RUNS_BASE_PATH, "128x128_mean_lin_li_120_t_500"),
    "128x128, \n$t=500, \lambda_i=120$, \n Mean $\downarrow$", fid_reference=rgbReference_128x128, load_size=(128, 128))

mean_highpass_li_170_t_850_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_mean_highpass_li_170_t_850"),
                                                      "128x128, \n$t=850, \lambda_i=170$, \nHighpass $\leftarrow$ Mean",
                                                      fid_reference=rgbReference_128x128, load_size=(128, 128))

t_1000_random_start_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_t_1000_random_start"),
                                               "128x128, Unconditional", fid_reference=rgbReference_128x128,
                                               load_size=(128, 128))

t_1000_random_start_240_128x128 = EvaluationResult(join(EGSDE_RUNS_BASE_PATH, "128x128_t_1000_random_start_240"),
                                                   "128x128, Unconditional", fid_reference=rgbReference_128x128,
                                                   load_size=(128, 128))

cycle_gan = EvaluationResult("../../cycle-gan/results/cycle_gan_serengeti_inc_0_000015_0_000045/test_200/images",
                             "CycleGAN", fid_reference=rgbReference, filename_matcher="*_fake.png")

cycle_gan_128x128 = EvaluationResult(
    "../../cycle-gan/results/cycle_gan_serengeti_inc_0_000015_0_000045/test_200/images",
    "CycleGAN", fid_reference=rgbReference_128x128, filename_matcher="*_fake.png", load_size=(128, 128))


In [None]:
plot_grid(
    [nirReference_128x128, highpass_li_120_t_500_0_128x128, mean_li_120_t_500_128x128, mean_lin_li_120_t_500_128x128,
     mean_highpass_li_170_t_850_128x128, cycle_gan_128x128],
    columns=1)

In [None]:
result_fid_to_df([highpass_li_120_t_500_0_128x128, mean_li_120_t_500_128x128, mean_lin_li_120_t_500_128x128,
                  mean_highpass_li_170_t_850_128x128, cycle_gan_128x128])

In [None]:
plot_grid([nirReference_128x128, highpass_li_120_t_500_0_128x128, mean_li_120_t_500_128x128, cycle_gan_128x128],
          columns=2)

In [None]:
result_fid_to_df([cycle_gan_128x128, highpass_li_120_t_500_0_128x128, mean_li_120_t_500_128x128])

In [None]:
plot_grid([nirReference, mean_t_2000_li_500_ls_0, cycle_gan], columns=2)

In [None]:
result_fid_to_df([mean_t_2000_li_500_ls_0, cycle_gan])

In [None]:
plot_grid(
    [nirReference, highpass_t_2000_li_120_ls_0, highpass_mean_t_2000_li_120_ls_0, highpass_mean_t_2000_li_500_ls_0,
     mean_t_2000_li_500_ls_0], columns=1)

In [None]:
result_fid_to_df([highpass_t_2000_li_120_ls_0, highpass_mean_t_2000_li_120_ls_0, highpass_mean_t_2000_li_500_ls_0,
                  mean_t_2000_li_500_ls_0])

In [None]:
plot_grid([nirReference_128x128, rgbReference_128x128, cycle_gan_128x128, t_1000_random_start_128x128,
           t_1000_random_start_240_128x128], columns=2,
          column_titles=["NIR", "Real RGB", "CycleGAN", "Diffusion 90", "Diffusion 240"], rows=20)

In [None]:
result_fid_to_df([cycle_gan_128x128, t_1000_random_start_128x128, t_1000_random_start_240_128x128])

In [None]:
plot_grid([nirReference_128x128, highpass_li_120_t_500_0_128x128, highpass_li_120_t_500_1_128x128,
           highpass_li_120_t_600_0_128x128, highpass_li_120_t_600_1_128x128], columns=1)

In [None]:
result_fid_to_df([highpass_li_120_t_500_0_128x128, highpass_li_120_t_500_1_128x128, highpass_li_120_t_600_0_128x128,
                  highpass_li_120_t_600_1_128x128])

In [None]:
plot_grid([nirReference_128x128, highpass_li_120_t_500_0_128x128, highpass_li_120_t_500_1_128x128,
           highpass_li_1000_t_500_log_0_128x128, highpass_li_1000_t_500_log_1_128x128], columns=1)

In [None]:
result_fid_to_df(
    [highpass_li_120_t_500_0_128x128, highpass_li_120_t_500_1_128x128, highpass_li_1000_t_500_log_0_128x128,
     highpass_li_1000_t_500_log_1_128x128])

In [None]:
plot_grid([nirReference, downN_0_li_10_t_2000, highpass_t_2000_li_120, highpass_t_2000_li_120_ls_0], columns=2, rows=2)

In [None]:
result_fid_to_df([downN_0_li_10_t_2000, highpass_t_2000_li_120, highpass_t_2000_li_120_ls_0])

In [None]:
plot_grid([nirReference, downN_0_li_10_t_2000, highpass_t_2000_li_500, highpass_t_2000_li_120], columns=2, rows=2)

In [None]:
result_fid_to_df([downN_0_li_10_t_2000, highpass_t_2000_li_500, highpass_t_2000_li_120])

In [None]:
plot_grid([nirReference, downN_00, downN_0_li_10_t_2000], columns=3, rows=3,
          column_titles=["NIR", "$t=500, \lambda_i=2$", "$t=2000, \lambda_i=10$"])

In [None]:
result_fid_to_df([downN_00, downN_0_li_10_t_2000])

In [None]:
result_fid_to_df([downN_32, downN_0_li_10_t_2000, highpass_t_2000_li_500, cycle_gan])

In [None]:
plot_grid([nirReference, cycle_gan, highpass_t_2000_li_500, downN_32], 2, rows=16)

In [None]:
result_fid_to_df([cycle_gan, highpass_t_2000_li_500])

In [None]:
plot_grid([nirReference, downN_32], 3)

In [None]:
plot_grid([nirReference, downN_32, ls_0_li_0], 2)

In [None]:
plot_grid([nirReference, downN_32, initial_random], 2)

In [None]:
plot_grid([nirReference, ls_0_li_500], columns=3)

In [None]:
plot_grid([nirReference, downN_00, downN_02, downN_32], columns=2)

In [None]:
plot_grid([nirReference, ls_0_li_0, li_0, downN_32], columns=1)
result_fid_to_df([ls_0_li_0, li_0, downN_32])

In [None]:
plot_grid([initial_random_li_0_ls_0, initial_random_li_0_ls_0_t_4000], columns=4, rows=3, )

In [None]:
result_fid_to_df([initial_random_li_0_ls_0, initial_random_li_0_ls_0_t_4000])

In [None]:
plot_grid([cycle_gan, initial_random_li_0_ls_0_t_4000], columns=3, rows=3, column_titles=["CycleGAN", "Diffusion"])

In [None]:
result_fid_to_df([cycle_gan, initial_random_li_0_ls_0_t_4000])

In [None]:
plot_grid([nirReference, cycle_gan, highpass_t_2000_li_500], columns=2, rows=2,
          column_titles=["NIR", "CycleGAN", "Diffusion"])

In [None]:
result_fid_to_df([cycle_gan, highpass_t_2000_li_500])

In [None]:
plot_grid([nirReference, t4000, initial_random_li_0_ls_0_t_4000], columns=2, rows=2,
          column_titles=["NIR", "$t=4000$\n$y \sim p_{M|0}(y\mid x_0)$\n$\lambda_s=500, \lambda_i=2$",
                         "$t=4000$\n$y \sim \mathcal{N}(0, \mathbf{I})$\n$\lambda_s=0, \lambda_i=0$"])

In [None]:
plot_grid([rgbReference, cycle_gan, initial_random_li_0_ls_0_t_4000], columns=2, rows=3,
          column_titles=["RGB", "CycleGAN", "Diffusion"])

In [None]:
plot_grid([rgbReference, cycle_gan, highpass_t_2000_li_120], columns=2, rows=3,
          column_titles=["RGB", "CycleGAN", "Diffusion"])

In [None]:
import torch
import torchvision.transforms as TF

sim = torch.nn.CosineSimilarity()


def band_pass_loss(x1, x2):
    dataset1 = NumpyDataset(x1.load_images()[:100])
    batch1 = torch.tensor(np.array([i.numpy() for i in dataset1]))
    grad1 = sobel(batch1)

    dataset2 = NumpyDataset(x2.load_images()[:100])
    batch2 = torch.tensor(np.array([i.numpy() for i in dataset2]))
    grad2 = sobel(batch2)

    return sim(grad1, grad2).mean()

In [None]:
band_pass_loss(nirReference, cycle_gan)

In [None]:
dataset1 = NumpyDataset(cycle_gan.load_images()[:100])
batch1 = torch.tensor(np.array([i.numpy() for i in dataset1]))
grad1 = sobel(batch1)

TF.ToPILImage()(grad1.numpy()[0])


In [None]:
band_pass_loss(nirReference, downN_32)

In [None]:
plot_grid([nirReference, cycle_gan, downN_32], columns=3)

In [None]:
from torchvision.transforms import ToTensor, ToPILImage

M = torch.tensor([[5.7735014e-01, -8.1649649e-01, 4.7008697e-08],
                  [5.7735026e-01, 4.0824834e-01, 7.0710671e-01],
                  [5.7735026e-01, 4.0824822e-01, -7.0710683e-01]])

# `invM` is the inverse transformation of `M`
invM = torch.inverse(M)


# Decouple a gray-scale image with `M`
def decouple(inputs):
    return torch.einsum('bihw,ij->bjhw', inputs, M)


# The inverse function to `decouple`.
def couple(inputs):
    return torch.einsum('bihw,ij->bjhw', inputs, invM)


def get_mask(image):
    mask = torch.cat([torch.ones_like(image[:, :1, ...]),
                      torch.zeros_like(image[:, 1:, ...])], dim=1)
    return mask


image_rgb = cycle_gan_128x128.load_images()[0]
plt.imshow(image_rgb)
plt.axis("off")
plt.show()

batch_rgb = ToTensor()(image_rgb)
batch_rgb = torch.tensor([batch_rgb.numpy()])

image_nir = nirReference_128x128.load_images()[0]
plt.imshow(image_nir)
plt.axis("off")
plt.show()

batch_nir = ToTensor()(image_nir)
batch_nir = torch.tensor([batch_nir.numpy()])


def plot_batch(batch):
    fig, axes = plt.subplots(1, 3)

    axes[0].imshow(batch[0, 0, :, :], cmap="gray")
    axes[0].axis('off')
    axes[1].imshow(batch[0, 1, :, :], cmap="gray")
    axes[1].axis('off')
    axes[2].imshow(batch[0, 2, :, :], cmap="gray")
    axes[2].axis('off')

    plt.subplots_adjust(wspace=0, hspace=0)
    plt.show()


decoupled_rgb = decouple(batch_rgb)
plot_batch(decoupled_rgb)

mask = get_mask(batch_rgb)
masked_decoupled_rgb = decoupled_rgb * (1. - mask)
plot_batch(masked_decoupled_rgb)

decoupled_nir = decouple(batch_nir)
plot_batch(decoupled_nir)

masked_decoupled_nir = decoupled_nir * mask
plot_batch(masked_decoupled_nir)

combined_decoupled = masked_decoupled_rgb + masked_decoupled_nir
plot_batch(combined_decoupled)

mixed = couple(combined_decoupled)
plot_batch(mixed)
plt.imshow(ToPILImage()(mixed[0]))
plt.axis("off")
plt.show()

plt.show()

plt.imshow(ToPILImage()(couple(decoupled_rgb)[0]))

plt.show()

In [None]:
M = np.array([
    [1 / 3, 1 / 3, 1 / 3],
    [1 / 3, 0, 0],
    [0, 1 / 3, 0]
])

invM = np.linalg.inv(M)

display(np.linalg.qr(M))
display(np.linalg.eigh(invM)[1])