New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
新增混淆矩阵绘图功能 #20
Comments
如果类别特别多(100+)PyG2Plot会卡顿,可以尝试 seaborn.heatmap,实例代码: a = np.arange(4).reshape((2,2))
y_label = pd.Index(['a', 'b'], name='y')
x_label = pd.Index(['a', 'b'], name='x')
df = pd.DataFrame(a, index=y_label, columns=x_label)
sns.heatmap(df).get_figure().savefig('./data/output/submit/test.jpg') |
sns实现细节:
import matplotlib.pyplot as plt
import seaborn as sns
sns.set(rc={'figure.figsize': (40, 40)})
sns.set_style('whitegrid', {'font.sans-serif': ['SimHei', 'Arial']})
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False
def get_true_pred_data(model, ds_test):
y_true, y_pred = [], []
for item in ds_test:
y_pred.append(model.predict(item[0], batch_size=Config.batch_size))
y_true.append(item[1])
y_pred = (np.vstack(y_pred) > 0.5).astype(np.int64)
y_true = np.vstack(y_true)
np.save(Config.y_true_path, y_true)
np.save(Config.y_pred_path, y_pred)
return y_true, y_pred
def get_matrix(y_true, y_pred):
matrix = np.zeros((len(Config.labels), len(Config.labels)))
for _y_true, _y_pred in zip(y_true, y_pred):
_y_true = np.argwhere(_y_true == 1).flatten()
_y_pred = np.argwhere(_y_pred == 1).flatten()
num = len(_y_true) * len(_y_pred)
for _y_true_ in _y_true:
for _y_pred_ in _y_pred:
matrix[_y_true_, _y_pred_] += (1/num)
matrix /= (np.sum(matrix, axis=1, keepdims=True) + 1.0e-5)
return matrix
def get_plot_confusion_data(matrix):
y_label = pd.Index(Config.labels, name='y_true')
x_label = pd.Index(Config.labels, name='y_pred')
df = pd.DataFrame(matrix, index=y_label, columns=x_label)
return df
def plot_confusion_matrix(plot_data):
heat_map = sns.heatmap(plot_data, vmax=1, cmap='GnBu', annot=True, fmt='.0%', annot_kws={'size': 6})
heat_map.get_figure().savefig(Config.confusion_matrix_path, dpi=400) |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
绘图建议使用PyG2Plot
The text was updated successfully, but these errors were encountered: