In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import os; os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [None]:
# from jax.config import config
# config.update("jax_debug_nans", True)

In [None]:
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

In [None]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from typing import Any, Callable, Sequence, Union
import numpy as np
from fastcore.xtras import Path
from fastprogress.fastprogress import master_bar, progress_bar
import pandas as pd
import cv2

import jax
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze, FrozenDict
from flax import linen as nn
from flax import struct
from flax.training import train_state
from flax.training import orbax_utils

import optax
import orbax.checkpoint

from clu import metrics
from ml_collections import ConfigDict

from einops import reduce, rearrange
import wandb
from iqadatasets.datasets import *
from fxlayers.layers import *
from fxlayers.initializers import mean
from JaxPlayground.utils.constraints import *
from JaxPlayground.utils.wandb import *

# Wandb config

In [None]:
config = {
        'epochs':500,
        'learning_rate':3e-4,
        'batch_size':64,
        'kernel_initializer':'ones',
        'gdn_kernel_size':1,
        'learnable_undersampling':False,
        'verbose': 0,
        'dataset': 'imagenette', # imagenet / imagenette / cifar10 / cifar100,
        'validation_split': 0.2,
        'seed': 42,
        'GAP': False,
        'use_bias': True,
        "dropout_rate": 0.0,
        "l1": False,
        "LAMBDA": 0.0005,
    }

In [None]:
wandb.init(project='PerceptNetClassification_JaX',
            notes="",
            tags=[],
            name = 'NoVisionModel',
            config=config,
            job_type="training",
            mode="online",
            )
config = wandb.config

wandb: Currently logged in as: jorgvt. Use `wandb login --relogin` to force relogin
wandb: wandb version 0.16.3 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.16.0
wandb: Run data is saved locally in /home/jorge/perceptnet/Notebooks/11_Classification/wandb/run-20240214_190307-3g72mp5c
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run NoVisionModel
wandb:  View project at https://wandb.ai/jorgvt/PerceptNetClassification_JaX
wandb:  View run at https://wandb.ai/jorgvt/PerceptNetClassification_JaX/runs/3g72mp5c


# Load the data

In [None]:
def load_imagenet():
    path_data = Path("/lustre/ific.uv.es/ml/uv075/Databases/imagenet_images/")
    dst_train = tf.keras.utils.image_dataset_from_directory(
                path_data,
                validation_split=config.validation_split,
                subset="training",
                seed=config.seed,
                shuffle=True,
                # image_size=(img_height, img_width),
                batch_size=config.batch_size)
    dst_val = tf.keras.utils.image_dataset_from_directory(
                path_data,
                validation_split=config.validation_split,
                subset="validation",
                seed=config.seed,
                shuffle=False,
                # image_size=(img_height, img_width),
                batch_size=config.batch_size)
    return dst_train, dst_val

In [None]:
def load_imagenette():
    import tensorflow_datasets as tfds

    dst_train, info = tfds.load("imagenette/320px-v2", split=f"train[:{(1-config.validation_split)*100:.0f}%]", with_info=True, shuffle_files=True)
    dst_val = tfds.load("imagenette/320px-v2", split=f"train[{(1-config.validation_split)*100:.0f}%:]", with_info=False, shuffle_files=False)
    def prepare_tfds(item):
        x, y = item["image"], item["label"]
        x = tf.image.resize_with_crop_or_pad(x, 256, 256)
        return x, y
    dst_train = dst_train.map(prepare_tfds)
    dst_val = dst_val.map(prepare_tfds)

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size), info.features["label"].num_classes

In [None]:
def load_cifar10():
    from tensorflow.keras.datasets import cifar10
    from sklearn.model_selection import train_test_split

    (X_train, Y_train), (X_test, Y_test) = cifar10.load_data()
    X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=config.validation_split, random_state=config.seed)
    dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
    dst_val = tf.data.Dataset.from_tensor_slices((X_val, Y_val))

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size)

