In [1]:
import copy
import os
import pickle
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import tensorflow as tf

relative_path = '../'
sys.path.append(relative_path)

from Utils.generic_utils import get_reduced_name_list, tensor_to_images, plot_to_projector
from Utils.trainer_utils import DataGenerator, compose_input
from LatentCA.model import AutoEncoderDown2, NAFCA
from NAFNet.model import NAFNet
from Restormer.model import Restormer
from ViTCA.model import ViTCA

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

1 Physical GPUs, 1 Logical GPUs


# Utilities to manage training data

In [10]:
# line_styles for clear plots
line_style_tuple = [('dotted', (0, (1, 1))), ('dashed', (0, (5, 5)))]

# load data grouping by epoch and computing mean over it, then computing the total weighted loss
def load_epoch_summary(filename, weights_dict):
    with open(filename, 'rb') as f:
        d = pickle.load(f).groupby(by='epoch').mean()
    d['total_loss'] = 0.
    for key, value in weights_dict.items():
        d['total_loss'] += value * d[key]
    return d

# plot single metric
def plot_dataset(d, y, ax, x_label, y_label, real=True):
    if real:
        d.plot(y=y, ax=ax, kind='line', linestyle=line_style_tuple[0][1], xlabel=x_label, ylabel=y_label)
    else:
        d.plot(y=y, ax=ax, kind='line', linestyle=line_style_tuple[1][1], xlabel=x_label, ylabel=y_label)

# plot all metrics
def plot_all_dataset_wrapper(d_dict, y_list, x_label, y_label_list, size, max_x, save=False, filename=None):
    rows, cols = 3, 2
    fig, ax = plt.subplots(rows, cols, figsize=size)
    for i in range(rows):
        for j in range(cols):
            for key, value in d_dict.items():
                plot_dataset(value[0], y_list[i * cols + j], ax[i, j], x_label, y_label_list[i * cols + j], value[1])
            ax[i, j].legend(list(d_dict.keys()), prop={'size': 10})
            ax[i, j].grid(visible=True, alpha=0.2)
            ax[i, j].set_xticks(np.arange(0, max_x, step=2))
    plt.tight_layout()
    if save:
        plt.savefig(filename, format='svg', dpi=1200)
    plt.show()

dataset_dict = {
    'Blur': ['CelebA', 'CIFAR-10', 'GoPro', 'RealBlur', 'TinyImageNet'],
    'Noise': ['CelebA', 'CIFAR-10', 'Renoir', 'SID', 'TinyImageNet']
}
results_folder = 'Results'
model_name_list = ['LatentNAFCA', 'LatentViTCA', 'NAFNet', 'Restormer', 'ViTCA']

# Autoencoder plots

In [None]:
task_dict = {'task': 'Noise', 'range': 1., 'intensity': 0.5, 'a_min': 0.25, 'a_max': 0.75}
# task_dict = {'task': 'Blur', 'range': 15, 'intensity': 7, 'a_min': 4, 'a_max': 10}
model_name = model_name_list[1]
file_name = 'val_log.pickle'
weights = {'rec_loss': 100, 'dist_loss': 10, 'task_loss': 10, 'equiv_loss': 100}
d_ae = {data: (load_epoch_summary(relative_path + os.path.join(results_folder, task_dict['task'], data, model_name, 'Logs', file_name),weights), False)
        for data in dataset_dict[task_dict['task']]}

plot_all_dataset_wrapper(d_dict=d_ae, y_list=['rec_loss', 'dist_loss', 'task_loss', 'equiv_loss', 'total_loss', 'SSIM'], x_label='epoch',
                         y_label_list=['Reconstruction Loss', 'Distance Loss', 'Task Loss', 'Equivalent Loss', 'Total Loss', 'SSIM'], size=(20, 15),
                         max_x=len(list(d_ae.values())[0][0]), save=False, filename=None)

# CA plots

In [None]:
file_name = 'val_ca_log.pickle'
weights = {'rec_loss': 100, 'latent_loss': 100, 'out_loss': 1, 'hid_loss': 1}
d_ca = {data: (load_epoch_summary(relative_path + os.path.join(results_folder, task_dict['task'], data, model_name, 'Logs', file_name),weights), False)
        for data in dataset_dict[task_dict['task']]}

plot_all_dataset_wrapper(d_dict=d_ca, y_list=['rec_loss', 'latent_loss', 'out_loss', 'hid_loss', 'total_loss', 'SSIM'], x_label='epoch',
                         y_label_list=['Reconstruction Loss', 'Latent Loss', 'Output Overflow Loss', 'Hidden Overflow Loss', 'Total Loss', 'SSIM'], size=(20, 15),
                         max_x=len(list(d_ae.values())[0][0]), save=False, filename=None)

# Latent space visualization

