In [13]:
import numpy as np
from typing import Dict, Optional, Any
from cvnn.losses import ComplexAverageCrossEntropy, ComplexWeightedAverageCrossEntropy
from cvnn.metrics import ComplexCategoricalAccuracy, ComplexAverageAccuracy, ComplexPrecision, ComplexRecall
from cvnn.layers import complex_input, ComplexConv2D, ComplexDropout, \
    ComplexMaxPooling2DWithArgmax, ComplexUnPooling2D, ComplexInput, ComplexBatchNormalization, ComplexDense, \
    ComplexUpSampling2D, ComplexConv2DTranspose, ComplexAvgPooling2D, ComplexPolarAvgPooling2D, ComplexMaxPooling2D
from cvnn.activations import cart_softmax, cart_relu
from cvnn.initializers import ComplexHeNormal

In [14]:
import tensorflow as tf
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.initializers import HeNormal
from tensorflow.keras.layers import Concatenate, Add, Activation, Input
from tensorflow.keras.layers import Conv2D, Dropout, Conv2DTranspose, BatchNormalization, MaxPooling2D, \
    UpSampling2D, AvgPool2D
from tensorflow.keras.losses import CategoricalCrossentropy
from tensorflow.keras import Model, Sequential
from tensorflow.keras.metrics import Recall, Precision, CategoricalAccuracy

In [15]:
hyper_params = {
    'padding': 'same',
    'consecutive_conv_layers': 0,
    'kernel_shape': (3, 3),
    'block6_kernel_shape': (1, 1),
    'max_pool_kernel': (2, 2),
    'concat': Add,
    'upsampling_layer': ComplexUnPooling2D,
    'stride': 2,
    'pooling': ComplexMaxPooling2DWithArgmax,
    'activation': cart_relu,
    'kernels': [12, 24, 48, 96, 192],
    'output_function': cart_softmax,
    'init': ComplexHeNormal,
    'optimizer': Adam,
    'learning_rate': 0.0001,
    'depth': 5
}

In [None]:
DROPOUT_DEFAULT = {
    "downsampling": None,
    "bottle_neck": None,
    "upsampling": None
}

In [16]:

IMG_HEIGHT = None  # 128
IMG_WIDTH = None  # 128

In [17]:
def _get_downsampling_block(input_to_block, num: int, dtype=np.complex64, dropout: Optional[bool] = False):
    conv = ComplexConv2D(hyper_params['kernels'][:hyper_params['depth']][num], hyper_params['kernel_shape'],
                         activation='linear', padding=hyper_params['padding'],
                         kernel_initializer=hyper_params['init'](), dtype=dtype)(input_to_block)
    for _ in range(hyper_params['consecutive_conv_layers']):
        conv = ComplexConv2D(hyper_params['kernels'][:hyper_params['depth']][num], hyper_params['kernel_shape'],
                             activation='linear', padding=hyper_params['padding'],
                             kernel_initializer=hyper_params['init'](), dtype=dtype)(conv)
    conv = ComplexBatchNormalization(dtype=dtype)(conv)
    conv = Activation(hyper_params['activation'])(conv)
    if hyper_params['pooling'] == ComplexMaxPooling2DWithArgmax:
        pool, pool_argmax = ComplexMaxPooling2DWithArgmax(hyper_params['max_pool_kernel'],
                                                          strides=hyper_params['stride'])(conv)
    elif hyper_params['pooling'] == ComplexAvgPooling2D:
        pool = ComplexAvgPooling2D(hyper_params['max_pool_kernel'], strides=hyper_params['stride'])(conv)
        pool_argmax = None
    elif hyper_params['pooling'] == ComplexPolarAvgPooling2D:
        pool = ComplexPolarAvgPooling2D(hyper_params['max_pool_kernel'], strides=hyper_params['stride'])(conv)
        pool_argmax = None
    else:
        raise ValueError(f"Unknown pooling {hyper_params['pooling']}")
    if dropout:
        pool = ComplexDropout(rate=dropout, dtype=dtype)(pool)
    return pool, pool_argmax

