Skip to content

Commit

Permalink
Apparently fixed ComplexMaxPooling!
Browse files Browse the repository at this point in the history
  • Loading branch information
NEGU93 committed Jan 20, 2021
1 parent b1a7959 commit d6da458
Show file tree
Hide file tree
Showing 4 changed files with 9 additions and 11 deletions.
2 changes: 1 addition & 1 deletion cvnn/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '1.0.1'
__version__ = '1.0.2'
6 changes: 2 additions & 4 deletions cvnn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -947,9 +947,7 @@ def pool_function(self, inputs, ksize, strides, padding, data_format):
output, argmax = tf.nn.max_pool_with_argmax(input=abs_in, ksize=ksize, strides=strides,
padding=padding, data_format=data_format,
include_batch_in_index=True)
if inputs.shape[0] is None:
return output
shape = output.shape
shape = tf.shape(output)
tf_res = tf.reshape(tf.gather(tf.reshape(inputs, [-1]), argmax), shape)
# assert np.all(tf_res == output) # For debugging when the input is real only!
assert tf_res.dtype == inputs.dtype
Expand Down Expand Up @@ -984,7 +982,7 @@ def get_real_equivalent(self):
__copyright__ = 'Copyright 2020, {project_name}'
__credits__ = ['{credit_list}']
__license__ = '{license}'
__version__ = '1.0.1'
__version__ = '1.0.2'
__maintainer__ = 'J. Agustin BARRACHINA'
__email__ = 'joseagustin.barra@gmail.com; jose-agustin.barrachina@centralesupelec.fr'
__status__ = '{dev_status}'
2 changes: 1 addition & 1 deletion docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Complex-Valued Neural Network (CVNN)
====================================

:Author: J. Agustin Barrachina
:Version: 1.0.1 of 01/20/2021
:Version: 1.0.2 of 01/20/2021


Content
Expand Down
10 changes: 5 additions & 5 deletions tests/test_several_datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from cvnn.layers import ComplexDense, ComplexFlatten, ComplexInput
import cvnn.layers as complex_layers
from cvnn import layers
from pdb import set_trace
from cvnn.montecarlo import run_gaussian_dataset_montecarlo


Expand Down Expand Up @@ -130,17 +131,16 @@ def random_dataset():
# run_eagerly=True
)
model.summary()

# Train and evaluate
history = model.fit(x_train, y_train, epochs=100, validation_data=(x_test, y_test))
history = model.fit(x_train, y_train, epochs=2, validation_data=(x_test, y_test))
test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)


def test_datasets():
cifar10_test()
random_dataset()
# fashion_mnist_example()
# mnist_example()
cifar10_test()
fashion_mnist_example()
mnist_example()
# run_gaussian_dataset_montecarlo(epochs=2, iterations=1)


Expand Down

0 comments on commit d6da458

Please sign in to comment.