# Learned Perceptual Image Patch Similarity

Convert the weights from the official repo for Learned Perceptual Image Patch Similarity.

Code from: [https://github.com/richzhang/PerceptualSimilarity/blob/master/lpips/lpips.py]()

In [None]:
import hashlib
import os
from collections import namedtuple

import numpy as np
import requests
import tensorflow as tf
import torch
import torch.nn as nn
from torchvision import models
from torchvision.models import vgg16
from tqdm import tqdm

from lpips import LPIPS as LPIPSTF

URL_MAP = {
    "vgg_lpips": "https://heibox.uni-heidelberg.de/f/607503859c864bc1b30b/?dl=1"
}

CKPT_MAP = {
    "vgg_lpips": "vgg.pth"
}

MD5_MAP = {
    "vgg_lpips": "d507d7349b931f0638a25a48a722f98a"
}


def download(url, local_path, chunk_size=1024):
    os.makedirs(os.path.split(local_path)[0], exist_ok=True)
    with requests.get(url, stream=True) as r:
        total_size = int(r.headers.get("content-length", 0))
        with tqdm(total=total_size, unit="B", unit_scale=True) as pbar:
            with open(local_path, "wb") as f:
                for data in r.iter_content(chunk_size=chunk_size):
                    if data:
                        f.write(data)
                        pbar.update(chunk_size)


def md5_hash(path):
    with open(path, "rb") as f:
        content = f.read()
    return hashlib.md5(content).hexdigest()


def get_ckpt_path(name, root, check=False):
    assert name in URL_MAP
    path = os.path.join(root, CKPT_MAP[name])
    if not os.path.exists(path) or (check and not md5_hash(path) == MD5_MAP[name]):
        print("Downloading {} model from {} to {}".format(name, URL_MAP[name], path))
        download(URL_MAP[name], path)
        md5 = md5_hash(path)
        assert md5 == MD5_MAP[name], md5
    return path


class KeyNotFoundError(Exception):
    def __init__(self, cause, keys=None, visited=None):
        self.cause = cause
        self.keys = keys
        self.visited = visited
        messages = list()
        if keys is not None:
            messages.append("Key not found: {}".format(keys))
        if visited is not None:
            messages.append("Visited: {}".format(visited))
        messages.append("Cause:\n{}".format(cause))
        message = "\n".join(messages)
        super().__init__(message)


def retrieve(
        list_or_dict, key, splitval="/", default=None, expand=True, pass_success=False
):
    """Given a nested list or dict return the desired value at key expanding
    callable nodes if necessary and :attr:`expand` is ``True``. The expansion
    is done in-place.

    Parameters
    ----------
        list_or_dict : list or dict
            Possibly nested list or dictionary.
        key : str
            key/to/value, path like string describing all keys necessary to
            consider to get to the desired value. List indices can also be
            passed here.
        splitval : str
            String that defines the delimiter between keys of the
            different depth levels in `key`.
        default : obj
            Value returned if :attr:`key` is not found.
        expand : bool
            Whether to expand callable nodes on the path or not.

    Returns
    -------
        The desired value or if :attr:`default` is not ``None`` and the
        :attr:`key` is not found returns ``default``.

    Raises
    ------
        Exception if ``key`` not in ``list_or_dict`` and :attr:`default` is
        ``None``.
    """

    keys = key.split(splitval)

    success = True
    try:
        visited = []
        parent = None
        last_key = None
        for key in keys:
            if callable(list_or_dict):
                if not expand:
                    raise KeyNotFoundError(
                        ValueError(
                            "Trying to get past callable node with expand=False."
                        ),
                        keys=keys,
                        visited=visited,
                    )
                list_or_dict = list_or_dict()
                parent[last_key] = list_or_dict

            last_key = key
            parent = list_or_dict

            try:
                if isinstance(list_or_dict, dict):
                    list_or_dict = list_or_dict[key]
                else:
                    list_or_dict = list_or_dict[int(key)]
            except (KeyError, IndexError, ValueError) as e:
                raise KeyNotFoundError(e, keys=keys, visited=visited)

            visited += [key]
        # final expansion of retrieved value
        if expand and callable(list_or_dict):
            list_or_dict = list_or_dict()
            parent[last_key] = list_or_dict
    except KeyNotFoundError as e:
        if default is None:
            raise e
        else:
            list_or_dict = default
            success = False

    if not pass_success:
        return list_or_dict
    else:
        return list_or_dict, success


class LPIPS(nn.Module):
    # Learned perceptual metric
    def __init__(self, use_dropout=True):
        super().__init__()
        self.scaling_layer = ScalingLayer()
        self.chns = [64, 128, 256, 512, 512]  # vg16 features
        self.net = vgg16(pretrained=True, requires_grad=False)
        self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
        self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
        self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
        self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
        self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
        self.load_from_pretrained()
        for param in self.parameters():
            param.requires_grad = False

    def load_from_pretrained(self, name="vgg_lpips"):
        ckpt = get_ckpt_path(name, "taming/modules/autoencoder/lpips")
        self.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
        print("loaded pretrained LPIPS loss from {}".format(ckpt))

    @classmethod
    def from_pretrained(cls, name="vgg_lpips"):
        if name != "vgg_lpips":
            raise NotImplementedError
        model = cls()
        ckpt = get_ckpt_path(name)
        model.load_state_dict(torch.load(ckpt, map_location=torch.device("cpu")), strict=False)
        return model

    def forward(self, input, target):
        in0_input, in1_input = (self.scaling_layer(input), self.scaling_layer(target))
        outs0, outs1 = self.net(in0_input), self.net(in1_input)
        feats0, feats1, diffs = {}, {}, {}
        lins = [self.lin0, self.lin1, self.lin2, self.lin3, self.lin4]
        for kk in range(len(self.chns)):
            feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
            diffs[kk] = (feats0[kk] - feats1[kk]) ** 2

        res = [spatial_average(lins[kk].model(diffs[kk]), keepdim=True) for kk in range(len(self.chns))]
        val = res[0]
        for l in range(1, len(self.chns)):
            val += res[l]
        return val


class ScalingLayer(nn.Module):
    def __init__(self):
        super(ScalingLayer, self).__init__()
        self.register_buffer('shift', torch.Tensor([-.030, -.088, -.188])[None, :, None, None])
        self.register_buffer('scale', torch.Tensor([.458, .448, .450])[None, :, None, None])

    def forward(self, inp):
        return (inp - self.shift) / self.scale


class NetLinLayer(nn.Module):
    """ A single linear layer which does a 1x1 conv """

    def __init__(self, chn_in, chn_out=1, use_dropout=False):
        super(NetLinLayer, self).__init__()
        layers = [nn.Dropout(), ] if (use_dropout) else []
        layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False), ]
        self.model = nn.Sequential(*layers)


