In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import os; os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [None]:
import os

import numpy as np
import tensorflow as tf
from fastcore.xtras import Path

import wandb
from wandb.keras import WandbMetricsLogger, WandbModelCheckpoint
import scipy.stats as stats

from perceptnet.networks import *
from iqadatasets.datasets.tid2013 import TID2013

from flayers.callbacks import *

In [None]:
class EvaluatePerceptuality(tf.keras.callbacks.Callback):
    """Evaluates a perceptual model that is part of another model."""

    def __init__(self, 
                 dst, # Dataset to be evaluated.
                 model, # Model to be evaluated.
                 name=None, # Name to prepend to the logged metrics.
                 ):
        self.dst = dst
        self.eval_model = model
        self.name = "" if name is None else name+"_"
        
    def on_epoch_end(self,
                     epoch, 
                     logs=None):
        distances, moses = [], []
        for i, data in enumerate(self.dst):
            img, dist_img, mos = data
            features_original = self.eval_model(img, training=False)
            features_distorted = self.eval_model(dist_img, training=False)
            l2 = (features_original-features_distorted)**2
            l2 = tf.reduce_sum(l2, axis=[1,2,3])
            l2 = tf.sqrt(l2)
            distances.extend(l2)
            moses.extend(mos)
        pearson = stats.pearsonr(distances, moses)[0]
        spearman = stats.spearmanr(distances, moses)[0]
        wandb.log({f"{self.name}Pearson": pearson,
                   f"{self.name}Spearman": spearman}, commit=False)

# Wandb config

In [None]:
config = {
        'epochs':500,
        'learning_rate':3e-4,
        'batch_size':64,
        'kernel_initializer':'ones',
        'gdn_kernel_size':1,
        'learnable_undersampling':False,
        'verbose': 0,
        'dataset': 'cifar10', # imagenet / imagenette / cifar10 / cifar100,
        'validation_split': 0.2,
        'seed': 42
    }

In [None]:
wandb.init(project='PerceptNetClassification',
            notes="",
            tags=[],
            name = 'Baseline',
            config=config,
            job_type="training",
            mode="online",
            )
config = wandb.config

Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
wandb: Currently logged in as: jorgvt. Use `wandb login --relogin` to force relogin


# Load the data

In [None]:
def load_imagenet():
    path_data = Path("/lustre/ific.uv.es/ml/uv075/Databases/imagenet_images/")
    dst_train = tf.keras.utils.image_dataset_from_directory(
                path_data,
                validation_split=config.validation_split,
                subset="training",
                seed=config.seed,
                shuffle=True,
                # image_size=(img_height, img_width),
                batch_size=config.batch_size)
    dst_val = tf.keras.utils.image_dataset_from_directory(
                path_data,
                validation_split=config.validation_split,
                subset="validation",
                seed=config.seed,
                shuffle=False,
                # image_size=(img_height, img_width),
                batch_size=config.batch_size)
    return dst_train, dst_val

In [None]:
def load_imagenette():
    import tensorflow_datasets as tfds

    dst_train, info = tfds.load("imagenette/320px-v2", split=f"train[:{config.validation_split*100:.0f}%]", with_info=True, shuffle_files=True)
    dst_val = tfds.load("imagenette/320px-v2", split=f"train[{config.validation_split*100:.0f}%:]", with_info=False, shuffle_files=False)
    def prepare_tfds(item):
        x, y = item["image"], item["label"]
        x = tf.image.resize_with_crop_or_pad(x, 256, 256)
        return x, y
    dst_train = dst_train.map(prepare_tfds)
    dst_val = dst_val.map(prepare_tfds)

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size), info.features["label"].num_classes

In [None]:
def load_cifar10():
    from tensorflow.keras.datasets import cifar10
    from sklearn.model_selection import train_test_split

    (X_train, Y_train), (X_test, Y_test) = cifar10.load_data()
    X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=config.validation_split, random_state=config.seed)
    dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
    dst_val = tf.data.Dataset.from_tensor_slices((X_val, Y_val))

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size)

In [None]:
def load_cifar100():
    from tensorflow.keras.datasets import cifar100
    from sklearn.model_selection import train_test_split

    (X_train, Y_train), (X_test, Y_test) = cifar100.load_data()
    X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=config.validation_split, random_state=config.seed)
    dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
    dst_val = tf.data.Dataset.from_tensor_slices((X_val, Y_val))

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size)

