In [None]:
import sys
from pathlib import Path

import polars as pl
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns

this_path = Path(__file__) if '__file__' in globals() else Path("<unknown>.ipynb").resolve()
work_path = next((p for p in this_path.parents if p.name == "research"), None)
tools_path = work_path / Path("../torch-tools")
sys.path.append(str(tools_path))

from run_manager import RunViewer


In [None]:
rv = RunViewer(exp_path=this_path.parent)
df_base = rv.fetch_results(met_listed=False)

nested_columns = [name for name, dtype in zip(df_base.columns, df_base.dtypes) if dtype.is_nested()]
df_base = df_base.with_columns([pl.col(name).list.last().alias(f"{name}") for name in nested_columns])

print(df_base.columns)
# display(df_base)

In [None]:
df = df_base

# df = df.filter(pl.col("fils").is_in([32, 16, 8, 4]))

piv_values = ["val_acc"]            # 表示する値
piv_index = "fils"   # 縦軸
piv_on = "max_lr"                   # 横軸

# agg = "len"
agg = "mean"

ext_column = "train_ndata" # このカラムの要素ごとにheatmapを表示

ext_l = df[ext_column].unique()
for ext in ext_l:
    # pivot table 作成
    df_ext = df.filter(pl.col(ext_column) == ext)
    df_piv = df_ext.pivot(values=piv_values, index=piv_index, on=piv_on, sort_columns=True, aggregate_function=agg)

    # カラムが文字列順になっているため、数字部分をソート
    _num_columns = sorted(int(x) for x in df_piv.columns if x.isdigit())
    new_columns = [str(_num_columns.pop(0)) if x.isdigit() else x for x in df_piv.columns]
    df_piv = df_piv.select(new_columns)

    # 0列目がx軸ラベル、1列目以降がy軸ラベルになる df を heat map に変換
    square_size = 0.75
    hm_x = df_piv.columns[1:]
    hm_y = df_piv[df_piv.columns[0]]
    data = df_piv.select(hm_x).to_numpy()
    annot = data.copy()

    # 正規化の方向を設定
    axis = 1    # 0: 行方向, 1: 列方向

    # min-max 正規化
    # min_vals = data.min(axis=axis, keepdims=True)
    # max_vals = data.max(axis=axis, keepdims=True)
    # data = (data - min_vals) / (max_vals - min_vals + 1e-8)  # ゼロ除算対策
    
    # Zスコア正規化
    mean_vals = data.mean(axis=axis, keepdims=True)
    std_vals = data.std(axis=axis, keepdims=True)
    data = (data - mean_vals) / (std_vals + 1e-8)  # ゼロ除算対策

    annot *= 100

    fig, ax = plt.subplots(figsize=(len(hm_x)*square_size, len(hm_y)*square_size))

    fontname, fontweight = "Lato", 300 # フォント名とウェイトを指定

    hm_kwargs = {
        "cmap": "Blues_r",
        "cbar": False,
        "fmt": ".1f",
        "annot_kws": {"size": 11, "fontname": fontname, "fontweight": 500}
    }

    ax = sns.heatmap(data, annot=annot, square=True, **hm_kwargs)

    ax.set_title(f"{ext_column}: {ext}", fontsize=14, fontname=fontname, fontweight=fontweight)
    ax.set_xlabel(piv_on, fontsize=12, fontname=fontname, fontweight=fontweight)
    ax.set_ylabel(hm_y.name, fontsize=12, rotation=90, fontname=fontname, fontweight=fontweight)

    ax.set_xticklabels(hm_x, fontsize=11, rotation=0)
    for label in ax.get_xticklabels():
        label.set_fontname(fontname)
        label.set_fontweight(fontweight)

    ax.set_yticklabels(hm_y, fontsize=11, rotation=0)
    for label in ax.get_yticklabels():
        label.set_fontname(fontname)
        label.set_fontweight(fontweight)

    plt.show()