Skip to content
This repository was archived by the owner on Jul 18, 2024. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 2 additions & 57 deletions ptp/utils/statistics_aggregator.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@

__author__ = "Vincent Marois, Tomasz Kornuta"

import numpy as np
from ptp.utils.statistics_collector import StatisticsCollector


Expand Down Expand Up @@ -128,6 +127,7 @@ def __iter__(self):
"""
return self.aggregators.__iter__()


def initialize_csv_file(self, log_dir, filename):
"""
This method creates a new `csv` file and initializes it with a header produced \
Expand All @@ -142,25 +142,8 @@ def initialize_csv_file(self, log_dir, filename):
:return: File stream opened for writing.

"""
header_str = ''

# Iterate through keys and concatenate them.
for key in self.aggregators.keys():
# If formatting is set to '' - ignore this key.
if self.formatting.get(key) is not None:
header_str += key + ","

# Remove last coma.
if len(header_str) > 0:
header_str = header_str[:-1]
# Add \n.
header_str = header_str + '\n'

# Open file for writing.
self.csv_file = open(log_dir + filename, 'w', 1)
self.csv_file.write(header_str)
return super().base_initialize_csv_file(log_dir, filename, self.aggregators.keys())

return self.csv_file

def export_to_csv(self, csv_file=None):
"""
Expand Down Expand Up @@ -273,41 +256,3 @@ def export_to_tensorboard(self, tb_writer = None):
# If formatting is set to None - ignore this key.
if self.formatting.get(key) is not None:
tb_writer.add_scalar(key, value, episode)


if __name__ == "__main__":

stat_col = StatisticsCollector()
stat_agg = StatisticsAggregator()


# Add default statistics with formatting.
stat_col.add_statistics('loss', '{:12.10f}')
stat_col.add_statistics('episode', '{:06d}')
stat_col.add_statistics('batch_size', None)

import random
# create some random values
loss_values = random.sample(range(100), 100)
# "Collect" basic statistics.
for episode, loss in enumerate(loss_values):
stat_col['episode'] = episode
stat_col['loss'] = loss
stat_col['batch_size'] = 1
# print(stat_col.export_statistics_to_string())

print(stat_agg.export_to_string())

# Add new aggregator (a simulation of "additional statistics collected by model")
# Add default statistical aggregators for the loss (indicating a formatting).
#stat_agg.add_aggregator('loss', '{:12.10f}')
# add 'aggregators' for the episode.
#stat_agg.add_aggregator('episode', '{:06d}')
# Number of aggregated episodes.
#stat_agg.add_aggregator('episodes_aggregated', '{:06d}')
stat_agg.add_aggregator('acc_mean', '{:2.5f}')
collected_loss_values = stat_col['loss']
batch_sizes = stat_col['batch_size']
stat_agg['acc_mean'] = np.mean(collected_loss_values) / np.sum(batch_sizes)

print(stat_agg.export_to_string('[Epoch 1]'))
75 changes: 34 additions & 41 deletions ptp/utils/statistics_collector.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,24 +119,27 @@ def empty(self):
for key in self.statistics.keys():
del self.statistics[key][:]

def initialize_csv_file(self, log_dir, filename):

def base_initialize_csv_file(self, log_dir, filename, keys):
"""
Method creates new csv file and initializes it with a header produced
on the base of statistics names.
This method creates a new `csv` file and initializes it with a header produced \
on the base of the statistical aggregators names.

:param log_dir: Path to file.
:type log_dir: str

:param filename: Filename to be created.
:type filename: str

:param keys: Names of keys that will be used as header of columns in csv file.

:return: File stream opened for writing.

"""
header_str = ''

# Iterate through keys and concatenate them.
for key in self.statistics.keys():
for key in keys:
# If formatting is set to '' - ignore this key.
if self.formatting.get(key) is not None:
header_str += key + ","
Expand All @@ -151,7 +154,25 @@ def initialize_csv_file(self, log_dir, filename):
self.csv_file = open(log_dir + filename, 'w', 1)
self.csv_file.write(header_str)

return self.csv_file
return self.csv_file


def initialize_csv_file(self, log_dir, filename):
"""
This method creates a new `csv` file and initializes it with a header produced \
on the base of the statistical aggregators names.

:param log_dir: Path to file.
:type log_dir: str

:param filename: Filename to be created.
:type filename: str

:return: File stream opened for writing.

"""
return self.base_initialize_csv_file(log_dir, filename, self.statistics.keys())


