In [1]:
import os
import sys
import copy
import GPUtil
import tensorflow as tf

relative_path = '../'
sys.path.append(relative_path)

from LatentCA.model import NAFCA
from LatentCA.trainer import LatentTrainerWrapper
from NAFNet.trainer import TrainerNAFNet
from Restormer.trainer import TrainerRestormer
from Utils.generic_utils import get_reduced_name_list
from Utils.trainer_utils import DataGenerator, TestLauncher, create_dummy_dataset, get_elapsed_time
from ViTCA.model import ViTCA
from ViTCA.trainer import TrainerViTCA

gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        for gpu in gpus:
            tf.config.experimental.set_memory_growth(gpu, True)
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPUs")
    except RuntimeError as e:
        print(e)

1 Physical GPUs, 1 Logical GPUs


In [2]:
img_shape = (32, 32, 32, 3)
num_iterations = 3
model_names = ['nafnet', 'restormer', 'vitca', 'latentvitca', 'latentnafca']
model = model_names[4]
save_df = True

In [3]:
dataset_directory = relative_path + 'Datasets/Dummy_{}_{}_{}_{}/'.format(*img_shape)
if not os.path.exists(dataset_directory):
    print('Dataset does not exist, creating...')
    os.mkdir(dataset_directory)
    create_dummy_dataset(num_images=1250, shape=img_shape[1:], path=dataset_directory)
else:
    print('Dataset already existing, loading...')

Dataset does not exist, creating...


100%|██████████| 1250/1250 [00:00<00:00, 2205.05it/s]


In [4]:
reduction = 1.
test_split = 0.8
val_split = 0.8
curriculum_learning = True
task_dict = {'task': 'Noise', 'range': 1., 'intensity': 0.5, 'a_min': 0.25, 'a_max': 0.75}
real_dataset = False
train, val, test = get_reduced_name_list(path=dataset_directory, reduction=reduction, test_split=test_split,
                                         val_split=val_split, real_dataset=real_dataset)
train_gen = DataGenerator(directory=dataset_directory, filenames=train, img_size=img_shape,
                          task_info=copy.deepcopy(task_dict), train_ca=True, shuffle=True,
                          num_iterations=num_iterations, curriculum=curriculum_learning, train=True,
                          real_dataset=real_dataset)
val_gen = DataGenerator(directory=dataset_directory, filenames=val, img_size=img_shape,
                        task_info=copy.deepcopy(task_dict), train_ca=True, shuffle=False,
                        num_iterations=num_iterations, curriculum=False, train=False, real_dataset=real_dataset)
test_gen = DataGenerator(directory=dataset_directory, filenames=test, img_size=img_shape,
                         task_info=copy.deepcopy(task_dict), train_ca=True, shuffle=False,
                         num_iterations=num_iterations, curriculum=False, train=False, real_dataset=real_dataset)

Data Split:
Train: 800
Validation: 200
Test: 250


100%|██████████| 800/800 [00:00<00:00, 5414.10it/s]
100%|██████████| 200/200 [00:00<00:00, 5427.90it/s]
100%|██████████| 250/250 [00:00<00:00, 5986.01it/s]


In [5]:
common_params = {
    'base_dir': relative_path + 'Results/',
    'dataset': dataset_directory.split('/')[-2],
    'task': task_dict['task'],
    'epochs': num_iterations,
    'train': train_gen,
    'val': val_gen,
    'test': test_gen,
    'monitor_metric': 'SSIM',
    'display_frequency': num_iterations + 1,
    'test_length': 4,
    'model_params_dicts': None,
    'telegram_token': None,
    'telegram_chat_id': None
}

max_depth = 16
num_down_sampling = 2
cell_in_channels = cell_out_channels = max_depth
cell_hidden_channels = 32
pool_shape = (img_shape[0], img_shape[1] // (2 ** num_down_sampling), img_shape[1] // (2 ** num_down_sampling),
              cell_in_channels + cell_out_channels + cell_hidden_channels)