In [None]:
# load the generator for the specific dataset
dataset = 'CelebA'
dataset_directory = relative_path + 'Datasets/' + dataset
img_shape = (8, 32, 32, 3)
reduction = .01
test_split = 0.8
val_split = 0.8
real_dataset = False
train, val, test = get_reduced_name_list(path=dataset_directory, reduction=reduction, test_split=test_split, val_split=val_split, real_dataset=real_dataset)
test_gen = DataGenerator(directory=dataset_directory, filenames=test, img_size=img_shape, task_info=task_dict, train_ca=False, shuffle=False, num_iterations=0,
                         curriculum=False, train=False, real_dataset=real_dataset)

In [None]:
# set up the parameters for the model
max_depth = 16
num_down_sampling = 2
cell_in_channels = cell_out_channels = max_depth
cell_hidden_channels = 32
pool_shape = (img_shape[0], img_shape[1] // (2 ** num_down_sampling), img_shape[1] // (2 ** num_down_sampling),
              cell_in_channels + cell_out_channels + cell_hidden_channels)
latent_shape = (img_shape[1] // (2 ** num_down_sampling), img_shape[1] // (2 ** num_down_sampling), max_depth)
latent_nafca_params = {
    'class': NAFCA,
    'localized_attention_neighbourhood': (3, 3),
    'ffn_expand': 4,
    'dropout': 0.1,
    'embed_dim': 128,
    'cell_in_channels': cell_in_channels,
    'cell_out_channels': cell_out_channels,
    'cell_hidden_channels': cell_hidden_channels
}
latent_vitca_params = {
    'class': ViTCA,
    'localized_attention_neighbourhood': (3, 3),
    'patch_size': 1,
    'pos_encoding_max_freq': 5,
    'depth': 1,
    'heads': 4,
    'mlp_dim': 64,
    'dropout': 0.0,
    'embed_cells': True,
    'embed_dim': 128,
    'embed_dropout': 0.0,
    'pool_shape': pool_shape,
    'cell_in_channels': cell_in_channels,
    'cell_out_channels': cell_out_channels,
    'cell_hidden_channels': cell_hidden_channels
}
latent_nafca_params_complete = {
    'input_shape': img_shape,
    'pool_shape': pool_shape,
    'latent_shape': latent_shape,
    'CAParams': None,
    'Conv2DParams1': {
        'filters': max_depth // 2,
        'kernel_size': (3, 3),
        'strides': (2, 2),
        'padding': 'same',
        'activation': 'swish',
        'batch_normalization': True,
        'dropout_rate': 0.0
    },
    'Conv2DParams2': {
        'filters': max_depth,
        'kernel_size': (3, 3),
        'strides': (2, 2),
        'padding': 'same',
        'activation': 'swish',
        'batch_normalization': True,
        'dropout_rate': 0.0
    },
    'PassThroughParams1': {
        'filters': max_depth // 2,
        'kernel_size': (3, 3),
        'strides': (2, 2),
        'padding': 'same',
        'activation': 'swish',
        'batch_normalization': True,
        'dropout_rate': 0.0
    },
    'Conv2DTransposeParams2': {
        'filters': max_depth // 2,
        'kernel_size': (3, 3),
        'strides': (2, 2),
        'padding': 'same',
        'activation': 'swish',
        'batch_normalization': True,
        'dropout_rate': 0.0
    },
    'MixParams': {
        'filters': max_depth // 2,
        'kernel_size': (3, 3),
        'strides': (1, 1),
        'padding': 'same',
        'activation': 'swish',
        'batch_normalization': True,
        'dropout_rate': 0.0
    },
    'Conv2DTransposeParams1': {
        'filters': 3,
        'kernel_size': (3, 3),
        'strides': (2, 2),
        'padding': 'same',
        'activation': 'sigmoid',
        'batch_normalization': True,
        'dropout_rate': 0.0
    }
}
latent_vitca_params_complete = copy.deepcopy(latent_nafca_params_complete)
latent_nafca_params_complete['CAParams'] = latent_nafca_params
latent_vitca_params_complete['CAParams'] = latent_vitca_params

# load the pre-trained model
model = AutoEncoderDown2(latent_vitca_params_complete)
model.load_weights(relative_path + os.path.join(results_folder, task_dict['task'], dataset, model_name_list[1], 'Checkpoints', 'Best', 'weights'))

In [None]:
# create the latent space
if not os.path.exists('Projector/'):
    os.mkdir('Projector/')
plot_to_projector(model=model, generator=test_gen, path='Projector/', num_sample=7, noise=task_dict['task'] == 'Noise',
                  min_value=task_dict['a_min'], max_value=task_dict['a_max'])

To visualize open a terminal and use the command **tensorboard --logdir=Notebooks/Projector/**
Then select the Projector tab

# Test results visualization between all models

In [None]:
# load all test results for the selected dataset
dataset = 'CelebA'
metrics = ['PSNR', 'SSIM', 'REL_PSNR', 'REL_SSIM']
df_list = []
for name in model_name_list:
    with open(relative_path + os.path.join(results_folder, task_dict['task'], dataset, name, 'Logs', 'test_log.pickle'), 'rb') as f:
        df_list.append(pickle.load(f).mean()[metrics])
df = pd.DataFrame(df_list, index=model_name_list)
df

# Visual comparison between all models

In [21]:
# set up the parameters for the remaining models
nafnet_params = {
    'input_shape': img_shape,
    'width': 64,
    'enc_block_nums': [2, 2, 4, 8],
    'middle_block_num': 12,
    'dec_block_nums': [2, 2, 2, 2]
}
restormer_params = {
    'input_shape': img_shape,
    'out_channels': 3,
    'dim': 48,
    'num_blocks': [4, 6, 6, 8],
    'num_refinement_blocks': 4,
    'heads': [1, 2, 4, 8],
    'ffn_expansion_factor': 2.66,
    'bias': False
}
vitca_params = {
    'input_shape': img_shape,
    'localized_attention_neighbourhood': [3, 3],
    'patch_size': 1,
    'overlapping_patches': False,
    'pos_encoding_method': 'vit_handcrafted',
    'pos_encoding_basis': 'raw_xy',
    'pos_encoding_max_freq': 5,
    'depth': 1,
    'heads': 4,
    'mlp_dim': 64,
    'dropout': 0.0,
    'cell_init': 'constant',
    'cell_in_channels': img_shape[-1],
    'cell_out_channels': img_shape[-1],
    'cell_hidden_channels': cell_hidden_channels,
    'embed_cells': True,
    'embed_dim': 128,
    'embed_dropout': 0.0
}

In [None]:
# load all the pre-trained models
latent_nafca_model = AutoEncoderDown2(latent_nafca_params_complete)
latent_nafca_model.load_weights(relative_path + os.path.join(results_folder, task_dict['task'], dataset, model_name_list[0], 'Checkpoints', 'Best', 'weights'))
latent_vitca_model = AutoEncoderDown2(latent_vitca_params_complete)
latent_vitca_model.load_weights(relative_path + os.path.join(results_folder, task_dict['task'], dataset, model_name_list[1], 'Checkpoints', 'Best', 'weights'))
nafnet_model = NAFNet(nafnet_params)
nafnet_model.load_weights(relative_path + os.path.join(results_folder, task_dict['task'], dataset, model_name_list[2], 'Checkpoints', 'Best', 'weights'))
restormer_model = Restormer(restormer_params)
restormer_model.load_weights(relative_path + os.path.join(results_folder, task_dict['task'], dataset, model_name_list[3], 'Checkpoints', 'Best', 'weights'))
vitca_model = ViTCA(vitca_params)
vitca_model.load_weights(relative_path + os.path.join(results_folder, task_dict['task'], dataset, model_name_list[4], 'Checkpoints', 'Best', 'weights'))

In [24]:
# sample a batch from the generator and perform inference
test_gen.switch_mode(True)
anchor, positive = test_gen.__getitem__(0)
latent_nafca_images = tensor_to_images(latent_nafca_model(compose_input(positive, pool_shape, 64, 0.5, None, False, True), training=False)[0])
latent_vitca_images = tensor_to_images(latent_vitca_model(compose_input(positive, pool_shape, 64, 0.5, None, False, True), training=False)[0])
nafnet_images = tensor_to_images(nafnet_model(positive, training=False))
restormer_images = tensor_to_images(restormer_model(positive, training=False))
vitca_images = tensor_to_images(vitca_model.get_rgb_out(vitca_model([vitca_model.seed(positive), tf.constant(64, dtype=tf.float32, shape=(1,)),
                                                                     tf.constant(0.5, dtype=tf.float32, shape=(1,))], training=False)))
anchor_images = tensor_to_images(anchor)
positive_images = tensor_to_images(positive)
model_images = [latent_nafca_images, latent_vitca_images, nafnet_images, restormer_images, vitca_images]

In [None]:
# visualize the first "rows" results
rows = 5
cols = len(model_name_list) + 2
col_names = ['Ground Truth', 'Corrupted Image', *model_name_list]
f, ax = plt.subplots(rows, cols, figsize=(cols * 10, rows * 10))
for i in range(rows):
    ax[i, 0].imshow(anchor_images[i])
    ax[i, 0].axis('off')
    ax[i, 1].imshow(positive_images[i])
    ax[i, 1].axis('off')
    for pos, value in enumerate(model_images):
        ax[i, pos + 2].imshow(value[i])
        ax[i, pos + 2].axis('off')
for ax, col in zip(ax[0], col_names):
    ax.set_title(col, fontsize = 48)
plt.tight_layout()
plt.show()

# End