In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# import os; os.environ["CUDA_VISIBLE_DEVICES"]="-1"

In [3]:
# from jax.config import config
# config.update("jax_debug_nans", True)

In [4]:
import tensorflow as tf
tf.config.set_visible_devices([], device_type='GPU')

2023-09-21 13:05:58.557677: I tensorflow/core/platform/cpu_feature_guard.cc:182] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [5]:
import os
# os.environ["CUDA_VISIBLE_DEVICES"] = "1"
from typing import Any, Callable, Sequence, Union
import numpy as np
from fastcore.xtras import Path
from fastprogress.fastprogress import master_bar, progress_bar
import pandas as pd
import cv2

import jax
from jax import lax, random, numpy as jnp
from flax.core import freeze, unfreeze, FrozenDict
from flax import linen as nn
from flax import struct
from flax.training import train_state
from flax.training import orbax_utils

import optax
import orbax.checkpoint

from clu import metrics
from ml_collections import ConfigDict

from einops import reduce, rearrange
import wandb
from iqadatasets.datasets import *
from fxlayers.layers import *
from fxlayers.initializers import mean
from JaxPlayground.utils.constraints import *
from JaxPlayground.utils.wandb import *

2023-09-21 13:07:05.463732: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:1: failed initializing StreamExecutor for CUDA device ordinal 1: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_DEVICE_UNAVAILABLE: CUDA-capable device(s) is/are busy or unavailable
2023-09-21 13:07:05.463819: W external/xla/xla/service/platform_util.cc:198] unable to create StreamExecutor for CUDA:0: failed initializing StreamExecutor for CUDA device ordinal 0: INTERNAL: failed call to cuDevicePrimaryCtxRetain: CUDA_ERROR_DEVICE_UNAVAILABLE: CUDA-capable device(s) is/are busy or unavailable
No GPU/TPU found, falling back to CPU. (Set TF_CPP_MIN_LOG_LEVEL=0 and rerun for more info.)


# Wandb config

In [6]:
config = {
        'epochs':500,
        'learning_rate':3e-4,
        'batch_size':64,
        'kernel_initializer':'ones',
        'gdn_kernel_size':1,
        'learnable_undersampling':False,
        'verbose': 0,
        'dataset': 'imagenet', # imagenet / imagenette / cifar10 / cifar100,
        'validation_split': 0.2,
        'seed': 42,
        'GAP': False,
        'use_bias': True,
    }

In [7]:
wandb.init(project='PerceptNetClassification_JaX',
            notes="",
            tags=[],
            name = 'Baseline-Flatten-NoBias',
            config=config,
            job_type="training",
            mode="online",
            )
config = wandb.config

wandb: Currently logged in as: jorgvt. Use `wandb login --relogin` to force relogin
wandb: wandb version 0.15.10 is available!  To upgrade, please run:
wandb:  $ pip install wandb --upgrade
wandb: Tracking run with wandb version 0.15.1
wandb: Run data is saved locally in /lhome/ext/uv075/uv0752/perceptnet/Notebooks/11_Classification/wandb/run-20230921_132217-tu6scbey
wandb: Run `wandb offline` to turn off syncing.
wandb: Syncing run Baseline-Flatten-NoBias
wandb:  View project at https://wandb.ai/jorgvt/PerceptNetClassification_JaX
wandb:  View run at https://wandb.ai/jorgvt/PerceptNetClassification_JaX/runs/tu6scbey


# Load the data

In [8]:
def load_imagenet():
    path_data = Path("/lustre/ific.uv.es/ml/uv075/Databases/imagenet_images/")
    dst_train = tf.keras.utils.image_dataset_from_directory(
                path_data,
                validation_split=config.validation_split,
                subset="training",
                seed=config.seed,
                shuffle=True,
                # image_size=(img_height, img_width),
                batch_size=config.batch_size)
    dst_val = tf.keras.utils.image_dataset_from_directory(
                path_data,
                validation_split=config.validation_split,
                subset="validation",
                seed=config.seed,
                shuffle=False,
                # image_size=(img_height, img_width),
                batch_size=config.batch_size)
    return dst_train, dst_val

In [9]:
def load_imagenette():
    import tensorflow_datasets as tfds

    dst_train, info = tfds.load("imagenette/320px-v2", split=f"train[:{(1-config.validation_split)*100:.0f}%]", with_info=True, shuffle_files=True)
    dst_val = tfds.load("imagenette/320px-v2", split=f"train[{(1-config.validation_split)*100:.0f}%:]", with_info=False, shuffle_files=False)
    def prepare_tfds(item):
        x, y = item["image"], item["label"]
        x = tf.image.resize_with_crop_or_pad(x, 256, 256)
        return x, y
    dst_train = dst_train.map(prepare_tfds)
    dst_val = dst_val.map(prepare_tfds)

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size), info.features["label"].num_classes