latent_shape = (img_shape[1] // (2 ** num_down_sampling), img_shape[1] // (2 ** num_down_sampling), max_depth)
latent_params = {
    'model_params': {
        'input_shape': img_shape,
        'pool_shape': pool_shape,
        'latent_shape': latent_shape,
        'CAParams': None,
        'Conv2DParams1': {
            'filters': max_depth // 2,
            'kernel_size': (3, 3),
            'strides': (2, 2),
            'padding': 'same',
            'activation': 'swish',
            'batch_normalization': True,
            'dropout_rate': 0.0
        },
        'Conv2DParams2': {
            'filters': max_depth,
            'kernel_size': (3, 3),
            'strides': (2, 2),
            'padding': 'same',
            'activation': 'swish',
            'batch_normalization': True,
            'dropout_rate': 0.0
        },
        'PassThroughParams1': {
            'filters': max_depth // 2,
            'kernel_size': (3, 3),
            'strides': (2, 2),
            'padding': 'same',
            'activation': 'swish',
            'batch_normalization': True,
            'dropout_rate': 0.0
        },
        'Conv2DTransposeParams2': {
            'filters': max_depth // 2,
            'kernel_size': (3, 3),
            'strides': (2, 2),
            'padding': 'same',
            'activation': 'swish',
            'batch_normalization': True,
            'dropout_rate': 0.0
        },
        'MixParams': {
            'filters': max_depth // 2,
            'kernel_size': (3, 3),
            'strides': (1, 1),
            'padding': 'same',
            'activation': 'swish',
            'batch_normalization': True,
            'dropout_rate': 0.0
        },
        'Conv2DTransposeParams1': {
            'filters': 3,
            'kernel_size': (3, 3),
            'strides': (2, 2),
            'padding': 'same',
            'activation': 'sigmoid',
            'batch_normalization': True,
            'dropout_rate': 0.0
        }},
    'trainer': LatentTrainerWrapper,
    'learning_rate': 1e-2,
    'num_down_sampling': num_down_sampling,
    'margin': 1.,
    'reconstruction_loss_ae': 'MSE',
    'reconstruction_loss_weight_ae': 100,
    'distance_loss': 'MSE',
    'distance_loss_weight': 10,
    'task_loss': 'MSE',  # PSNR
    'task_loss_weight': 10,
    'equivalent_loss': 'PureNoiseMSE',
    'equivalent_loss_weight': 100,
    'perturbation_intensity': 0.5,
    'total_variation_loss_ae': False,
    'total_variation_loss_weight_ae': 1,
    'reconstruction_loss': 'MSE',
    'reconstruction_loss_weight': 100,
    'latent_loss': 'MSE',
    'latent_loss_weight': 100,
    'output_overflow_loss': True,
    'output_overflow_loss_weight': 1,
    'hidden_overflow_loss': True,
    'hidden_overflow_loss_weight': 1,
    'total_variation_loss': False,
    'total_variation_loss_weight': 1,
    'update_probability': 1.,
    'min_cell_updates': 8,
    'max_cell_updates': 32,
    'pool_length': 1024
}
latent_nafca_params = {
    'class': NAFCA,
    'localized_attention_neighbourhood': (3, 3),
    'ffn_expand': 4,
    'dropout': 0.1,
    'embed_dim': 128,
    'cell_in_channels': cell_in_channels,
    'cell_out_channels': cell_out_channels,
    'cell_hidden_channels': cell_hidden_channels
}
latent_vitca_params = {
    'class': ViTCA,
    'localized_attention_neighbourhood': (3, 3),
    'patch_size': 1,
    'pos_encoding_max_freq': 5,
    'depth': 1,
    'heads': 4,
    'mlp_dim': 64,
    'dropout': 0.0,
    'embed_cells': True,
    'embed_dim': 128,
    'embed_dropout': 0.0,
    'pool_shape': pool_shape,
    'cell_in_channels': cell_in_channels,
    'cell_out_channels': cell_out_channels,
    'cell_hidden_channels': cell_hidden_channels
}
nafnet_params = {
    'model_name': 'NAFNet',
    'trainer': TrainerNAFNet,
    'model_params': {
        'input_shape': img_shape,
        'width': 64,
        'enc_block_nums': [2, 2, 4, 8],
        'middle_block_num': 12,
        'dec_block_nums': [2, 2, 2, 2]},
    'learning_rate': 1e-3
}
restormer_params = {
    'model_name': 'Restormer',
    'trainer': TrainerRestormer,
    'model_params': {
        'input_shape': img_shape,
        'out_channels': 3,
        'dim': 48,
        'num_blocks': [4, 6, 6, 8],
        'num_refinement_blocks': 4,
        'heads': [1, 2, 4, 8],
        'ffn_expansion_factor': 2.66,
        'bias': False},
    'learning_rate': 3e-4,
}
vitca_params = {
    'model_name': 'ViTCA',
    'trainer': TrainerViTCA,
    'model_params': {
        'input_shape': img_shape,
        'localized_attention_neighbourhood': [3, 3],
        'patch_size': 1,
        'overlapping_patches': False,
        'pos_encoding_method': 'vit_handcrafted',
        'pos_encoding_basis': 'raw_xy',
        'pos_encoding_max_freq': 5,
        'depth': 1,
        'heads': 4,
        'mlp_dim': 64,
        'dropout': 0.0,
        'cell_init': 'constant',
        'cell_in_channels': img_shape[-1],
        'cell_out_channels': img_shape[-1],
        'cell_hidden_channels': cell_hidden_channels,
        'embed_cells': True,
        'embed_dim': 128,
        'embed_dropout': 0.0},
    'learning_rate': 1e-3,
    'output_overflow_loss': True,
    'hidden_overflow_loss': True,
    'reconstruction_loss_factor': 1e2,
    'overflow_loss_factor': 1e2,
    'pool_length': 1024,
    'update_probability': 0.5,
    'min_cell_updates': 8,
    'max_cell_updates': 32,
}
latent_nafca_params_complete = copy.deepcopy(latent_params)
latent_nafca_params_complete['model_params']['CAParams'] = latent_nafca_params
latent_nafca_params_complete['model_name'] = 'LatentNAFCA'
latent_vitca_params_complete = copy.deepcopy(latent_params)
latent_vitca_params_complete['model_params']['CAParams'] = latent_vitca_params
latent_vitca_params_complete['model_name'] = 'LatentViTCA'

In [6]:
latent_model = False
if model == nafnet_params['model_name'].lower():
    common_params['model_params_dicts'] = [nafnet_params]
elif model == restormer_params['model_name'].lower():
    common_params['model_params_dicts'] = [restormer_params]
elif model == vitca_params['model_name'].lower():
    common_params['model_params_dicts'] = [vitca_params]
elif model == latent_vitca_params_complete['model_name'].lower():
    common_params['model_params_dicts'] = [latent_vitca_params_complete]
    latent_model = True
else:
    common_params['model_params_dicts'] = [latent_nafca_params_complete]
    latent_model = True

In [7]:
gpus = GPUtil.getGPUs()
vram_before = gpus[0].memoryFree
print('Initial free GPU VRAM: {:.2f} MB'.format(vram_before))

Initial free GPU VRAM: 11181.00 MB


In [8]:
launcher = TestLauncher(common_params)
launcher.launch()

Start LatentNAFCA's Training...
1/3 epoch - 24/24 batch - train_rec_loss: 0.073 - train_dist_loss: 0.082 - train_task_loss: 0.021 - train_equiv_loss: 0.041 - train_PSNR: 11.438 - train_SSIM: 0.324 || val_rec_loss: 0.074 - val_dist_loss: 0.878 - val_task_loss: 0.043 - val_equiv_loss: 0.081 - val_PSNR: 11.298 - val_SSIM: 0.170
 || elapsed_time: 13.846
Found new Best, saving...



KeyboardInterrupt



Error in callback <function _draw_all_if_interactive at 0x0000029662C42DD0> (for post_execute):



KeyboardInterrupt



Error in callback <function flush_figures at 0x000002979A82BB50> (for post_execute):



KeyboardInterrupt



In [None]:
gpus = GPUtil.getGPUs()
vram_used = vram_before - gpus[0].memoryFree
print('Maximum GPU VRAM usage: {:.2f} MB'.format(vram_used))

In [None]:
df = get_elapsed_time(base_dir=os.path.join(common_params['base_dir'], common_params['task'],common_params['dataset']),
                      model_name=common_params['model_params_dicts'][0]['model_name'], latent=latent_model)
if save_df:
    df.to_pickle('training_latency-{}-{}-{}.pkl'.format(model, img_shape, num_iterations))
df

In [None]:
print('Mean elapsed time: {:.2f} s'.format(df.mean()))