In [None]:
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns

from constants import cat_name_mapping, exclude_models_w_mae

sns.set_style('ticks')

In [None]:
from helper import load_model_configs_and_allowed_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path='../scripts/models_config_wo_barlowtwins_n_alignment.json',
    exclude_models=exclude_models_w_mae,
    exclude_alignment=True,
)
orig_cols = {'Objective pair': 'objective', 'Architecture pair': 'architecture_class', 'Dataset pair': 'dataset_class',
             'Model size pair': 'size_class'}

In [None]:
# base_path_aggregated = '/home/space/diverse_priors/results/aggregated'
base_path_aggregated = Path('/Users/lciernik/Documents/TUB/projects/divers_prios/results/aggregated')

### Config similarity data
sim_data = pd.read_csv(base_path_aggregated / 'model_sims/all_metric_ds_model_pair_similarity.csv')

In [None]:
sim_data = sim_data[sim_data['Model 1'].isin(allowed_models) & sim_data['Model 2'].isin(allowed_models)]

In [None]:
pair_columns = sorted(['Objective pair', 'Architecture pair', 'Dataset pair', 'Model size pair'])

In [None]:
for col in pair_columns:
    sim_data[col] = sim_data[col].apply(eval)
    sim_data[f"M1 {col}"] = sim_data[col].apply(lambda x: cat_name_mapping[x[0]])
    sim_data[f"M2 {col}"] = sim_data[col].apply(lambda x: cat_name_mapping[x[1]])
    sim_data[col] = sim_data[col].apply(lambda x: f"{cat_name_mapping[x[0]]} – {cat_name_mapping[x[1]]}")

In [None]:
for col in pair_columns:
    sim_data[col] = sim_data[col].apply(str)

In [None]:
from matplotlib import ticker


def get_box_plt_sim_distributions(all_data, curr_pair_columns):
    # n = sim_data['Similarity metric'].nunique()
    n = all_data['Similarity metric'].nunique()
    m = len(curr_pair_columns)
    cm = 0.393701
    fig, axes = plt.subplots(n, m, figsize=(10 * cm * m, 6 * cm * n), sharey=True, sharex='col')
    for i, metric in enumerate(all_data['Similarity metric'].unique()):
        for j, col in enumerate(curr_pair_columns):
            ax = axes[i, j]

            data = all_data[all_data['Similarity metric'] == metric]
            data = data[data[f"M1 {col}"] == data[f"M2 {col}"]]

            data = data.sort_values(by=col)
            sns.boxenplot(
                data=data,
                x=col,
                y='Similarity value',
                ax=ax,
                hue=col,
                palette='tab10',
            )
            ax.tick_params(axis='x',  # Apply to both x and y axes
                           which='major',  # Apply to major ticks
                           rotation=90)
            ax.yaxis.set_major_formatter(ticker.FormatStrFormatter('%.2f'))

            title = col if i == 0 else ''
            ax.set_title(title)
            ax.set_xlabel('')

            ylbl = f'{metric}\nSimilarity value' if j == 0 else ''
            ax.set_ylabel(ylbl)

            # ax.axhline(0.7, c='r', ls=":", zorder=-1)

    fig.subplots_adjust(wspace=0.1, hspace=0.1)
    return fig


fig = get_box_plt_sim_distributions(sim_data, pair_columns)

In [None]:
settings_to_fix_cat = [
    {'fix_cols': ['Dataset pair'],
     'fix_vals': ['IN1k']},

    {'fix_cols': ['Dataset pair'],
     'fix_vals': ['XLarge DS']},

    {'fix_cols': ['Architecture pair'],
     'fix_vals': ['TX']},

    {'fix_cols': ['Architecture pair'],
     'fix_vals': ['CNN']},

    {'fix_cols': ['Objective pair'],
     'fix_vals': ['SSL']},

    {'fix_cols': ['Objective pair'],
     'fix_vals': ['Sup']},

    {'fix_cols': ['Objective pair'],
     'fix_vals': ['Img-Txt']},

    {'fix_cols': ['Model size pair'],
     'fix_vals': ['small']},

    {'fix_cols': ['Model size pair'],
     'fix_vals': ['medium']},

    {'fix_cols': ['Model size pair'],
     'fix_vals': ['large']},

    {'fix_cols': ['Dataset pair', 'Objective pair'],
     'fix_vals': ['IN1k', 'Sup']},

    {'fix_cols': ['Dataset pair', 'Objective pair'],
     'fix_vals': ['IN21k', 'Sup']},

    {'fix_cols': ['Architecture pair', 'Objective pair'],
     'fix_vals': ['CNN', 'Sup']},

    {'fix_cols': ['Architecture pair', 'Objective pair'],
     'fix_vals': ['TX', 'Sup']},

    {'fix_cols': ['Dataset pair', 'Objective pair'],
     'fix_vals': ['IN1k', 'SSL']},

    {'fix_cols': ['Dataset pair', 'Architecture pair'],
     'fix_vals': ['IN1k', 'CNN']},

    {'fix_cols': ['Dataset pair', 'Architecture pair'],
     'fix_vals': ['IN1k', 'TX']},

    {'fix_cols': ['Dataset pair', 'Architecture pair'],
     'fix_vals': ['Large DS', 'TX']},
]

In [None]:
for setting in settings_to_fix_cat:
    subset_data = sim_data.copy()
    curr_pair_cols = pair_columns.copy()
    all_fixed_vals = []
    for col_name, fix_value in zip(setting['fix_cols'], setting['fix_vals']):
        m1_col = f'M1 {col_name}'
        m2_col = f'M2 {col_name}'
        subset_data = subset_data[(subset_data[m1_col] == fix_value) & (subset_data[m2_col] == fix_value)]
        curr_pair_cols.remove(col_name)
        all_fixed_vals.append(fix_value)

    model_set = set()
    model_set.update(subset_data['Model 1'].unique())
    model_set.update(subset_data['Model 2'].unique())
    model_set = sorted(list(model_set))
    tmp = [orig_cols[c] for c in curr_pair_cols]
    model_set = [(mid, [cat_name_mapping[k] for k in model_configs.loc[mid, tmp].to_list()]) for mid in model_set]
    print(f"Models:")
    for mid in model_set:
        print(f"Model: {mid[0]}, {mid[1]}")
    fig = get_box_plt_sim_distributions(subset_data, curr_pair_cols)
    fig.suptitle(f"Fixed values: {', '.join(all_fixed_vals)}")
    plt.show(fig)