# Game GAN

This is an attempt to re-implement Game Gan

**note: Some part of this re-implementation might be different from the original implementation.**

For more information about Game Gan, 

Refer: 

Paper: https://arxiv.org/pdf/2005.12126.pdf
    
original implementation: https://github.com/nv-tlabs/GameGAN_code

This code borrows heavily from https://github.com/nv-tlabs/GameGAN_code

In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.python.eager import def_function

In [2]:
class InstanceNormalization(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.epsilon = epsilon
        
    def build(self, input_shape):
        inp_dim = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.ys = self.add_weight(shape = (1, 1, 1, inp_dim), initializer = init, 
                                  trainable = True, name = 'scale')
        self.yb = self.add_weight(shape = (1, 1, 1, inp_dim), initializer = 'zeros', 
                                  trainable = True, name = 'shift')
        
    def call(self, inputs):
        mean = tf.math.reduce_mean(inputs, axis = [1, 2], keepdims = True)
        rstd = tf.math.rsqrt(tf.math.reduce_variance(inputs, axis = [1, 2], keepdims = True) + self.epsilon)
        norm = (inputs - mean) * rstd
        
        out = self.ys * norm + self.yb 
        return out

In [3]:
# referenced: spectral norm
# https://gist.github.com/FloydHsiu/828eea345e1ca6950e05bb42f0a75b50
# https://gist.github.com/FloydHsiu/ab33c7d98d78f9873757810e2f8db50d
class SpectralNormalization(tf.keras.layers.Wrapper):
    def __init__(self, layer, **kwargs):
        super().__init__(layer, **kwargs)
        pass
    
    def build(self, input_shape):
        
        if not self.layer.built:
            self.layer.build(input_shape)
            
        if not hasattr(self.layer, 'weight'):
            self.w = tf.convert_to_tensor(self.layer.get_weights()[0])
        else:
            self.w = tf.identity(self.layer.weight)
        
        if not hasattr(self, 'w'):
            raise ValueError()
        
        self.w_shape = self.w.shape.as_list()
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 0.02)
        self.u = self.add_weight(shape = (1, self.w_shape[-1]), initializer = init, 
                                   trainable = False, name = 'spectral_norm_u')
        
        # super().build()
        
    @def_function.function
    def call(self, inputs, training = None):
        if training is None:
            training = tf.keras.backend.learning_phase()
            
        if training == True:
            self._compute_weights()
            
        output = self.layer(inputs)
        return output
    
    def _compute_weights(self):
        w_reshaped = tf.reshape(self.w, (-1, self.w_shape[-1]))
        eps = 1e-12
        
        _u = tf.identity(self.u)
        _v = tf.matmul(_u, tf.transpose(w_reshaped, perm = [1, 0]))
        _v /= tf.maximum(tf.math.reduce_sum(_v ** 2) ** 0.5, eps)
        _u = tf.matmul(_v, w_reshaped)
        _u /= tf.maximum(tf.math.reduce_sum(_u ** 2) ** 0.5, eps)
        
        _u = tf.stop_gradient(_u)
        _v = tf.stop_gradient(_v)
        
        self.u.assign(_u)
        sigma = tf.matmul(tf.matmul(_v, w_reshaped), tf.transpose(_u, perm = [1, 0]))
        
        self.layer.weight.assign(self.w / sigma)

In [4]:
class Linear(tf.keras.layers.Layer):
    def __init__(self, neurons, use_bias = True, gain = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.neurons = neurons
        self.use_bias = use_bias
        self.gain = gain
        
    def build(self, input_shape):
        inp_neurons = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.W = self.add_weight(shape = (inp_neurons, self.neurons), initializer = init, 
                                 trainable = True, name = 'weight')
        if self.use_bias:
            self.B = self.add_weight(shape = (1, self.neurons), initializer = 'zeros', 
                                     trainable = True, name = 'bias')
            
        self.w_scale = self.gain * tf.math.rsqrt(tf.cast(inp_neurons, tf.float32))
            
    def call(self, inputs):
        if self.use_bias:
            return tf.add(tf.matmul(inputs * self.w_scale, self.W), self.B)
        return tf.matmul(inputs * self.w_scale, self.W)

In [5]:
class ReflectionPadding2D(tf.keras.layers.Layer):
    def __init__(self, padding, **kwargs):
        super().__init__(**kwargs)
        if isinstance(padding, int):
            self.padding = (padding, padding), (padding, padding)
        elif isinstance(padding, tuple) | isinstance(padding, list):
            if isinstance(padding[0], tuple) | isinstance(padding[0], list):
                self.padding = (padding[0][0], padding[0][1]), (padding[1][0], padding[1][1])
            elif isinstance(padding[0], int):
                self.padding = (padding[0], padding[0]), (padding[1], padding[1])
                
    def call(self, inputs):
        return tf.pad(inputs, ((0, 0), self.padding[0], self.padding[1], (0, 0)), 'REFLECT')

In [6]:
class Conv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, padding, use_bias = True, gain = 1.0, **kwargs):
        super().__init__(**kwargs)
        self.filters = filters
        self.kernel_size = kernel_size
        self.strides = strides
        self.use_bias = use_bias
        self.gain = gain
        
        if padding.upper() in ['SAME', 'VALID']:
            self.padding = padding.upper()
        elif isinstance(padding, int) | isinstance(padding, tuple) | isinstance(padding, list):
            self.padding = ReflectionPadding2D(padding)
        
    def build(self, input_shape):
        inp_filters = input_shape[-1]
        
        init = tf.keras.initializers.RandomNormal(mean = 0.0, stddev = 1.0)
        self.kernel = self.add_weight(shape = self.kernel_size + (inp_filters, self.filters), initializer = init, 
                                      trainable = True, name = 'weight')
        
        if self.use_bias:
            self.bias = self.add_weight(shape = (1, 1, 1, self.filters), initializer = 'zeros', 
                                        trainable = True, name = 'bias')
            
        fan_in = tf.cast(self.kernel_size[0] * self.kernel_size[1] * inp_filters, tf.float32)
        self.w_scale = self.gain * tf.math.rsqrt(fan_in)
        
    def call(self, inputs):
        if isinstance(self.padding, str):
            out = tf.nn.conv2d(inputs * self.w_scale, self.kernel, self.strides, self.padding)
        else:
            out = tf.nn.conv2d(self.padding(inputs) * self.w_scale, self.kernel, self.strides, 'VALID')
            
        if self.use_bias:
            return tf.add(out, self.bias)
        return out

