In [1]:
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

2023-12-04 15:08:35.279672: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [44]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"
from tqdm.auto import tqdm

from typing import Any, Callable, Sequence, Union
import numpy as np
import scipy.stats as stats

import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

import jax
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze, FrozenDict
from flax import linen as nn
from flax import struct
from flax.training import train_state
from flax.training import orbax_utils

import optax
import orbax.checkpoint

from clu import metrics
from ml_collections import ConfigDict

from einops import reduce, rearrange
import wandb

from iqadatasets.datasets import *
from fxlayers.layers import *
from fxlayers.layers import GaborLayerLogSigma_, GaussianLayerGamma, FreqGaussianGamma, OrientGaussianGamma
from fxlayers.initializers import *
from JaxPlayground.utils.constraints import *
from JaxPlayground.utils.wandb import *

In [3]:
# dst_train = TID2008("/lustre/ific.uv.es/ml/uv075/Databases/IQA//TID/TID2008/", exclude_imgs=[25])
# dst_val = TID2013("/lustre/ific.uv.es/ml/uv075/Databases/IQA//TID/TID2013/", exclude_imgs=[25])
# dst_train = TID2008("/media/disk/databases/BBDD_video_image/Image_Quality//TID/TID2008/", exclude_imgs=[25])
# dst_val = TID2013("/media/disk/databases/BBDD_video_image/Image_Quality//TID/TID2013/", exclude_imgs=[25])
# dst = KADIK10K("/media/disk/databases/BBDD_video_image/Image_Quality/KADIK10K/")
# dst = PIPAL("/media/disk/databases/BBDD_video_image/Image_Quality/PIPAL/")
dst = TID2013("/lustre/ific.uv.es/ml/uv075/Databases/IQA//TID/TID2013/")

In [1]:
img, img_dist, mos = next(iter(dst.dataset))
img.shape, img_dist.shape, mos.shape

(TensorShape([384, 512, 3]), TensorShape([384, 512, 3]), TensorShape([]))

In [2]:
config = {
    "BATCH_SIZE": 64,
    "EPOCHS": 500,
    "LEARNING_RATE": 3e-3,
    "SEED": 42,
    "GDN_CLIPPING": True,
    "NORMALIZE_PROB": False,
    "NORMALIZE_ENERGY": True,
    "ZERO_MEAN": True,
    "USE_BIAS": False,
    "CS_KERNEL_SIZE": 21,
    "GABOR_KERNEL_SIZE": 31,
    "N_SCALES": 4,
    "N_ORIENTATIONS": 8,
}
config = ConfigDict(config)
config

BATCH_SIZE: 64
CS_KERNEL_SIZE: 21
EPOCHS: 500
GABOR_KERNEL_SIZE: 31
GDN_CLIPPING: true
LEARNING_RATE: 0.003
NORMALIZE_ENERGY: true
NORMALIZE_PROB: false
N_ORIENTATIONS: 8
N_SCALES: 4
SEED: 42
USE_BIAS: false
ZERO_MEAN: true

In [6]:
id = "8a4ra7yl"

In [7]:
api = wandb.Api()
prev_run = api.run(f"jorgvt/PerceptNet_v15/{id}")

In [3]:
config = ConfigDict(prev_run.config["_fields"])
config

BATCH_SIZE: 64
CS_KERNEL_SIZE: 21
EPOCHS: 500
GABOR_KERNEL_SIZE: 31
GDN_CLIPPING: true
LEARNING_RATE: 0.003
NORMALIZE_ENERGY: true
NORMALIZE_PROB: false
N_ORIENTATIONS: 8
N_SCALES: 4
SEED: 42
USE_BIAS: false
ZERO_MEAN: true

In [9]:
for file in prev_run.files():
    file.download(root=prev_run.dir, replace=True)

In [4]:
wandb.init(project="PerceptNet_JaX_Eval",
           name=prev_run.name,
           job_type="evaluate",
           mode="online",
           )
config = config
config

wandb: Currently logged in as: jorgvt. Use `wandb login --relogin` to force relogin
wandb: wandb version 0.16.0 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.15.1
wandb: Run data is saved locally in /lhome/ext/uv075/uv0752/perceptnet/Notebooks/13_JaX/13_04_V18/wandb/run-20231204_200918-0y86zt67
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run DN_OG_G_Fixed_GDNSpatioFreqOrient_Symm_CSPos
wandb:  View project at https://wandb.ai/jorgvt/PerceptNet_JaX_Eval
wandb:  View run at https://wandb.ai/jorgvt/PerceptNet_JaX_Eval/runs/0y86zt67


BATCH_SIZE: 64
CS_KERNEL_SIZE: 21
EPOCHS: 500
GABOR_KERNEL_SIZE: 31
GDN_CLIPPING: true
LEARNING_RATE: 0.003
NORMALIZE_ENERGY: true
NORMALIZE_PROB: false
N_ORIENTATIONS: 8
N_SCALES: 4
SEED: 42
USE_BIAS: false
ZERO_MEAN: true

