-
Notifications
You must be signed in to change notification settings - Fork 1
/
plot_learning_curve.py
52 lines (45 loc) · 1.54 KB
/
plot_learning_curve.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
# -*- coding: utf-8 -*-
"""
@File: plot_learning_curve.py
@Author: Chance (Qian Zhen)
@Description: Plot learning rate based on weights log.
@Date: 2020/12/17
"""
import os
import re
import numpy as np
import matplotlib.pyplot as plt
from utils import plot_curve
weight_path = "C:/Level4Project/model/lr_0.01/resnet101"
weight_list = os.listdir(weight_path)
weight_list.sort(key=lambda x: int(x.split("_")[0][5:]))
train_acc_list, test_acc_list = [], []
train_loss_list, test_loss_list = [], []
for weight in weight_list:
train_loss, test_loss, train_acc, test_acc = re.findall(r"\d+\.?\d*", weight)[1:]
train_loss_list.append(float(train_loss))
test_loss_list.append(float(test_loss))
train_acc_list.append(float(train_acc))
test_acc_list.append(float(test_acc))
print("min train loss: {} in eopch: {}".format(min(train_loss_list), np.argmin(train_loss_list)))
print("min test loss: {} in eopch: {}".format(min(test_loss_list), np.argmin(test_loss_list)))
print("max train acc: {} in eopch: {}".format(max(train_acc_list), np.argmax(train_acc_list)))
print("max test acc: {} in eopch: {}".format(max(test_acc_list), np.argmax(test_acc_list)))
plot_curve(
train_loss_list,
test_loss_list,
"Training loss",
"Test loss",
"Epochs",
"Loss",
title=weight_path.split("/")[-2] + "_" + weight_path.split("/")[-1]
)
# plot_curve(
# train_acc_list,
# test_acc_list,
# "Training accuracy",
# "Test accuracy",
# "Epochs",
# "Accuracy",
# title=weight_path.split("/")[-2] + "_" + weight_path.split("/")[-1]
# )