Skip to content

Commit

Permalink
created complexAvgPooling1D
Browse files Browse the repository at this point in the history
  • Loading branch information
NEGU93 committed Sep 1, 2021
1 parent 0c2c4c6 commit 1905d81
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 9 deletions.
2 changes: 1 addition & 1 deletion cvnn/_version.py
@@ -1 +1 @@
__version__ = '1.1.70'
__version__ = '1.1.71'
4 changes: 2 additions & 2 deletions cvnn/layers/__init__.py
@@ -1,6 +1,6 @@
# https://stackoverflow.com/questions/24100558/how-can-i-split-a-module-into-multiple-files-without-breaking-a-backwards-compa/24100645
from cvnn.layers.pooling import ComplexMaxPooling2D, ComplexAvgPooling2D
from cvnn.layers.pooling import ComplexUnPooling2D, ComplexMaxPooling2DWithArgmax
from cvnn.layers.pooling import ComplexUnPooling2D, ComplexMaxPooling2DWithArgmax, ComplexAvgPooling1D
from cvnn.layers.convolutional import ComplexConv2D, ComplexConv1D, ComplexConv3D
from cvnn.layers.convolutional import ComplexConv2DTranspose
from cvnn.layers.core import ComplexInput, ComplexDense, ComplexFlatten, ComplexDropout, complex_input
Expand All @@ -12,7 +12,7 @@
__copyright__ = 'Copyright 2020, {project_name}'
__credits__ = ['{credit_list}']
__license__ = '{license}'
__version__ = '1.0.12'
__version__ = '1.0.13'
__maintainer__ = 'J. Agustin BARRACHINA'
__email__ = 'joseagustin.barra@gmail.com; jose-agustin.barrachina@centralesupelec.fr'
__status__ = '{dev_status}'
84 changes: 83 additions & 1 deletion cvnn/layers/pooling.py
Expand Up @@ -291,4 +291,86 @@ def get_config(self):
'dynamic': False,
}
base_config = super(ComplexUnPooling2D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
return dict(list(base_config.items()) + list(config.items()))


"""
1D Pooling
"""


class ComplexPooling1D(Layer, ComplexLayer):
def __init__(self, pool_size=2, strides=None,
padding='valid', data_format='channels_last',
name=None, dtype=DEFAULT_COMPLEX_TYPE, **kwargs):
self.my_dtype = dtype
super(ComplexPooling1D, self).__init__(name=name, **kwargs)
if data_format is None:
data_format = backend.image_data_format()
if strides is None:
strides = pool_size
self.pool_size = conv_utils.normalize_tuple(pool_size, 1, 'pool_size')
self.strides = conv_utils.normalize_tuple(strides, 1, 'strides')
self.padding = conv_utils.normalize_padding(padding)
self.data_format = conv_utils.normalize_data_format(data_format)
self.input_spec = InputSpec(ndim=3)

@abstractmethod
def pool_function(self, inputs, ksize, strides, padding, data_format):
pass

def call(self, inputs, **kwargs):
outputs = self.pool_function(
inputs,
self.pool_size,
strides=self.strides,
padding=self.padding.upper(),
data_format=conv_utils.convert_data_format(self.data_format, 3))
return outputs

def compute_output_shape(self, input_shape):
input_shape = tf.TensorShape(input_shape).as_list()
if self.data_format == 'channels_first':
steps = input_shape[2]
features = input_shape[1]
else:
steps = input_shape[1]
features = input_shape[2]
length = conv_utils.conv_output_length(steps,
self.pool_size[0],
self.padding,
self.strides[0])
if self.data_format == 'channels_first':
return tf.TensorShape([input_shape[0], features, length])
else:
return tf.TensorShape([input_shape[0], length, features])

def get_config(self):
config = {
'strides': self.strides,
'pool_size': self.pool_size,
'padding': self.padding,
'data_format': self.data_format,
}
base_config = super(ComplexPooling1D, self).get_config()
return dict(list(base_config.items()) + list(config.items()))


class ComplexAvgPooling1D(ComplexPooling1D):

def pool_function(self, inputs, ksize, strides, padding, data_format):
inputs_r = tf.math.real(inputs)
inputs_i = tf.math.imag(inputs)
output_r = tf.nn.avg_pool1d(input=inputs_r, ksize=ksize, strides=strides,
padding=padding, data_format=data_format)
output_i = tf.nn.avg_pool1d(input=inputs_i, ksize=ksize, strides=strides,
padding=padding, data_format=data_format)
if inputs.dtype.is_complex:
output = tf.complex(output_r, output_i)
else:
output = output_r
return output

def get_real_equivalent(self):
return ComplexAvgPooling1D(pool_size=self.pool_size, strides=self.strides, padding=self.padding,
data_format=self.data_format, name=self.name + "_real_equiv")
2 changes: 1 addition & 1 deletion docs/index.rst
Expand Up @@ -3,7 +3,7 @@ Complex-Valued Neural Network (CVNN)
====================================

:Author: J. Agustin Barrachina
:Version: 1.1.70 of 09/01/2021
:Version: 1.1.71 of 09/01/2021


Content
Expand Down
70 changes: 70 additions & 0 deletions docs/layers/complex_pooling_1d.rst
@@ -0,0 +1,70 @@
Complex Pooling 2D
------------------

.. py:class:: ComplexPooling2D
Pooling layer for arbitrary pooling functions, for 1D inputs.
Abstract class. This class only exists for code reuse. It will never be an exposed API.

.. py:method:: __init__(self, pool_size=(2, 2), strides=None, padding='valid', data_format=None, name=None, **kwargs)
:param pool_size: An integer: Specifying the size of the pooling window.
:param strides: Integer, or None. Factor by which to downscale. E.g. 2 will halve the input. If None, it will default to pool_size.
:param padding: One of "valid" or "same" (case-insensitive). "valid" means no padding. "same" results in padding evenly to the left/right or up/down of the input such that output has the same height/width dimension as the input.
:param data_format: A string, one of `channels_last` (default) or `channels_first`. The ordering of the dimensions in the inputs. :code:`channels_last` corresponds to inputs with shape :code:`(batch, steps, features)` while :code:`channels_first` corresponds to inputs with shape :code:`(batch, features, steps)`.
:param name: A string, the name of the layer.

.. _complex-max-pooling-label:

Complex Average Pooling 1D
^^^^^^^^^^^^^^^^^^^^^^^^^^

.. py:class:: ComplexAvgPooling1D
Downsamples the input representation by taking the average value over the window defined by :code:`pool_size`.
The window is shifted by strides. The resulting output when using "valid" padding option has a shape of: :code:`output_shape = (input_shape - pool_size + 1) / strides)`
The resulting output shape when using the "same" padding option is: :code:`output_shape = input_shape / strides`

**Complex dtype example**

First, let's create a complex image

.. code-block:: python
img_r = np.array([[
[0, 1, 2, 0, 2, 2, 0, 5, 7]
], [
[0, 4, 5, 3, 7, 9, 4, 5, 3]
]]).astype(np.float32)
img_i = np.array([[
[0, 4, 5, 3, 7, 9, 4, 5, 3]
], [
[0, 4, 5, 3, 2, 2, 4, 8, 9]
]]).astype(np.float32)
img = img_r + 1j * img_i
img = np.reshape(img, (2, 9, 1))
print(img[...,0])
This outputs

.. code-block:: python
[[0.+0.j 1.+4.j 2.+5.j 0.+3.j 2.+7.j 2.+9.j 0.+4.j 5.+5.j 7.+3.j]
[0.+0.j 4.+4.j 5.+5.j 3.+3.j 7.+2.j 9.+2.j 4.+4.j 5.+8.j 3.+9.j]]
Now let's run the :code:`ComplexAvgPooling1D` layer

.. code-block:: python
avg_pool = ComplexAvgPooling1D(strides=1)
res = avg_pool(img.astype(np.complex64))
print(res[...,0])
The results is then

.. code-block:: python
tf.Tensor(
[[0.5+2.j 1. +4.j 2. +8.j 2.5+4.5j]
[2. +2.j 4. +4.j 8. +2.j 4.5+6.j ]],
shape=(2, 4), dtype=complex64)
48 changes: 44 additions & 4 deletions tests/test_custom_layers.py
@@ -1,7 +1,7 @@
import numpy as np
from cvnn.layers import ComplexDense, ComplexFlatten, ComplexInput, ComplexConv2D, ComplexMaxPooling2D, \
ComplexAvgPooling2D, ComplexConv2DTranspose, ComplexUnPooling2D, ComplexMaxPooling2DWithArgmax, \
ComplexUpSampling2D, ComplexBatchNormalization
ComplexUpSampling2D, ComplexBatchNormalization, ComplexAvgPooling1D
import cvnn.layers as complex_layers
from tensorflow.keras.models import Sequential
import tensorflow as tf
Expand Down Expand Up @@ -184,6 +184,41 @@ def get_img():
return img


@tf.autograph.experimental.do_not_convert
def complex_avg_pool_1d():
x = tf.constant([1., 2., 3., 4., 5.])
x = tf.reshape(x, [1, 5, 1])
avg_pool_1d = tf.keras.layers.AveragePooling1D(pool_size=2, strides=1, padding='valid')
tf_res = avg_pool_1d(x)
own_res = ComplexAvgPooling1D(pool_size=2, strides=1, padding='valid')(x)
assert np.all(tf_res.numpy() == own_res.numpy())
avg_pool_1d = tf.keras.layers.AveragePooling1D(pool_size=2, strides=2, padding='valid')
tf_res = avg_pool_1d(x)
own_res = ComplexAvgPooling1D(pool_size=2, strides=2, padding='valid')(x)
assert np.all(tf_res.numpy() == own_res.numpy())
avg_pool_1d = tf.keras.layers.AveragePooling1D(pool_size=2, strides=1, padding='same')
tf_res = avg_pool_1d(x)
own_res = ComplexAvgPooling1D(pool_size=2, strides=1, padding='same')(x)
assert np.all(tf_res.numpy() == own_res.numpy())
img_r = np.array([[
[0, 1, 2, 0, 2, 2, 0, 5, 7]
], [
[0, 4, 5, 3, 7, 9, 4, 5, 3]
]]).astype(np.float32)
img_i = np.array([[
[0, 4, 5, 3, 7, 9, 4, 5, 3]
], [
[0, 4, 5, 3, 2, 2, 4, 8, 9]
]]).astype(np.float32)
img = img_r + 1j * img_i
img = np.reshape(img, (2, 9, 1))
avg_pool = ComplexAvgPooling1D()
res = avg_pool(img.astype(np.complex64))
expected = tf.expand_dims(tf.convert_to_tensor([[0.5 + 2.j, 1. + 4.j, 2. + 8.j, 2.5 + 4.5j],
[2. + 2.j, 4. + 4.j, 8. + 2.j, 4.5 + 6.j]], dtype=tf.complex64), axis=-1)
assert np.all(res.numpy() == expected.numpy())


@tf.autograph.experimental.do_not_convert
def complex_max_pool_2d(test_unpool=True):
img = get_img()
Expand Down Expand Up @@ -455,7 +490,7 @@ def batch_norm():

z = np.random.rand(3, 43, 12, 75) # + np.random.rand(3, 43, 12, 75)*1j
bn = tf.keras.layers.BatchNormalization(epsilon=0)
c_bn = ComplexBatchNormalization(dtype=np.float32) # If I use the complex64 then the init is different
c_bn = ComplexBatchNormalization(dtype=np.float32) # If I use the complex64 then the init is different
input = tf.convert_to_tensor(z.astype(np.float32), dtype=np.float32)
out = bn(input, training=False)
c_out = c_bn(input, training=False)
Expand All @@ -473,15 +508,20 @@ def batch_norm():
assert check_proximity(bn.beta, c_bn.beta, "Beta after training")


def pooling_layers():
complex_max_pool_2d()
complex_avg_pool_1d()
complex_avg_pool()


@tf.autograph.experimental.do_not_convert
def test_layers():
new_max_unpooling_2d_test()
complex_max_pool_2d()
pooling_layers()
batch_norm()
test_upsampling()
complex_conv_2d_transpose()
dropout()
complex_avg_pool()
shape_ad_dtype_of_conv2d()
dense_example()

Expand Down

0 comments on commit 1905d81

Please sign in to comment.