In [9]:
import os, urllib.request, urllib.error
import tensorflow as tf

_URL = "https://rail.eecs.berkeley.edu/models/lpips"

# --------------------------------------------------------------------
def _get_pb(model='net-lin', net='alex', version='0.1'):
    fname = f"{model}_{net}_v{version}.pb"
    cache_dir = os.path.expanduser("~/.lpips")
    os.makedirs(cache_dir, exist_ok=True)

    url = f"{_URL}/{fname}"
    dst = os.path.join(cache_dir, fname)
    if not os.path.isfile(dst):                         # download once
        try:
            tf.keras.utils.get_file(fname, origin=url, cache_dir=cache_dir,
                                     cache_subdir='', file_hash=None)
        except urllib.error.HTTPError as e:
            raise FileNotFoundError(f"Could not download {url}\n{e}") from None
    return dst
# --------------------------------------------------------------------


_lpips_cache = {}        # keeps one wrapped graph per (model, net, ver) combo
def _get_lpips_fn(model='net-lin', net='alex', version='0.1'):
    key = (model, net, version)
    if key in _lpips_cache:                 # already wrapped → re-use
        return _lpips_cache[key]

    pb_path = _get_pb(model, net, version)

    # 1. read frozen graph
    with tf.io.gfile.GFile(pb_path, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())

    # 2. import it inside wrap_function so it runs in graph mode
    def _imports():
        tf.import_graph_def(graph_def, name="")
    wrapped = tf.compat.v1.wrap_function(_imports, [])   # no inputs yet

    # 3. pick out placeholders & output, then “prune” → creates callable
    g = wrapped.graph
    x_ph   = g.get_tensor_by_name("0:0")                 # input 1 (NCHW)
    y_ph   = g.get_tensor_by_name("1:0")                 # input 2 (NCHW)
    out_t  = g.get_tensor_by_name(g.get_operations()[-1].name + ":0")
    lpips_fn = wrapped.prune([x_ph, y_ph], [out_t])      # callable(x, y) → dist

    _lpips_cache[key] = lpips_fn
    return lpips_fn
# --------------------------------------------------------------------


def lpips_tf2(input0, input1, model='net-lin', net='alex', version='0.1'):
    """
    LPIPS distance in TF-2 eager mode. Inputs NHWC in [0,1]; output shape matches
    the leading batch dims (…,).
    """
    # --- reshape leading dims, →NCHW, scale to [-1,1] ------------------------
    leading_shape = tf.shape(input0)[:-3]                             # [...]
    x = tf.reshape(input0,
                   tf.concat([[-1], tf.shape(input0)[-3:]], axis=0))   # [N,H,W,C]
    y = tf.reshape(input1,
                   tf.concat([[-1], tf.shape(input1)[-3:]], axis=0))
    x = tf.transpose(x, [0, 3, 1, 2]) * 2. - 1.                       # [N,C,H,W]
    y = tf.transpose(y, [0, 3, 1, 2]) * 2. - 1.

    # --- call wrapped frozen graph -----------------------------------------
    lpips_fn = _get_lpips_fn(model, net, version)      # eager callable
    dist = lpips_fn(x, y)[0]                           # returns tuple

    # squeeze & restore leading dims
    if dist.shape.ndims == 4:
        dist = tf.squeeze(dist, axis=[-3, -2, -1])     # [N]
    return tf.reshape(dist, leading_shape)
# --------------------------------------------------------------------


# --------------------------- quick test --------------------------------------
if __name__ == "__main__":
    tf.random.set_seed(0)
    img0 = tf.random.uniform((32, 64, 64, 3), dtype=tf.float32)  # [0,1]
    img1 = tf.random.uniform((32, 64, 64, 3), dtype=tf.float32)
    d = lpips_tf2(img0, img1)        # eager call
    print("distance shape:", d.shape)  # (32,)
    print("first 3 values :", d.numpy()[:3])


2025-06-25 21:38:17.003367: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:117] Plugin optimizer for device_type GPU is enabled.


distance shape: (32,)
first 3 values : [0.14032242 0.11136191 0.13181928]
