<a href="https://colab.research.google.com/github/arminwitte/FoolsUNet/blob/main/foolsunet.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!git clone https://github.com/arminwitte/FoolsUNet.git

Cloning into 'FoolsUNet'...
remote: Enumerating objects: 8, done.[K
remote: Counting objects: 100% (8/8), done.[K
remote: Compressing objects: 100% (8/8), done.[K
remote: Total 8 (delta 1), reused 0 (delta 0), pack-reused 0[K
Receiving objects: 100% (8/8), 7.68 KiB | 7.68 MiB/s, done.
Resolving deltas: 100% (1/1), done.


In [None]:
import tensorflow as tf
print(tf.__version__)

2.15.0


In [None]:
from tensorflow.keras import layers

In [None]:
class SqueezeExcite(layers.Layer):
    """
    https://keras.io/examples/vision/patch_convnet/

    Applies squeeze and excitation to input feature maps as seen in
    https://arxiv.org/abs/1709.01507.

    Args:
        ratio: The ratio with which the feature map needs to be reduced in
        the reduction phase.

    Inputs:
        Convolutional features.

    Outputs:
        Attention modified feature maps.
    """

    def __init__(self, ratio, **kwargs):
        super().__init__(**kwargs)
        self.ratio = ratio

    def get_config(self):
        config = super().get_config()
        config.update({"ratio": self.ratio})
        return config

    def build(self, input_shape):
        filters = input_shape[-1]
        self.squeeze = layers.GlobalAveragePooling2D(keepdims=True)
        self.reduction = layers.Dense(
            units=filters // self.ratio,
            activation="relu",
            use_bias=False,
        )
        self.excite = layers.Dense(units=filters, activation="sigmoid", use_bias=False)
        self.multiply = layers.Multiply()

    def call(self, x):
        shortcut = x
        x = self.squeeze(x)
        x = self.reduction(x)
        x = self.excite(x)
        x = self.multiply([shortcut, x])
        return x

In [None]:

class InverseResidualBlock(layers.Layer):
    """Implements an Inverse Residual Block like in MobileNetV2 and MobileNetV3

    https://stackoverflow.com/a/61334159

    Args:
        features: Number of features.
        expand_factor: factor by witch to expand number of layers
        strides: Stride used in last convolution.
        batch_norm: flag if Batch Normalisation should be used.

    Inputs:
        Convolutional features.

    Outputs:
        Modified feature maps.
    """

    def __init__(self, features=16, expand_factor=4, strides=1, batch_norm=True, **kwargs):
        super().__init__(**kwargs)
        self.features = features
        self.expand_factor = expand_factor
        self.strides = strides
        self.batch_norm = batch_norm

    def get_config(self):
        config = super().get_config()
        config.update({"features": self.features,
                       "expand_factor": self.expand_factor,
                       "strides": self.strides,
                       "batch_norm": self.batch_norm, })
        return config

    def build(self, input_shape):
        self.conv1 = layers.Conv2D(filters*expand_factor, (1,1), strides=1)
        if self.batch_norm:
            self.bn1 = layers.BatchNormalization()
        self.activation1 = layers.Activation('relu6')
        self.dwise = layers.DepthwiseConv2D(3, padding='same', strides=strides)
        if self.batch_norm:
            self.bn2 = layers.BatchNormalization()
        self.activation2 = layers.Activation('relu6')
        self.squeeze_exite = SqueezeExcite(ratio=4)
        self.conf2 = layers.Conv2D(squeeze, (1,1), strides=1, padding='same')
        if self.batch_norm:
            self.bn3 = layers.BatchNormalization()
    def call(self, x):
        shortcut = x
        x = self.conf1(x)
        x = self.reduction(x)
        x = self.excite(x)
        x = self.multiply([shortcut, x])
        return x

  if (
    # stride check enforces that we don't add residuals when spatial
    # dimensions are None
    strides == 1 and
    # Depth matches
    m.get_shape().as_list()[3] == x.get_shape().as_list()[3]
  ):
    m = tf.keras.layers.Add()([m, x])

  return m

SyntaxError: unterminated string literal (detected at line 26) (<ipython-input-1-b2faf6397421>, line 26)