Skip to content

Commit

Permalink
docs added plotly results of circularity simuls
Browse files Browse the repository at this point in the history
  • Loading branch information
NEGU93 committed Sep 29, 2020
1 parent d5a747c commit 8e063a5
Show file tree
Hide file tree
Showing 64 changed files with 2,012 additions and 11,706 deletions.
2 changes: 1 addition & 1 deletion cvnn/_version.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = '0.3.7'
__version__ = '0.3.8'
7 changes: 4 additions & 3 deletions cvnn/cvnn_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ def __init__(self, name, shape, loss_fun, optimizer='Adam',
:param name: Name of the model. It will be used to distinguish models
:param shape: List of cvnn.layers.ComplexLayer objects
:param loss_fun: tensorflow.python.keras.losses to be used.
:param optimizer: Optimizer to be used. Keras optimizers are not allowed. Only cvnn.optimizers modules.
:param optimizer: Optimizer to be used. Keras optimizers are not allowed.
Can be either cvnn.optimizers.Optimizer or a string listed in opt_dispatcher.
:param verbose: if True it will print information of np.prod(w_vals.shape)the model just created
:param tensorboard: If true it will save tensorboard information inside log/.../tensorboard_logs/
- Loss and accuracy
Expand Down Expand Up @@ -301,7 +302,7 @@ def get_real_equivalent(self, classifier: bool = True, capacity_equivalent: bool
if name is None:
name = self.name + "_real_equiv"
# set_trace()
return CvnnModel(name=name, shape=real_shape, loss_fun=self.loss_fun,
return CvnnModel(name=name, shape=real_shape, loss_fun=self.loss_fun, optimizer=self.optimizer.__deepcopy__(),
tensorboard=self.tensorboard, verbose=False)

# ====================
Expand Down Expand Up @@ -933,7 +934,7 @@ def training_param_summary(self):
__copyright__ = 'Copyright 2020, {project_name}'
__credits__ = ['{credit_list}']
__license__ = '{license}'
__version__ = '0.2.45'
__version__ = '0.2.46'
__maintainer__ = 'J. Agustin BARRACHINA'
__email__ = 'joseagustin.barra@gmail.com; jose-agustin.barrachina@centralesupelec.fr'
__status__ = '{dev_status}'
4 changes: 2 additions & 2 deletions cvnn/data_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -853,7 +853,7 @@ def do_all(self, extension=".svg"):
self.plot_histogram(key=key, library=lib, showfig=False, savefig=True, extension=extension)
except np.linalg.LinAlgError:
logger.warning("Could not plot Histogram with " + str(lib) + " because matrix was singular")
self.monte_carlo_plotter.plot_line_confidence_interval(key=key, x_axis='epoch', library=lib)
self.monte_carlo_plotter.plot_line_confidence_interval(key=key, x_axis='step', library=lib)

def box_plot(self, step=-1, library='plotly', key='test accuracy', showfig=False, savefig=True, extension='.svg'):
if library == 'plotly':
Expand Down Expand Up @@ -1113,6 +1113,6 @@ def _plot_histogram_seaborn(self, key='test accuracy', step=-1,
monte.monte_carlo_plotter.plot_line_confidence_interval(key='test loss', x_axis='epochs')

__author__ = 'J. Agustin BARRACHINA'
__version__ = '0.1.27'
__version__ = '0.1.28'
__maintainer__ = 'J. Agustin BARRACHINA'
__email__ = 'joseagustin.barra@gmail.com; jose-agustin.barrachina@centralesupelec.fr'
4 changes: 2 additions & 2 deletions cvnn/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,7 +224,7 @@ def __init__(self, output_size, input_size=None, activation=None, input_dtype=No
apply_activation(self.activation,
tf.cast(tf.complex([[1., 1.], [1., 1.]], [[1., 1.], [1., 1.]]), self.input_dtype)
).numpy().dtype
self.dropout = dropout
self.dropout = dropout # TODO: I don't find the verification that it is between 0 and 1. I think I omitted
if weight_initializer is None:
weight_initializer = initializers.GlorotUniform()
self.weight_initializer = weight_initializer
Expand Down Expand Up @@ -376,7 +376,7 @@ def trainable_variables(self):
__copyright__ = 'Copyright 2020, {project_name}'
__credits__ = ['{credit_list}']
__license__ = '{license}'
__version__ = '0.0.26'
__version__ = '0.0.27'
__maintainer__ = 'J. Agustin BARRACHINA'
__email__ = 'joseagustin.barra@gmail.com; jose-agustin.barrachina@centralesupelec.fr'
__status__ = '{dev_status}'
54 changes: 46 additions & 8 deletions cvnn/montecarlo.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,18 +152,55 @@ def run_montecarlo(models, dataset, open_dataset=None, iterations=500,


def run_gaussian_dataset_montecarlo(iterations=1000, m=10000, n=128, param_list=None,
epochs=150, batch_size=100, display_freq=1,
epochs=150, batch_size=100, display_freq=1, optimizer='Adam',
shape_raw=None, activation='cart_relu', debug=False, polar=False, do_all=True,
dropout=None):
"""
This function is used to compare CVNN vs RVNN performance over statistical non-circular data.
1. Generates a complex-valued gaussian correlated noise with the characteristics given by the inputs.
2. It then runs a monte carlo simulation of several iterations of both CVNN and an equivalent RVNN model.
3. Saves several files into ./log/montecarlo/date/of/run/
3.1. run_summary.txt: Summary of the run models and data
3.2. run_data.csv: Full information of performance of iteration of each model at each epoch
3.3. complex_network_statistical_result.csv: Statistical results of all iterations of CVNN per epoch
3.4. complex_network_statistical_result.csv: Statistical results of all iterations of RVNN per epoch
3.5. (Optional) `plot/` folder with the corresponding plots generated by MonteCarloAnalyzer.do_all()
:param iterations: Number of iterations to be done for each model
:param m: Total size of the dataset (number of examples)
:param n: Number of features / input vector
:param param_list: A list of len = number of classes.
Each element of the list is another list of len = 3 with values: [correlation_coeff, sigma_x, sigma_y]
Example for dataset type A of paper https://arxiv.org/abs/2009.08340:
param_list = [
[0.5, 1, 1],
[-0.5, 1, 1]
]
Default: None will default to the example.
:param epochs: Number of epochs for each iteration
:param batch_size: Batch size at each iteration
:param display_freq: Frequency in terms of epochs of when to do a checkpoint.
:param optimizer: Optimizer to be used. Keras optimizers are not allowed.
Can be either cvnn.optimizers.Optimizer or a string listed in opt_dispatcher.
:param shape_raw: List of sizes of each hidden layer.
For example [64] will generate a CVNN with one hidden layer of size 64.
Default None will default to example.
:param activation: Activation function to be used at each hidden layer
:param debug:
:param polar: Boolean weather the RVNN should receive real and imaginary part (False) or amplitude and phase (True)
:param do_all: If true (default) it creates a `plot/` folder with the plots generated by MonteCarloAnalyzer.do_all()
:param dropout: (float) Dropout to be used at each hidden layer. If None it will not use any dropout.
:return: (string) Full path to the run_data.csv generated file.
It can be used by cvnn.data_analysis.SeveralMonteCarloComparison to compare several runs.
"""
# Get parameters
if param_list is None:
param_list = [
[0.5, 1, 1],
[-0.5, 1, 1]
]
dataset = dp.CorrelatedGaussianCoeffCorrel(m, n, param_list, debug=False)
mlp_run_real_comparison_montecarlo(dataset, None, iterations, epochs, batch_size, display_freq,
shape_raw, activation, debug, polar, do_all, dropout=dropout)
return mlp_run_real_comparison_montecarlo(dataset, None, iterations, epochs, batch_size, display_freq, optimizer,
shape_raw, activation, debug, polar, do_all, dropout=dropout)


def mlp_run_real_comparison_montecarlo(dataset, open_dataset=None, iterations=1000,
Expand Down Expand Up @@ -197,7 +234,7 @@ def mlp_run_real_comparison_montecarlo(dataset, open_dataset=None, iterations=10
shape.append(Dense(output_size=output_size, activation='softmax_real', dropout=None))

complex_network = CvnnModel(name="complex_network", shape=shape, loss_fun=categorical_crossentropy,
verbose=False, tensorboard=False)
optimizer=optimizer, verbose=False, tensorboard=False)

# Monte Carlo
monte_carlo = RealVsComplex(complex_network,
Expand All @@ -223,6 +260,7 @@ def mlp_run_real_comparison_montecarlo(dataset, open_dataset=None, iterations=10
real_median_train = real_last_epochs['train accuracy'].median()
_save_rvnn_vs_cvnn_montecarlo_log(path=str(monte_carlo.monte_carlo_analyzer.path),
dataset_name=dataset.dataset_name,
optimizer=optimizer, loss=categorical_crossentropy,
hl=str(len(shape_raw)), shape=str(shape_raw),
dropout=str(dropout), num_classes=str(dataset.y.shape[1]),
polar_mode='Yes' if polar else 'No',
Expand Down Expand Up @@ -266,24 +304,24 @@ def _create_excel_file(fieldnames, row_data, filename=None, percentage_cols=None


def _save_rvnn_vs_cvnn_montecarlo_log(path, dataset_name, hl, shape, dropout, num_classes, polar_mode,
activation,
activation, optimizer, loss,
dataset_size, feature_size, epochs, batch_size, winner,
complex_median, real_median, complex_iqr, real_iqr,
complex_median_train, real_median_train,
comments='', filename=None):
fieldnames = ['dataset', '# Classes', "Dataset Size", 'Feature Size', "Polar Mode",
fieldnames = ['dataset', '# Classes', "Dataset Size", 'Feature Size', "Polar Mode", "Optimizer", "Loss",
'HL', 'Shape', 'Dropout', "Activation Function", 'epochs', 'batch size',
"Winner", "CVNN median", "RVNN median", 'CVNN IQR', 'RVNN IQR',
"CVNN train median", "RVNN train median",
'path', "cvnn version", "Comments"
]
row_data = [dataset_name, num_classes, dataset_size, feature_size, polar_mode, # Dataset information
hl, shape, dropout, activation, epochs, batch_size, # Model information
optimizer, str(loss), hl, shape, dropout, activation, epochs, batch_size, # Model information
winner, complex_median, real_median, complex_iqr, real_iqr, # Preliminary results
complex_median_train, real_median_train,
path, cvnn.__version__, comments # Library information
]
percentage_cols = ['N', 'O', 'P', 'Q', 'R', 'S']
percentage_cols = ['P', 'Q', 'R', 'S', 'T', 'U']
_create_excel_file(fieldnames, row_data, filename, percentage_cols=percentage_cols)


Expand Down
45 changes: 43 additions & 2 deletions cvnn/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,18 @@ def compile(self, shape):
def optimize(self, variables, gradients):
pass

def __deepcopy__(self, memodict=None):
pass


class SGD(Optimizer):
def __init__(self, learning_rate: float = 0.01, momentum: float = 0.0, name: str = 'SGD'):
"""
:param learning_rate: The learning rate. Defaults to 0.001.
:param momentum: float hyperparameter between [0, 1) that accelerates gradient descent in the relevant
direction and dampens oscillations. Defaults to 0, i.e., vanilla gradient descent.
:param name: Optional name for the operations created when applying gradients. Defaults to "Adam".
"""
self.name = name
self.learning_rate = learning_rate
if momentum > 1 or momentum < 0:
Expand All @@ -42,6 +51,11 @@ def __init__(self, learning_rate: float = 0.01, momentum: float = 0.0, name: str
self.first_time = True
super().__init__()

def __deepcopy__(self, memodict={}):
if memodict is None:
memodict = {}
return SGD(learning_rate=self.learning_rate, momentum=self.momentum, name=self.name)

def compile(self, shape):
for layer in shape:
for elem in layer.trainable_variables():
Expand All @@ -59,7 +73,14 @@ def optimize(self, variables, gradients):


class RMSprop(Optimizer):
def __init__(self, learning_rate=0.001, rho=0.9, momentum=0.0, epsilon=1e-07, name="Adam"):
def __init__(self, learning_rate=0.001, rho=0.9, momentum=0.0, epsilon=1e-07, name="RMSprop"):
"""
:param learning_rate: The learning rate. Defaults to 0.001.
:param rho: Discounting factor for the history/coming gradient. Defaults to 0.9.
:param momentum: The exponential decay rate for the 1st moment estimates. Defaults to 0.9.
:param epsilon: A small constant for numerical stability. Default 1e-07.
:param name: Optional name for the operations created when applying gradients. Defaults to "Adam".
"""
self.name = name
self.learning_rate = learning_rate
if rho > 1 or rho < 0:
Expand All @@ -75,6 +96,12 @@ def __init__(self, learning_rate=0.001, rho=0.9, momentum=0.0, epsilon=1e-07, na
self.sdw = []
super().__init__()

def __deepcopy__(self, memodict={}):
if memodict is None:
memodict = {}
return RMSprop(learning_rate=self.learning_rate, rho=self.rho, momentum=self.momentum, epsilon=self.epsilon,
name=self.name)

def compile(self, shape):
for layer in shape:
for elem in layer.trainable_variables():
Expand All @@ -90,7 +117,15 @@ def optimize(self, variables, gradients):


class Adam(Optimizer):
def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07, name="Adam"):
def __init__(self, learning_rate: float = 0.001, beta_1: float = 0.9, beta_2: float = 0.999,
epsilon: float = 1e-07, name="Adam"):
"""
:param learning_rate: The learning rate. Defaults to 0.001.
:param beta_1: The exponential decay rate for the 1st moment estimates. Defaults to 0.9.
:param beta_2: The exponential decay rate for the 2nd moment estimates. Defaults to 0.999.
:param epsilon: A small constant for numerical stability. Default 1e-07.
:param name: Optional name for the operations created when applying gradients. Defaults to "Adam".
"""
self.name = name
self.learning_rate = learning_rate
if beta_1 >= 1 or beta_1 < 0:
Expand All @@ -107,6 +142,12 @@ def __init__(self, learning_rate=0.001, beta_1=0.9, beta_2=0.999, epsilon=1e-07,
self.iter = 1
super().__init__()

def __deepcopy__(self, memodict={}):
if memodict is None:
memodict = {}
return Adam(learning_rate=self.learning_rate, beta_1=self.beta_1, beta_2=self.beta_2, epsilon=self.epsilon,
name=self.name)

def compile(self, shape):
for layer in shape:
for elem in layer.trainable_variables():
Expand Down
Binary file modified docs/_build/doctrees/cvnn.doctree
Binary file not shown.
Binary file removed docs/_build/doctrees/data_processing.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/_build/doctrees/index.doctree
Binary file not shown.
Binary file removed docs/_build/doctrees/mnist_example.doctree
Binary file not shown.
Binary file modified docs/_build/doctrees/results.doctree
Binary file not shown.
Binary file removed docs/_build/html.rar
Binary file not shown.
Binary file modified docs/_build/html/.doctrees/cvnn.doctree
Binary file not shown.
Binary file modified docs/_build/html/.doctrees/environment.pickle
Binary file not shown.
Binary file modified docs/_build/html/.doctrees/index.doctree
Binary file not shown.
Binary file removed docs/_build/html/.doctrees/mnist_example.doctree
Binary file not shown.
Binary file modified docs/_build/html/.doctrees/results.doctree
Binary file not shown.
4 changes: 2 additions & 2 deletions docs/_build/html/_sources/cvnn.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ Results
:param y: Labels
:return: tuple (loss, accuracy)
.. py:method:: get_confusion_matrix(self, x, y, save_result=False):
.. py:method:: get_confusion_matrix(self, x, y, save_result=False)
Generates a pandas data-frame with the confusion matrix of result of x and y (labels)
Expand Down Expand Up @@ -134,7 +134,7 @@ Others
if not model.is_complex():
x = cvnn.utils.transform_to_real(x)
.. py:method:: get_real_equivalent(self, classifier=True, name=None):
.. py:method:: get_real_equivalent(self, classifier=True, name=None)
Creates a new model equivalent of current model. If model is already real throws and error.
Expand Down
63 changes: 0 additions & 63 deletions docs/_build/html/_sources/data_processing.rst.txt

This file was deleted.

2 changes: 1 addition & 1 deletion docs/_build/html/_sources/index.rst.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Complex-Valued Neural Network (CVNN)
====================================

:Author: J. Agustin Barrachina
:Version: 1.0 of 07/10/2019
:Version: 1.1 of 25/09/2020

Content
=======
Expand Down
15 changes: 0 additions & 15 deletions docs/_build/html/_sources/mnist_example.rst.txt

This file was deleted.

0 comments on commit 8e063a5

Please sign in to comment.