In [11]:
dst_rdy = dst.dataset.batch(config.BATCH_SIZE, num_parallel_calls=tf.data.AUTOTUNE)

## Define the model we're going to use

> It's going to be a very simple model just for demonstration purposes.

In [12]:
def pad_same_from_kernel_size(inputs, kernel_size, mode):
    return jnp.pad(inputs,
                   [[0,0],
                    [(kernel_size-1)//2, (kernel_size-1)//2],
                    [(kernel_size-1)//2, (kernel_size-1)//2],
                    [0,0]],
                    mode=mode)

In [14]:
class PerceptNet(nn.Module):
    """IQA model inspired by the visual system."""

    @nn.compact
    def __call__(self,
                 inputs, # Assuming fs = 128 (cpd)
                 **kwargs,
                 ):
        ## (Independent) Color equilibration (Gamma correction)
        ## Might need to be the same for each number
        ## bias = 0.1 / kernel = 0.5
        outputs = GDN(kernel_size=(1,1), apply_independently=True)(inputs)
        
        ## Color (ATD) Transformation
        outputs = nn.Conv(features=3, kernel_size=(1,1), use_bias=False)(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        
        ## GDN Star A - T - D [Separated]
        outputs = GDN(kernel_size=(1,1), apply_independently=True)(outputs)

        ## Center Surround (DoG)
        ## Initialized so that 3 are positives and 3 are negatives and no interaction between channels is present
        outputs = pad_same_from_kernel_size(outputs, kernel_size=config.CS_KERNEL_SIZE, mode="symmetric")
        outputs = CenterSurroundLogSigmaK(features=3, kernel_size=config.CS_KERNEL_SIZE, fs=21, use_bias=False, padding="VALID")(outputs, **kwargs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))

        ## GDN per channel with mean substraction in T and D (Spatial Gaussian Kernel)
        ### fs = 32 / kernel_size = (11,11) -> 0.32 > 0.02 --> OK!
        ## TO-DO: - Spatial Gaussian Kernel (0.02 deg) -> fs = 64/2 & 0.02*64/2 = sigma (px) = 0.69
        outputs = GDN(kernel_size=(1,1), apply_independently=True)(outputs)

        ## GaborLayer per channel with GDN mixing only same-origin-channel information
        ### [Gaussian] sigma = 0.2 (deg) fs = 32 / kernel_size = (21,21) -> 21/32 = 0.66 --> OK!
        outputs = pad_same_from_kernel_size(outputs, kernel_size=config.GABOR_KERNEL_SIZE, mode="symmetric")
        outputs = GaborLayerLogSigma_(n_scales=config.N_SCALES, n_orientations=config.N_ORIENTATIONS, kernel_size=config.GABOR_KERNEL_SIZE, fs=32, xmean=config.GABOR_KERNEL_SIZE/32/2, ymean=config.GABOR_KERNEL_SIZE/32/2, strides=1, padding="VALID", normalize_prob=config.NORMALIZE_PROB, normalize_energy=config.NORMALIZE_ENERGY, zero_mean=config.ZERO_MEAN, use_bias=config.USE_BIAS)(outputs, **kwargs)
        
        ## Final GDN mixing Gabor information (?)
        outputs = GDN(kernel_size=(1,1), apply_independently=False)(outputs)
        return outputs

## Define the metrics with `clu`

In [15]:
@struct.dataclass
class Metrics(metrics.Collection):
    """Collection of metrics to be tracked during training."""
    loss: metrics.Average.from_output("loss")

By default, `TrainState` doesn't include metrics, but it's very easy to subclass it so that it does:

In [16]:
class TrainState(train_state.TrainState):
    metrics: Metrics
    state: FrozenDict

We'll define a function that initializes the `TrainState` from a module, a rng key and some optimizer:

In [17]:
def create_train_state(module, key, tx, input_shape):
    """Creates the initial `TrainState`."""
    variables = module.init(key, jnp.ones(input_shape))
    state, params = variables.pop('params')
    return TrainState.create(
        apply_fn=module.apply,
        params=params,
        state=state,
        tx=tx,
        metrics=Metrics.empty()
    )

## Define evaluation step

In [18]:
@jax.jit
def compute_distance(*, state, batch):
    """Obtaining the metrics for a given batch."""
    img, img_dist, mos = batch
    
    ## Forward pass through the model
    img_pred = state.apply_fn({"params": state.params, **state.state}, img, train=False)
    img_dist_pred = state.apply_fn({"params": state.params, **state.state}, img_dist, train=False)

    ## Calculate the distance
    dist = ((img_pred - img_dist_pred)**2).sum(axis=(1,2,3))**(1/2)
    
    ## Calculate pearson correlation
    return dist

## Load the pretrained model!

In [19]:
state = create_train_state(PerceptNet(), random.PRNGKey(config.SEED), optax.adam(config.LEARNING_RATE), input_shape=(1,384,512,3))

In [20]:
def check_trainable(path):
    return False
    # return ("A" in path) or ("alpha_achrom" in path) or ("alpha_chrom_rg" in path) or ("alpha_chrom_yb" in path)

In [21]:
trainable_tree = freeze(flax.traverse_util.path_aware_map(lambda path, v: "non_trainable" if check_trainable(path)  else "trainable", state.params))

In [5]:
param_count = sum(x.size for x in jax.tree_util.tree_leaves(state.params))
trainable_param_count = sum([w.size if t=="trainable" else 0 for w, t in zip(jax.tree_util.tree_leaves(state.params), jax.tree_util.tree_leaves(trainable_tree))])
param_count, trainable_param_count

(286, 286)

In [26]:
wandb.run.summary["total_parameters"] = param_count
wandb.run.summary["trainable_parameters"] = trainable_param_count

In [23]:
optimizers = {
    "trainable": optax.adam(learning_rate=config.LEARNING_RATE),
    "non_trainable": optax.set_to_zero(),
}

In [24]:
tx = optax.multi_transform(optimizers, trainable_tree)

In [25]:
state = create_train_state(PerceptNet(), random.PRNGKey(config.SEED), tx, input_shape=(1,384,512,3))

In [27]:
# Before actually training the model we're going to set up the checkpointer to be able to save our trained models:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(state)

In [30]:
# Load weights
state = orbax_checkpointer.restore(os.path.join(prev_run.dir,"model-best"), item=state)

## Evaluate!

In [31]:
metrics_history = {
    "distance": [],
    "mos": [],
}

In [41]:
%%time
for batch in tqdm(dst_rdy.as_numpy_iterator()):
    img, img_dist, mos = batch
    distance = compute_distance(state=state, batch=batch)
    metrics_history["distance"].extend(distance)
    metrics_history["mos"].extend(mos)
    # break

CPU times: user 34.6 s, sys: 22.6 s, total: 57.3 s
Wall time: 26.5 s



0it [00:00, ?it/s]
1it [00:12, 12.37s/it]
3it [00:12,  3.26s/it]
5it [00:12,  1.63s/it]
7it [00:12,  1.02it/s]
9it [00:13,  1.54it/s]
11it [00:13,  2.18it/s]
13it [00:13,  2.95it/s]
15it [00:13,  3.81it/s]
17it [00:13,  4.73it/s]
19it [00:14,  5.65it/s]
21it [00:14,  6.51it/s]
23it [00:14,  7.28it/s]
24it [00:14,  7.65it/s]
25it [00:14,  8.03it/s]
27it [00:14,  8.67it/s]
29it [00:15,  9.10it/s]
31it [00:15,  9.38it/s]
33it [00:15,  9.56it/s]
34it [00:15,  9.63it/s]
35it [00:15,  9.69it/s]
36it [00:15,  9.75it/s]
37it [00:15,  9.80it/s]
38it [00:16,  9.84it/s]
39it [00:16,  9.87it/s]
40it [00:16,  9.89it/s]
41it [00:16,  9.91it/s]
42it [00:16,  9.92it/s]
43it [00:16,  9.92it/s]
44it [00:16,  9.93it/s]
45it [00:16,  9.94it/s]
46it [00:16,  9.93it/s]
47it [00:26,  2.93s/it]
47it [00:26,  1.78it/s]


In [None]:
assert len(metrics_history["distance"]) == len(dst.data)

In [6]:
stats.pearsonr(metrics_history["distance"], metrics_history["mos"]), stats.spearmanr(metrics_history["distance"], metrics_history["mos"])

(PearsonRResult(statistic=-0.8709295802064617, pvalue=0.0),
 SignificanceResult(statistic=-0.8506642916805114, pvalue=0.0))

In [7]:
results = dst.data.copy()
results["Distance"] = metrics_history["distance"]
results.head()

Unnamed: 0,Reference,Distorted,MOS,Reference_ID,Distortion_ID,Distortion_Intensity,Distance
0,I01.BMP,i01_01_1.bmp,5.51429,1,1,1,31.102566
1,I01.BMP,i01_01_2.bmp,5.56757,1,1,2,42.50249
2,I01.BMP,i01_01_3.bmp,4.94444,1,1,3,60.09311
3,I01.BMP,i01_01_4.bmp,4.37838,1,1,4,76.20087
4,I01.BMP,i01_01_5.bmp,3.86486,1,1,5,96.480576


In [None]:
wandb.log({"TID2013": wandb.Table(dataframe=results),
           "TID2013_pearson": stats.pearsonr(metrics_history["distance"], metrics_history["mos"])[0],
           "TID2013_spearman": stats.spearmanr(metrics_history["distance"], metrics_history["mos"])[0],
           })

In [None]:
wandb.finish()