In [1]:
%load_ext autoreload
%autoreload 2

# IQA tracking params and variables

> When using parametric layers we have to be able to keep track of the parameters and the variables of the model (which are not going to be trained). We're going to play with this concept using our implementation of the functional layers.

In [2]:
# import os; os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"]=".99"

In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

from typing import Any, Callable, Sequence, Union
import numpy as np

import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

import jax
from jax import lax, random, numpy as jnp
import flax
from flax.core import freeze, unfreeze, FrozenDict
from flax import linen as nn
from flax import struct
from flax.training import train_state
from flax.training import orbax_utils

import optax
import orbax.checkpoint

from clu import metrics
from ml_collections import ConfigDict

from einops import reduce, rearrange, repeat
import wandb
from iqadatasets.datasets import *
from fxlayers.layers import *
from fxlayers.layers import GaussianLayerGamma, FreqGaussianGamma, OrientGaussianGamma, GaborLayerGamma_, GaborLayerGammaRepeat
from fxlayers.initializers import *
from JaxPlayground.utils.constraints import *
from JaxPlayground.utils.wandb import *

## Load pre-trained model

In [4]:
import json

from huggingface_hub import hf_hub_download
import flax
import orbax.checkpoint
from ml_collections import ConfigDict

from paramperceptnet.models import PerceptNet as PerceptNetBioFitted
from paramperceptnet.models import Baseline
from paramperceptnet.training import create_train_state as create_train_state_teacher

In [5]:
config_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
                              filename="config.json")
with open(config_path, "r") as f:
    config_t = ConfigDict(json.load(f))

In [6]:
from safetensors.flax import load_file

weights_path = hf_hub_download(repo_id="Jorgvt/ppnet-bio-fitted",
                               filename="weights.safetensors")
variables = load_file(weights_path)
variables = flax.traverse_util.unflatten_dict(variables, sep=".")
state = variables["state"]
params = variables["params"]

In [7]:
state_t = create_train_state_teacher(PerceptNetBioFitted(config_t), random.PRNGKey(42), optax.adam(3e-4), input_shape=(1,384,512,3))
state_t = state_t.replace(params=params,
                          state=state)

In [8]:
# jax.config.update("jax_debug_nans", False)

## Load the data

> We're going to employ `iqadatasets` to ease the loading of the data.

In [9]:
# dst_train = TID2008("/lustre/ific.uv.es/ml/uv075/Databases/IQA//TID/TID2008/", exclude_imgs=[25])
# dst_train = KADIK10K("/lustre/ific.uv.es/ml/uv075/Databases/IQA/KADIK10K/")
# dst_val = TID2013("/lustre/ific.uv.es/ml/uv075/Databases/IQA//TID/TID2013/", exclude_imgs=[25])
dst_train = TID2008("/media/disk/vista/BBDD_video_image/Image_Quality//TID/TID2008/", exclude_imgs=[25])
dst_val = TID2013("/media/disk/vista/BBDD_video_image/Image_Quality//TID/TID2013/", exclude_imgs=[25])
# dst_train = TID2008("/media/databases/IQA/TID/TID2008/", exclude_imgs=[25])
# dst_val = TID2013("/media/databases/IQA/TID/TID2013/", exclude_imgs=[25])

In [None]:
img, img_dist, mos = next(iter(dst_train.dataset))
img.shape, img_dist.shape, mos.shape

In [None]:
img, img_dist, mos = next(iter(dst_val.dataset))
img.shape, img_dist.shape, mos.shape

