In [None]:
%%capture
import tensorflow as tf
import tensorflow_datasets as tfds
import numpy as np
import importlib as imp

from collections import namedtuple
from random import sample, shuffle
from functools import reduce
from itertools import accumulate
from math import floor, ceil, sqrt, log, pi
from matplotlib import pyplot as plt
from tensorflow.keras import layers, utils, losses, models as mds, optimizers

if imp.util.find_spec('aggdraw'): import aggdraw
if imp.util.find_spec('tensorflow_addons'): from tensorflow_addons import layers as tfa_layers
if imp.util.find_spec('tensorflow_models'): from official.vision.beta.ops import augment as visaugment
if imp.util.find_spec('tensorflow_probability'): from tensorflow_probability import distributions as tfd

## Scale Invariant Block

In [None]:
def sic_block(input_tensor, filters, strides, padding, activation):
    def variant_dims(size):
        n_variants = ceil(log(ceil(sqrt(size)), pi))
        return list(map(lambda x: 3**x, range(1, n_variants+1)))

    def make_layer(size):
        kwargs = dict(strides=strides, padding=padding)
        return layers.Conv2D(filters, size, **kwargs)

    size = min(input_tensor.shape[1:-1])
    variants = variant_dims(size)
    conv_layers = map(make_layer, variants)
    conv_outputs = list(map(lambda x: x(input_tensor), conv_layers))
    merged = tf.concat(conv_outputs, axis=-1)
    normalized = layers.BatchNormalization()(merged)
    output = layers.Activation(activation)(normalized)

    return output

## SEBlock

In [None]:
class SEBlock(layers.Layer):
    def __init__(self, filters, ratio=2):
        super(SEBlock, self).__init__()
        self.block = tf.keras.Sequential([
            layers.GlobalAveragePooling2D(), # Squeeze
            layers.Dense(filters//ratio, activation='relu'),
            layers.Dense(filters, activation='sigmoid'), # Excite
            layers.Reshape([1, 1, filters]),
        ])

    def call(self, inputs):
        return inputs * self.block(inputs)
    
    def get_config(self):
        return dict(filters=self.filters, ratio=self.ratio)

## UNet Blocks

In [None]:
IMG_SIZE = 128
INITIAL_WIDTH = 64
N_ENCODERS = 2
DROPOUT_RATE = 0.6
DEBUG = False

def encoder_block(input, width_multiplier, name='block'):
    conv_1 = layers.Conv2D(INITIAL_WIDTH*width_multiplier,
                           (3, 3), activation="relu", padding="valid")(input)
    conv_2 = layers.Conv2D(INITIAL_WIDTH*width_multiplier,
                           (3, 3), activation="relu", padding="valid")(conv_1)
    pool = layers.MaxPooling2D((2, 2))(conv_2)
    dropout = layers.Dropout(DROPOUT_RATE)(pool)

    if DEBUG:
        print(name, input.shape, conv_1.shape, conv_2.shape, pool.shape)

    return dropout

def decoder_block(input, skip_input, width_multiplier, name='block'):
    conv_transpose = layers.Conv2DTranspose(
        INITIAL_WIDTH*width_multiplier, (3, 3), strides=(2, 2), padding='same')(input)

    cropped_skip_input = central_crop(skip_input, conv_transpose.shape[1])
    conv_input = layers.Concatenate()([conv_transpose, cropped_skip_input])

    conv_1 = layers.Conv2D(INITIAL_WIDTH*width_multiplier,
                           (3, 3), activation="relu", padding="valid")(conv_input)
    conv_2 = layers.Conv2D(INITIAL_WIDTH*width_multiplier,
                           (3, 3), activation="relu", padding="valid")(conv_1)
    dropout = layers.Dropout(DROPOUT_RATE)(conv_2)

    if DEBUG:
        print(name, conv_input.shape, conv_1.shape, conv_2.shape)

    return dropout

def central_crop(x, target_size):
    current_size = x.shape[1]
    extra_size = current_size - target_size
    start = extra_size//2
    end = start+target_size
    return x[:, start:end, start:end, :]

def resize_block(input):
    conv_transpose_1 = layers.Conv2DTranspose(
        8, (3, 3), strides=(2, 2), padding='same')(input)
    conv_transpose_2 = layers.Conv2DTranspose(
        1, (3, 3), strides=(2, 2), padding='same')(conv_transpose_1)
    
    cropped = central_crop(conv_transpose_2, IMG_SIZE)
                                    
    return cropped