In [None]:
import json
import pandas as pd
import numpy as np

JSON_PATH = "vit_chest_xray.json"

with open(JSON_PATH, "r") as f:
    data = json.load(f)

classes = data["classes"]

# ---- loss & acc to long df ----
def metric_to_long(metric_dict, metric_name):
    # metric_dict: {"1.0000": {"Cardiomegaly": 0.1636, ...}, ...}
    rows = []
    for rho_str, cls_map in metric_dict.items():
        rho = float(rho_str)
        for c in classes:
            rows.append({"rho": rho, "class": c, "metric": metric_name, "value": float(cls_map[c])})
    return pd.DataFrame(rows)

df_loss = metric_to_long(data["loss"], "loss")
df_acc  = metric_to_long(data["accuracy"], "acc")

df_metrics = pd.concat([df_loss, df_acc], ignore_index=True)
df_metrics["rho"] = df_metrics["rho"].astype(float)

# sort rho from 1.0 -> small
df_metrics = df_metrics.sort_values(["metric", "class", "rho"], ascending=[True, True, False])

# ---- confidence dominance rates to long df ----
# confidence_dominance_rate: {"1.0000": {"0.9375": {"Cardiomegaly":0.79,...}, ...}, ...}
rows = []
for rho1_str, inner in data["confidence_dominance_rate"].items():
    rho1 = float(rho1_str)
    for rho2_str, cls_map in inner.items():
        rho2 = float(rho2_str)
        for c in classes:
            rows.append({
                "rho1": rho1, "rho2": rho2, "class": c,
                "rate": float(cls_map[c])
            })
df_rate = pd.DataFrame(rows).sort_values(["class", "rho1", "rho2"], ascending=[True, False, False])

df_metrics.head(), df_rate.head()


(        rho         class metric   value
 80   1.0000  Cardiomegaly    acc  0.9786
 85   0.9375  Cardiomegaly    acc  0.9487
 90   0.8750  Cardiomegaly    acc  0.9701
 95   0.8125  Cardiomegaly    acc  0.9573
 100  0.7500  Cardiomegaly    acc  0.9573,
     rho1    rho2         class    rate
 0    1.0  0.9375  Cardiomegaly  0.7949
 5    1.0  0.8750  Cardiomegaly  0.7735
 10   1.0  0.8125  Cardiomegaly  0.7821
 15   1.0  0.7500  Cardiomegaly  0.7778
 20   1.0  0.6875  Cardiomegaly  0.7906)

In [None]:
import numpy as np
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

def dominance_matrix_data(target_class):
    d = df_rate[df_rate["class"] == target_class].copy()

    rhos = sorted(
        list(set(d["rho1"].unique()).union(set(d["rho2"].unique())))
    )

    mat = pd.DataFrame(np.nan, index=rhos, columns=rhos)

    for _, row in d.iterrows():
        mat.loc[row["rho1"], row["rho2"]] = row["rate"]

    for i, iv in enumerate(rhos):
        for j, jv in enumerate(rhos):
            if i == j:
                neighbors = []
                if i - 1 >= 0: neighbors.append(mat.loc[rhos[i - 1], rhos[i]])
                if i + 1 < len(rhos): neighbors.append(mat.loc[rhos[i + 1], rhos[i]])
                if i - 1 >= 0: neighbors.append(mat.loc[rhos[i], rhos[i - 1]])
                if i + 1 < len(rhos): neighbors.append(mat.loc[rhos[i], rhos[i+1]])
                neighbors = [v for v in neighbors if not pd.isna(v)]
                if len(neighbors) > 0:
                    mat.loc[iv, iv] = np.mean(neighbors)
            elif pd.isna(mat.loc[iv, jv]) and not pd.isna(mat.loc[jv, iv]):
                mat.loc[iv, jv] = mat.loc[jv, iv]

    mat = mat.loc[rhos, rhos]

    def rho_to_label(rho):
        size = int(round(224 * rho))
        return f"{size}√ó{size}"

    resolution_labels = [rho_to_label(r) for r in rhos]

    return mat.values, resolution_labels


classes = [
    "Cardiomegaly",
    "Consolidation",
    "Edema",
    "Pneumonia",
    "No Finding"
]

fig = make_subplots(
    rows=1,
    cols=len(classes),
    # subplot_titles=classes,
    horizontal_spacing=0.0
)

for i, cls in enumerate(classes):
    z, labels = dominance_matrix_data(cls)

    fig.add_trace(
        go.Heatmap(
            z=z,
            x=labels,
            y=labels,
            zmin=0.5,
            zmax=1.0,
            coloraxis="coloraxis"   # üî• shared legend
        ),
        row=1,
        col=i+1
    )

# üî• Rotate x tick labels
for i in range(1, len(classes)+1):
    fig.update_xaxes(
        constrain='domain',
        tickangle=45,      # try 45 or 60 for better readability
        row=1,
        col=i
    )


# üî• Shared colorbar
fig.update_layout(
    coloraxis=dict(
        # colorscale="Viridis",
        cmin=0.5,
        cmax=1.0,
        colorbar=dict(
            # title="P(conf[œÅ1] > conf[œÅ2])",
            len=1.2,
            thickness=20
        )
    ),
    height=450,
    width=2000,
)

