In [None]:
import numpy as np
import tensorflow as tf

def gkern(l=5, sig=1.):
    ax = np.linspace(-(l - 1) / 2., (l - 1) / 2., l)
    gauss = np.exp(-0.5 * np.square(ax) / np.square(sig))
    kernel = np.outer(gauss, gauss)
    return (kernel / np.sum(kernel)).astype(np.float32)

def gkern_tf(l=5, sig=1.):
    return tf.convert_to_tensor(gkern(l, sig))

class LaplacianPyramid(tf.keras.layers.Layer):
    def __init__(self, numLayers=5, guass_kernel=5, guass_sig=1.0, rgb=False, returnFull=False, gain=50):
        super(LaplacianPyramid, self).__init__()
        if rgb:
            self.outputDims = 3
        else:
            self.outputDims = 1
        self.gaussKernel = tf.reshape(gkern_tf(guass_kernel, guass_sig),(guass_kernel, guass_kernel, 1, 1))
        self.gaussKernel = tf.constant(tf.tile(self.gaussKernel, [1,1,self.outputDims,1]))
        self.numLayers = numLayers
        self.returnFull = returnFull
        self.upSample = tf.keras.layers.UpSampling2D(2, interpolation='gaussian')
        self.gain = gain
        
    def call(self,x):
        _dsOutputs = []
        laplacianOut = []
        if self.returnFull:
            _dsOutputs.append(x)
        for _ in range(self.numLayers):
            x = tf.pad(x, [[0, 0], [2, 2],[2, 2], [0, 0]], mode='SYMMETRIC')
            x = tf.nn.depthwise_conv2d(x, self.gaussKernel, [1,2,2,1], padding='VALID')
            _dsOutputs.append(x)
        
        laplacianOut.append(_dsOutputs[-1])
        for i in range(len(_dsOutputs)-1,0,-1):
            laplacianOut.insert(0, self.gain*tf.square(_dsOutputs[i-1] - self.upSample(_dsOutputs[i])))
        return laplacianOut

class LaplacianPyramidLoss(tf.keras.losses.Loss):
    def __init__(self, rgb=False):
        super().__init__()
        self.laplacianPyramid = LaplacianPyramid(rgb=rgb)
        self.MSE = tf.keras.losses.MeanSquaredError()

    def call(self, true, pred):
        true = self.laplacianPyramid(true)
        pred = self.laplacianPyramid(pred)
        mse = 0
        for i, (t, p) in enumerate(zip(true, pred)):
            if (i != len(true)):
                mse += (2**i)*(self.MSE(t,p))
            else:
                mse += (self.MSE(t,p))
        return mse 