In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Layer, Conv2D, Add

In [None]:
mnist = tf.keras.datasets.mnist
(X, y), (X_test, y_test) = mnist.load_data()

def preprocess(imgs_int):
    imgs_int = np.expand_dims(imgs_int, -1)
    imgs_int = tf.image.resize(imgs_int, (16, 16)).numpy()
    imgs_int = (imgs_int / (256 / 4)).astype(int)
    imgs = imgs_int.astype("float32")
    imgs = imgs / 4
    return imgs, imgs_int

input_data, output_data = preprocess(X)

In [3]:
class MaskedLayer(Layer):
    
    def __init__(self, mask_type, **kwargs):
        super(MaskedLayer, self).__init__()
        self.mask_type = mask_type
        self.conv = Conv2D(**kwargs)
        self.mask = None
        
    def build(self, input_shape):
                
        self.conv.build(input_shape)
        self.convolution_op = self.conv.convolution_op
        kernel = self.conv.kernel_size[0]
        
        shape1 = (kernel // 2, kernel, input_shape[-1], self.conv.filters)
        shape2 = (1, kernel // 2, input_shape[-1], self.conv.filters)
        mid = (1, 1, input_shape[-1], self.conv.filters)
        
        above, below = tf.ones(shape1), tf.zeros(shape1)
        left, right = tf.ones(shape2), tf.zeros(shape2)
        
        if self.mask_type == 1:
            mid = tf.ones(mid)
        else:
            mid = tf.zeros(mid)
        
        mid_layer = tf.concat([left, mid, right], axis=1)
        all_layers = tf.concat([above, mid_layer, below], axis=0)
                        
        self.mask = all_layers
    
    def masked_convolution_op(self, filters, kernel):
        
        return self.convolution_op(filters, tf.math.multiply(kernel, self.mask))
    
    def call(self, inputs):
        
        self.conv.convolution_op = self.masked_convolution_op
        return self.conv.call(inputs)

In [4]:
class ResidualBlock(Layer):
    
    def __init__(self, filters, **kwargs):
        super(ResidualBlock, self).__init__(**kwargs)
        self.conv1 = Conv2D(filters = filters // 2, kernel_size = (1, 1), activation = 'relu')
        self.mask = MaskedLayer(1, filters = filters // 2, kernel_size = (3, 3), activation = 'relu', padding = 'same')
        self.conv2 = Conv2D(filters = filters, kernel_size = (1, 1), activation = 'relu')
        self.add = Add()
        
    def call(self, inputs):
        
        x = self.conv1(inputs)
        x = self.mask(x)
        x = self.conv2(x)
        
        return self.add([inputs, x])

In [5]:
def get_model():
    inputs = tf.keras.layers.Input(shape=(16, 16, 1))
    x = MaskedLayer(mask_type=0, filters=128, kernel_size=7, activation="relu", padding="same")(inputs)

    for _ in range(5):
        x = ResidualBlock(filters=128)(x)
    for _ in range(2):
        x = MaskedLayer(mask_type=1, filters=128, kernel_size=1, strides=1, activation="relu", padding="valid")(x)
    
    out = tf.keras.layers.Conv2D(filters=4, kernel_size=1, strides=1, activation="softmax", padding="valid")(x)

    return tf.keras.models.Model(inputs, out)

In [None]:
model = get_model()

adam = tf.keras.optimizers.Adam(learning_rate=0.0005)
model.compile(optimizer=adam, loss="sparse_categorical_crossentropy")
model.fit(input_data, output_data, batch_size=128, epochs=150)

In [20]:
def sample_from(probs, temperature):
    
    probs = probs ** (1 / temperature)
    probs = probs / np.sum(probs)

    return np.random.choice(len(probs), p=probs)

def generate(temperature):
    
    generated_images = np.zeros(shape=(12,) + (model.input_shape)[1:])
    batch, rows, cols, channels = generated_images.shape
    
    for row in range(rows):
        for col in range(cols):
            for channel in range(channels):
                probs = model.predict(generated_images)[:, row, col, :]
                generated_images[:, row, col, channel] = [sample_from(x, temperature) for x in probs]
                generated_images[:, row, col, channel] /= 4

    return generated_images

In [None]:
images = generate(1)
plt.figure(figsize=(12, 12))
for i in range(12):
    
    plt.subplot(4, 4, i+1)
    plt.axis('off')
    plt.imshow(images[i])
plt.show()