In [10]:
def load_cifar10():
    from tensorflow.keras.datasets import cifar10
    from sklearn.model_selection import train_test_split

    (X_train, Y_train), (X_test, Y_test) = cifar10.load_data()
    X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=config.validation_split, random_state=config.seed)
    dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
    dst_val = tf.data.Dataset.from_tensor_slices((X_val, Y_val))

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size)

In [11]:
def load_cifar100():
    from tensorflow.keras.datasets import cifar100
    from sklearn.model_selection import train_test_split

    (X_train, Y_train), (X_test, Y_test) = cifar100.load_data()
    X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=config.validation_split, random_state=config.seed)
    dst_train = tf.data.Dataset.from_tensor_slices((X_train, Y_train))
    dst_val = tf.data.Dataset.from_tensor_slices((X_val, Y_val))

    return dst_train.batch(config.batch_size), dst_val.batch(config.batch_size)

In [12]:
if config.dataset == "imagenet":
    dst_train, dst_val = load_imagenet()
    N_CLASSES = len(dst_train.class_names)
elif config.dataset == "cifar10":
    dst_train, dst_val = load_cifar10()
    N_CLASSES = 10
elif config.dataset == "cifar100":
    dst_train, dst_val = load_cifar100()
    N_CLASSES = 100
elif config.dataset == "imagenette":
    dst_train, dst_val, N_CLASSES = load_imagenette()
else:
    raise ValueError("Dataset parameter not allowed.")
print(f"Training on {config.dataset} with {N_CLASSES} classes.")

Found 202050 files belonging to 409 classes.
Using 161640 files for training.
Found 202050 files belonging to 409 classes.
Using 40410 files for validation.
Training on imagenet with 409 classes.


In [1]:
x, y = next(iter(dst_train))
input_shape = x[0].shape
input_shape

TensorShape([256, 256, 3])

In [14]:
wandb.run.summary["N_CLASSES"] = N_CLASSES
wandb.run.summary["Input_Shape"] = tuple(input_shape)

In [15]:
dst_tid2013 = TID2013("/lustre/ific.uv.es/ml/uv075/Databases/IQA/TID/TID2013").dataset\
                                                                              .batch(config.batch_size)\
                                                                              .prefetch(1)

### Normalize the data

In [20]:
if len(y.shape) != 1:
    dst_train = dst_train.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y[:,0]))
    dst_val = dst_val.map(lambda x,y: (tf.cast(x, tf.float32)/255.0, y[:,0]))

### Performance

In [21]:
AUTOTUNE = tf.data.AUTOTUNE

dst_train_rdy = dst_train.cache().prefetch(buffer_size=1)
dst_val_rdy = dst_val.cache().prefetch(buffer_size=1)

# Define the model

In [22]:
class GDN(nn.Module):
    """Generalized Divisive Normalization."""
    kernel_size: Union[int, Sequence[int]]
    strides: int = 1
    padding: str = "SAME"
    apply_independently: bool = False
    # kernel_init: Callable = nn.initializers.lecun_normal()
    kernel_init: Callable = mean()
    bias_init: Callable = nn.initializers.ones_init()
    alpha: float = 2.
    epsilon: float = 1/2 # Exponential of the denominator
    eps: float = 1e-6 # Numerical stability in the denominator

    @nn.compact
    def __call__(self,
                 inputs,
                 ):
        denom = nn.Conv(features=inputs.shape[-1], # Same output channels as input
                        kernel_size=self.kernel_size if isinstance(self.kernel_size, Sequence) else [self.kernel_size]*2, 
                        strides=self.strides, 
                        padding=self.padding,
                        feature_group_count=inputs.shape[-1] if self.apply_independently else 1,
                        kernel_init=self.kernel_init, 
                        bias_init=self.bias_init)(inputs**self.alpha)
        return inputs / (jnp.clip(denom, a_min=1e-5)**self.epsilon + self.eps)