In [None]:
def load_cifar100():
    from tensorflow.keras.datasets import cifar100
    from sklearn.model_selection import train_test_split

    (X_train, Y_train), (X_test, Y_test) = cifar100.load_data()
    X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=config.validation_split, random_state=config.seed)
    dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
    dst_val = tf.data.Dataset.from_tensor_slices((X_val, Y_val))

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size)

In [None]:
if config.dataset == "imagenet":
    dst_train, dst_val = load_imagenet()
    dst_train = dst_train.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y))
    dst_val = dst_val.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y))
    N_CLASSES = len(dst_train.class_names)
elif config.dataset == "cifar10":
    dst_train, dst_val = load_cifar10()
    dst_train = dst_train.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y[:,0]))
    dst_val = dst_val.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y[:,0]))
    N_CLASSES = 10
elif config.dataset == "cifar100":
    dst_train, dst_val = load_cifar100()
    dst_train = dst_train.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y[:,0]))
    dst_val = dst_val.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y[:,0]))
    N_CLASSES = 100
elif config.dataset == "imagenette":
    dst_train, dst_val, N_CLASSES = load_imagenette()
    dst_train = dst_train.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y))
    dst_val = dst_val.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y))
else:
    raise ValueError("Dataset parameter not allowed.")
print(f"Training on {config.dataset} with {N_CLASSES} classes.")

Training on imagenette with 10 classes.


In [1]:
x, y = next(iter(dst_train.as_numpy_iterator()))
input_shape = x[0].shape
input_shape, y.shape

((256, 256, 3), (64,))

In [None]:
wandb.run.summary["N_CLASSES"] = N_CLASSES
wandb.run.summary["Input_Shape"] = tuple(input_shape)

In [None]:
# dst_tid2013 = TID2013("/lustre/ific.uv.es/ml/uv075/Databases/IQA/TID/TID2013").dataset\
#                                                                               .batch(config.batch_size)\
#                                                                               .prefetch(1)
dst_tid2013 = TID2013("/media/databases/IQA/TID/TID2013").dataset\
                                                         .batch(config.batch_size)\
                                                         .prefetch(1)                                                                              

### Performance

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

dst_train_rdy = dst_train.cache().prefetch(buffer_size=1)
dst_val_rdy = dst_val.cache().prefetch(buffer_size=1)

# Define the model

In [None]:
# class Identity(nn.Module):

#     @nn.compact
#     def __call__(self,
#                  inputs,
#                  **kwargs,
#                  ):
#         return inputs

In [None]:
class PerceptNet(nn.Module):
    """IQA model inspired by the visual system."""

    @nn.compact
    def __call__(self,
                 inputs,
                 **kwargs,
                 ):
        # return Identity()(inputs, **kwargs)
        return inputs

In [None]:
class Classifier(nn.Module):
    N_CLASSES: int
    GAP: bool = False
    dropout_rate: float = 0.5

    @nn.compact
    def __call__(self,
                 inputs,
                 train=False,
                 ):
        outputs = reduce(inputs, "b h w c -> b c", reduction="mean") if self.GAP else rearrange(inputs, "b h w c -> b (h w c)")
        outputs = nn.Dropout(rate=self.dropout_rate, deterministic=not train)(outputs) if self.dropout_rate > 0.0 else outputs
        outputs = nn.Dense(self.N_CLASSES)(outputs)
        return outputs

In [None]:
class PerceptNetClassifier(nn.Module):
    """Classifier with a PerceptNet backbone."""

    def setup(self):
        self.perceptnet = PerceptNet()
        self.cls = Classifier(N_CLASSES=N_CLASSES, GAP=config.GAP, dropout_rate=config.dropout_rate)

    def __call__(self,
                 inputs,
                 train=False,
                 ):
        outputs = self.perceptnet(inputs, train=train)
        outputs = self.cls(outputs, train=train)
        return outputs

In [None]:
@struct.dataclass
class Metrics(metrics.Collection):
    """Collection of metrics to be tracked during training."""
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")

In [None]:
class TrainState(train_state.TrainState):
    metrics: Metrics
    state: FrozenDict
    key: jax.Array

