In [1]:
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd

In [None]:
def dotplot_matrix(A, xlabels=None, ylabels=None, title="Dot plot",
                   size_scale=1200, vmin=None, vmax=None, cmap=None,
                   origin="lower", show_colorbar=True):
    """
    A: (N,M) array
    size_scale: 전체 버블 크기 스케일 (키우면 전체가 커짐)
    vmin/vmax: 색상/사이즈 정규화 범위 (None이면 A의 min/max)
    cmap: matplotlib colormap (None이면 기본)
    origin: "lower"면 (0,0)이 좌하단, "upper"면 좌상단
    """
    A = np.asarray(A, dtype=float)
    N, M = A.shape

    if vmin is None: vmin = np.nanmin(A)
    if vmax is None: vmax = np.nanmax(A)
    denom = (vmax - vmin) if (vmax - vmin) != 0 else 1.0

    # 0~1로 정규화 (NaN은 0 처리)
    An = (A - vmin) / denom
    An = np.nan_to_num(An, nan=0.0)
    An = np.clip(An, 0.0, 1.0)

    # scatter는 size가 "면적(points^2)"임 → radius ~ sqrt(size)
    sizes = (An ** 1.0) * size_scale  # **gamma로 민감도 조절 가능

    yy, xx = np.indices((N, M))
    x = xx.ravel()
    y = yy.ravel()
    s = sizes.ravel()
    c = A.ravel()

    fig, ax = plt.subplots(figsize=(max(6, M * 0.35), max(5, N * 0.35)))
    sc = ax.scatter(x, y, s=s, c=c, cmap=cmap, vmin=vmin, vmax=vmax, edgecolors="none")

    # 축/라벨
    ax.set_title(title)
    ax.set_xlabel("X")
    ax.set_ylabel("Y")

    ax.set_xticks(np.arange(M))
    ax.set_yticks(np.arange(N))

    if xlabels is None:
        ax.set_xticklabels([str(i) for i in range(M)], rotation=90)
    else:
        ax.set_xticklabels(xlabels, rotation=90)

    if ylabels is None:
        ax.set_yticklabels([str(i) for i in range(N)])
    else:
        ax.set_yticklabels(ylabels)

    # (0,0) 위치 설정
    if origin == "lower":
        ax.set_ylim(-0.5, N - 0.5)
    else:
        ax.set_ylim(N - 0.5, -0.5)

    ax.set_xlim(-0.5, M - 0.5)

    # 셀 경계 그리드
    ax.set_xticks(np.arange(-.5, M, 1), minor=True)
    ax.set_yticks(np.arange(-.5, N, 1), minor=True)
    ax.grid(which="minor", linestyle="-", linewidth=0.5)
    ax.tick_params(which="minor", bottom=False, left=False)

    ax.set_aspect("equal")  # 원이 찌그러지지 않게

    if show_colorbar:
        cb = fig.colorbar(sc, ax=ax)
        cb.set_label("value", rotation=270, labelpad=15)

    plt.tight_layout()
    plt.show()


In [None]:


dotplot_matrix(A, xlabels=pred_labels, ylabels=true_labels, title="CM dotplot", origin="lower")