Skip to content

Commit

Permalink
Added SegNet and UNet
Browse files Browse the repository at this point in the history
  • Loading branch information
Ritvik19 committed Apr 9, 2021
1 parent 70faefb commit 3bd106d
Show file tree
Hide file tree
Showing 5 changed files with 208 additions and 2 deletions.
3 changes: 2 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ NAS Net Large | Customized Implementation of NAS Net Large | 4D tensor with shap
MobileNet | Customized Implementation of MobileNet | 4D tensor with shape (batch_shape, rows, cols, channels) | 4D tensor with shape (batch_shape, new_rows, new_cols, new_channels) | [usage 1](https://github.com/Ritvik19/pyradox-doc/blob/main/usage/MobileNet/MobileNet-1.md) [usage 2](https://github.com/Ritvik19/pyradox-doc/blob/main/usage/MobileNet/MobileNet-2.md)
Mobile Net V2 | Customized Implementation of Mobile Net V2 | 4D tensor with shape (batch_shape, rows, cols, channels) | 4D tensor with shape (batch_shape, new_rows, new_cols, new_channels) | [usage 1](https://github.com/Ritvik19/pyradox-doc/blob/main/usage/MobileNetV2/MobileNetV2-1.md) [usage 2](https://github.com/Ritvik19/pyradox-doc/blob/main/usage/MobileNetV2/MobileNetV2-2.md)
Mobile Net V3 | Customized Implementation of Mobile Net V3 | 4D tensor with shape (batch_shape, rows, cols, channels) | 4D tensor with shape (batch_shape, new_rows, new_cols, new_channels) | [usage 1](https://github.com/Ritvik19/pyradox-doc/blob/main/usage/MobileNetV3/MobileNetV3-1.md) [usage 2](https://github.com/Ritvik19/pyradox-doc/blob/main/usage/MobileNetV3/MobileNetV3-2.md)

Seg Net | Generalised Implementation of SegNet for Image Segmentation Applications | 4D tensor with shape (batch_shape, rows, cols, channels) | 4D tensor with shape (batch_shape, rows, cols, channels) | [check here](https://github.com/Ritvik19/pyradox-doc/blob/main/usage/SegNet/SegNet.md)
U Net | Generalised Implementation of UNet for Image Segmentation Applications | 4D tensor with shape (batch_shape, rows, cols, channels) | 4D tensor with shape (batch_shape, rows, cols, channels) | [check here](https://github.com/Ritvik19/pyradox-doc/blob/main/usage/UNet/UNet.md)
### DenseNets

Module | Description | Input Shape | Output Shape | Usage
Expand Down
173 changes: 173 additions & 0 deletions pyradox/convnets.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math, copy
from functools import reduce
from tensorflow.keras import layers
from pyradox.modules import *
from tensorflow.keras.activations import swish
Expand Down Expand Up @@ -2654,3 +2655,175 @@ def __call__(self, inputs):
x = self.activation(x)

return x


class GeneralizedSegNet(layers.Layer):
"""Generalised Implementation of SegNet for Image Segmentation Applications
encoder_config (list of tuples): configuration of the encoder block as a list of tuples containing
(n_layers, n_filters_conv, kernel_size, pool_size), default: the configuration mentioned in the paper
dropout (float): the dropout rate, default: 0
activation (keras Activation): activation applied to convolutions, default: relu
**kwargs : keyword arguments for convolution layers
"""

def __init__(self, encoder_config=None, activation="relu", dropout=0, **kwargs):
if encoder_config is None:
self.encoder_config = [
(2, 64, 3, 2),
(2, 128, 3, 2),
(3, 256, 3, 2),
(3, 512, 3, 2),
(3, 512, 3, 2),
]
else:
self.encoder_config = encoder_config
self.activation = activation
self.dropout = dropout
self.kwargs = kwargs

def _encoder_block(self, config, inputs):
n_layer, n_filters, kernel, pool_size = config
x = inputs
for i in range(n_layer):
x = layers.Convolution2D(n_filters, kernel, padding="same", **self.kwargs)(
x
)
if self.dropout != 0:
x = layers.Dropout(self.dropout)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(self.activation)(x)
x = layers.MaxPooling2D(pool_size=pool_size)(x)
return x

def _decoder_block(self, config, inputs):
n_layer, n_filters, kernel, pool_size = config
x = layers.UpSampling2D(size=pool_size, interpolation="bilinear")(inputs)
for i in range(n_layer):
x = layers.Convolution2D(n_filters, kernel, padding="same", **self.kwargs)(
x
)
if self.dropout != 0:
x = layers.Dropout(self.dropout)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(self.activation)(x)
return x

def _pooling_factor(self):
return (
reduce(lambda a, b: a[3] * b[3], self.encoder_config)
if len(self.encoder_config) > 1
else self.encoder_config[0][3]
)

def __call__(self, inputs):
if (
inputs.shape[1] % self._pooling_factor() == 0
and inputs.shape[2] % self._pooling_factor() == 0
):
x = inputs
for config in self.encoder_config:
x = self._encoder_block(config, x)
for config in self.encoder_config[::-1]:
x = self._decoder_block(config, x)
return x
raise Exception("Image dimensions are not a multiple of pooling factor")


class GeneralizedUNet(layers.Layer):
"""Generalised Implementation of UNet for Image Segmentation Applications
encoder_config (list of tuples): configuration of the encoder block as a list of tuples containing
(n_layers, n_filters_conv, kernel_size, pool_size), default: the configuration mentioned in the paper
bottleneck_conv (list): number of convolutions in bottleneck block and kernel size, default [1024, (3, 3)]
dropout (float): the dropout rate, default: 0
activation (keras Activation): activation applied to convolutions, default: relu
**kwargs : keyword arguments for convolution layers
"""

def __init__(
self,
encoder_config=None,
bottleneck_conv=None,
activation="relu",
dropout=0,
**kwargs,
):
if encoder_config is None:
self.encoder_config = [
(2, 64, 3, 2),
(2, 128, 3, 2),
(2, 256, 3, 2),
(2, 512, 3, 2),
]
else:
self.encoder_config = encoder_config
if bottleneck_conv is None:
self.bottleneck_conv = [1024, 3]
else:
self.bottleneck_conv = bottleneck_conv
self.activation = activation
self.dropout = dropout
self.kwargs = kwargs

def _encoder_block(self, config, inputs):
n_layer, n_filters, kernel, pool_size = config
x = inputs
for i in range(n_layer):
x = layers.Convolution2D(n_filters, kernel, padding="same", **self.kwargs)(
x
)
if self.dropout != 0:
x = layers.Dropout(self.dropout)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(self.activation)(x)
pooled = layers.MaxPooling2D(pool_size=pool_size)(x)
return x, pooled

def _decoder_block(self, config, inputs):
n_layer, n_filters, kernel, pool_size = config
pooled, x = inputs
upsampled = layers.UpSampling2D(size=pool_size, interpolation="bilinear")(
pooled
)
x = layers.concatenate([x, upsampled])
for i in range(n_layer):
x = layers.Convolution2D(n_filters, kernel, padding="same", **self.kwargs)(
x
)
if self.dropout != 0:
x = layers.Dropout(self.dropout)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(self.activation)(x)
return x

def _pooling_factor(self):
return (
reduce(lambda a, b: a[3] * b[3], self.encoder_config)
if len(self.encoder_config) > 1
else self.encoder_config[0][3]
)

def __call__(self, inputs):
if (
inputs.shape[1] % self._pooling_factor() == 0
and inputs.shape[2] % self._pooling_factor() == 0
):
x = inputs
encoder_outputs = []
for config in self.encoder_config:
enc_op, x = self._encoder_block(config, x)
encoder_outputs.append(enc_op)
x = layers.Convolution2D(
self.bottleneck_conv[0], self.bottleneck_conv[1], padding="same"
)(x)
if self.dropout != 0:
x = layers.Dropout(self.dropout)(x)
x = layers.BatchNormalization()(x)
x = layers.Activation(self.activation)(x)
for config, enc_op in zip(self.encoder_config[::-1], encoder_outputs[::-1]):
x = self._decoder_block(config, [x, enc_op])
return x
raise Exception("Image dimensions are not a multiple of pooling factor")
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setuptools.setup(
name="pyradox",
version="0.17.10",
version="0.18.10",
author="Ritvik Rastogi",
author_email="rastogiritvik99@gmail.com",
description="State of the Art Neural Networks for Deep Learning",
Expand Down
15 changes: 15 additions & 0 deletions tests/test_seg_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@
import sys, os

sys.path.append(os.path.dirname(os.getcwd()))

from tensorflow import keras
import numpy as np
from pyradox import convnets


def test():
inputs = keras.Input(shape=(28, 28, 1))
x = convnets.GeneralizedSegNet(encoder_config=[(2, 32, 3, 7)])(inputs)
outputs = keras.layers.Convolution2D(1, 1, activation="sigmoid")(x)

model = keras.models.Model(inputs=inputs, outputs=outputs)
17 changes: 17 additions & 0 deletions tests/test_u_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
import sys, os

sys.path.append(os.path.dirname(os.getcwd()))

from tensorflow import keras
import numpy as np
from pyradox import convnets


def test():
inputs = keras.Input(shape=(28, 28, 1))
x = convnets.GeneralizedUNet(
encoder_config=[(2, 32, 3, 2), (2, 64, 3, 2)], bottleneck_conv=[32, (1, 1)]
)(inputs)
outputs = keras.layers.Convolution2D(1, 1, activation="sigmoid")(x)

model = keras.models.Model(inputs=inputs, outputs=outputs)

0 comments on commit 3bd106d

Please sign in to comment.