In [None]:
if config.dataset == "imagenet":
    dst_train, dst_val = load_imagenet()
    N_CLASSES = len(dst_train.class_names)
elif config.dataset == "cifar10":
    dst_train, dst_val = load_cifar10()
    N_CLASSES = 10
elif config.dataset == "cifar100":
    dst_train, dst_val = load_cifar100()
    N_CLASSES = 100
elif config.dataset == "imagenette":
    dst_train, dst_val, N_CLASSES = load_imagenette()
else:
    raise ValueError("Dataset parameter not allowed.")
print(f"Training on {config.dataset} with {N_CLASSES} classes.")

Training on cifar10 with 10 classes.


In [1]:
x, y = next(iter(dst_train))
input_shape = x[0].shape
input_shape

TensorShape([32, 32, 3])

In [None]:
wandb.run.summary["N_CLASSES"] = N_CLASSES
wandb.run.summary["Input_Shape"] = input_shape

<class 'AttributeError'>: 'TensorShape' object has no attribute 'eval'

In [None]:
dst_tid2013 = TID2013("/lustre/ific.uv.es/ml/uv075/Databases/IQA/TID/TID2013", exclude_imgs=[25]).dataset.batch(config.batch_size)

### Normalize the data

In [None]:
normalization_layer = layers.Rescaling(1./255)

In [None]:
dst_train = dst_train.map(lambda x,y: (normalization_layer(x), y))
dst_val = dst_val.map(lambda x,y: (normalization_layer(x), y))

### Performance

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

dst_train = dst_train.cache().prefetch(buffer_size=AUTOTUNE)
dst_val = dst_val.cache().prefetch(buffer_size=AUTOTUNE)

# Define the model

In [None]:
# model = PerceptNetExpGDNGaussian(kernel_initializer=config.kernel_initializer, gdn_kernel_size=config.gdn_kernel_size)
# model = PerceptNetExpGaborLast(kernel_initializer=config.kernel_initializer, gdn_kernel_size=config.gdn_kernel_size)
feature_extractor = PerceptNet(kernel_initializer=config.kernel_initializer, gdn_kernel_size=config.gdn_kernel_size, learnable_undersampling=config.learnable_undersampling)
model = tf.keras.Sequential([
    feature_extractor,
    layers.GlobalAveragePooling2D(),
    layers.Dense(N_CLASSES, activation="softmax")
])

In [None]:
model.compile(optimizer=tf.optimizers.Adam(learning_rate=config.learning_rate),
              loss="sparse_categorical_crossentropy",
              metrics=["accuracy"])

Log the number of trainable weights:

In [None]:
if config.dataset == "imagenet" or config.dataset == "imagenette":
    model.build((None,256,256,3))
elif config.dataset == "cifar10" or config.dataset == "cifar100":
    model.build((None,32,32,3))
else: # If it isn't a known dataset, just call the model on a batch of data to build the weights.
    pred = model(x)

In [None]:
num_trainable_vars = np.sum([np.prod(v.shape) for v in model.trainable_variables])
wandb.run.summary["trainable_parameters"] = num_trainable_vars
num_vars = np.sum([np.prod(v.shape) for v in model.weights])
wandb.run.summary["parameters"] = int(num_vars)
print(f"Trainable: {num_trainable_vars} | Vars: {num_vars}")

Trainable: 37658 | Vars: 37666


In [None]:
history = model.fit(dst_train, 
                    epochs=config.epochs, 
                    validation_data=dst_val,
                    callbacks=[EvaluatePerceptuality(dst=dst_tid2013, model=feature_extractor, name="TID2013"),
                               WandbMetricsLogger(log_freq="epoch"),
                               WandbModelCheckpoint(filepath=os.path.join(wandb.run.dir, "model-best"),
                                                    monitor="val_loss",
                                                    save_best_only=True,
                                                    save_weights_only=True,
                                                    mode="min")
                               ],
                    verbose=config.verbose)



In [None]:
wandb.finish()

<class 'TypeError'>: get_range() missing 1 required positional argument: 'session'