In [None]:
config = {
    "TEACHER_ID": "bio-fitted",
    "BATCH_SIZE": 64,
    "EPOCHS": 500,
    "LEARNER_LR": 3e-4,
    # "INITIAL_LR": 1e-2,
    # "PEAK_LR": 4e-2,
    # "END_LR": 5e-3,
    # "WARMUP_EPOCHS": 15,
    "SEED": 42,
    "GDN_CLIPPING": True,
    "NORMALIZE_PROB": False,
    "NORMALIZE_ENERGY": True,
    "ZERO_MEAN": True,
    "USE_BIAS": False,
    "CS_KERNEL_SIZE": 21,
    "GDNGAUSSIAN_KERNEL_SIZE": 11,
    "GDNSPATIOFREQ_KERNEL_SIZE": 11,
    "GABOR_KERNEL_SIZE": 31,
    "N_SCALES": 4,
    "N_ORIENTATIONS": 16,
    "N_GABORS": 128,
    "USE_GAMMA": True,
    "INIT_JH": True,
    "INIT_GABOR": True,
    "TRAIN_JH": False,
    "TRAIN_CS": False,
    "TRAIN_GABOR": False,
    "A_GABOR": False,
    "A_GDNSPATIOFREQORIENT": False,
    "TRAIN_ONLY_LAST_GDN": True,
}
config = ConfigDict(config)
config

In [13]:
# if config.TEACHER_ID is not None:
#     id = config.TEACHER_ID
#     api = wandb.Api()
#     prev_run = api.run(f"jorgvt/PerceptNet_v15/{id}")
#     config_t = ConfigDict(prev_run.config["_fields"])
#     for file in prev_run.files():
#         file.download(root=prev_run.dir, replace=True)

In [None]:
wandb.init(project="ParametricKnowledgeDistillation",
           name="FreeToBioFitted",
           job_type="training",
           config=config,
           mode="online",
           )
config = config
config

In [15]:
dst_train_rdy = dst_train.dataset.shuffle(buffer_size=100,
                                      reshuffle_each_iteration=True,
                                      seed=config.SEED)\
                                 .batch(config.BATCH_SIZE, drop_remainder=True)
dst_val_rdy = dst_val.dataset.batch(config.BATCH_SIZE, drop_remainder=True)

## Define the model we're going to use

> It's going to be a very simple model just for demonstration purposes.