In [None]:
def create_train_state(module, key, tx, input_shape):
    """Creates the initial `TrainState`."""
    variables = module.init(key, jnp.ones(input_shape), train=False)
    _, dropout_key = random.split(random.PRNGKey(42))
    state, params = variables.pop('params')
    
    params = unfreeze(params)
    params["perceptnet"] = {}
    params = freeze(params)
    
    return TrainState.create(
        apply_fn=module.apply,
        params=params,
        state=state,
        key=dropout_key,
        tx=tx,
        metrics=Metrics.empty()
    )

In [None]:
state = create_train_state(PerceptNetClassifier(), random.PRNGKey(config.seed), optax.adam(config.learning_rate), input_shape=(1,*(x.shape[1:])))
state = state.replace(params=clip_layer(state.params, "GDN", a_min=0))

In [None]:
# params = unfreeze(state.params)
# params["perceptnet"] = {}
# state = state.replace(params=freeze(params))
# state.params

Log the number of trainable weights:

In [2]:
param_count = sum(x.size for x in jax.tree_util.tree_leaves(state.params))
param_count

1966090

In [None]:
wandb.run.summary["trainable_parameters"] = param_count

In [None]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(state)

## Train the model!

In [None]:
@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    dropout_train_key = random.fold_in(key=state.key, data=state.step)
    img, label = batch
    def loss_fn(params):
        ## Forward pass through the model
        img_pred = state.apply_fn({"params": params, **state.state}, img, train=True, rngs={"dropout": dropout_train_key})

        ## Calculate crossentropy
        loss = optax.softmax_cross_entropy_with_integer_labels(img_pred, label).mean()

        ## Add L1 regularization
        if config.l1: loss += config.LAMBDA*jnp.abs(state.params["cls"]["Dense_0"]["kernel"]).mean()
        
        return loss, img_pred
    
    (loss, dist_diff), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    metrics_updates = state.metrics.single_from_model_output(loss=loss, logits=dist_diff, labels=jnp.round(label).astype(int))
    metrics = state.metrics.merge(metrics_updates)
    state = state.replace(metrics=metrics)
    return state

In [None]:
@jax.jit
def val_step(state, batch):
    """Train for a single step."""
    img, label = batch
    def loss_fn(params):
        ## Forward pass through the model
        img_pred = state.apply_fn({"params": params, **state.state}, img, train=False)

        ## Calculate crossentropy
        return optax.softmax_cross_entropy_with_integer_labels(img_pred, label).mean(), img_pred
    
    loss, dist_diff = loss_fn(state.params)
    metrics_updates = state.metrics.single_from_model_output(loss=loss, logits=dist_diff, labels=jnp.round(label).astype(int))
    metrics = state.metrics.merge(metrics_updates)
    state = state.replace(metrics=metrics)
    return state

In [None]:
def forward_pass(state, img):
    img_pred = PerceptNet().apply({"params": state.params["perceptnet"]}, img)
    return img_pred

In [None]:
def rmse(a, b): return jnp.sqrt(jnp.sum((a-b)**2, axis=(1,2,3)))

In [None]:
@jax.jit
def obtain_distances(state, batch):
    ref, dist, mos = batch
    pred_ref = forward_pass(state, ref)
    pred_dist = forward_pass(state, dist)
    distance = rmse(pred_ref, pred_dist)
    return distance

In [None]:
import scipy.stats as stats

In [None]:
def obtain_correlation(state, dst):
    distances, moses = [], []
    for batch in dst:
        distance = obtain_distances(state, batch)
        distances.extend(distance)
        moses.extend(batch[2])
        # break
    return stats.pearsonr(distances, moses)[0]

In [None]:
metrics_history = {
    "train_loss": [],
    "train_accuracy": [],
    "val_loss": [],
    "val_accuracy": [],
    "correlation": [],
}