def export_to_csv(self, csv_file=None):
"""
Expand All @@ -176,7 +197,9 @@ def export_to_csv(self, csv_file=None):
format_str = self.formatting.get(key, '{}')

# Add value to string using formatting.
values_str += format_str.format(value[-1]) + ","
if len(value) > 0:
values_str += format_str.format(value[-1])
values_str += ","

# Remove last coma.
if len(values_str) > 1:
Expand All @@ -201,7 +224,8 @@ def export_to_checkpoint(self):
format_str = self.formatting.get(key, '{}')

# Add to dict.
chkpt[key] = format_str.format(value[-1])
if len(value) > 0:
chkpt[key] = format_str.format(value[-1])

return chkpt

Expand All @@ -226,7 +250,9 @@ def export_to_string(self, additional_tag=''):
# Get formatting - using '{}' as default.
format_str = self.formatting.get(key, '{}')
# Add value to string using formatting.
stat_str += format_str.format(value[-1]) + "; "
if len(value) > 0:
stat_str += format_str.format(value[-1])
stat_str += "; "

# Remove last two elements.
if len(stat_str) > 2:
Expand Down Expand Up @@ -268,36 +294,3 @@ def export_to_tensorboard(self, tb_writer=None):
# If formatting is set to None - ignore this key.
if self.formatting.get(key) is not None:
tb_writer.add_scalar(key, value[-1], episode)


if __name__ == "__main__":

stat_col = StatisticsCollector()
stat_col.add_statistics('loss', '{:12.10f}')
stat_col.add_statistics('episode', '{:06d}')
stat_col.add_statistics('acc', '{:2.3f}')
stat_col.add_statistics('acc_help', None)

stat_col['episode'] = 0
stat_col['loss'] = 0.7
stat_col['acc'] = 100
stat_col['acc_help'] = 121

csv_file = stat_col.initialize_csv_file('./', 'collector_test.csv')
stat_col.export_to_csv(csv_file)
print(stat_col.export_to_string())

stat_col['episode'] = 1
stat_col['loss'] = 0.7
stat_col['acc'] = 99.3

stat_col.add_statistics('seq_length', '{:2.0f}')
stat_col['seq_length'] = 5

stat_col.export_to_csv(csv_file)
print(stat_col.export_to_string('[Validation]'))

stat_col.empty()

for k in stat_col:
print('key: {} - value {}:'.format(k, stat_col[k]))
4 changes: 2 additions & 2 deletions tests/pipeline_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,5 +119,5 @@ def test_priorities(self):
self.assertEqual(pipe[1].name, 'bow_encoder2')


if __name__ == "__main__":
unittest.main()
#if __name__ == "__main__":
# unittest.main()
4 changes: 2 additions & 2 deletions tests/samplers_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,5 +166,5 @@ def test_kfold_weighed_random_sampler_current_fold(self):
self.assertIn(ix, [4,7])


if __name__ == "__main__":
unittest.main()
#if __name__ == "__main__":
# unittest.main()
101 changes: 101 additions & 0 deletions tests/statistics_tests.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# -*- coding: utf-8 -*-
#
# Copyright (C) tkornuta, IBM Corporation 2019
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

__author__ = "Tomasz Kornuta"

import unittest
import random
import numpy as np

from ptp.utils.statistics_collector import StatisticsCollector
from ptp.utils.statistics_aggregator import StatisticsAggregator

class TestStatistics(unittest.TestCase):

def __init__(self, *args, **kwargs):
super(TestStatistics, self).__init__(*args, **kwargs)

def test_collector_string(self):
""" Tests whether the collector is collecting and producing the right string. """

stat_col = StatisticsCollector()
stat_col.add_statistics('loss', '{:12.10f}')
stat_col.add_statistics('episode', '{:06d}')
stat_col.add_statistics('acc', '{:2.3f}')
stat_col.add_statistics('acc_help', None)

# Episode 0.
stat_col['episode'] = 0
stat_col['loss'] = 0.7
stat_col['acc'] = 100
stat_col['acc_help'] = 121

# Export.
#csv_file = stat_col.initialize_csv_file('./', 'collector_test.csv')
#stat_col.export_to_csv(csv_file)
self.assertEqual(stat_col.export_to_string(), "loss 0.7000000000; episode 000000; acc 100.000 ")

# Episode 1.
stat_col['episode'] = 1
stat_col['loss'] = 0.7
stat_col['acc'] = 99.3

stat_col.add_statistics('seq_length', '{:2.0f}')
stat_col['seq_length'] = 5

# Export.
#stat_col.export_to_csv(csv_file)
self.assertEqual(stat_col.export_to_string('[Validation]'), "loss 0.7000000000; episode 000001; acc 99.300; seq_length 5 [Validation]")

# Empty.
stat_col.empty()
self.assertEqual(stat_col.export_to_string(), "loss ; episode ; acc ; seq_length ")

def test_aggregator_string(self):
""" Tests whether the collector is aggregating and producing the right string. """

stat_col = StatisticsCollector()
stat_agg = StatisticsAggregator()

# Add default statistics with formatting.
stat_col.add_statistics('loss', '{:12.10f}')
stat_col.add_statistics('episode', '{:06d}')
stat_col.add_statistics('batch_size', None)

# create some random values
loss_values = random.sample(range(100), 100)
# "Collect" basic statistics.
for episode, loss in enumerate(loss_values):
stat_col['episode'] = episode
stat_col['loss'] = loss
stat_col['batch_size'] = 1
# print(stat_col.export_statistics_to_string())

# Empty before aggregation.
self.assertEqual(stat_agg.export_to_string(), " ")

# Number of aggregated episodes.
stat_agg.add_aggregator('acc_mean', '{:2.5f}')
collected_loss_values = stat_col['loss']
batch_sizes = stat_col['batch_size']
stat_agg['acc_mean'] = np.mean(collected_loss_values) / np.sum(batch_sizes)

# Aggregated result.
self.assertEqual(stat_agg.export_to_string('[Epoch 1]'), "acc_mean 0.49500 [Epoch 1]")


#if __name__ == "__main__":
# unittest.main()