In [7]:
class SPADE(tf.keras.layers.Layer):
    def __init__(self, epsilon = 1e-8, **kwargs):
        super().__init__(**kwargs)
        self.norm = InstanceNormalization()
        
        self.conv = Conv2D(filters = 128, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
        self.act = tf.keras.layers.ReLU()
        
        def build(self, inputs):
            inp_filters = inputs[0][-1]
        
            self.conv_gamma = Conv2D(filters = inp_filters, kernel_size = (3, 3), strides = (1 ,1), padding = 'same')
            self.conv_beta = Conv2D(filters = inp_filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
        
    def call(self, inputs):
        assert len(inputs) == 2
        x, mask = inputs
        mask = tf.image.resize(mask, tf.shape(x)[1:3], tf.image.ResizeMethod.NEAREST_NEIGHBOR)
        
        out = self.act(self.conv(mask))
        gamma = self.conv_gamma(out)
        beta = self.conv_beta(out)
        
        return self.norm(x) * gamma + beta

In [8]:
class ImageEncoder(tf.keras.layers.Layer):
    def __init__(self, encoder_for = 'pacman', **kwargs):
        super().__init__(**kwargs)
        model = tf.keras.models.Sequential()
        if encoder_for == 'pacman':
            for _ in range(2):
                model.add(Conv2D(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'valid'))
                model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                model.add(Conv2D(filters = 64, kernel_size = (3, 3), strides = (1, 1), padding = 'valid'))
                model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(Conv2D(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'valid'))
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            
            #model.add(tf.keras.layers.Reshape(8*8*64))
            
        elif encoder_for == 'vizdoom':
            model.add(Conv2D(filters = 64, kernel_size = (4, 4), strides = (1, 1), padding = 'same'))
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            
            for _ in range(3):
                model.add(Conv2D(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'valid'))
                model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                
            #model.add(tf.keras.layers.Reshape(7*7*64))
            
        else:
            raise Exception(f'Image Encoder for `{encoder_for}` is still not defined.')
        
        model.add(tf.keras.layers.Flatten())
        model.add(Linear(512))
        
        self.model = model
        
    def call(self, inputs):
        return self.model(inputs)

In [9]:
class ActionLSTM(tf.keras.layers.Layer):
    def __init__(self, hidden_dim = 512, **kwargs):
        super().__init__(**kwargs)
        self.hidden_dim = hidden_dim
        
        self.W_fv = Linear(neurons = hidden_dim, use_bias = False)
        self.W_fs = Linear(neurons = hidden_dim, use_bias = False)
        
        self.W_iv = Linear(neurons = hidden_dim, use_bias = False)
        self.W_is = Linear(neurons = hidden_dim, use_bias = False)
        
        self.W_cv = Linear(neurons = hidden_dim, use_bias = False)
        self.W_cs = Linear(neurons = hidden_dim, use_bias = False)
        
        self.W_ov = Linear(neurons = hidden_dim, use_bias = False)
        self.W_os = Linear(neurons = hidden_dim, use_bias = False)
        
    def initialize_states(self, batch_size):
        return [tf.zeros((batch_size, self.hidden_dim)), tf.zeros((batch_size, self.hidden_dim))]
        
    def call(self, inputs):
        v, s, prev_cell_state = inputs
        
        f = tf.nn.sigmoid(tf.add(tf.matmul(v, self.W_fv), tf.matmul(s, self.W_fs)))
        i = tf.nn.sigmoid(tf.add(tf.matmul(v, self.W_iv), tf.matmul(s, self.W_is)))
        c = tf.nn.tanh(tf.add(tf.matmul(v, self.W_cv), tf.matmul(s, self.W_cs)))
        o = tf.nn.sigmoid(tf.add(tf.matmul(v, self.W_ov), tf.matmul(s, self.W_os)))
        
        new_cell_state = tf.multiply(f, prev_cell_state) + tf.multiply(i * c)
        h_state = tf.multiply(o, tf.nn.tanh(new_cell_state))
        
        return h_state, new_cell_state

In [10]:
class DynamicsEngine(tf.keras.layers.Layer):
    def __init__(self, dim = 512, use_memory = True, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.use_memory = use_memory
        
        self.mlp_action = tf.keras.models.Sequential([Linear(neurons = dim), 
                                                      tf.keras.layers.LeakyReLU(alpha = 0.2)])
        self.mlp_z = tf.keras.models.Sequential([Linear(neurons = dim), 
                                                 tf.keras.layers.LeakyReLU(alpha = 0.2)])
        
        if use_memory:
            self.mlp_mem = tf.keras.models.Sequential([Linear(neurons = dim), 
                                                       tf.keras.layers.LeakyReLU(alpha = 0.2)])
        
        self.mlp = tf.keras.models.Sequential([Linear(neurons = dim), 
                                               tf.keras.layers.LeakyReLU(alpha = 0.2), 
                                               Linear(neurons = dim)])
        
        self.img_encoder = ImageEncoder()
        
        self.rnn = ActionLSTM(hidden_dim = dim)
        
    def initialize_states(self, batch_size):
        return self.rnn.initialize_states(batch_size)
        
    def build(self, input_shape):
        h_shp_dim = input_shape[0][-1]
        
        if h_shp_dim != self.dim:
            self.mlp_h = tf.keras.models.Sequential([Linear(neurons = self.dim), 
                                                     tf.keras.layers.LeakyReLU(alpha = 0.2)])
            
    def call(self, inputs):
        if self.use_memory:
            assert len(inputs) == 6
            hidden_state, cell_state, img, stoch_var, action, mem = inputs
            
        else:
            assert len(inputs) == 5
            hidden_state, cell_state, img, stoch_var, action = inputs
        
        a = self.mlp_action(action)
        z = self.mlp_z(stoch_var)
        
        if hasattr(self, 'mlp_h'):
            h = self.mlp_h(hidden_state)
        else:
            h = hidden_state
            
        if self.use_memory:
            x = tf.concat([a, z, mem], axis = -1)
        else:
            x = tf.concat([a, z], axis = -1)
            
        v = tf.multiply(h, self.mlp(x))
        s = self.img_encoder(img)
        
        new_hidden_state, new_cell_state = self.rnn([v, s, cell_state])
        return new_hidden_state, new_cell_state

In [11]:
class Memory(tf.keras.layers.Layer):
    def __init__(self, dim = 512, use_mem_h = False, memory_for = 'pacman', 
                 mem_h = 441, **kwargs):
        super().__init__(**kwargs)
        self.dim = dim
        self.use_mem_h = use_mem_h
        self.memory_for = memory_for
        self.mem_h = mem_h
        
        self.K_block = tf.keras.models.Sequential([Linear(neurons = dim), 
                                                   tf.keras.layers.LeakyReLU(alpha = 0.2), 
                                                   Linear(neurons = 9)])
    
        self.G_block = tf.keras.models.Sequential([Linear(neurons = dim), 
                                                   tf.keras.layers.LeakyReLU(alpha = 0.2), 
                                                   Linear(neurons = 1), 
                                                   tf.keras.layers.Activation('sigmoid')])
        
        self.E_block = tf.keras.models.Sequential([Linear(neurons = dim*2)])
        
        gain = tf.sqrt(1.0/(mem_h + dim))
        init = tf.keras.initializers.RandomUniform(minval = -gain, maxval = gain)
        self.add_weight(shape = (mem_h, dim), initializer = init, 
                        trainable = True)
        
    def initialize_memory(self, batch_size):
        return tf.tile(tf.expand_dims(self.add_weight, axis = 0), [batch_size, 1, 1])
    
    def merge_K_G(self, prev_alpha, kernels):
        bs, d = prev_alpha
        mem_hw = int(tf.sqrt(tf.cast(d, tf.float32)))
        
        p = tf.reshape(prev_alpha, (bs, mem_hw, mem_hw, 1))
        p = tf.reshape(tf.transpose(p, perm = [1, 2, 0, 4]), (1, mem_hw, mem_hw, bs))
        
        out = tf.nn.conv2d(p, kernels, padding = 'same')
        alpha = tf.reshape(tf.transpose(out, perm = [3, 0, 1, 2]), (bs, -1))
        
        return alpha
    
    def write(self, erase, add, alpha, M):
        alpha_write = tf.expand_dims(alpha, axis = -1)
        erase = tf.matmul(alpha_write, tf.expand_dims(erase, axis = 1))
        add = tf.matmul(alpha_write, tf.expand_dims(add, axis = 1))
        M = M * (1 - erase) + add
        return M
    
    def read(self, alpha, M):
        out = tf.matmul(tf.expand_dims(alpha, axis = 1), M)
        out = tf.squeeze(out)
        return out
    
    def call(self, hidden, action, prev_hidden, prev_alpha, M, read_only = False):
        bs = action.shape[0]
        if self.use_mem_h:
            h = hidden
        else:
            #h_norm = hidden * tf.rsqrt(tf.math.reduce_sum(hidden*hidden, axis = 1, keepdims = True))
            h_norm = hidden / tf.norm(hidden, axis = 1, keepdims = True)
            prev_h_norm = hidden / tf.norm(prev_hidden, axis = 1, keepdims = True)
            
            h = h_norm - prev_h_norm
            
        kernels = tf.reshape(tf.transpose(self.K_block(action), perm = [1, 0]), (3, 3, 1, bs))
            
        new_a = action.numpy()
        action_label = tf.argmax(action, axis = 1).numpy()
        mask = np.zeros((bs, 1))
        
        for i in range(bs):
            if self.memory_for == 'pacman':
                if action_label[i] == 2:
                    new_a[i][1] = 1.0
                    new_a[i][2] = 0.0
                    mask[i][0] = 1.0
                    
                elif action_label[i] == 4:
                    new_a[i][3] = 1.0
                    new_a[i][4] = 0.0
                    mask[i][0] = 1.0
                    
            elif self.memory_for == 'vizdoom':
                if action_label[i] == 0:
                    new_a[i][1] == 1.0
                    new_a[i][0] == 0.0
                    new_a[i][0] == 1.0
                    
        mask = tf.reshape(tf.cast(mask, tf.float32), (-1, 1, 1, 1))
        new_a = tf.cast(new_a, tf.float32)
        
        flipped_kernels = tf.keras.backend.reverse(tf.reshape(tf.transpose(self.K_block(new_a), perm = [1, 0]), (3, 3, 1, -1)), [0, 1])
        kernels = (1 - mask) * kernels + mask * flipped_kernels
        
        kernels = tf.transpose(tf.reshape(kernels, (-1, bs)), perm = [1, 0])
        kernels = tf.nn.softmax(kernels, axis = 1)
        kernels = tf.reshape(tf.transpose(kernels, perm = [1, 0]), (3, 3, 1, bs))
        
        g = self.G_block(h)
        
        alpha =  g * self.merge_K_G(prev_alpha, kernels) +  (1 - g) * prev_alpha
        #alpha = alpha * g + prev_alpha * (1 - g)
        
        if not read_only:
            e = self.E_block(hidden)
            erase = tf.nn.sigmoid(e[:, :self.dim])
            add = e[:, self.dim:]
            
            M = self.write(erase, add, alpha, M)
            
        read = self.read(alpha, M)
            
        return read, M, alpha

In [12]:
class RenderingResidualBlock(tf.keras.layers.Layer):
    def __init__(self, filters, up_sample = True, **kwargs):
        super().__init__(**kwargs)
        
        model = tf.keras.models.Sequential()
        model.add(InstanceNormalization())
        model.add(tf.keras.layers.ReLU())
        if up_sample:
            model.add(tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'nearest'))
            
        model.add(Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same'))
        model.add(InstanceNormalization())
        model.add(tf.keras.layers.ReLU())
        model.add(Conv2D(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same'))
        
        res_path = tf.keras.models.Sequential()
        if up_sample:
            res_path.add(tf.keras.layers.UpSampling2D(size = (2, 2), interpolation = 'nearest'))
        res_path.add(Conv2D(filters = filters, kernel_size = (1, 1), strides = (1, 1), padding = 'same'))
        
        self.model = model
        self.res_path = res_path
        
    def call(self, inputs):
        return tf.add(self.model(inputs), self.res_path(inputs))

In [13]:
class SimpleRenderingEngine(tf.keras.layers.Layer):
    def __init__(self, encoder_for = 'pacman', **kwargs):
        super().__init__(**kwargs)
        
        model = tf.keras.models.Sequential()
        model.add(Linear(neurons = 7*7*512))
        model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
        model.add(tf.keras.layers.Reshape((7, 7, 512)))
        
        if encoder_for == 'pacman':
            model.add(tf.keras.layers.Conv2DTranspose(filters = 512, kernel_size = (3, 3), strides = (1, 1), 
                                                      padding = 'valid', output_padding = 0))       
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(tf.keras.layers.Conv2DTranspose(filters = 256, kernel_size = (3, 3), strides = (2, 2), 
                                                      padding = 'valid', output_padding = 1))       
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (4, 4), strides = (2, 2), 
                                                      padding = 'valid', output_padding = 0))       
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(tf.keras.layers.Conv2DTranspose(filters = 64, kernel_size = (4, 4), strides = (2, 2), 
                                                      padding = 'valid', output_padding = 0))       
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(tf.keras.layers.Conv2DTranspose(filters = 3, kernel_size = (3, 3), strides = (1, 1), 
                                                      padding = 'valid', output_padding = 0))       
            
        elif encoder_for == 'vizdoom':
            model.add(tf.keras.layers.Conv2DTranspose(filters = 512, kernel_size = (4, 4), strides = (1, 1), 
                                                      padding = 'valid', output_padding = 0))       
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(tf.keras.layers.Conv2DTranspose(filters = 256, kernel_size = (4, 4), strides = (1, 1), 
                                                      padding = 'valid', output_padding = 0))       
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (5, 5), strides = (2, 2), 
                                                      padding = 'valid', output_padding = 0))       
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(tf.keras.layers.Conv2DTranspose(filters = 64, kernel_size = (5, 5), strides = (2, 2), 
                                                      padding = 'valid', output_padding = 0))       
            model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
            model.add(tf.keras.layers.Conv2DTranspose(filters = 3, kernel_size = (4, 4), strides = (1, 1), 
                                                      padding = 'valid', output_padding = 0))
            
        self.model = model
        
    def call(self, inputs):
        return self.model(inputs)

In [14]:
class DisentanglingRenderingEngine(tf.keras.layers.Layer):
    def __init__(self, res_blocks = True, relax_dynamic_constraint = True, con_h = False, render_for = 'pacman',
                 apply_mask = True, sigmoid_maps = False, base_temperature = 0.1, **kwargs):
        super().__init__(**kwargs)
        self.res_blocks = res_blocks
        self.relax_dynamic_constraint = relax_dynamic_constraint
        self.con_h = con_h
        self.render_for = render_for
        self.apply_mask = apply_mask
        self.sigmoid_maps = sigmoid_maps
        self.base_temperature = base_temperature
    
    def build(self, input_shape):
        num_components = len(input_shape)
        assert num_components >= 2
        if num_components == 2:
            assert self.con_h == False, 'More than two inputs required for concatenation, first two inputs are not concatenated'
        
        self.all_get_maps = []
        self.proj_v = []
        self.R_block = []
        self.spade_layers = []
        self.more_layers = []
        self.output_layers = []
        self.fine_masks = []
        
        d = 2 if self.con_h else num_components
        for ind in range(d):
            
            ### ROUGH SKETCH STAGE
            # Layers for Extracting attribute map (A) and object map (O) for each c_k vectors.
            map_model = tf.keras.models.Sequential()
            map_model.add(Linear(neurons = 3*3*128))
            map_model.add(tf.keras.layers.Reshape((3, 3, 128)))
            
            if self.res_blocks:
                map_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                map_model.add(tf.keras.layers.Conv2DTranspose(filters = 512, kernel_size = (3, 3), strides = (1, 1), 
                                                              padding = 'valid', output_padding = 0))
                map_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                map_model.add(tf.keras.layers.Conv2DTranspose(filters = 32 + 1, kernel_size = (3, 3), strides = (1, 1), 
                                                              padding = 'valid', output_padding = 0))
                
            else:
                map_model.add(RenderingResidualBlock(filters = 512, up_sample = False))
                map_model.add(RenderingResidualBlock(filters = 32+1, up_sample = False))
            
            self.all_get_maps.append(map_model)
            
            # layers for extracting v_k
            v_model = tf.keras.models.Sequential()
            if self.relax_dynamic_constraint:
                v_model.add(Linear(neurons = 7*7*32))
                v_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                v_model.add(tf.keras.layers.Reshape((7, 7, 32)))

            else:
                v_model.add(Linear(neurons = 32))
                v_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                v_model.add(tf.keras.layers.Lambda(lambda x: tf.tile(x[:, tf.newaxis, tf.newaxis, :], 
                                                                     [1, 7, 7, 1])))

            self.proj_v.append(v_model)
            
                
            # For R_k
            r_model = tf.keras.models.Sequential()
            if self.res_blocks:
                r_model.add(RenderingResidualBlock(filters = 256, up_sample = False))
                r_model.add(RenderingResidualBlock(filters = 128, up_sample = True))

            else:
                r_model.add(tf.keras.layers.Conv2DTranspose(filters = 256, kernel_size = (3, 3), strides = (1, 1), 
                                                            padding = 'valid', output_padding = 0))
                r_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                
                if self.render_for == 'pacman':
                    r_model.add(tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (3, 3), strides = (2, 2), 
                                                                    padding = 'valid', output_padding = 1))
                
                elif self.render_for == 'vizdoom':
                    r_model.add(tf.keras.layers.Conv2DTranspose(filters = 128, kernel_size = (3, 3), strides = (2, 2), 
                                                                    padding = 'same', output_padding = 0))
            
            self.R_block.append(r_model)
            
            
            ### ATTRIBUTE STAGE
            # spade layers
            self.spade_layers.append(SPADE())
            
            # some more transposed layers
            more_layers_model = tf.keras.models.Sequential()
            if self.render_for == 'pacman':
                if self.res_blocks:
                    more_layers_model.add(RenderingResidualBlock(filters = 64, up_sample = True))
                    more_layers_model.add(RenderingResidualBlock(filters = 32, up_sample = True))
                    
                else:
                    more_layers_model.add(tf.keras.layers.Conv2DTranspose(
                        filters = 64, kernel_size = (4, 4), strides = (2, 2), padding = 'valid', output_padding = 0))
                    more_layers_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                    more_layers_model.add(tf.keras.layers.Conv2DTranspose(
                        filters = 32, kernel_size = (4, 4), strides = (2, 2), padding = 'valid', output_padding = 0))
                    more_layers_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                    
            elif self.render_for == 'vizdoom':
                if self.res_blocks:
                    more_layers_model.add(RenderingResidualBlock(filters = 64, up_sample = True))
                    more_layers_model.add(RenderingResidualBlock(filters = 32, up_sample = True))
                    
                else:
                    more_layers_model.add(tf.keras.layers.Conv2DTranspose(
                        filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'same', output_padding = 0))
                    more_layers_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                    more_layers_model.add(tf.keras.layers.Conv2DTranspose(
                        filters = 32, kernel_size = (3, 3), strides = (2, 2), padding = 'same', output_padding = 0))
                    more_layers_model.add(tf.keras.layers.LeakyReLU(alpha = 0.2))
                    
            self.more_layers.append(more_layers_model)
            
            
            ### FINAL RENDERING STAGE
            self.output_layers.append(
                tf.keras.models.Sequential([
                    InstanceNormalization(),
                    tf.keras.layers.ReLU(),
                    Conv2D(filters = 3, kernel_size = (3, 3), strides = (1, 1), padding = 'same')
                ])
            )
            
            self.fine_masks.append(
                tf.keras.models.Sequential([
                    tf.keras.layers.ReLU(),
                    Conv2D(filters = 1, kernel_size = (3, 3), strides = (1, 1), padding = 'same'),
                    tf.keras.layers.LeakyReLU(alpha = 0.2)
                ])
            )
            
        # does not effects much (as per paper)
        if self.apply_mask: 
            self.mask_N = tf.keras.models.Sequential()
            self.mask_N.add(Conv2D(filters = 512, kernel_size = (1, 1), strides = (1, 1), padding = 'valid'))
            self.mask_N.add(tf.keras.layers.LeakyReLU(alpha = 0.2))

            dim = 2 if self.con_h else num_components
            self.mask_N.add(Conv2D(filters = dim, kernel_size = (1, 1), strides = (1, 1), padding = 'valid'))
            
        
    def call(self, inputs):
        assert len(inputs) >= 2
        # len(inputs) == # of components
        
        if len(inputs) == 2:
            assert self.con_h == False, 'More than two inputs required for concatenation, first two inputs are not concatenated'
            m, h = inputs
            c_vector = [m, h]
        elif len(inputs) > 2:
            m, h = inputs[:2]
            c_vector = [m, h]
            if len(inputs[2:]) == 1:
                v_inp = inputs[-1]
                c_vector.append(v_inp)
            else:
                if self.con_h:
                    v_inp = tf.concat(inputs[2:], axis = -1)
                    c_vector.append(v_inp)
                else:
                    c_vector += inputs[2:]
                
        
        attr_maps, obj_maps, vs = [], [], []
        
        d = 2 if self.con_h else len(inputs)
        for i in range(d):
            maps = self.all_get_maps[i](c_vector[i])
            attr_maps.append(maps[:, :, :, 1:])
            obj_maps.append(maps[:, :, :, :1])
            
            idx = -1 if self.con_h else i
            vs.append(self.proj_v[i](c_vector[idx]))
            
        obj_len = len(obj_maps)
        obj_maps = tf.concat(obj_maps, axis = -1)
        if self.sigmoid_maps:
            obj_maps = tf.nn.sigmoid(obj_maps / self.base_temperature)
        else:
            obj_maps = tf.nn.softmax(obj_maps / self.base_temperature, axis = -1)
        obj_maps = tf.split(obj_maps, obj_len, -1)
            
        base_singles, unmasked_base_imgs = [], []
        for i, (om, am, v_) in enumerate(zip(obj_maps, attr_maps, vs)):
            r_out = self.R_block[i](om * v_)
            spade_out = self.spade_layers[i]([r_out, om * am])
            more_out = self.more_layers[i](spade_out)
            
            base_singles.append(self.fine_masks[i](more_out))
            unmasked_base_imgs.append(tf.nn.tanh(self.output_layers[i](more_out)))

        fine_mask = tf.concat(base_singles, axis = -1)
        fine_mask = self.mask_N(mask) if self.apply_mask else mask
        fine_mask = tf.nn.softmax(mask, axis = -1)
        fine_mask = tf.split(mask, len(base_singles), -1)
        
        out_img, base_imgs = 0, []
        for msk, umsk in zip(fine_mask, unmasked_base_imgs):
            base_imgs.append(msk * umsk)
            out_img += base_imgs[-1]
            
        for umsk in unmasked_base_imgs:
            base_imgs.append(umsk)
            
        return out_img, fine_mask, obj_maps, base_imgs