class vgg16(torch.nn.Module):
    def __init__(self, requires_grad=False, pretrained=True):
        super(vgg16, self).__init__()
        vgg_pretrained_features = models.vgg16(pretrained=pretrained).features
        self.slice1 = torch.nn.Sequential()
        self.slice2 = torch.nn.Sequential()
        self.slice3 = torch.nn.Sequential()
        self.slice4 = torch.nn.Sequential()
        self.slice5 = torch.nn.Sequential()
        self.N_slices = 5
        for x in range(4):
            self.slice1.add_module(str(x), vgg_pretrained_features[x])
        for x in range(4, 9):
            self.slice2.add_module(str(x), vgg_pretrained_features[x])
        for x in range(9, 16):
            self.slice3.add_module(str(x), vgg_pretrained_features[x])
        for x in range(16, 23):
            self.slice4.add_module(str(x), vgg_pretrained_features[x])
        for x in range(23, 30):
            self.slice5.add_module(str(x), vgg_pretrained_features[x])
        if not requires_grad:
            for param in self.parameters():
                param.requires_grad = False

    def forward(self, X):
        h = self.slice1(X)
        h_relu1_2 = h
        h = self.slice2(h)
        h_relu2_2 = h
        h = self.slice3(h)
        h_relu3_3 = h
        h = self.slice4(h)
        h_relu4_3 = h
        h = self.slice5(h)
        h_relu5_3 = h
        vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
        out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
        return out


def normalize_tensor(x, eps=1e-10):
    norm_factor = torch.sqrt(torch.sum(x ** 2, dim=1, keepdim=True))
    return x / (norm_factor + eps)


def spatial_average(x, keepdim=True):
    return x.mean([2, 3], keepdim=keepdim)


