In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# import os; os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
import numpy as np
import tensorflow as tf

import wandb
from wandb.keras import WandbCallback
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint

from perceptnet.networks import *
from perceptnet.pearson_loss import PearsonCorrelation

from iqadatasets.datasets.tid2008 import TID2008
from iqadatasets.datasets.tid2013 import TID2013
from flayers.callbacks import *

# Load the data

In [4]:
cuac = TID2008("/lustre/ific.uv.es/ml/uv075/Databases/IQA/TID/TID2008", exclude_imgs=[25])
cuac_val = TID2013("/lustre/ific.uv.es/ml/uv075/Databases/IQA/TID/TID2013", exclude_imgs=[25])

In [5]:
dst_train = cuac.dataset
dst_val = cuac_val.dataset

2023-01-30 00:07:52.741883: E tensorflow/stream_executor/cuda/cuda_driver.cc:271] failed call to cuInit: CUDA_ERROR_NO_DEVICE: no CUDA-capable device is detected
2023-01-30 00:07:52.741959: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:169] retrieving CUDA diagnostic information for host: mlui02.ific.uv.es
2023-01-30 00:07:52.741978: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:176] hostname: mlui02.ific.uv.es
2023-01-30 00:07:52.742220: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:200] libcuda reported version is: 520.61.5
2023-01-30 00:07:52.742276: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:204] kernel reported version is: 520.61.5
2023-01-30 00:07:52.742289: I tensorflow/stream_executor/cuda/cuda_diagnostics.cc:310] kernel version seems to match DSO: 520.61.5
2023-01-30 00:07:52.743860: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CP

# Wandb config

In [6]:
config = {
        'train_dataset':'TID2008',
        'test_dataset':'TID2013',
        'epochs':500,
        'learning_rate':3e-4,
        'batch_size':64,
        'kernel_initializer':'ones',
        'gdn_kernel_size':1,
        'learnable_undersampling':False,
        'verbose': 0,
        # 'test_images':['20', '21', '22', '23', '24'],
        # 'test_dists':['05', '10', '15', '20', '24']
    }

In [7]:
wandb.init(project='PerceptNet2',
            notes="",
            tags=["full", "norm", "min", "excluded non-natural", 'Train_TID2008', 'Test_TID2013'],
            name = 'ExpGaborLast',
            config=config,
            job_type="training",
            mode="online",
            )
config = wandb.config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: jorgvt. Use `wandb login --relogin` to force relogin


# Define the model

In [8]:
# model = PerceptNetExpGDNGaussian(kernel_initializer=config.kernel_initializer, gdn_kernel_size=config.gdn_kernel_size)
model = PerceptNetExpGaborLast(kernel_initializer=config.kernel_initializer, gdn_kernel_size=config.gdn_kernel_size)

In [9]:
model.compile(optimizer=tf.optimizers.Adam(learning_rate=config.learning_rate),
              loss=PearsonCorrelation())

Log the number of trainable weights:

In [10]:
model.build((None,384,512,3))
num_trainable_vars = np.sum([np.prod(v.shape) for v in model.trainable_variables])
wandb.run.summary["trainable_parameters"] = num_trainable_vars
num_vars = np.sum([np.prod(v.shape) for v in model.weights])
wandb.run.summary["parameters"] = int(num_vars)
print(f"Trainable: {num_trainable_vars} | Vars: {num_vars}")

Trainable: 22544 | Vars: 329755.0


In [11]:
history = model.fit(dst_train.shuffle(buffer_size=100,
                                      reshuffle_each_iteration=True,
                                      seed=42) \
                             .batch(config.batch_size), epochs=config.epochs, 
                    validation_data=dst_val.batch(config.batch_size),
                    callbacks=[WandbMetricsLogger(log_freq="epoch"),
                               WandbModelCheckpoint(filepath="model-best",
                                                    monitor="val_loss",
                                                    save_best_only=True,
                                                    save_weights_only=True,
                                                    mode="min")
                               ],
                    verbose=config.verbose)



In [None]:
wandb.finish()

<class 'TypeError'>: get_range() missing 1 required positional argument: 'session'