In [15]:
class Generator(tf.keras.models.Model):
    def __init__(self, gen_for = 'pacman', dim = 512, use_memory = True, use_mem_h = False, mem_h = 441, 
                 simple_render_block = False, res_blocks = True, relax_dynamic_constraint = True, con_h = False, 
                 apply_mask = True, sigmoid_maps = False, base_temperature = 0.1, 
                 z_dim = 32, alpha_loss_multiplier = -1, cycle_loss = False, 
                 cycle_start_epoch = 0, rev_multiply_map = False, **kwargs):
        super().__init__(**kwargs)
        
        self.gen_for = gen_for
        self.dim = dim
        self.use_memory = use_memory
        self.use_mem_h = use_mem_h
        self.mem_h = mem_h
        self.simple_render_block = simple_render_block
        self.res_blocks = res_blocks
        self.relax_dynamic_constraint = relax_dynamic_constraint
        self.con_h = con_h
        self.apply_mask = apply_mask
        self.sigmoid_maps = sigmoid_maps
        self.base_temperature = base_temperature
        self.z_dim = z_dim
        self.alpha_loss_multiplier= alpha_loss_multiplier
        self.cycle_loss = cycle_loss
        self.cycle_start_epoch = cycle_start_epoch
        self.rev_multiply_map = rev_multiply_map
        
        # PARTS OF GENERATOR
        # dynamics engine
        self.dynamics_engine = DynamicsEngine(dim = dim, use_memory = use_memory)
        
        # memory module
        if use_memory:
            self.memory = Memory(dim = dim, use_mem_h = use_mem_h, memory_for = gen_for, mem_h = mem_h)
            
        # rendering engine
        if simple_render_block:
            self.render_engine = SimpleRenderingEngine(render_for = gen_for)
        else:
            self.render_engine = DisentanglingRenderingEngine(res_blocks = res_blocks, 
                                                              relax_dynamic_constraint = relax_dynamic_constraint, 
                                                              con_h = con_h, render_for = gen_for, 
                                                              apply_mask = apply_mask, sigmoid_maps = sigmoid_maps, 
                                                              base_temperature = base_temperature)

    
    def run_step(self, image, h, c, action, zdist, batch_size, prev_read, 
                 prev_alpha, M, read_only = False):
            
        # dynamics engine
        z = zdist(shape = (batch_size, self.z_dim))
        de_inp = [h, c, image, z, action]
        de_inp += [prev_read] if self.use_memory else []
        
        new_h, new_c = self.dynamics_engine(de_inp)
        
        # memory module
        if self.use_memory:
            # hidden, action, prev_hidden, prev_alpha, M, read_only = False
            read, M, alpha = self.memory(new_h, action, h, prev_alpha, M, read_only)
            
            assert self.simple_render_block == False
            render_inp = [read, new_h]
            
            prev_alpha, prev_read = alpha, read
        else:
            assert self.simple_render_block == True
            render_inp = new_h
            
        
        # Render Engine
        ### note: input to render engine is limited to at max 2 inputs
        ### other inputs to it are not defined yet.
        alpha_loss = 0
        if self.simple_render_block:
            render_out = self.render_engine(render_inp)
            render_out = tf.nn.tanh(render_out)
            
        else:
            render_out, fine_mask, obj_maps, base_imgs = self.render_engine(render_inp)
            if self.alpha_loss_multiplier > 0:
                # memory regularization
                for i in range(1, len(fine_mask)):
                    alpha_loss += tf.reduce_sum(tf.abs(m[i]))/batch_size
            
            
        return (render_outs, fine_mask, prev_alpha, alpha_loss, z, M, prev_read, new_h, 
                new_c, obj_maps, base_imgs)
    
    def call(self, inputs):
        # action, zdist, images, warmup_steps = 10, epoch = 0
        action, zdist, images, warmup_steps, epoch = inputs
        
        batch_size = images[0].shape[0]
        h_init, c_init = self.dynamics_engine.initialize_states(batch_size)
        
        if self.use_memory:
            M_init = self.memory.initialize_memory(batch_size)
            
            alpha_init = tf.zeros((batch_size, self.mem_h))
            mem_wh = int(tf.sqrt(self.mem_h))
            alpha_init[:, mem_wh * (mem_wh//2) + mem_wh//2] = 1.0
            
            read_init = tf.zeros((batch_size, self.dim))
            
            
        outputs, zs, alphas, base_imgs_all, hiddens, fine_masks, obj_maps = [], [], [], [], [], []
        alpha_losses = 0
        
        h, c = h_init, c_init
        prev_read, M, prev_alpha = read_init, M_init, alpha_init 
        
        for i in range(len(actions) - 1):
            out_img = out_img if i > warmup_steps else images[i]
            
            out_img, fm, prev_alpha, alpha_loss, z, M, prev_read, h,\
            c, obj_map, base_imgs = self.run_step(out_img, h, c, actions[i], zdist, 
                                                 batch_size, prev_read, prev_alpha, M)
            
            outputs.append(out_img)
            fine_masks.append(fm)
            alphas.append(prev_alpha)
            alpha_losses += alpha_loss
            zs.append(z)
            base_imgs_all.append(base_imgs)
            hiddens.append(h)
            obj_maps.append(obj_map)
            
        alpha_losses /= (len(actions) - 1)
        
        rev_outputs, rev_base_imgs, rev_alphas, rev_maps = [], [], [], []
        response = {}
        if self.use_memory and (self.cycle_loss and (self.cycle_start_epoch <= 0)) and (self.simple_render_block == False):
            # read from previously visited locations for cycle loss
            for i in range(len(alphas) - 1, -1, -1):
                cur_read = self.memory.read(alphas[i], M)
                render_out, fm, om, bm = self.render_engine([cur_read, tf.zeros_like(cur_read)])
                
                if self.rev_multiply_map:
                    rev_outputs.append(bm[2] * fine_masks[i][0])
                else:
                    rev_outputs.append(bm[2])
                    
                rev_maps.append(fm)
                rev_base_imgs.append(base_imgs)
                rev_alphas.append(alphas[i])
                
        
        response['alpha_loss'] = alpha_losses
        response['rev_outputs'] = rev_outputs
        response['rev_alphas'] = rev_alphas
        response['rev_maps'] = rev_maps
        response['rev_base_imgs_all'] = rev_base_imgs
        response['maps'] = fine_masks
        response['obj_maps'] = obj_maps
        response['zs'] = zs
        response['outputs'] = outputs
        response['alphas'] = alphas
        response['base_imgs_all'] = base_imgs_all
        return response

In [16]:
class SNConv2D(tf.keras.layers.Layer):
    def __init__(self, filters, kernel_size, strides, padding, use_bias = True, gain = 1.0, **kwargs):
        super().__init__(**kwargs)
        
        self.layer = SpectralNormalization(Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, 
                                                  padding = padding, use_bias = use_bias, gain = gain))
        
    def call(self, inputs):
        return self.layer(inputs)
    
class SNLinear(tf.keras.layers.Layer):
    def __init__(self, neurons, use_bias = True, gain = 1.0, **kwargs):
        super().__init__(**kwargs)
        
        self.layer = SpectralNormalization(Linear(neurons = neurons, use_bias = use_bias, gain = gain))
        
    def call(self, inputs):
        return self.layer(inputs)

In [17]:
class SimpleDiscriminator(tf.keras.layers.Layer):
    def __init__(self, disc_for = 'pacman', apply_sn = True, **kwargs):
        super().__init__(**kwargs)
        self.apply_sn = apply_sn
        
        if disc_for == 'pacman':
            self.model = tf.keras.models.Sequential([
                self.get_conv(filters = 16, kernel_size = (5, 5), strides = (2, 2), padding = 'valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.LeakyReLU(alpha = 0.2),
                
                self.get_conv(filters = 32, kernel_size = (5, 5), strides = (2, 2), padding = 'valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.LeakyReLU(alpha = 0.2),
                
                self.get_conv(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.LeakyReLU(alpha = 0.2),
                
                self.get_conv(filters = 64, kernel_size = (3, 3), strides = (2, 2), padding = 'valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.LeakyReLU(alpha = 0.2),
                tf.keras.layers.Reshape((3, 3, 64)),
            ])
        elif disc_for == 'vizdoom':
            self.model = tf.keras.models.Sequential([
                self.get_conv(filters = 64, kernel_size = (4, 4), strides = (2, 2), padding = 'valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.LeakyReLU(alpha = 0.2),
                
                self.get_conv(filters = 128, kernel_size = (3, 3), strides = (2, 2), padding = 'valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.LeakyReLU(alpha = 0.2),
                
                self.get_conv(filters = 256, kernel_size = (3, 3), strides = (2, 2), padding = 'valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.LeakyReLU(alpha = 0.2),
                
                self.get_conv(filters = 256, kernel_size = (3, 3), strides = (2, 2), padding = 'valid'),
                tf.keras.layers.BatchNormalization(),
                tf.keras.layers.LeakyReLU(alpha = 0.2),
                tf.keras.layers.Reshape((3, 3, 256)),
            ])
            
    def get_conv(self, filters, kernel_size, strides, padding):
        if self.apply_sn:
            return SNConv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)
        return Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)
            
    def call(self, inputs):
        return self.model(inputs)

In [18]:
class SingleFrameDiscriminator(tf.keras.layers.Layer):
    def __init__(self, disc_type, disc_for = 'pacman', apply_sn = True, **kwargs):
        super().__init__(**kwargs)
        self.apply_sn = apply_sn
        if disc_for == 'pacman':
            dim = 64
        elif disc_for == 'vizdoom':
            dim = 256
        else:
            raise Exception('')
            
        if disc_type == 'patch':
            padding = (1, 1)
            reshape = (3, 3, 1)
        else:
            padding = 'valid'
            reshape = (1, )
        
        self.model = tf.keras.models.Sequential([
            self.get_conv(filters = dim, kernel_size = (2, 2), strides = (1, 1), padding = padding),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(alpha = 0.2),
            self.get_conv(filters = 1, kernel_size = (1, 1), strides = (2, 2), padding = (1, 1)),
            tf.keras.layers.Reshape(reshape)
        ])
            
            
    def get_conv(self, filters, kernel_size, strides, padding):
        if self.apply_sn:
            return SNConv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)
        return Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)
    
    def call(self, inputs):
        return self.model(inputs)

In [19]:
class Residual_DBlock(tf.keras.layers.Layer):
    def __init__(self, filters, apply_sn = True, downsample = True, **kwargs):
        super().__init__(**kwargs)
        self.apply_sn = apply_sn
        
        self.main_path = tf.keras.models.Sequential()
        self.main_path.add(tf.keras.layers.ReLU())
        self.main_path.add(self.get_conv(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same'))
        self.main_path.add(tf.keras.layers.ReLU())
        self.main_path.add(self.get_conv(filters = filters, kernel_size = (3, 3), strides = (1, 1), padding = 'same'))
        if downsample:
            self.main_path.add(tf.keras.layers.AveragePooling2D())
            
        self.skip_path = tf.keras.models.Sequential()
        self.skip_path.add(self.get_conv(filters = filters, kernel_size = (1, 1), strides = (1, 1), padding = 'same'))
        if downsample:
            self.skip_path.add(tf.keras.layers.AveragePooling2D())
        
    def get_conv(self, filters, kernel_size, strides, padding):
        if self.apply_sn:
            return SNConv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)
        return Conv2D(filters = filters, kernel_size = kernel_size, strides = strides, padding = padding)
        
    def call(self, inputs):
        return tf.add(self.main_path(inputs), self.skip_path(inputs))

In [20]:
class ResidualDiscriminator(tf.keras.layers.Layer):
    def __init__(self, dim = 64, r = [1, 2, 4, 8, 16], **kwargs):
        super().__init__(**kwargs)
        
        self.model = tf.keras.models.Sequential()
        for i in r:
            self.model.add(Residual_DBlock(filters = dim ** i))
            
        self.model.add(tf.keras.layers.ReLU())
        
        self.linear = SNLinear(neurons = 1)
        
    def call(self, inputs):
        x = self.model(inputs)
        out = self.linear(tf.math.reduce_sum(x, axis = [1, 2]))
        return out, x

In [25]:
def choose_netD_temporal(dim, window, first_spatial_filter = 2, simple_block = False, 
                         d_temp_mode = 'sn'):
    extractors, finals = [], []
    
    if simple_block:
        net1 = tf.keras.models.Sequential([
            SpectralNormalization(tf.keras.layers.Conv3D(filters = dim, kernel_size = (2, 2, 2), strides = (1, 1, 1))),
            tf.keras.layers.LeakyReLU(alpha = 0.2),
            SpectralNormalization(tf.keras.layers.Conv3D(filters = dim * 2, kernel_size = (3, 2, 2), strides = (2, 1, 1))),
            tf.keras.layers.LeakyReLU(alpha = 0.2),
        ])
        head1 = tf.keras.models.Sequential([
            SpectralNormalization(tf.keras.layers.Conv3D(filters = 1, kernel_size = (2, 1, 1), strides = (2, 1, 1)))
        ])
        
        extractors.append(net1)
        finals.append(head1)
        
        if window > 6: # 18
            net2 = tf.keras.models.Sequential([
                SpectralNormalization(tf.keras.layers.Conv3D(filters = dim * 4, kernel_size = (3, 1, 1), 
                                                             strides = (2, 1, 1))),
                tf.keras.layers.LeakyReLU(alpha = 0.2)
            ])
            head2 = tf.keras.models.Sequential([
                SpectralNormalization(tf.keras.layers.Conv3D(filters = 1, kernel_size = (2, 1, 1)))
            ])
            
            extractors.append(net2)
            finals.append(head2)
            
        if window > 18: #32
            net3 = tf.keras.models.Sequential([
                SpectralNormalization(tf.keras.layers.Conv3D(filters = dim * 8, kernel_size = (3, 1, 1), 
                                                             strides = (2, 1, 1))),
                tf.keras.layers.LeakyReLU(alpha = 0.2)
            ])
            head3 = tf.keras.models.Sequential([
                SpectralNormalization(tf.keras.layers.Conv3D(filters = 1, kernel_size = (3, 1, 1)))
            ])
            
            extractors.append(net3)
            finals.append(head3)
            
    elif 'sn' in d_temp_mode:
        net1 = tf.keras.models.Sequential([
            SpectralNormalization(tf.keras.layers.Conv3D(filters = dim, 
                                                         kernel_size = (2, first_spatial_filter, first_spatial_filter), 
                                                         strides = (1, 1, 1))),
            tf.keras.layers.LeakyReLU(alpha = 0.2),
            SpectralNormalization(tf.keras.layers.Conv3D(filters = dim * 2, kernel_size = (3, 3, 3), strides = (2, 1, 1))),
            tf.keras.layers.LeakyReLU(alpha = 0.2)
        ])
        head1 = tf.keras.models.Sequential([
            SpectralNormalization(tf.keras.layers.Conv3D(filters = 1, kernel_size = (2, 1, 1), strides = (1, 1, 1)))
        ])
        
        extractors.append(net1)
        finals.append(head1)
        
        if window >= 12: # 12
            net2 = tf.keras.models.Sequential([
                SpectralNormalization(tf.keras.layers.Conv3D(filters = dim * 4, kernel_size = (3, 1, 1), 
                                                             strides = (1, 1, 1))),
                tf.keras.layers.LeakyReLU(alpha = 0.2)
            ])
            head2 = tf.keras.models.Sequential([
                SpectralNormalization(tf.keras.layers.Conv3D(filters = 1, kernel_size = (3, 1, 1)))
            ])
            
            extractors.append(net2)
            finals.append(head2)
            
        if window >= 18: # 18
            net3 = tf.keras.models.Sequential([
                SpectralNormalization(tf.keras.layers.Conv3D(filters = dim * 8, kernel_size = (3, 1, 1), 
                                                             strides = (1, 1, 1))),
                tf.keras.layers.LeakyReLU(alpha = 0.2)
            ])
            if window == 18 or window == 28:
                head3 = tf.keras.models.Sequential([
                    SpectralNormalization(tf.keras.layers.Conv3D(filters = 1, kernel_size = (2, 1, 1), 
                                                                 strides = (1, 1, 1)))
                ])
            else:
                head3 = tf.keras.models.Sequential([
                    SpectralNormalization(tf.keras.layers.Conv3D(filters = 1, kernel_size = (4, 1, 1), 
                                                                 strides = (2, 1, 1)))
                ])
            
            extractors.append(net3)
            finals.append(head3)
            
    else:
        net1 = tf.keras.models.Sequential([
            tf.keras.layers.Conv3D(filters = dim, kernel_size = (2, 3, 3), strides = (1, 1, 1)),
            tf.keras.layers.LeakyReLU(alpha = 0.2),
            tf.keras.layers.Conv3D(filters = dim * 2, kernel_size = (3, 3, 3), strides = (2, 1, 1)),
            tf.keras.layers.LeakyReLU(alpha = 0.2),
        ])
        head1 = tf.keras.models.Sequential([
            tf.keras.layers.Conv3D(filters = 1, kernel_size = (2, 1, 1), strides = (1, 1, 1)),
        ])
        extractors.append(net1)
        finals.append(head1)
        
        if window >= 12: #12
            net2 = tf.keras.models.Sequential([
                tf.keras.layers.Conv3D(filters = dim * 4, kernel_size = (3, 1, 1), strides = (1, 1, 1)),
                tf.keras.layers.LeakyReLU(alpha = 0.2)
            ])
            head2 = tf.keras.models.Sequential([
                tf.keras.layers.Conv3D(filters = 1, kernel_size = (3, 1, 1))
            ])
            extractors.append(net2)
            finals.append(head2)
            
        if window >= 18: #18
            net3 = tf.keras.models.Sequential([
                tf.keras.layers.Conv3D(filters = dim * 8, kernel_size = (2, 1, 1), strides = (2, 1, 1)), 
                tf.keras.layers.LeakyReLU(alpha = 0.2)
            ])
            head3 = tf.keras.models.Sequential([
                tf.keras.layers.Conv3D(filters = 1, kernel_size = (3, 1, 1))
            ])
            
            extractors.append(net3)
            finals.append(head3)
            
    return extractors, finals

In [31]:
class Discriminator(tf.keras.models.Model):
    def __init__(self, simple_block = False, d_dim = 64, z_dim = 32, action_space = 10, 
                 temporal_window = 18, temporal_hierarchy = True, 
                 temporal_hierarchy_epoch = 0, num_steps = 15, **kwargs):
        super().__init__(**kwargs)
        self.action_space = action_space
        self.z_dim = z_dim
        self.temporal_window = temporal_window
        self.temporal_hierarchy = temporal_hierarchy
        self.temporal_hierarchy_epoch = temporal_hierarchy_epoch
        self.num_steps = num_steps
        
        if simple_block:
            self.DS = SimpleDiscriminator()
            self.single_frame_discriminator_patch = SingleFrameDiscriminator(disc_type = 'patch')
            self.single_frame_discriminator_full = SingleFrameDiscriminator(disc_type = 'full')
            
        else:
            self.DS = ResidualDiscriminator(dim = d_dim)
            
        
        # action-conditioned discriminator
        self.action_to_feat = Linear(neurons = 256)
        self.to_transition_feature = tf.keras.models.Sequential([
            SNConv2D(filters = 256, kernel_size = (4, 4), strides = (1, 1), padding = 'valid'),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(alpha = 0.2),
            tf.keras.layers.Reshape((256, ))
        ])
        self.action_discriminator = tf.keras.models.Sequential([
            SNLinear(neurons = 512),
            tf.keras.layers.BatchNormalization(),
            tf.keras.layers.LeakyReLU(alpha = 0.2),
            SNLinear(neurons = 1)
        ])
        self.reconstruction_action_z = tf.keras.models.Sequential([
            Linear(neurons = action_space + z_dim)
        ])
        
        # temporal discriminator
        self.conv3d, self.conv3d_final = choose_netD_temporal(dim = 64, window = temporal_window, 
                                                              simple_block = simple_block)
        
        
    def call(self, inputs):
        images, actions, states, warm_up, neg_actions, epoch = inputs
        neg_action_predictions, rev_predictions, content_predictions = None, None, []
        
        batch_size = actions[0].shape[0]
        
        if warm_up == 0:
            warm_up = 1 # even if warm_up is 0, the first screen is from GT
            
        # run single frame discriminator
        gt_states = tf.concat(states[:warm_up], dim = 0)
        single_frame_predictions_patch = None
        if self.simple_block:
            tmp_gt = self.DS(gt_states)
            tmp_gen = self.DS(images)
            tmp_features = tf.concat([tmp_gt, tmp_gen], axis = 0)
            single_frame_predictions_patch = self.single_frame_discriminator_patch(tmp_gen)
            single_frame_predictions_full = self.single_frame_discriminator_full(tmp_gen)
        else:
            single_frame_predictions_full, tmp_features = self.DS(tf.concat([gt_states, images], axis = 0))
            single_frame_predictions_full = single_frame_predictions_full[warm_up*batch_size:]
            
        frame_features = tmp_features[warm_up*batch_size:]
        
        # run action-conditioned discriminator and reconstruct action, z
        prev_frames = tf.concat([tmp_features[:warm_up * batch_size],
                                 tmp_features[(warm_up + warm_up - 1) * batch_size:-batch_size]], axis = 0)
        
        transition_features = self.to_transition_feature(tf.concat([prev_frames, frame_features], axis = -1))
        action_features = self.action_to_feat(tf.concat(actions[:-1], axis = 0))
        action_predictions = self.action_discriminator(tf.concat([action_features, transition_features], axis = -1))
        
        if neg_actions is not None:
            neg_action_features = self.action_to_feat(tf.concat(neg_actions[:-1], axis = 0))
            neg_action_predictions = self.action_discriminator(
                tf.concat([neg_action_features, transition_features], axis = -1)
            )
            
        action_z_recon = self.reconstruction_action_z(transition_features)
        action_recon = action_z_recon[:, :self.action_space]
        z_recon = action_z_recon[:, self.action_space:self.action_space+self.z_dim]
        
        # run temporal discriminator
        if self.temporal_hierarchy and self.num_steps > 4:
            new_l = []
            temporal_predictions = []
            for entry in tf.split(tmp_features[:warm_up * batch_size], warm_up, axis = 0):
                new_l.append(entry)
            for entry in tf.split(tmp_features[(warm_up*2 - 1)*batch_size:], warm_up, axis = 0):
                new_l.append(entry)
                
            window_size = len(new_l)
            start = np.random.randint(0, len(new_l) - window_size + 1)
            stacked = tf.stack(new_l[start:start+window_size], axis = 1)
            
            aa = self.conv3d[0](stacked)
            a_out = self.conv3d_final[0](aa)
            temporal_predictions.append(tf.reshape(a_out, (batch_size, -1)))
            
            if self.temporal_window >= 12 and epoch >= self.temporal_hierarchy_epoch:
                bb = self.conv3d[1](aa)
                b_out = self.conv3d_final[1](bb)
                temporal_predictions.append(tf.reshape(b_out, (batch_size, -1)))
                
            if self.temporal_window >= 18 and epoch >= self.temporal_hierarchy_epoch:
                cc = self.conv3d[2](bb)
                c_out = self.conv3d_final[2](cc)
                temporal_predictions.append(tf.reshape(c_out, (batch_size, - 1)))
                
        d_out = {}
        ddout['disc_features'] = frame_features[:(len(states)-1)*batch_size]
        dout['action_predictions'] = action_predictions
        dout['single_frame_predictions_all'] = single_frame_predictions_full
        dout['content_predictions'] = temporal_predictions
        dout['neg_action_predictions'] = neg_action_predictions
        dout['action_recon'] = action_recon
        dout['z_recon'] = z_recon
        dout['single_frame_predictions_patch'] = single_frame_predictions_patch
        return dout