In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
pytorch_model = LPIPS().to(device)
pytorch_model.eval()



loaded pretrained LPIPS loss from taming/modules/autoencoder/lpips/vgg.pth


LPIPS(
  (scaling_layer): ScalingLayer()
  (net): vgg16(
    (slice1): Sequential(
      (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (1): ReLU(inplace=True)
      (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (3): ReLU(inplace=True)
    )
    (slice2): Sequential(
      (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (6): ReLU(inplace=True)
      (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (8): ReLU(inplace=True)
    )
    (slice3): Sequential(
      (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
      (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace=True)
      (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (13): ReLU(inplace=True)
      (14): Conv2d(256, 256, 

## Instantiate Tensorflow Model

In [None]:
# Create the equivalent TensorFlow model
tensorflow_model = LPIPSTF()

Before loading the weights, we first have to build the TensorFlow model with an input shape. Only then, we can proceed in loading int some weights.

## Test model to build it

In [None]:
i1 = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255
i2 = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255

pytorch_result = pytorch_model(
    torch.from_numpy(i1).to(device),  # base PyTorch works with B,C,H,W format
    torch.from_numpy(i2).to(device),
)
tensorflow_result = tensorflow_model([
    np.moveaxis(i1, 1, -1),  # base TensorFlow works with B,H,W,C format.
    np.moveaxis(i2, 1, -1),
])

Let's create the weights dictionary of the PyTorch model.

In [None]:
state_dict = pytorch_model.state_dict()
{k: (v.dtype, v.shape) for k, v in pytorch_model.state_dict().items()}

{'scaling_layer.shift': (torch.float32, torch.Size([1, 3, 1, 1])),
 'scaling_layer.scale': (torch.float32, torch.Size([1, 3, 1, 1])),
 'net.slice1.0.weight': (torch.float32, torch.Size([64, 3, 3, 3])),
 'net.slice1.0.bias': (torch.float32, torch.Size([64])),
 'net.slice1.2.weight': (torch.float32, torch.Size([64, 64, 3, 3])),
 'net.slice1.2.bias': (torch.float32, torch.Size([64])),
 'net.slice2.5.weight': (torch.float32, torch.Size([128, 64, 3, 3])),
 'net.slice2.5.bias': (torch.float32, torch.Size([128])),
 'net.slice2.7.weight': (torch.float32, torch.Size([128, 128, 3, 3])),
 'net.slice2.7.bias': (torch.float32, torch.Size([128])),
 'net.slice3.10.weight': (torch.float32, torch.Size([256, 128, 3, 3])),
 'net.slice3.10.bias': (torch.float32, torch.Size([256])),
 'net.slice3.12.weight': (torch.float32, torch.Size([256, 256, 3, 3])),
 'net.slice3.12.bias': (torch.float32, torch.Size([256])),
 'net.slice3.14.weight': (torch.float32, torch.Size([256, 256, 3, 3])),
 'net.slice3.14.bias': (

# Move weights to TensorFlow model

Scaling layer: Not really needed to transfer the weights, as we can manually define the 6 values used as mean/std to scale the input.

In [None]:
# Scaling layer
tensorflow_model.scaling_layer.set_weights([
    np.reshape(state_dict['scaling_layer.shift'].cpu().numpy(), (1, 1, 1, 3)),
    np.reshape(state_dict['scaling_layer.scale'].cpu().numpy(), (1, 1, 1, 3)),
])

VGG16 layers: while the default pre-trained layers available in Keras can be used, I prefer use the same exact configuration used in the original implementation.

In [70]:
# VGG
tensorflow_model.vgg.layers[0].layers[1].kernel.assign(  # Weights can be moved either by assigning a tf.Variable...
    tf.Variable(
        # Note: as explained before, PyTorch works as B,C,H,W: therefore, we need to transpose the matrix
        # before assigning it!
        state_dict['net.slice1.0.weight'].cpu().numpy().transpose(2, 3, 1, 0),
        dtype=tf.float32,
    )
)
tensorflow_model.vgg.layers[0].layers[1].bias.assign(
    tf.Variable(
        state_dict['net.slice1.0.bias'].cpu().numpy(),
        dtype=tf.float32,
    )
)
tensorflow_model.vgg.layers[0].layers[3].set_weights([  # Or by assigning a list of [kernel, bias] weights
    np.transpose(state_dict['net.slice1.2.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice1.2.bias'].cpu().numpy(),
])

tensorflow_model.vgg.layers[1].layers[2].set_weights([
    np.transpose(state_dict['net.slice2.5.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice2.5.bias'].cpu().numpy(),
])
tensorflow_model.vgg.layers[1].layers[4].set_weights([
    np.transpose(state_dict['net.slice2.7.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice2.7.bias'].cpu().numpy(),
])

tensorflow_model.vgg.layers[2].layers[2].set_weights([
    np.transpose(state_dict['net.slice3.10.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice3.10.bias'].cpu().numpy(),
])
tensorflow_model.vgg.layers[2].layers[4].set_weights([
    np.transpose(state_dict['net.slice3.12.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice3.12.bias'].cpu().numpy(),
])
tensorflow_model.vgg.layers[2].layers[6].set_weights([
    np.transpose(state_dict['net.slice3.14.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice3.14.bias'].cpu().numpy(),
])

tensorflow_model.vgg.layers[3].layers[2].set_weights([
    np.transpose(state_dict['net.slice4.17.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice4.17.bias'].cpu().numpy(),
])
tensorflow_model.vgg.layers[3].layers[4].set_weights([
    np.transpose(state_dict['net.slice4.19.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice4.19.bias'].cpu().numpy(),
])
tensorflow_model.vgg.layers[3].layers[6].set_weights([
    np.transpose(state_dict['net.slice4.21.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice4.21.bias'].cpu().numpy(),
])

tensorflow_model.vgg.layers[4].layers[2].set_weights([
    np.transpose(state_dict['net.slice5.24.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice5.24.bias'].cpu().numpy(),
])
tensorflow_model.vgg.layers[4].layers[4].set_weights([
    np.transpose(state_dict['net.slice5.26.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice5.26.bias'].cpu().numpy(),
])
tensorflow_model.vgg.layers[4].layers[6].set_weights([
    np.transpose(state_dict['net.slice5.28.weight'].cpu().numpy(), (2, 3, 1, 0)),
    state_dict['net.slice5.28.bias'].cpu().numpy(),
])

Linear layers: the weights in this layers must all be positive! Otherwise, the final loss value could be negative. This happens because the squared differences passes through these layers, and then are averaged. Therefore, if the weights are negative, it can be possible for the output value loss to be negative.


In [71]:
for i in range(5):
    tensorflow_model.linear_layers.layers[i].set_weights(
        [np.transpose(state_dict[f'lin{i}.model.1.weight'].cpu().numpy(), (2, 3, 1, 0))]
    )

In [72]:
for i in range(5):
    print(np.min(state_dict[f'lin{i}.model.1.weight'].cpu().numpy()))

0.021017654
0.009468972
0.037710946
0.039098237
0.019278834


# Compare the results between the two models

In [73]:
i1 = np.random.rand(1, 3, 224, 224).astype(np.float32)
i2 = np.random.rand(1, 3, 224, 224).astype(np.float32)

pytorch_result = pytorch_model(
    torch.from_numpy(i1).to(device),
    torch.from_numpy(i2).to(device),
)
tensorflow_result = tensorflow_model([
    np.moveaxis(i1, 1, -1),
    np.moveaxis(i2, 1, -1),
])

In [74]:
np.allclose(pytorch_result[0].cpu().numpy().squeeze(), tensorflow_result)

True

In [75]:
for i in range(10):
    i1 = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255
    i2 = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255

    pytorch_result = pytorch_model(
        torch.from_numpy(i1).to(device),
        torch.from_numpy(i2).to(device),
    )
    tensorflow_result = tensorflow_model([
        np.moveaxis(i1, 1, -1),
        np.moveaxis(i2, 1, -1)]).numpy()

    print("pytorch=", pytorch_result)
    print("tensorflow=", tensorflow_result)
    np.allclose(pytorch_result.cpu().numpy().squeeze(), tensorflow_result)

pytorch= tensor([[[[0.2380]]]], device='cuda:0')
tensorflow= 0.23802666
pytorch= tensor([[[[0.2398]]]], device='cuda:0')
tensorflow= 0.23982744
pytorch= tensor([[[[0.2345]]]], device='cuda:0')
tensorflow= 0.23448285
pytorch= tensor([[[[0.2383]]]], device='cuda:0')
tensorflow= 0.23829569
pytorch= tensor([[[[0.2317]]]], device='cuda:0')
tensorflow= 0.23170039
pytorch= tensor([[[[0.2382]]]], device='cuda:0')
tensorflow= 0.23820105
pytorch= tensor([[[[0.2323]]]], device='cuda:0')
tensorflow= 0.23226494
pytorch= tensor([[[[0.2399]]]], device='cuda:0')
tensorflow= 0.2399226
pytorch= tensor([[[[0.2374]]]], device='cuda:0')
tensorflow= 0.23743632
pytorch= tensor([[[[0.2359]]]], device='cuda:0')
tensorflow= 0.23594017


In [76]:
for i in range(10):
    i1 = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255
    i2 = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255

    pytorch_result = pytorch_model(
        torch.from_numpy(i1).to(device),
        torch.from_numpy(i1 * i2).to(device),
    )
    tensorflow_result = tensorflow_model([
        np.moveaxis(i1, 1, -1),
        np.moveaxis(i1 * i2, 1, -1)]).numpy()

    print(pytorch_result)
    print(tensorflow_result)
    assert np.allclose(pytorch_result.cpu().numpy().squeeze(), tensorflow_result, )

tensor([[[[0.1988]]]], device='cuda:0')
0.19875127
tensor([[[[0.2047]]]], device='cuda:0')
0.20469661
tensor([[[[0.1980]]]], device='cuda:0')
0.19804001
tensor([[[[0.2132]]]], device='cuda:0')
0.21316913
tensor([[[[0.2019]]]], device='cuda:0')
0.20190895
tensor([[[[0.1955]]]], device='cuda:0')
0.1955429
tensor([[[[0.2011]]]], device='cuda:0')
0.20108235
tensor([[[[0.1892]]]], device='cuda:0')
0.18921155
tensor([[[[0.1958]]]], device='cuda:0')
0.1958212
tensor([[[[0.1880]]]], device='cuda:0')
0.18803968


# Save weights

In [77]:
tensorflow_model.save("./lpips")
tensorflow_model.save_weights("weights.h5")



# Check that there are no problems with the loaded model

In [None]:
new_model = tf.keras.models.load_model("./lpips")



In [None]:
for i in range(10):
    i1 = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255
    i2 = np.random.rand(1, 3, 224, 224).astype(np.float32) * 255

    pytorch_result = pytorch_model(
        torch.from_numpy(i1).to(device),
        torch.from_numpy(i1 * i2).to(device),
    )
    tensorflow_result = new_model([
        np.moveaxis(i1, 1, -1),
        np.moveaxis(i1 * i2, 1, -1)]).numpy()

    print(pytorch_result)
    print(tensorflow_result)
    assert np.allclose(pytorch_result.cpu().numpy().squeeze(), tensorflow_result, )

tensor([[[[0.0071]]]], device='cuda:0')
0.0070816744
tensor([[[[0.0079]]]], device='cuda:0')
0.007931068
tensor([[[[0.0087]]]], device='cuda:0')
0.00866513
tensor([[[[0.0083]]]], device='cuda:0')
0.008283554
tensor([[[[0.0084]]]], device='cuda:0')
0.008445184
tensor([[[[0.0069]]]], device='cuda:0')
0.006898164
tensor([[[[0.0064]]]], device='cuda:0')
0.006368886
tensor([[[[0.0060]]]], device='cuda:0')
0.0060019363
tensor([[[[0.0073]]]], device='cuda:0')
0.007324224
tensor([[[[0.0078]]]], device='cuda:0')
0.007751484


In [None]:
!zip -r archive.zip./ lpips

  adding: lpips/ (stored 0%)
  adding: lpips/variables/ (stored 0%)
  adding: lpips/variables/variables.data-00000-of-00001 (deflated 7%)
  adding: lpips/variables/variables.index (deflated 66%)
  adding: lpips/assets/ (stored 0%)
  adding: lpips/fingerprint.pb (stored 0%)
  adding: lpips/saved_model.pb (deflated 91%)
  adding: lpips/keras_metadata.pb (deflated 95%)