In [None]:
%%time
for epoch in range(config.epochs):
    ## Training
    for batch in dst_train_rdy.as_numpy_iterator():
        new_state = train_step(state, batch)
        new_state = new_state.replace(params=clip_layer(new_state.params, "GDN", a_min=0))
        params_diff = jax.tree_map(lambda x, y: jnp.mean((x-y)**2), state.params, new_state.params)
        state = new_state
        wandb.log(unfreeze(params_diff), commit=False)
        # state = compute_metrics(state=state, batch=batch)
        # break

    ## Log the metrics
    for name, value in state.metrics.compute().items():
        metrics_history[f"train_{name}"].append(value)
    
    ## Empty the metrics
    state = state.replace(metrics=state.metrics.empty())

    ## Evaluation (Classification)
    for batch in dst_val_rdy.as_numpy_iterator():
        state = val_step(state=state, batch=batch)
        # break
    for name, value in state.metrics.compute().items():
        metrics_history[f"val_{name}"].append(value)
    state = state.replace(metrics=state.metrics.empty())

    ## Evaluation (Correlation)
    correlation = obtain_correlation(state, dst_tid2013.as_numpy_iterator())
    metrics_history["correlation"].append(correlation)
    
    ## Checkpointing
    if metrics_history["val_loss"][-1] <= min(metrics_history["val_loss"]):
        orbax_checkpointer.save(os.path.join(wandb.run.dir, "model-best"), state, save_args=save_args, force=True) # force=True means allow overwritting.

    wandb.log({f"{k}": wandb.Histogram(v) for k, v in flatten_params(state.params).items()}, commit=False)
    wandb.log({"epoch": epoch+1, **{name:values[-1] for name, values in metrics_history.items()}})
    print(f'Epoch {epoch} -> [Train] Loss: {metrics_history["train_loss"][-1]:.3f} Acc: {metrics_history["train_accuracy"][-1]:.3f} [Val] Loss: {metrics_history["val_loss"][-1]:.3f} Acc: {metrics_history["val_accuracy"][-1]:.3f} || Corr: {metrics_history["correlation"][-1]:.3f}')
    # break

Epoch 0 -> [Train] Loss: 7.648 Acc: 0.199 [Val] Loss: 4.222 Acc: 0.243 || Corr: -0.598
Epoch 1 -> [Train] Loss: 4.365 Acc: 0.253 [Val] Loss: 5.298 Acc: 0.230 || Corr: -0.598
Epoch 2 -> [Train] Loss: 3.393 Acc: 0.312 [Val] Loss: 4.376 Acc: 0.197 || Corr: -0.598
Epoch 3 -> [Train] Loss: 4.685 Acc: 0.279 [Val] Loss: 3.983 Acc: 0.233 || Corr: -0.598
Epoch 4 -> [Train] Loss: 3.418 Acc: 0.329 [Val] Loss: 5.208 Acc: 0.216 || Corr: -0.598
Epoch 5 -> [Train] Loss: 3.936 Acc: 0.329 [Val] Loss: 4.708 Acc: 0.260 || Corr: -0.598
Epoch 6 -> [Train] Loss: 3.061 Acc: 0.388 [Val] Loss: 3.417 Acc: 0.276 || Corr: -0.598
Epoch 7 -> [Train] Loss: 2.554 Acc: 0.422 [Val] Loss: 3.491 Acc: 0.276 || Corr: -0.598
Epoch 8 -> [Train] Loss: 2.820 Acc: 0.421 [Val] Loss: 3.204 Acc: 0.270 || Corr: -0.598
Epoch 9 -> [Train] Loss: 2.545 Acc: 0.444 [Val] Loss: 3.128 Acc: 0.294 || Corr: -0.598
Epoch 10 -> [Train] Loss: 2.670 Acc: 0.440 [Val] Loss: 3.971 Acc: 0.282 || Corr: -0.598
Epoch 11 -> [Train] Loss: 4.194 Acc: 0.390

In [None]:
orbax_checkpointer.save(os.path.join(wandb.run.dir, "model-final"), state, save_args=save_args)

In [None]:
wandb.finish()