parse_log.py modified for multiple test-nets #5697

Open
wants to merge 1 commit into
from
Jump to file or symbol
Failed to load files and symbols.
+10 −4
Split
View
@@ -23,12 +23,14 @@ def parse_log(path_to_log):
"""
regex_iteration = re.compile('Iteration (\d+)')
+ regex_netnum = re.compile('net \(#(\d+)\)')
regex_train_output = re.compile('Train net output #(\d+): (\S+) = ([\.\deE+-]+)')
regex_test_output = re.compile('Test net output #(\d+): (\S+) = ([\.\deE+-]+)')
regex_learning_rate = re.compile('lr = ([-+]?[0-9]*\.?[0-9]+([eE]?[-+]?[0-9]+)?)')
# Pick out lines of interest
iteration = -1
+ netnum = 0
learning_rate = float('NaN')
train_dict_list = []
test_dict_list = []
@@ -44,6 +46,9 @@ def parse_log(path_to_log):
iteration_match = regex_iteration.search(line)
if iteration_match:
iteration = float(iteration_match.group(1))
+ netnum_match = regex_netnum.search(line)
+ if netnum_match:
+ netnum = int(netnum_match.group(1))
if iteration == -1:
# Only start parsing for other stuff if we've found the first
# iteration
@@ -70,11 +75,11 @@ def parse_log(path_to_log):
train_dict_list, train_row = parse_line_for_net_output(
regex_train_output, train_row, train_dict_list,
- line, iteration, seconds, learning_rate
+ line, iteration, netnum, seconds, learning_rate
)
test_dict_list, test_row = parse_line_for_net_output(
regex_test_output, test_row, test_dict_list,
- line, iteration, seconds, learning_rate
+ line, iteration, netnum, seconds, learning_rate
)
fix_initial_nan_learning_rate(train_dict_list)
@@ -84,7 +89,7 @@ def parse_log(path_to_log):
def parse_line_for_net_output(regex_obj, row, row_dict_list,
- line, iteration, seconds, learning_rate):
+ line, iteration, netnum, seconds, learning_rate):
"""Parse a single line for training or test output
Returns a a tuple with (row_dict_list, row)
@@ -95,7 +100,7 @@ def parse_line_for_net_output(regex_obj, row, row_dict_list,
output_match = regex_obj.search(line)
if output_match:
- if not row or row['NumIters'] != iteration:
+ if not row or row['NumIters'] != iteration or row['NetNum'] != netnum :
# Push the last row and start a new one
if row:
# If we're on a new iteration, push the last row
@@ -106,6 +111,7 @@ def parse_line_for_net_output(regex_obj, row, row_dict_list,
row = OrderedDict([
('NumIters', iteration),
+ ('NetNum', netnum),
('Seconds', seconds),
('LearningRate', learning_rate)
])