# üî• Force square cells
for i in range(1, len(classes)+1):
    fig.update_yaxes(scaleanchor=f"x{i}", scaleratio=1, row=1, col=i)

fig.show()

fig.write_html(
    "fig_plotly1.html",
    include_plotlyjs="cdn",
    config={
        "displayModeBar": False
    }
)


In [None]:
import json
import pandas as pd
import plotly.graph_objects as go
from plotly.subplots import make_subplots

JSON_PATH = "vit_chest_xray.json"

with open(JSON_PATH, "r") as f:
    data = json.load(f)

# classes = data["classes"]  # length = 5 in your case
classes = ['Cardiomegaly', 'Edema', 'Consolidation', 'Pneumonia', 'No Finding']

# ---- Convert to long df ----
def metric_to_long(metric_dict, metric_name):
    rows = []
    for rho_str, cls_map in metric_dict.items():
        rho = float(rho_str)
        for c in classes:
            rows.append({"rho": rho, "class": c, "metric": metric_name, "value": float(cls_map[c])})
    return pd.DataFrame(rows)

def extract_confidence_dominance(conf_dict):
    rows = []

    rho1_key = "1.0000"  # reference rho = 1.0

    for rho2_str, cls_map in conf_dict[rho1_key].items():
        rho2 = float(rho2_str)
        for c in classes:
            rows.append({
                "rho": rho2,
                "class": c,
                "metric": "ConfidenceDominance",
                "value": float(cls_map[c])
            })

    return pd.DataFrame(rows).sort_values(["class", "rho"], ascending=[True, False])

df_acc  = metric_to_long(data["accuracy"], "Accuracy")
df_loss = metric_to_long(data["loss"], "Loss")
df_conf = extract_confidence_dominance(data["confidence_dominance_rate"])
df = pd.concat([df_acc, df_loss, df_conf], ignore_index=True)

# Sort rho from 1.0 -> small
rho_order = sorted(df["rho"].unique(), reverse=True)
print(f"rho_order: {rho_order}")

def rho_to_label(rho):
    size = int(round(224 * rho))
    return f"{size}√ó{size}"

labels = [rho_to_label(r) for r in rho_order]
print(labels)

# ---- 1x5 subplots, each with secondary y-axis ----
fig = make_subplots(
    rows=1,
    cols=len(classes),
    # subplot_titles=classes,
    specs=[[{"secondary_y": True} for _ in classes]],
    horizontal_spacing=0.03,
)


for i, c in enumerate(classes, start=1):
    sub = df[df["class"] == c].copy()

    acc_series = (
        sub[sub["metric"] == "Accuracy"]
        .set_index("rho")
        .reindex(rho_order)["value"]
    )
    loss_series = (
        sub[sub["metric"] == "Loss"]
        .set_index("rho")
        .reindex(rho_order)["value"]
    )
    conf_series = (
        sub[sub["metric"] == "ConfidenceDominance"]
        .set_index("rho")
        .reindex(rho_order)["value"]
    )

    # Accuracy (left y-axis)
    fig.add_trace(
        go.Scatter(
            x=labels,
            y=acc_series,
            mode="lines+markers",
            name="Accuracy",
            legendgroup="Accuracy",
            # showlegend=(i == 1),
        ),
        row=1, col=i, secondary_y=False
    )

    # Confidence Dominance (share Accuracy axis)
    fig.add_trace(
        go.Scatter(
            x=labels,
            y=conf_series,
            mode="lines+markers",
            name="Conf. Dominance (1.0 ‚Üí œÅ)",
            legendgroup="ConfidenceDominance",
            line=dict(dash="dot"),
            # showlegend=(i == 1),
        ),
        row=1, col=i, secondary_y=False
    )

    # Axis titles per subplot
    fig.update_xaxes(
        # title_text="œÅ",
        tickangle=45,
        autorange="reversed",
        categoryorder="array",
        # categoryarray=rho_order,
        row=1, col=i
    )

    # Axis titles per subplot
    fig.update_yaxes(
        # title_text="ACC & Conf. Dominance",
        row=1, col=i
    )

# ---- Global layout ----
fig.update_layout(
    height=420,
    width=1600,
    margin=dict(l=40, r=40, t=80, b=40),
    template="plotly_white",
    plot_bgcolor="rgba(225,225,255,1)",
    paper_bgcolor="white",
    showlegend=False,
)

fig.show()

fig.write_html(
    "fig_plotly2.html",
    include_plotlyjs="cdn",
    config={
        "displayModeBar": False
    }
)

rho_order: [np.float64(1.0), np.float64(0.9375), np.float64(0.875), np.float64(0.8125), np.float64(0.75), np.float64(0.6875), np.float64(0.625), np.float64(0.5625), np.float64(0.5), np.float64(0.4375), np.float64(0.375), np.float64(0.3125), np.float64(0.25), np.float64(0.1875), np.float64(0.125), np.float64(0.0625)]
['224√ó224', '210√ó210', '196√ó196', '182√ó182', '168√ó168', '154√ó154', '140√ó140', '126√ó126', '112√ó112', '98√ó98', '84√ó84', '70√ó70', '56√ó56', '42√ó42', '28√ó28', '14√ó14']