In [18]:
def _get_upsampling_block(input_to_block, pool_argmax, kernels, num: int, activation,
                          dropout: Optional[bool] = False, dtype=np.complex64):
    if hyper_params['upsampling_layer'] == ComplexUnPooling2D:
        unpool = ComplexUnPooling2D(upsampling_factor=2)([input_to_block, pool_argmax])
    elif hyper_params['upsampling_layer'] == ComplexUpSampling2D:
        unpool = ComplexUpSampling2D(size=2)(input_to_block)
    elif hyper_params['upsampling_layer'] == ComplexConv2DTranspose:
        unpool = ComplexConv2DTranspose(filters=num, kernel_size=3, strides=(2, 2), padding='same',
                                        dilation_rate=(1, 1))(input_to_block)
    else:
        raise ValueError(f"Upsampling method {hyper_params['upsampling_layer'].name} not supported")
    conv = ComplexConv2D(kernels, hyper_params['kernel_shape'],
                         activation='linear', padding=hyper_params['padding'],
                         kernel_initializer=hyper_params['init'](), dtype=dtype)(unpool)
    for _ in range(hyper_params['consecutive_conv_layers']):
        conv = ComplexConv2D(kernels, hyper_params['kernel_shape'],
                             activation='linear', padding=hyper_params['padding'],
                             kernel_initializer=hyper_params['init'](), dtype=dtype)(conv)
    conv = ComplexBatchNormalization(dtype=dtype)(conv)
    conv = Activation(activation)(conv)
    if dropout:
        conv = ComplexDropout(rate=dropout, dtype=dtype)(conv)
    return conv

In [19]:
def _get_my_model(in1, get_downsampling_block, get_upsampling_block, dtype=np.complex64, name="my_own_model",
                  dropout_dict=None, num_classes=4, weights=None):
    # Downsampling
    if dropout_dict is None:
        dropout_dict = DROPOUT_DEFAULT

    pool = in1
    pools = []
    argmax_pools = []
    for index in range(len(hyper_params['kernels'][:hyper_params['depth']])):
        pool, pool_argmax = get_downsampling_block(pool, index, dtype=dtype, dropout=dropout_dict["downsampling"])
        pools.append(pool)
        argmax_pools.append(pool_argmax)

    # Bottleneck
    index = -1
    conv = ComplexConv2D(hyper_params['kernels'][:hyper_params['depth']][index], (1, 1),
                         activation=hyper_params['activation'], padding=hyper_params['padding'],
                         dtype=dtype)(pools.pop())
    if dropout_dict["bottle_neck"] is not None:
        conv = ComplexDropout(rate=dropout_dict["bottle_neck"], dtype=dtype)(conv)

    # Upsampling
    while pools:
        index -= 1
        pool = pools.pop()
        pool_argmax = argmax_pools.pop()
        conv = get_upsampling_block(conv, pool_argmax, hyper_params['kernels'][:hyper_params['depth']][index], num=4,
                                    activation=hyper_params['activation'],
                                    dropout=dropout_dict["upsampling"], dtype=dtype)
        if hyper_params['concat'] == Concatenate:
            conv = Concatenate()([conv, pool])
        elif hyper_params['concat'] == Add:
            conv = Add()([conv, pool])
        else:
            raise KeyError(f"Concatenation {hyper_params['concat']} not known")
    out = get_upsampling_block(conv, argmax_pools.pop(), activation=hyper_params['output_function'], dropout=False,
                               num=0, kernels=num_classes, dtype=dtype)

    if weights is not None:
        loss = ComplexWeightedAverageCrossEntropy(weights=weights)
    else:
        loss = ComplexAverageCrossEntropy()

    model = Model(inputs=[in1], outputs=[out], name=name)
    model.compile(optimizer=hyper_params['optimizer'](learning_rate=hyper_params['learning_rate']), loss=loss,
                  metrics=[ComplexCategoricalAccuracy(name='accuracy'),
                           ComplexAverageAccuracy(name='average_accuracy'),
                           ComplexPrecision(name='precision'),
                           ComplexRecall(name='recall')
                           ])
    return model

In [20]:
def get_my_unet_model(input_shape=(IMG_HEIGHT, IMG_WIDTH, 3), num_classes=4, dtype=np.complex64,
                      tensorflow: bool = False,
                      name="my_model", dropout_dict=None, weights=None, hyper_dict: Optional[Dict] = None):
    if hyper_dict is not None:
        for key, value in hyper_dict.items():
            if key in hyper_params.keys():
                hyper_params[key] = value
            else:
                print(f"WARGNING: parameter {key} is not used")
    if dropout_dict is None:
        dropout_dict = DROPOUT_DEFAULT
    if not tensorflow:
        in1 = complex_input(shape=input_shape, dtype=dtype)
        return _get_my_model(in1, _get_downsampling_block, _get_upsampling_block, dtype=dtype, name=name,
                             dropout_dict=dropout_dict, num_classes=num_classes, weights=weights)
    else:
        in1 = Input(shape=input_shape)
        return _get_my_model_with_tf(in1, _tf_get_downsampling_block, _get_tf_upsampling_block, name="tf_" + name,
                                     dropout_dict=dropout_dict, num_classes=num_classes, weights=weights)