In [16]:
class PerceptNetBaseline(nn.Module):
    """IQA model inspired by the visual system."""

    @nn.compact
    def __call__(self,
                 inputs,
                 **kwargs,
                 ):
        outputs = GDN(kernel_size=1, strides=1, padding="SAME", apply_independently=True)(inputs)
        outputs = nn.Conv(features=3, kernel_size=(1,1), strides=1, padding="SAME")(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = GDN(kernel_size=1, strides=1, padding="SAME", apply_independently=False)(outputs)
        outputs = nn.Conv(features=6, kernel_size=(config.CS_KERNEL_SIZE,config.CS_KERNEL_SIZE), strides=1, padding="SAME")(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = GDN(kernel_size=config.GDNGAUSSIAN_KERNEL_SIZE, strides=1, padding="SAME", apply_independently=False)(outputs)
        outputs = nn.Conv(features=config.N_GABORS, kernel_size=(config.GABOR_KERNEL_SIZE,config.GABOR_KERNEL_SIZE), strides=1, padding="SAME")(outputs)
        outputs = GDN(kernel_size=config.GDNSPATIOFREQ_KERNEL_SIZE, strides=1, padding="SAME", apply_independently=False)(outputs)
        return outputs

## Define the metrics with `clu`

In [17]:
@struct.dataclass
class Metrics(metrics.Collection):
    """Collection of metrics to be tracked during training."""
    loss: metrics.Average.from_output("loss")
    correlation_l: metrics.Average.from_output("correlation_l")

In [18]:
@struct.dataclass
class MetricsT(metrics.Collection):
    """Collection of metrics to be tracked during training."""
    loss: metrics.Average.from_output("loss")

By default, `TrainState` doesn't include metrics, but it's very easy to subclass it so that it does:

In [19]:
class TrainState(train_state.TrainState):
    metrics: Metrics
    state: FrozenDict

We'll define a function that initializes the `TrainState` from a module, a rng key and some optimizer:

In [20]:
def create_train_state(module, key, tx, input_shape):
    """Creates the initial `TrainState`."""
    variables = module.init(key, jnp.ones(input_shape))
    state, params = variables.pop('params')
    return TrainState.create(
        apply_fn=module.apply,
        params=params,
        state=state,
        tx=tx,
        metrics=Metrics.empty()
    )

## Defining the training step

> We want to write a function that takes the `TrainState` and a batch of data can performs an optimization step.

In [21]:
def pearson_correlation(vec1, vec2):
    vec1 = vec1.squeeze()
    vec2 = vec2.squeeze()
    vec1_mean = vec1.mean()
    vec2_mean = vec2.mean()
    num = vec1-vec1_mean
    num *= vec2-vec2_mean
    num = num.sum()
    denom = jnp.sqrt(jnp.sum((vec1-vec1_mean)**2))
    denom *= jnp.sqrt(jnp.sum((vec2-vec2_mean)**2))
    return num/denom

In [22]:
from functools import partial

In [40]:
@partial(jax.jit, static_argnums=3)
def train_step(state, state_t, batch, return_grads=False):
    """Train for a single step."""
    img, img_dist, mos = batch
    def loss_fn(params):
        ## Forward pass through the teacher model
        img_pred_t = state_t.apply_fn({"params": state_t.params, **state_t.state}, img, train=False)
        img_dist_pred_t = state_t.apply_fn({"params": state_t.params, **state_t.state}, img_dist, train=False)

        ## Forward pass through the learner model
        img_pred = state.apply_fn({"params": params, **state.state}, img)
        img_dist_pred = state.apply_fn({"params": params, **state.state}, img_dist)

        ## Calculate the distance
        dist_img = ((img_pred_t - img_pred)**2).sum(axis=(1,2,3))**(1/2)
        dist_img_dist = ((img_dist_pred_t - img_dist_pred)**2).sum(axis=(1,2,3))**(1/2)

        ## Optimize so that the distances are the same as the parametric model
        loss = (dist_img.mean() + dist_img_dist.mean()).mean()

        dist_l = ((img_pred - img_dist_pred)**2).sum(axis=(1,2,3))**(1/2)

        ## Return the distillation loss and the correlation of the learner
        return loss, pearson_correlation(dist_l, mos)
    
    (loss, correlation_l), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    metrics_updates = state.metrics.single_from_model_output(loss=loss, correlation_l=correlation_l)
    metrics = state.metrics.merge(metrics_updates)
    state = state.replace(metrics=metrics)
    if return_grads: return state, grads
    else: return state

In their example, they don't calculate the metrics at the same time. I think it is kind of a waste because it means having to perform a new forward pass, but we'll follow as of now. Let's define a function to perform metric calculation:

In [41]:
@jax.jit
def compute_metrics(*, state, state_t, batch):
    """Obtaining the metrics for a given batch."""
    img, img_dist, mos = batch
    def loss_fn(params):
        ## Forward pass through the teacher model
        img_pred_t = state_t.apply_fn({"params": state_t.params, **state_t.state}, img, train=False)
        img_dist_pred_t = state_t.apply_fn({"params": state_t.params, **state_t.state}, img_dist, train=False)

        ## Forward pass through the learner model
        img_pred = state.apply_fn({"params": params, **state.state}, img)
        img_dist_pred = state.apply_fn({"params": params, **state.state}, img_dist)

        ## Calculate the distance
        dist_img = ((img_pred_t - img_pred)**2).sum(axis=(1,2,3))**(1/2)
        dist_img_dist = ((img_dist_pred_t - img_dist_pred)**2).sum(axis=(1,2,3))**(1/2)

        ## Optimize so that the distances are the same as the parametric model
        loss = (dist_img.mean() + dist_img_dist.mean()).mean()

        dist_l = ((img_pred - img_dist_pred)**2).sum(axis=(1,2,3))**(1/2)

        ## Return the distillation loss and the correlation of the learner
        return loss, pearson_correlation(dist_l, mos)
    loss, correlation_l = loss_fn(state.params) 
    metrics_updates = state.metrics.single_from_model_output(loss=loss, correlation_l=correlation_l)
    metrics = state.metrics.merge(metrics_updates)
    state = state.replace(metrics=metrics)
    return state

## Train the model!

In [25]:
tx = optax.adam(learning_rate=config.LEARNER_LR)
state = create_train_state(PerceptNetBaseline(), random.PRNGKey(config.SEED), tx, input_shape=(1,384,512,3))
state = state.replace(params=clip_layer(state.params, "GDN", a_min=0))

Before actually training the model we're going to set up the checkpointer to be able to save our trained models:

In [26]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(state)

In [27]:
orbax_checkpointer.save(os.path.join(wandb.run.dir, "model-0"), state, save_args=save_args, force=True) # force=True means allow overwritting.

In [28]:
metrics_history = {
    "train_loss": [],
    "train_correlation_l": [],
    "val_loss": [],
    "val_correlation_l": [],
}

In [None]:
batch = next(iter(dst_train_rdy.as_numpy_iterator()))

In [30]:
from functools import partial

In [31]:
@jax.jit
def forward(state, inputs):
    return state.apply_fn({"params": state.params, **state.state}, inputs, train=False)

In [32]:
@jax.jit
def forward_intermediates(state, inputs):
    return state.apply_fn({"params": state.params, **state.state}, inputs, train=False, capture_intermediates=True)

In [None]:
%%time
outputs = forward(state, batch[0])
outputs.shape

In [None]:
%%time
s1, grads = train_step(state, state_t, batch, return_grads=True)

In [44]:
# jax.config.update("jax_debug_nans", True)

In [45]:
def filter_extra(extra):
    def filter_intermediates(path, x):
        path = "/".join(path)
        if "Gabor" in path:
            return (x[0][0],)
        else: 
            return x
    extra = unfreeze(extra)
    extra["intermediates"] = flax.traverse_util.path_aware_map(filter_intermediates, extra["intermediates"])
    return freeze(extra)

In [None]:
%%time
step = 0
for epoch in range(config.EPOCHS):
    ## Training
    for batch in dst_train_rdy.as_numpy_iterator():
        state, grads = train_step(state, state_t, batch, return_grads=True)
        state = state.replace(params=clip_layer(state.params, "GDN", a_min=0))
        state = state.replace(params=clip_param(state.params, "A", a_min=0))
        state = state.replace(params=clip_param(state.params, "K", a_min=1+1e-5))
        wandb.log({f"{k}_grad": wandb.Histogram(v) for k, v in flatten_params(grads).items()}, commit=False)
        step += 1
        # state = compute_metrics(state=state, batch=batch)
        # break

    ## Log the metrics
    for name, value in state.metrics.compute().items():
        metrics_history[f"train_{name}"].append(value)
    
    ## Empty the metrics
    state = state.replace(metrics=state.metrics.empty())

    ## Evaluation
    for batch in dst_val_rdy.as_numpy_iterator():
        state = compute_metrics(state=state, state_t=state_t, batch=batch)
        # break
    for name, value in state.metrics.compute().items():
        metrics_history[f"val_{name}"].append(value)
    state = state.replace(metrics=state.metrics.empty())
    
    ## Obtain activations of last validation batch
    _, extra = forward_intermediates(state, batch[0])
    extra = filter_extra(extra) ## Needed because the Gabor layer has multiple outputs
    
    ## Checkpointing
    if metrics_history["val_correlation_l"][-1] <= min(metrics_history["val_correlation_l"]):
        orbax_checkpointer.save(os.path.join(wandb.run.dir, "model-best"), state, save_args=save_args, force=True) # force=True means allow overwritting.
    # orbax_checkpointer.save(os.path.join(wandb.run.dir, f"model-{epoch+1}"), state, save_args=save_args, force=False) # force=True means allow overwritting.

    wandb.log({f"{k}": wandb.Histogram(v) for k, v in flatten_params(state.params).items()}, commit=False)
    wandb.log({f"{k}": wandb.Histogram(v) for k, v in flatten_params(extra["intermediates"]).items()}, commit=False)
    # wandb.log({"epoch": epoch+1, "learning_rate": schedule_lr(step), **{name:values[-1] for name, values in metrics_history.items()}})
    wandb.log({"epoch": epoch+1, **{name:values[-1] for name, values in metrics_history.items()}})
    print(f'Epoch {epoch} -> [Train] Loss: {metrics_history["train_loss"][-1]} [Val] Loss: {metrics_history["val_loss"][-1]}')
    # break

Save the final model as well in case we want to keep training from it or whatever:

In [38]:
orbax_checkpointer.save(os.path.join(wandb.run.dir, "model-final"), state, save_args=save_args)

In [39]:
wandb.finish()