diff --git a/ptp/utils/statistics_aggregator.py b/ptp/utils/statistics_aggregator.py index 6b35e49..f756472 100644 --- a/ptp/utils/statistics_aggregator.py +++ b/ptp/utils/statistics_aggregator.py @@ -17,7 +17,6 @@ __author__ = "Vincent Marois, Tomasz Kornuta" -import numpy as np from ptp.utils.statistics_collector import StatisticsCollector @@ -128,6 +127,7 @@ def __iter__(self): """ return self.aggregators.__iter__() + def initialize_csv_file(self, log_dir, filename): """ This method creates a new `csv` file and initializes it with a header produced \ @@ -142,25 +142,8 @@ def initialize_csv_file(self, log_dir, filename): :return: File stream opened for writing. """ - header_str = '' - - # Iterate through keys and concatenate them. - for key in self.aggregators.keys(): - # If formatting is set to '' - ignore this key. - if self.formatting.get(key) is not None: - header_str += key + "," - - # Remove last coma. - if len(header_str) > 0: - header_str = header_str[:-1] - # Add \n. - header_str = header_str + '\n' - - # Open file for writing. - self.csv_file = open(log_dir + filename, 'w', 1) - self.csv_file.write(header_str) + return super().base_initialize_csv_file(log_dir, filename, self.aggregators.keys()) - return self.csv_file def export_to_csv(self, csv_file=None): """ @@ -273,41 +256,3 @@ def export_to_tensorboard(self, tb_writer = None): # If formatting is set to None - ignore this key. if self.formatting.get(key) is not None: tb_writer.add_scalar(key, value, episode) - - -if __name__ == "__main__": - - stat_col = StatisticsCollector() - stat_agg = StatisticsAggregator() - - - # Add default statistics with formatting. - stat_col.add_statistics('loss', '{:12.10f}') - stat_col.add_statistics('episode', '{:06d}') - stat_col.add_statistics('batch_size', None) - - import random - # create some random values - loss_values = random.sample(range(100), 100) - # "Collect" basic statistics. - for episode, loss in enumerate(loss_values): - stat_col['episode'] = episode - stat_col['loss'] = loss - stat_col['batch_size'] = 1 - # print(stat_col.export_statistics_to_string()) - - print(stat_agg.export_to_string()) - - # Add new aggregator (a simulation of "additional statistics collected by model") - # Add default statistical aggregators for the loss (indicating a formatting). - #stat_agg.add_aggregator('loss', '{:12.10f}') - # add 'aggregators' for the episode. - #stat_agg.add_aggregator('episode', '{:06d}') - # Number of aggregated episodes. - #stat_agg.add_aggregator('episodes_aggregated', '{:06d}') - stat_agg.add_aggregator('acc_mean', '{:2.5f}') - collected_loss_values = stat_col['loss'] - batch_sizes = stat_col['batch_size'] - stat_agg['acc_mean'] = np.mean(collected_loss_values) / np.sum(batch_sizes) - - print(stat_agg.export_to_string('[Epoch 1]')) diff --git a/ptp/utils/statistics_collector.py b/ptp/utils/statistics_collector.py index 3e62709..5367b4e 100644 --- a/ptp/utils/statistics_collector.py +++ b/ptp/utils/statistics_collector.py @@ -119,10 +119,11 @@ def empty(self): for key in self.statistics.keys(): del self.statistics[key][:] - def initialize_csv_file(self, log_dir, filename): + + def base_initialize_csv_file(self, log_dir, filename, keys): """ - Method creates new csv file and initializes it with a header produced - on the base of statistics names. + This method creates a new `csv` file and initializes it with a header produced \ + on the base of the statistical aggregators names. :param log_dir: Path to file. :type log_dir: str @@ -130,13 +131,15 @@ def initialize_csv_file(self, log_dir, filename): :param filename: Filename to be created. :type filename: str + :param keys: Names of keys that will be used as header of columns in csv file. + :return: File stream opened for writing. """ header_str = '' # Iterate through keys and concatenate them. - for key in self.statistics.keys(): + for key in keys: # If formatting is set to '' - ignore this key. if self.formatting.get(key) is not None: header_str += key + "," @@ -151,7 +154,25 @@ def initialize_csv_file(self, log_dir, filename): self.csv_file = open(log_dir + filename, 'w', 1) self.csv_file.write(header_str) - return self.csv_file + return self.csv_file + + + def initialize_csv_file(self, log_dir, filename): + """ + This method creates a new `csv` file and initializes it with a header produced \ + on the base of the statistical aggregators names. + + :param log_dir: Path to file. + :type log_dir: str + + :param filename: Filename to be created. + :type filename: str + + :return: File stream opened for writing. + + """ + return self.base_initialize_csv_file(log_dir, filename, self.statistics.keys()) + def export_to_csv(self, csv_file=None): """ @@ -176,7 +197,9 @@ def export_to_csv(self, csv_file=None): format_str = self.formatting.get(key, '{}') # Add value to string using formatting. - values_str += format_str.format(value[-1]) + "," + if len(value) > 0: + values_str += format_str.format(value[-1]) + values_str += "," # Remove last coma. if len(values_str) > 1: @@ -201,7 +224,8 @@ def export_to_checkpoint(self): format_str = self.formatting.get(key, '{}') # Add to dict. - chkpt[key] = format_str.format(value[-1]) + if len(value) > 0: + chkpt[key] = format_str.format(value[-1]) return chkpt @@ -226,7 +250,9 @@ def export_to_string(self, additional_tag=''): # Get formatting - using '{}' as default. format_str = self.formatting.get(key, '{}') # Add value to string using formatting. - stat_str += format_str.format(value[-1]) + "; " + if len(value) > 0: + stat_str += format_str.format(value[-1]) + stat_str += "; " # Remove last two elements. if len(stat_str) > 2: @@ -268,36 +294,3 @@ def export_to_tensorboard(self, tb_writer=None): # If formatting is set to None - ignore this key. if self.formatting.get(key) is not None: tb_writer.add_scalar(key, value[-1], episode) - - -if __name__ == "__main__": - - stat_col = StatisticsCollector() - stat_col.add_statistics('loss', '{:12.10f}') - stat_col.add_statistics('episode', '{:06d}') - stat_col.add_statistics('acc', '{:2.3f}') - stat_col.add_statistics('acc_help', None) - - stat_col['episode'] = 0 - stat_col['loss'] = 0.7 - stat_col['acc'] = 100 - stat_col['acc_help'] = 121 - - csv_file = stat_col.initialize_csv_file('./', 'collector_test.csv') - stat_col.export_to_csv(csv_file) - print(stat_col.export_to_string()) - - stat_col['episode'] = 1 - stat_col['loss'] = 0.7 - stat_col['acc'] = 99.3 - - stat_col.add_statistics('seq_length', '{:2.0f}') - stat_col['seq_length'] = 5 - - stat_col.export_to_csv(csv_file) - print(stat_col.export_to_string('[Validation]')) - - stat_col.empty() - - for k in stat_col: - print('key: {} - value {}:'.format(k, stat_col[k])) diff --git a/tests/pipeline_tests.py b/tests/pipeline_tests.py index 2748894..42c89f1 100644 --- a/tests/pipeline_tests.py +++ b/tests/pipeline_tests.py @@ -119,5 +119,5 @@ def test_priorities(self): self.assertEqual(pipe[1].name, 'bow_encoder2') -if __name__ == "__main__": - unittest.main() \ No newline at end of file +#if __name__ == "__main__": +# unittest.main() \ No newline at end of file diff --git a/tests/samplers_tests.py b/tests/samplers_tests.py index 53ca272..323155a 100644 --- a/tests/samplers_tests.py +++ b/tests/samplers_tests.py @@ -166,5 +166,5 @@ def test_kfold_weighed_random_sampler_current_fold(self): self.assertIn(ix, [4,7]) -if __name__ == "__main__": - unittest.main() \ No newline at end of file +#if __name__ == "__main__": +# unittest.main() \ No newline at end of file diff --git a/tests/statistics_tests.py b/tests/statistics_tests.py new file mode 100644 index 0000000..c968294 --- /dev/null +++ b/tests/statistics_tests.py @@ -0,0 +1,101 @@ +# -*- coding: utf-8 -*- +# +# Copyright (C) tkornuta, IBM Corporation 2019 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +__author__ = "Tomasz Kornuta" + +import unittest +import random +import numpy as np + +from ptp.utils.statistics_collector import StatisticsCollector +from ptp.utils.statistics_aggregator import StatisticsAggregator + +class TestStatistics(unittest.TestCase): + + def __init__(self, *args, **kwargs): + super(TestStatistics, self).__init__(*args, **kwargs) + + def test_collector_string(self): + """ Tests whether the collector is collecting and producing the right string. """ + + stat_col = StatisticsCollector() + stat_col.add_statistics('loss', '{:12.10f}') + stat_col.add_statistics('episode', '{:06d}') + stat_col.add_statistics('acc', '{:2.3f}') + stat_col.add_statistics('acc_help', None) + + # Episode 0. + stat_col['episode'] = 0 + stat_col['loss'] = 0.7 + stat_col['acc'] = 100 + stat_col['acc_help'] = 121 + + # Export. + #csv_file = stat_col.initialize_csv_file('./', 'collector_test.csv') + #stat_col.export_to_csv(csv_file) + self.assertEqual(stat_col.export_to_string(), "loss 0.7000000000; episode 000000; acc 100.000 ") + + # Episode 1. + stat_col['episode'] = 1 + stat_col['loss'] = 0.7 + stat_col['acc'] = 99.3 + + stat_col.add_statistics('seq_length', '{:2.0f}') + stat_col['seq_length'] = 5 + + # Export. + #stat_col.export_to_csv(csv_file) + self.assertEqual(stat_col.export_to_string('[Validation]'), "loss 0.7000000000; episode 000001; acc 99.300; seq_length 5 [Validation]") + + # Empty. + stat_col.empty() + self.assertEqual(stat_col.export_to_string(), "loss ; episode ; acc ; seq_length ") + + def test_aggregator_string(self): + """ Tests whether the collector is aggregating and producing the right string. """ + + stat_col = StatisticsCollector() + stat_agg = StatisticsAggregator() + + # Add default statistics with formatting. + stat_col.add_statistics('loss', '{:12.10f}') + stat_col.add_statistics('episode', '{:06d}') + stat_col.add_statistics('batch_size', None) + + # create some random values + loss_values = random.sample(range(100), 100) + # "Collect" basic statistics. + for episode, loss in enumerate(loss_values): + stat_col['episode'] = episode + stat_col['loss'] = loss + stat_col['batch_size'] = 1 + # print(stat_col.export_statistics_to_string()) + + # Empty before aggregation. + self.assertEqual(stat_agg.export_to_string(), " ") + + # Number of aggregated episodes. + stat_agg.add_aggregator('acc_mean', '{:2.5f}') + collected_loss_values = stat_col['loss'] + batch_sizes = stat_col['batch_size'] + stat_agg['acc_mean'] = np.mean(collected_loss_values) / np.sum(batch_sizes) + + # Aggregated result. + self.assertEqual(stat_agg.export_to_string('[Epoch 1]'), "acc_mean 0.49500 [Epoch 1]") + + +#if __name__ == "__main__": +# unittest.main() \ No newline at end of file