In [None]:
class CPC_Encoder(tf.keras.layers.Layer):
    def __init__(self, latent_dim):
        super(CPC_Encoder, self).__init__()
        self.layers = [
            tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')
            tf.keras.layers.BatchNormalization()
            tf.keras.layers.LeakyReLU()
            tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')
            tf.keras.layers.BatchNormalization()
            tf.keras.layers.LeakyReLU()
            tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')
            tf.keras.layers.BatchNormalization()
            tf.keras.layers.LeakyReLU()
            tf.keras.layers.Conv2D(filters=64, kernel_size=3, strides=2, activation='linear')
            tf.keras.layers.BatchNormalization()
            tf.keras.layers.LeakyReLU()
            tf.keras.layers.Flatten()
            tf.keras.layers.Dense(units=256, activation='linear')
            tf.keras.layers.BatchNormalization()
            tf.keras.layers.LeakyReLU()
            tf.keras.layers.Dense(units=latent_dim, activation='linear', name='embedding')
        ]
        
    def call(self, x, training=False):
        for layer in self.layers:
            try:  
                x = layer(x, training) 
            except:
                x = layer(x)
                
        return x


class AutoReg(tf.keras.layers.Layer):
    def __init__(self, units=256):
        super(AutoReg, self).__init__()
        self.layers = [
            tf.keras.layers.GRU(units=units, return_sequences=False, name='context')
        ]
       
    def call(self, x, training=False):
        for layer in self.layers:
            try: 
                x = layer(x, training) 
            except:
                x = layer(x)
        return x


class CPC_Model(tf.keras.Model):
    def __init__(self, input_dim, latent_dim):
        super(CPC_Model, self).__init__()
        self.input = tf.keras.Layers.Input(input_dim)
        self.encoder = CPC_Encoder(latent_dim=latent_dim)
        self.auto_reg = AutoReg()
        self.decoder = CNN_Decoder(latent_dim=latent_dim,
                                   output_dim=input_dim,
                                   restore_shape=(7,7,64))  # currently manual
        
    def call(self, x, training=False):
        x = self.encoder(x, training)
        self.latent_repr = x  # keep latent_repr as property in case it should be analyzed
        x = self.decoder(x, training)
        return x
    