In [5]:
from load_results import load_result_dataset
import pandas as pd

pn1 = 'full_fine_tuning_50epochs_edge_paper_final2'
pn2 = 'full_fine_tuning_50epochs_paper_final2'
pn3 = 'none'
final_data1 = load_result_dataset(pn1, pn2, pn3)
final_data1 = [{**d, 'ft_strategy': 'FFT (50 epochs)'} for d in final_data1]
df1 = pd.DataFrame(final_data1)

pn1 = 'full_fine_tuning_5epochs_edge_article1'
pn2 = 'full_fine_tuning_5epochs_article1'
pn3 = 'none'
final_data2 = load_result_dataset(pn1, pn2, pn3)
final_data2 = [{**d, 'ft_strategy': 'FFT (5 epochs)'} for d in final_data2]
df2 = pd.DataFrame(final_data2)

pn1 = 'linearprobe_50epochs_edge_paper_final2'
pn2 = 'linearprobe_50epochs_paper_final2'
pn3 = 'none'
final_data3 = load_result_dataset(pn1, pn2, pn3)
final_data3 = [{**d, 'ft_strategy': 'LP (50 epochs)'} for d in final_data3]
df3 = pd.DataFrame(final_data3)

df = pd.concat([ df1, df2, df3 ], axis=0, ignore_index=True) 

final_data = []
final_data.extend(final_data1)
final_data.extend(final_data2)
final_data.extend(final_data3)

../results/none/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K_uc-merced-land-use-dataset_TRADES_v2.pkl
HEY
../results/full_fine_tuning_50epochs_paper_final2/CLIP-convnext_base_w-laion_aesthetic-s13B-b82K_uc-merced-land-use-dataset_TRADES_v2.pkl
../results/none/CLIP-convnext_base_w-laion2B-s13B-b82K_uc-merced-land-use-dataset_TRADES_v2.pkl
HEY
../results/full_fine_tuning_50epochs_paper_final2/CLIP-convnext_base_w-laion2B-s13B-b82K_uc-merced-land-use-dataset_TRADES_v2.pkl
../results/none/deit_small_patch16_224.fb_in1k_uc-merced-land-use-dataset_TRADES_v2.pkl
HEY
../results/full_fine_tuning_50epochs_paper_final2/deit_small_patch16_224.fb_in1k_uc-merced-land-use-dataset_TRADES_v2.pkl
../results/none/robust_resnet50_uc-merced-land-use-dataset_TRADES_v2.pkl
HEY
../results/full_fine_tuning_50epochs_paper_final2/robust_resnet50_uc-merced-land-use-dataset_TRADES_v2.pkl
../results/none/vit_small_patch16_224.augreg_in21k_uc-merced-land-use-dataset_TRADES_v2.pkl
HEY
../results/full_fine_tuning_50

In [6]:
import numpy as np
import pandas as pd

# ------------------------- helpers -------------------------
def mean_absolute_correlation(corr_df: pd.DataFrame) -> float:
    """Mean absolute off-diagonal of a correlation matrix."""
    mask = ~np.eye(len(corr_df), dtype=bool)
    return np.abs(corr_df.values[mask]).mean()

def compute_spearman_corr(df: pd.DataFrame) -> pd.DataFrame:
    cols = ['clean_acc', 'Linf_acc', 'L2_acc', 'L1_acc', 'common_acc']
    return df[cols].corr(method='spearman')

def mac_by_group_and_dataset(df: pd.DataFrame,
                             group_col: str,
                             dataset_col: str = "dataset") -> pd.DataFrame:
    """
    One row per *group_col* value.
    Columns = all datasets in the frame + "Average MAC".
    """
    dataset_list = sorted(df[dataset_col].unique())
    rows = []

    for grp_value, grp_df in df.groupby(group_col):
        macs = {}
        for ds in dataset_list:
            ds_df = grp_df[grp_df[dataset_col] == ds]
            # Need ≥ 2 rows for a correlation matrix
            macs[ds] = (
                np.nan
                if len(ds_df) < 2
                else mean_absolute_correlation(compute_spearman_corr(ds_df))
            )
        macs["Average MAC"] = np.nanmean(list(macs.values()))
        rows.append(pd.Series(macs, name=str(grp_value)))

    return pd.DataFrame(rows)[dataset_list + ["Average MAC"]].round(3)

# ------------------------- compute correlations -------------------------
fft_50_df = df[df["ft_strategy"] == "FFT (50 epochs)"]

# ---------- NEW global-row logic ----------
dataset_list = sorted(fft_50_df["dataset"].unique())

# MAC for each dataset (un-conditioned, i.e. using every row in that dataset)
global_per_ds = { ds: mean_absolute_correlation(
                            compute_spearman_corr(fft_50_df[fft_50_df["dataset"] == ds])) 
                            for ds in dataset_list }

# Overall MAC across *all* datasets
global_per_ds["Average MAC"] = np.mean( list( global_per_ds.values() ) )

global_row = pd.DataFrame(global_per_ds, index=["Global (Spearman)"]).round(3)

# ---- NEW: category-conditioned tables ----
loss_table   = mac_by_group_and_dataset(fft_50_df, "loss_function")
size_table   = mac_by_group_and_dataset(fft_50_df, "model_size")
type_table   = mac_by_group_and_dataset(fft_50_df, "model_type")
pretr_table  = mac_by_group_and_dataset(fft_50_df, "pre_training_strategy")

# ------------------------- combine & export -------------------------
all_tables = pd.concat([
    global_row,                         # 1-row, only "Average MAC"
    loss_table,                         # per-loss-fn  × datasets
    size_table,
    type_table,
    pretr_table
]).round(3)

# LaTeX generation stays identical — just feed `all_tables`
latex_core = all_tables.to_latex(index=True, escape=False, header=True)

tabular_start = next(i for i, l in enumerate(latex_core.splitlines()) if l.startswith(r"\begin{tabular}"))
tabular_end   = next(i for i, l in enumerate(latex_core.splitlines()) if l.startswith(r"\end{tabular}")) + 1
tabular_content = "\n".join(latex_core.splitlines()[tabular_start:tabular_end])

corrected_latex = rf"""
\centering
\label{{tab:spearman_corr}}
\resizebox{{0.45\textwidth}}{{!}}{{%
{tabular_content}
}}
"""

with open("./latex_tables/correlation_matrix.tex", "w") as f:
    f.write(corrected_latex)
