# Load a pretrained model and keep training it

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import os
from functools import partial

import jax
from jax import random, numpy as jnp
import flax
from flax.training import orbax_utils
import optax
import orbax
import orbax.checkpoint

from paramperceptnet.models import PerceptNet
from paramperceptnet.training import create_train_state, pearson_correlation, train_step, compute_metrics
from paramperceptnet.constraints import *

## Download a set of pretrained weights

Currently, the weights are stored in W&B (they will be in HF at some point), so we have to download them before putting them into our model.
The good part is that the configuration is stored along the weights, so we don't need to load them separatelly.

In [3]:
import wandb
from ml_collections import ConfigDict

id = "2ploco2u"

api = wandb.Api()
run = api.run(f"jorgvt/PerceptNet_v15/{id}")
save_path = f"./{id}/"

try:
    config = ConfigDict(run.config["_fields"])
except:
    config = ConfigDict(run.config)

for file in run.files():
    file.download(root=save_path, replace=True)

## Prepare `TrainState` & load a `state`

In other examples we omited the existence of the `TrainState` to eliminate complexity but it can be really handy when we want to train our model because it holds the parameters, the state, the optimizer and its parameteres and the metrics of interest. This makes it very easy to continue training an already trained model because we provide the whole state.

When training the model we employ a `optax.multi_transform` optimizer to be able to set some parameters to non-trainable. Because of this, if we want to load the same `TrainState` we have to define the same optimizer here (if the optimizers are different `optax` won't load the state). Another option would be loading the state as a python `dict` and then putting the loaded parameters and states into our `TrainState`. This would allow us to change the optimizer.

Let's define it and load a pretrained one:

In [4]:
state = create_train_state(PerceptNet(config), key=random.PRNGKey(42), tx=optax.adam(3e-4), input_shape=(1,384,512,3))

In [5]:
def check_trainable(path):                                                                                                                           
    if not config.A_GDNSPATIOFREQORIENT:                                                                                                             
        if ("GDNSpatioChromaFreqOrient_0" in path) and ("A" in path):                                                                                
            return True                                                                                                                              
    if "Color" in path:                                                                                                                              
        if not config.TRAIN_JH:                                                                                                                      
            return True                                                                                                                              
    if "CenterSurroundLogSigmaK_0" in path:                                                                                                          
        if not config.TRAIN_CS:                                                                                                                      
            return True                                                                                                                              
    if "Gabor" in "".join(path):                                                                                                                     
        if not config.TRAIN_GABOR:                                                                                                                   
            return True                                                                                                                              
    if "GDNSpatioChromaFreqOrient_0" not in path and config.TRAIN_ONLY_LAST_GDN:                                                                     
        return True                                                                                                                                  
    return False                                                                                                                                     
                                                                                                                                                     
trainable_tree = flax.core.freeze(flax.traverse_util.path_aware_map(lambda path, v: "non_trainable" if check_trainable(path)  else "trainable", state.params)) 

In [6]:
optimizers = {                                         
    "trainable": optax.adam(learning_rate=config.LEARNING_RATE),                                   
    "non_trainable": optax.set_to_zero(),              
}                                                      
tx = optax.multi_transform(optimizers, trainable_tree) 

In [7]:
state = create_train_state(PerceptNet(config), key=random.PRNGKey(42), tx=tx, input_shape=(1,384,512,3))

In [8]:
orbax_checkpointer = orbax.checkpoint.PyTreeCheckpointer()
save_args = orbax_utils.save_args_from_target(state)

state = orbax_checkpointer.restore(os.path.join(save_path,"model-best"), item=state)

## Define the loss and `train_step`

Both the loss function (pearson correlation) and `train_step` function are provided in `paramperceptnet.training`, but we will explicitly define here as well for completion.

In [9]:
def pearson_correlation(vec1, vec2):                    
    vec1 = vec1.squeeze()                               
    vec2 = vec2.squeeze()                               
    vec1_mean = vec1.mean()                             
    vec2_mean = vec2.mean()                             
    num = vec1 - vec1_mean                              
    num *= vec2 - vec2_mean                             
    num = num.sum()                                     
    denom = jnp.sqrt(jnp.sum((vec1 - vec1_mean) ** 2))  
    denom *= jnp.sqrt(jnp.sum((vec2 - vec2_mean) ** 2)) 
    return num / denom                                  

In [10]:
@partial(jax.jit, static_argnums=2)                                             
def train_step(state, batch, return_grads=False):                               
    """Train for a single step."""                                              
    img, img_dist, mos = batch                                                  
                                                                                
    def loss_fn(params):                                                        
        ## Forward pass through the model                                       
        img_pred, updated_state = state.apply_fn(                               
            {"params": params, **state.state},                                  
            img,                                                                
            mutable=list(state.state.keys()),                                   
            train=True,                                                         
        )                                                                       
        img_dist_pred, updated_state = state.apply_fn(                          
            {"params": params, **state.state},                                  
            img_dist,                                                           
            mutable=list(state.state.keys()),                                   
            train=True,                                                         
        )                                                                       
                                                                                
        ## Calculate the distance                                               
        dist = ((img_pred - img_dist_pred) ** 2).sum(axis=(1, 2, 3)) ** (1 / 2) 
                                                                                
        ## Calculate pearson correlation                                        
        return pearson_correlation(dist, mos), updated_state                    
                                                                                
    (loss, updated_state), grads = jax.value_and_grad(loss_fn, has_aux=True)(   
        state.params                                                            
    )                                                                           
    state = state.apply_gradients(grads=grads)                                  
    metrics_updates = state.metrics.single_from_model_output(loss=loss)         
    metrics = state.metrics.merge(metrics_updates)                              
    state = state.replace(metrics=metrics)                                      
    state = state.replace(state=updated_state)                                  
    if return_grads:                                                            
        return state, grads                                                     
    else:                                                                       
        return state                                                            

In [11]:
@jax.jit                                                                                                                                          
def compute_metrics(*, state, batch):                                                                                                             
    """Obtaining the metrics for a given batch."""                                                                                                
    img, img_dist, mos = batch                                                                                                                    
    def loss_fn(params):                                                                                                                          
        ## Forward pass through the model                                                                                                         
        img_pred, updated_state = state.apply_fn({"params": params, **state.state}, img, mutable=list(state.state.keys()), train=False)           
        img_dist_pred, updated_state = state.apply_fn({"params": params, **state.state}, img_dist, mutable=list(state.state.keys()), train=False) 
                                                                                                                                                  
        ## Calculate the distance                                                                                                                 
        dist = ((img_pred - img_dist_pred)**2).sum(axis=(1,2,3))**(1/2)                                                                           
                                                                                                                                                  
        ## Calculate pearson correlation                                                                                                          
        return pearson_correlation(dist, mos)                                                                                                     
                                                                                                                                                  
    metrics_updates = state.metrics.single_from_model_output(loss=loss_fn(state.params))                                                          
    metrics = state.metrics.merge(metrics_updates)                                                                                                
    state = state.replace(metrics=metrics)                                                                                                        
    return state                                                                                                                                  

## Getting some data

We will fetch the TID2008 dataset from HuggingFace as an example.

In [12]:
from datasets import load_dataset

In [32]:
%%time
dataset = load_dataset("Jorgvt/TID2008", num_proc=8, trust_remote_code=True)
dataset = dataset.with_format("jax")

CPU times: user 276 ms, sys: 21.6 ms, total: 298 ms
Wall time: 5.4 s


In [34]:
dst_train = dataset["train"]
dst_train_rdy = dst_train.iter(batch_size=config.BATCH_SIZE)

## Write a simple training loop

With both of these functions defined, we can write a simple training loop example. Notice we are clipping some of the parameters after every update.

In [38]:
metrics_history = {   
    "train_loss": [], 
    "val_loss": [],   
}                     

In [None]:
%%time
for epoch in range(config.EPOCHS):                                                                                                                             
    ## Training                                                                                                                                                
    for batch in dst_train_rdy:                                                                                                            
        batch = (batch["reference"]/255., batch["distorted"]/255., batch["mos"])
        state, grads = train_step(state, batch, return_grads=True)                                                                                             
        state = state.replace(params=clip_layer(state.params, "GDN", a_min=0))                                                                                 
        state = state.replace(params=clip_param(state.params, "A", a_min=0))                                                                                   
        state = state.replace(params=clip_param(state.params, "K", a_min=1+1e-5))                                                                              
        break
                                                                                                                                                               
    ## Log the metrics                                                                                                                                         
    for name, value in state.metrics.compute().items():                                                                                                        
        metrics_history[f"train_{name}"].append(value)                                                                                                         
                                                                                                                                                               
    ## Empty the metrics                                                                                                                                       
    state = state.replace(metrics=state.metrics.empty())                                                                                                       
                                                                                                                                                               
    ## Evaluation                                                                                                                                              
    for batch in dst_train_rdy:                                                                                                              
        batch = (batch["reference"]/255., batch["distorted"]/255., batch["mos"])
        state = compute_metrics(state=state, batch=batch)                                                                                                      
        break

    for name, value in state.metrics.compute().items():                                                                                                        
        metrics_history[f"val_{name}"].append(value)                                                                                                           
    state = state.replace(metrics=state.metrics.empty())                                                                                                       
                                                                                                                                                               
    ## Checkpointing                                                                                                                                           
    if metrics_history["val_loss"][-1] <= min(metrics_history["val_loss"]):                                                                                    
        orbax_checkpointer.save(os.path.join(wandb.run.dir, "model-best"), state, save_args=save_args, force=True) # force=True means allow overwritting.      
                                                                                                                                                               
    print(f'Epoch {epoch} -> [Train] Loss: {metrics_history["train_loss"][-1]} [Val] Loss: {metrics_history["val_loss"][-1]}')                                 
    break