In [1]:
import os

import jax
from jax import random, numpy as jnp
from flax.core import freeze, unfreeze
import orbax
import orbax.checkpoint
import wandb
from ml_collections import ConfigDict

from iqadatasets.datasets import *
from pnetkld.models import IndependentMeanStd, DependentStd
from pnetkld.training import create_train_state, kld, js

## Load model weights

In [2]:
id = "s10iigeo"
run_path = f"Jorgvt/PerceptNet_KLD/{id}"

In [3]:
api = wandb.Api()
run = api.run(run_path)
config = ConfigDict(run.config)
for file in run.files():
    file.download(root=run.dir, replace=True)

In [4]:
config

BATCH_SIZE: 32
CS_KERNEL_SIZE: 5
DISTANCE: kld
EPOCHS: 500
GABOR_KERNEL_SIZE: 5
GDNGAUSSIAN_KERNEL_SIZE: 1
GDNSPATIOFREQ_KERNEL_SIZE: 1
LAMBDA: 0
LEARNING_RATE: 0.0003
MODEL: independent
N_GABORS: 128
SEED: 42

In [5]:
ckpt = orbax.checkpoint.PyTreeCheckpointer()
state = ckpt.restore(os.path.join(run.dir, "model-best"))

In [6]:
model = IndependentMeanStd(config)
variables = model.init(random.PRNGKey(42), jnp.ones((1,384,512,3)))
variables = unfreeze(variables)
variables["params"] = state["params"]
variables = freeze(variables)

## Load data

In [7]:
# dst = TID2008("")

## Obtain all predictions

In [8]:
def forward(variables, img):
    pred = model.apply(variables, img)
    return pred

In [9]:
def calculate_distance(variables, img, img_dist):
    pred_mean, pred_logvar = forward(variables, img)
    pred_mean_dist, pred_logvar_dist = forward(variables, img_dist)
    return kld(pred_mean, pred_logvar, pred_mean_dist, pred_logvar_dist)

In [10]:
dist = calculate_distance(variables, jnp.ones((1,384,512,3)), jnp.ones((1,384,512,3)))
dist, dist.shape

(Array([0.], dtype=float32), (1,))