In [23]:
class PerceptNet(nn.Module):
    """IQA model inspired by the visual system."""

    @nn.compact
    def __call__(self,
                 inputs,
                 **kwargs,
                 ):
        outputs = GDN(kernel_size=1, strides=1, padding="SAME", apply_independently=True)(inputs)
        outputs = nn.Conv(features=3, kernel_size=(1,1), strides=1, padding="SAME", use_bias=config.use_bias)(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = GDN(kernel_size=1, strides=1, padding="SAME", apply_independently=False)(outputs)
        outputs = nn.Conv(features=6, kernel_size=(5,5), strides=1, padding="SAME", use_bias=config.use_bias)(outputs)
        outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = GDN(kernel_size=1, strides=1, padding="SAME", apply_independently=False)(outputs)
        outputs = nn.Conv(features=128, kernel_size=(5,5), strides=1, padding="SAME", use_bias=config.use_bias)(outputs)
        outputs = GDN(kernel_size=1, strides=1, padding="SAME", apply_independently=False)(outputs)
        return outputs

In [24]:
class Classifier(nn.Module):

    @nn.compact
    def __call__(self,
                 inputs,
                 ):
        outputs = nn.Dense()

In [25]:
class PerceptNetClassifier(nn.Module):
    """Classifier with a PerceptNet backbone."""

    def setup(self):
        self.perceptnet = PerceptNet()
        self.cls = nn.Dense(N_CLASSES)

    def __call__(self,
                 inputs,
                 ):
        outputs = self.perceptnet(inputs)
        # outputs = nn.max_pool(outputs, window_shape=(2,2), strides=(2,2))
        outputs = reduce(outputs, "b h w c -> b c", reduction="mean") if config.GAP else rearrange(outputs, "b h w c -> b (h w c)")
        outputs = self.cls(outputs)
        return outputs

In [26]:
@struct.dataclass
class Metrics(metrics.Collection):
    """Collection of metrics to be tracked during training."""
    accuracy: metrics.Accuracy
    loss: metrics.Average.from_output("loss")

In [27]:
class TrainState(train_state.TrainState):
    metrics: Metrics
    state: FrozenDict

In [28]:
def create_train_state(module, key, tx, input_shape):
    """Creates the initial `TrainState`."""
    variables = module.init(key, jnp.ones(input_shape))
    state, params = variables.pop('params')
    return TrainState.create(
        apply_fn=module.apply,
        params=params,
        state=state,
        tx=tx,
        metrics=Metrics.empty()
    )

In [44]:
state = create_train_state(PerceptNetClassifier(), random.PRNGKey(config.seed), optax.adam(config.learning_rate), input_shape=(1,*(x.shape[1:])))
state = state.replace(params=clip_layer(state.params, "GDN", a_min=0))

Log the number of trainable weights:

In [2]:
param_count = sum(x.size for x in jax.tree_util.tree_leaves(state.params))
param_count

214470569

In [46]:
wandb.run.summary["trainable_parameters"] = param_count

In [47]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(state)

## Train the model!

In [48]:
@jax.jit
def train_step(state, batch):
    """Train for a single step."""
    img, label = batch
    def loss_fn(params):
        ## Forward pass through the model
        img_pred = state.apply_fn({"params": params, **state.state}, img)

        ## Calculate crossentropy
        return optax.softmax_cross_entropy_with_integer_labels(img_pred, label).mean(), img_pred
    
    (loss, dist_diff), grads = jax.value_and_grad(loss_fn, has_aux=True)(state.params)
    state = state.apply_gradients(grads=grads)
    metrics_updates = state.metrics.single_from_model_output(loss=loss, logits=dist_diff, labels=jnp.round(label).astype(int))
    metrics = state.metrics.merge(metrics_updates)
    state = state.replace(metrics=metrics)
    return state

In [49]:
@jax.jit
def val_step(state, batch):
    """Train for a single step."""
    img, label = batch
    def loss_fn(params):
        ## Forward pass through the model
        img_pred = state.apply_fn({"params": params, **state.state}, img)

        ## Calculate crossentropy
        return optax.softmax_cross_entropy_with_integer_labels(img_pred, label).mean(), img_pred
    
    loss, dist_diff = loss_fn(state.params)
    metrics_updates = state.metrics.single_from_model_output(loss=loss, logits=dist_diff, labels=jnp.round(label).astype(int))
    metrics = state.metrics.merge(metrics_updates)
    state = state.replace(metrics=metrics)
    return state

In [50]:
def forward_pass(state, img):
    img_pred = PerceptNet().apply({"params": state.params["perceptnet"]}, img)
    return img_pred

In [51]:
def rmse(a, b): return jnp.sqrt(jnp.sum((a-b)**2, axis=(1,2,3)))

In [52]:
@jax.jit
def obtain_distances(state, batch):
    ref, dist, mos = batch
    pred_ref = forward_pass(state, ref)
    pred_dist = forward_pass(state, dist)
    distance = rmse(pred_ref, pred_dist)
    return distance

In [53]:
import scipy.stats as stats

In [54]:
def obtain_correlation(state, dst):
    distances, moses = [], []
    for batch in dst:
        distance = obtain_distances(state, batch)
        distances.extend(distance)
        moses.extend(batch[2])
        # break
    return stats.pearsonr(distances, moses)[0]

In [55]:
metrics_history = {
    "train_loss": [],
    "train_accuracy": [],
    "val_loss": [],
    "val_accuracy": [],
    "correlation": [],
}

In [None]:
%%time
for epoch in range(config.epochs):
    ## Training
    for batch in dst_train_rdy.as_numpy_iterator():
        state = train_step(state, batch)
        state = state.replace(params=clip_layer(state.params, "GDN", a_min=0))
        # state = compute_metrics(state=state, batch=batch)
        # break

    ## Log the metrics
    for name, value in state.metrics.compute().items():
        metrics_history[f"train_{name}"].append(value)
    
    ## Empty the metrics
    state = state.replace(metrics=state.metrics.empty())

    ## Evaluation (Classification)
    for batch in dst_val_rdy.as_numpy_iterator():
        state = val_step(state=state, batch=batch)
        # break
    for name, value in state.metrics.compute().items():
        metrics_history[f"val_{name}"].append(value)
    state = state.replace(metrics=state.metrics.empty())

    ## Evaluation (Correlation)
    correlation = obtain_correlation(state, dst_tid2013.as_numpy_iterator())
    metrics_history["correlation"].append(correlation)
    
    ## Checkpointing
    if metrics_history["val_loss"][-1] <= min(metrics_history["val_loss"]):
        orbax_checkpointer.save(os.path.join(wandb.run.dir, "model-best"), state, save_args=save_args, force=True) # force=True means allow overwritting.

    wandb.log({f"{k}": wandb.Histogram(v) for k, v in flatten_params(state.params).items()}, commit=False)
    wandb.log({"epoch": epoch+1, **{name:values[-1] for name, values in metrics_history.items()}})
    print(f'Epoch {epoch} -> [Train] Loss: {metrics_history["train_loss"][-1]:.3f} Acc: {metrics_history["train_accuracy"][-1]:.3f} [Val] Loss: {metrics_history["val_loss"][-1]:.3f} Acc: {metrics_history["val_accuracy"][-1]:.3f} || Corr: {metrics_history["correlation"][-1]:.3f}')
    # break

Epoch 0 -> [Train] Loss: 14.555 Acc: 0.009 [Val] Loss: 5.151 Acc: 0.203 || Corr: -0.635
Epoch 1 -> [Train] Loss: 4.474 Acc: 0.231 [Val] Loss: 3.880 Acc: 0.405 || Corr: -0.590
Epoch 2 -> [Train] Loss: 3.245 Acc: 0.435 [Val] Loss: 3.546 Acc: 0.517 || Corr: -0.543
Epoch 3 -> [Train] Loss: 2.433 Acc: 0.591 [Val] Loss: 3.624 Acc: 0.591 || Corr: -0.512
Epoch 4 -> [Train] Loss: 1.881 Acc: 0.701 [Val] Loss: 3.848 Acc: 0.636 || Corr: -0.498
Epoch 5 -> [Train] Loss: 1.557 Acc: 0.764 [Val] Loss: 4.050 Acc: 0.663 || Corr: -0.486
Epoch 6 -> [Train] Loss: 1.335 Acc: 0.803 [Val] Loss: 4.222 Acc: 0.682 || Corr: -0.482
Epoch 7 -> [Train] Loss: 1.160 Acc: 0.834 [Val] Loss: 4.349 Acc: 0.698 || Corr: -0.473
Epoch 8 -> [Train] Loss: 1.028 Acc: 0.856 [Val] Loss: 4.516 Acc: 0.707 || Corr: -0.462
Epoch 9 -> [Train] Loss: 0.889 Acc: 0.873 [Val] Loss: 4.657 Acc: 0.717 || Corr: -0.456
Epoch 10 -> [Train] Loss: 0.793 Acc: 0.888 [Val] Loss: 4.799 Acc: 0.722 || Corr: -0.442
Epoch 11 -> [Train] Loss: 0.728 Acc: 0.89

In [None]:
orbax_checkpointer.save(os.path.join(wandb.run.dir, "model-final"), state, save_args=save_args)

In [None]:
wandb.finish()