-
Notifications
You must be signed in to change notification settings - Fork 0
/
history.py
92 lines (78 loc) · 2.75 KB
/
history.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import json
from matplotlib import pyplot as plt
import tensorflow.python.keras as keras
# def plot_history(history, metrics=None):
# if isinstance(metrics, str):
# metrics = [metrics]
# if metrics is None:
# metrics = [x for x in history.history.keys() if x[:4] != 'val_']
# if len(metrics) == 0:
# print('No metrics to display.')
# return
#
# x = history.epoch
#
# rows = 1
# cols = len(metrics)
# count = 0
#
# plt.figure(figsize=(12 * cols, 8))
# for metric in sorted(metrics):
# count += 1
# plt.subplot(rows, cols, count)
# plt.plot(x, history.history[metric], label='Train')
# val_metric = f'val_{metric}'
# if val_metric in history.history.keys():
# plt.plot(x, history.history[val_metric], label='Validation')
# plt.title(metric.capitalize())
# plt.legend()
# plt.show()
def plot_history(history, metrics=None, y_limits=None):
if isinstance(metrics, str):
metrics = [metrics]
if metrics is None:
metrics = [x for x in history.history.keys() if x[:4] != 'val_']
if len(metrics) == 0:
print('No metrics to display.')
return
x = history.epoch
rows = 1
cols = len(metrics)
count = 0
plt.figure(figsize=(12 * cols, 8))
for metric in sorted(metrics):
count += 1
plt.subplot(rows, cols, count)
plt.plot(x, history.history[metric], label='Train')
val_metric = f'val_{metric}'
if val_metric in history.history.keys():
plt.plot(x, history.history[val_metric], label='Validation')
plt.title(metric.capitalize()) # This line sets the title of each subplot to the metric name
plt.legend()
if y_limits and metric in y_limits:
plt.ylim(*y_limits[metric])
plt.show()
def add_history(old_hist, new_hist):
old_hist.epoch.extend(new_hist.epoch)
old_hist.params = new_hist.params
for k in old_hist.history.keys():
old_hist.history[k].extend(new_hist.history[k])
return old_hist
def save_history(history, model_name):
hist_out = {}
hist_out['epoch'] = history.epoch
hist_out['history'] = history.history
hist_out['params'] = history.params
with open(f'{model_name}.history', 'w') as outfile:
json.dump(hist_out, outfile)
def load_history(model_name, model_format=''):
with open(f'{model_name}.history', 'r') as f:
hist = json.load(f)
history = keras.callbacks.History()
history.epoch = hist['epoch']
history.history = hist['history']
history.params = hist['params']
model = keras.models.load_model(f'{model_name}{model_format}')
# model = keras.models.load_model(f'{model_name}')
history.set_model(model)
return history