# CRN (Cascaded Refinement Network)

This is an attempt to re-implement the paper CRN

Paper: https://arxiv.org/pdf/1707.09405v1.pdf

Other Resources: 
* https://github.com/wojciechmo/crn

In [1]:
import tensorflow as tf

In [2]:
class RefinementModule(tf.keras.layers.Layer):
    def __init__(self, filters, final_module = False, **kwargs):
        super().__init__(**kwargs)
        
        # intermediate layer
        self.conv_1 = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
        self.norm_1 = tf.keras.layers.LayerNormalization()
        self.act_1 = tf.keras.layers.LeakyReLU(alpha = 0.2)
        
        if not final_module:
            self.conv_2 = tf.keras.layers.Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
            self.norm_2 = tf.keras.layers.LayerNormalization()
            self.act_2 = tf.keras.layers.LeakyReLU(alpha = 0.2)
    
    def call(self, inputs):
        out = self.act_1(self.norm_1(self.conv_1(inputs)))
        if hasattr(self, 'conv_2'):
            out = self.act_2(self.norm_2(self.conv_2(out)))
        return out

In [7]:
def generator(inp_shape = (1024, 2048, 19), n_modules = 9, k = 1):
    
    inp_label = tf.keras.layers.Input(shape = inp_shape, dtype = tf.float32, name = 'inp_label')
    
    down_sampled_labels = [inp_label]
    for i in range(n_modules - 1):
        down_sampled_labels.append(tf.keras.layers.AveragePooling2D()(down_sampled_labels[-1]))
            
    x = RefinementModule(filters = 1024)(down_sampled_labels.pop())
    x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')(x)
    
    x = tf.keras.layers.Concatenate()([x, down_sampled_labels.pop()])
    x = RefinementModule(filters = 1024)(x)
    x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')(x)
    
    x = tf.keras.layers.Concatenate()([x, down_sampled_labels.pop()])
    x = RefinementModule(filters = 1024)(x)
    x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')(x)
    
    x = tf.keras.layers.Concatenate()([x, down_sampled_labels.pop()])
    x = RefinementModule(filters = 1024)(x)
    x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')(x)
    
    x = tf.keras.layers.Concatenate()([x, down_sampled_labels.pop()])
    x = RefinementModule(filters = 1024)(x)
    x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')(x)
    
    x = tf.keras.layers.Concatenate()([x, down_sampled_labels.pop()])
    x = RefinementModule(filters = 512)(x)
    x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')(x)
    
    x = tf.keras.layers.Concatenate()([x, down_sampled_labels.pop()])
    x = RefinementModule(filters = 512)(x)
    x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')(x)
    
    x = tf.keras.layers.Concatenate()([x, down_sampled_labels.pop()])
    x = RefinementModule(filters = 128)(x)
    x = tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'bilinear')(x)
    
    x = tf.keras.layers.Concatenate()([x, down_sampled_labels.pop()])
    x = RefinementModule(filters = 32, final_module = True)(x)
    
    x = tf.keras.layers.Conv2D(filters = 3 * k, kernel_size = (1, 1), strides = (1, 1), padding = 'same')(x)
    x = tf.keras.layers.Activation('tanh')(x)
    
    return tf.keras.models.Model(inp_label, x, name = 'Generator')

In [46]:
class DiversePerceptualLoss(object):
    def __init__(self, pt_model = None, pt_layers = [], avg = True):
        self.avg = avg
        model = tf.keras.applications.VGG19(include_top = False, weights = 'imagenet') if pt_model is None else pt_model
        model.trainable = False
        
        layers = ['block1_conv2', 'block2_conv2', 'block3_conv2', 'block4_conv2', 'block5_conv2'] if len(pt_layers) == 0 else pt_layers
        
        outs = [model.get_layer(layer).output for layer in layers]
        self.model = tf.keras.models.Model(model.inputs, outs)
        
        self.preprocess_input = lambda x: tf.keras.applications.vgg19.preprocess_input(x)
        
    def __call__(self, real, gens, labels):
        real_outs = self.model(self.preprocess_input((real + 1)*127.5))
        
        all_labels = []
        for j in range(len(real_outs)):
            all_labels.append(tf.expand_dims(tf.image.resize(labels, [real_outs[j].shape[1], real_outs[j].shape[2]], 
                                                             tf.image.ResizeMethod.NEAREST_NEIGHBOR), axis = -2))
        
        losses = []
        for gen in gens:
            gen_outs = self.model(self.preprocess_input((gen + 1)*127.5))
            loss = 0
            for i, (r, g) in enumerate(zip(real_outs, gen_outs)):
                l = tf.math.reduce_mean(tf.math.abs(all_labels[i] * tf.expand_dims(r - g, axis = -1)), axis = [0, 1, 2])
                if self.avg:
                    loss += tf.math.reduce_mean(l, axis = 0, keepdims = True)
                else:
                    loss += tf.math.reduce_sum(l, axis = 0, keepdims = True)
            losses.append(loss)
            
        loss = tf.math.reduce_min(tf.concat(losses, axis = 0), axis = 0)
        if self.avg:
            loss = tf.math.reduce_mean(loss)
        else:
            loss = tf.math.reduce_sum(loss)
        return loss

In [48]:
class Trainer(object):
    def __init__(self, label_shape, k = 2, n_modules = 9, learning_rate = 1e-4, **kwargs):
        self.k = k
        self.optimizer = tf.keras.optimizers.Adam(learning_rate = learning_rate)
        
        self.generator = generator(inp_shape = label_shape, n_modules = n_modules, k = k)
        self.loss = DiversePerceptualLoss()
        
    def train_step(self, label, real):
        
        with tf.GradientTape() as tape:
            gen_imgs = self.generator(label, training = True)
            k_imgs = []
            for i in range(self.k):
                k_imgs.append(gen_imgs[:, :, :, i*3:(i+1)*3])
            loss = self.loss(real, k_imgs, labels)
            
        grads = tape.gradient(out, self.generator.trainable_variables)
        self.optimizer.apply_gradients(zip(grads, self.generator.trainable_variables))
        
        return loss
    
    def train(self, data, epochs = 1):
        losses = []
        for e in range(epochs):
            print(f'Epoch: {e} Starts')
            for label, real in data:
                loss = self.train_step(label, real)
                print('.', end='')
                
            losses.append(loss)
            print(f'\Loss: {loss}')
            print(f'Epoch: {e} Ends.\n')
        return losses