Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

新增混淆矩阵绘图功能 #20

Open
Jasonsey opened this issue Apr 21, 2023 · 2 comments
Open

新增混淆矩阵绘图功能 #20

Jasonsey opened this issue Apr 21, 2023 · 2 comments

Comments

@Jasonsey
Copy link
Owner

Jasonsey commented Apr 21, 2023

绘图建议使用PyG2Plot

  1. 单一标签混淆矩阵绘制
  2. 多标签混淆矩阵绘制,标签识别规则如下:
    • AB -> AB:
      • A->A ✅
      • B->B ✅
    • AB -> AC:
      • A->A ✅
      • B->C ❌
    • AB -> CD:
      • $\underbrace{A->C}_{0.5}$
      • $\underbrace{A->D}_{0.5}$
      • $\underbrace{B->C}_{0.5}$
      • $\underbrace{B->D}_{0.5}$
@Jasonsey
Copy link
Owner Author

Jasonsey commented Apr 24, 2023

如果类别特别多(100+)PyG2Plot会卡顿,可以尝试 seaborn.heatmap,实例代码:

a = np.arange(4).reshape((2,2))
y_label = pd.Index(['a', 'b'], name='y')
x_label = pd.Index(['a', 'b'], name='x')
df = pd.DataFrame(a, index=y_label, columns=x_label)

sns.heatmap(df).get_figure().savefig('./data/output/submit/test.jpg')

@Jasonsey
Copy link
Owner Author

sns实现细节:

  1. 需要安装对应的字体,如果系统中没有该字体,得想办法提示用户都某个网站上下载字体文件,并安装
  2. 配置
import matplotlib.pyplot as plt
import seaborn as sns

sns.set(rc={'figure.figsize': (40, 40)})
sns.set_style('whitegrid', {'font.sans-serif': ['SimHei', 'Arial']})
plt.rcParams['font.sans-serif'] = ['SimHei']
plt.rcParams['axes.unicode_minus'] = False

def get_true_pred_data(model, ds_test):
    y_true, y_pred = [], []
    for item in ds_test:
        y_pred.append(model.predict(item[0], batch_size=Config.batch_size))
        y_true.append(item[1])
    y_pred = (np.vstack(y_pred) > 0.5).astype(np.int64)
    y_true = np.vstack(y_true)
    np.save(Config.y_true_path, y_true)
    np.save(Config.y_pred_path, y_pred)
    return y_true, y_pred


def get_matrix(y_true, y_pred):
    matrix = np.zeros((len(Config.labels), len(Config.labels)))
    for _y_true, _y_pred in zip(y_true, y_pred):
        _y_true = np.argwhere(_y_true == 1).flatten()
        _y_pred = np.argwhere(_y_pred == 1).flatten()
        num = len(_y_true) * len(_y_pred)
        for _y_true_ in _y_true:
            for _y_pred_ in _y_pred:
                matrix[_y_true_, _y_pred_] += (1/num)
    matrix /= (np.sum(matrix, axis=1, keepdims=True) + 1.0e-5)
    return matrix


def get_plot_confusion_data(matrix):
    y_label = pd.Index(Config.labels, name='y_true')
    x_label = pd.Index(Config.labels, name='y_pred')
    df = pd.DataFrame(matrix, index=y_label, columns=x_label)
    return df


def plot_confusion_matrix(plot_data):
    heat_map = sns.heatmap(plot_data, vmax=1, cmap='GnBu', annot=True, fmt='.0%', annot_kws={'size': 6})
    heat_map.get_figure().savefig(Config.confusion_matrix_path, dpi=400)

@github-staff github-staff deleted a comment from SDH-IT-HO May 2, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant