In [None]:
import sys
from pathlib import Path

import polars as pl
import numpy as np
import matplotlib.pyplot as plt
from cycler import cycler
import seaborn as sns

this_path = Path(__file__) if '__file__' in globals() else Path("<unknown>.ipynb").resolve()
work_path = next((p for p in this_path.parents if p.name == "research"), None)
tools_path = work_path / Path("../torch-tools")
sys.path.append(str(tools_path))

from run_manager import RunViewer, cat_results
import pl_utils


In [None]:
pl_utils.Config()

p_path = this_path.parent

paths = [p_path]
# paths = [p_path / "exp_*"]

df_base = cat_results(paths)
# df_base = df_base.filter(pl.col("epochs") == pl.col("epoch").list.last())

# df_base.pipe(pl_utils.resolve_nested).write_csv(p_path / "df_base.csv")

print(df_base.columns)
print(df_base)

In [None]:
df_b = df_base
# df_b = df_b.pipe(pl_utils.add_iter_epoch).pipe(pl_utils.unnest_iter)

# df_b = df_b.filter(pl.col("optimizer") == "SGD")
# df_b = df_b.filter(pl.col("train_ndata").is_in([10000, 5000]))
# df_b = df_b.filter(pl.col("model_arc").str.contains("cifar"))



In [None]:
df_e = df_b
df_e = df_b.pipe(pl_utils.get_stats)

ext_columns = ["model_arc", "optimizer", "wd"] # この要素ごとにheatmapを表示
piv_values = ["val_acc"]
# piv_values = ["train_acc", "val_acc", "train_loss", "val_loss"]
piv_indices = ["div"]       # 縦軸
piv_on = "train_ndata"                               # 横軸

# heatmap
agg = "first"
axis = 0    # None: 正規化なし, 0: 行方向, 1: 列方向
norm = "zscore"  # "minmax", "zscore", None

# graph
graph_size = (4, 3)
ylim = True
x_col = "step"
# x_col = "iter_step"



In [None]:
ext_columns = ext_columns if isinstance(ext_columns, list) else [ext_columns]
piv_values = piv_values if isinstance(piv_values, list) else [piv_values]

def format_pivot_columns(df: pl.DataFrame, group_cols: list | None = None) -> pl.DataFrame:
    # Pivotテーブルのカラム数値ソートとインデックスの整形を行うヘルパー関数

    # カラムが文字列順になっているため、数字部分をソート
    new_columns = pl.DataFrame({"cols": df.columns}) .with_columns(pl.col("cols").cast(pl.Float32, strict=False).alias("numeric")).sort("numeric", maintain_order=True)["cols"].to_list()
    df = df.select(new_columns)

    # group_cols の値をまとめ，いい感じの文字列に
    if group_cols:
        if len(group_cols) == 1:
            idx_alias = group_cols[0]
        else:
            idx_alias = f"({', '.join(group_cols)})"
            df = df.with_columns((pl.lit("(") + pl.concat_str([pl.col(c).cast(pl.String) for c in group_cols], separator=", ") + pl.lit(")")).alias(idx_alias))
        # df のカラムに適用し，インデックス列を先頭に移動
        df = df.sort(group_cols, descending=True).select(idx_alias, pl.col(new_columns).exclude(group_cols))
    return df

# ユニークな組み合わせを取得
ext_l_df = df_e.select(ext_columns).unique()

# ループ処理: DataFrameの行を辞書としてイテレートする
for i, ext_row_dict in enumerate(ext_l_df.iter_rows(named=True)):
    fig, ax = None, None
    for piv_value in piv_values:
        # フィルタリング条件の作成
        filter_condition = pl.all_horizontal([(pl.col(k) == v) for k, v in ext_row_dict.items()])
        df_ext = df_e.filter(filter_condition)
        df_piv = df_ext.pivot(values=piv_value, index=piv_indices, on=piv_on, sort_columns=True, aggregate_function=agg).pipe(format_pivot_columns, group_cols=piv_indices)
        
        # 0列目がx軸ラベル、1列目以降がy軸ラベルになる df_s を heat map に変換
        ax_x = df_piv.columns[1:]
        ax_y = df_piv[df_piv.columns[0]]
        try:
            data = df_piv.select(ax_x).to_numpy()
        except ValueError:
            print(f"Skipping invalid pivot table for {ext_row_dict} and {piv_value}")
            continue
        
        # タイトルの作成
        title_parts = [f"{k}: {v}" for k, v in ext_row_dict.items()]
        title_str = ", ".join(title_parts)
        
    if np.issubdtype(data.dtype, np.number):
        # 数値データ
        annot = data.copy()

        if axis is not None:
            if norm == "minmax":
                # min-max 正規化
                min_vals = np.nanmin(data, axis=axis, keepdims=True)
                max_vals = np.nanmax(data, axis=axis, keepdims=True)
                data = (data - min_vals) / (max_vals - min_vals + 1e-8)  # ゼロ除算対策
            
            elif norm == "zscore":
                # Zスコア正規化 
                mean_vals = np.nanmean(data, axis=axis, keepdims=True)
                std_vals = np.nanstd(data, axis=axis, keepdims=True)
                data = (data - mean_vals) / (std_vals + 1e-8)

        if piv_value[-4:] == "_acc":
            annot *= 100
            
        square_size = 0.75

        fig, ax = plt.subplots(figsize=(len(ax_x)*square_size, len(ax_y)*square_size))
        fontname, fontweight = "Lato", 300

        hm_kwargs = {
            "cmap": "Blues_r",
            "cbar": False,
            # "cbar_kws": {"ticks": []},
            "fmt": ".2f",
            "annot_kws": {"size": 11, "fontname": fontname, "fontweight": 500}
        }

        ax = sns.heatmap(data, annot=annot, square=False, **hm_kwargs)
        
        # タイトル設定の変更
        ax.set_title(title_str, fontsize=14, fontname=fontname, fontweight=fontweight)
        ax.set_xlabel(piv_on, fontsize=12, fontname=fontname, fontweight=fontweight)
        ax.set_ylabel(ax_y.name, fontsize=12, rotation=90, fontname=fontname, fontweight=fontweight)

        ax.set_xticklabels(ax_x, fontsize=11, rotation=0)
        for label in ax.get_xticklabels():
            label.set_fontname(fontname)
            label.set_fontweight(fontweight)

        ax.set_yticklabels(ax_y, fontsize=11, rotation=0)
        for label in ax.get_yticklabels():
            label.set_fontname(fontname)
            label.set_fontweight(fontweight)
        plt.show()
    
    else:
        # df_piv と同じ形状で，x_col (= "step") の値を格納した DataFrame を作成
        df_step = df_ext.pivot(values=x_col, index=piv_indices, on=piv_on, sort_columns=True, aggregate_function=agg).pipe(format_pivot_columns, group_cols=piv_indices)

        _ax_x = df_step.columns[1:]
        # _ax_y = df_step[df_step.columns[0]]
        step_data = df_step.select(_ax_x).to_numpy()
        
        nrows, ncols = data.shape
        modern_colors = ['#368DFF', '#FF6495', '#3DD598', '#FFC542', '#8C52FF', '#FF6E4A']
        plt.rcParams['axes.prop_cycle'] = cycler(color=modern_colors)
        if fig is None:
            fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * graph_size[0], nrows * graph_size[1]), sharex=False, sharey=False, squeeze=False)
        
        fontname, fontweight = "Lato", 300

        fig.suptitle(title_str, fontsize=16, fontname=fontname, fontweight=fontweight)

        # 各サブプロットに折れ線グラフを描画
        for i in range(nrows):
            for j in range(ncols):
                ax = axes[i, j]
                
                # 対応するX軸とY軸のデータを取得
                x_data = step_data[i, j]
                y_data = data[i, j]
                
                # データが存在することを確認
                if x_data is not None and y_data is not None:
                    filtered_points = [(x, y) for x, y in zip(x_data, y_data) if y is not None and not np.isnan(y)]
                    # 有効なデータが1点以上残っている場合のみプロット
                    if len(filtered_points) > 1: # 線を描画するには点が2つ以上必要
                        # X軸の値でソート
                        sorted_points = sorted(filtered_points, key=lambda point: point[0])
                        # ソートされたリストからxとyを再作成
                        x_plot, y_plot = zip(*sorted_points)
                        # 最終的なプロット
                        ax.plot(x_plot, y_plot, linewidth=1.5, label=piv_value)

                # ティックのフォントを設定
                ax.tick_params(axis='both', which='major', labelsize=8)
                for label in ax.get_xticklabels() + ax.get_yticklabels():
                    label.set_fontname(fontname) 

                legend = ax.legend(fontsize=11, frameon=False)

                fig.patch.set_facecolor('#f4f7fc')
                ax.set_facecolor('#f2f2f5')

                if ylim:
                    ax.set_ylim([0, 1.005])

                ax.grid(axis='x', linestyle='-', color='#ffffff', alpha=1, linewidth=1.0) # グリッドを薄い白ベースで太く
                ax.grid(axis='y', linestyle='-', color='#ffffff', alpha=1, linewidth=1.0) # グリッドを薄い白ベースで太く

                ax.tick_params(axis='both', which='major', labelsize=11, colors='#333333')
                ax.tick_params(axis='both', which='both', length=0)

                ax.spines['top'].set_visible(False)
                ax.spines['right'].set_visible(False)
                ax.spines['left'].set_visible(False)
                ax.spines['bottom'].set_visible(False)
        
        # Y軸の行ラベル (グリッドの左端に表示)
        for i, row_label in enumerate(ax_y):
            # rotation=0で横書きにし、位置を調整
            axes[i, 0].set_ylabel(row_label, rotation=0, ha='right', va='center', fontsize=10, fontname=fontname, fontweight=fontweight)

        # X軸の列ラベル (グリッドの下端に表示)
        for j, col_label in enumerate(ax_x):
            axes[-1, j].set_xlabel(col_label, fontsize=10, fontname=fontname, fontweight=fontweight)
            
        # 全体の共通軸ラベル (推奨される方法)
        fig.supxlabel(piv_on, fontsize=12, fontname=fontname, fontweight=fontweight)
        fig.supylabel(ax_y.name, fontsize=12, fontname=fontname, fontweight=fontweight)
        
        # レイアウトを調整してラベルが重ならないようにする
        plt.tight_layout(rect=[0.02, 0.02, 1, 0.98]) # [left, bottom, right, top]
plt.show()
        