In [None]:
import sys
from pathlib import Path

import matplotlib.pyplot as plt
import pandas as pd
import seaborn as sns
import torch

from constants import sim_metric_name_mapping, similarity_metrics
from helper import get_model_ids

sys.path.append('..')
from scripts.helper import load_models

#### Global variables

In [None]:
## DATASET AND MODEL CONFIG
datasets = "../scripts/webdatasets_wo_imagenet.txt"
model_config = "../scripts/filtered_models_config.json"
anchor_model = "OpenCLIP_ViT-L-14_openai"  # ANCHOR MODEL 1
# anchor_model = "resnet50" # ANCHOR MODEL 2
combiner = 'concat'

## SIMILARITY METRICS 
sim_metric = similarity_metrics[1]

### IMAGENET SUBSET SIMILARITIES
base_subset = 'imagenet-subset-10k'
model_similarities_base_path = Path('/home/space/diverse_priors/model_similarities') / base_subset
model_similarities_path = model_similarities_base_path / sim_metric

### AGGREGATED RESULTS --> GOTTEN WITH gather_anchor_exp_results.ipynb
base_path_aggregated_results = Path('/home/space/diverse_priors/results/aggregated')

### SINGLE MODEL BEST PERFORMANCES --> structure path / [L1, L2, weight_decay] / [DATASET].json
single_model_best_perf_path = Path('/home/space/diverse_priors/results/aggregated/max_performance_per_model_n_ds')

#### Storing information

In [None]:
# base_storing_path = Path('/home/lciernik/projects/divers-priors/diverse_priors/benchmark/scripts/test_results/neg_corr_exp')
base_storing_path = Path('/home/space/diverse_priors/results/plots/performance_gap_similarity_value')
storing_path = base_storing_path / f"{base_subset.replace('-', '_')}__{anchor_model}__{sim_metric}"
SAVE = True

if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

#### Load similarity values

In [None]:
model_ids_fn = model_similarities_path / 'model_ids.txt'
sim_mat_fn = model_similarities_path / 'similarity_matrix.pt'

model_ids = get_model_ids(model_ids_fn)
sim_mat = torch.load(sim_mat_fn)
sim_mat = pd.DataFrame(sim_mat, index=model_ids, columns=model_ids)

# filter models 
models, nmodels = load_models(model_config)
allowed_models = sorted(list(models.keys()))
sim_mat = sim_mat.loc[allowed_models, allowed_models]
print(f"{sim_mat.shape=}")

#### Load experiment results

In [None]:
df = pd.read_pickle(base_path_aggregated_results / f'anchor_{anchor_model}.pkl')

In [None]:
HYPER_PARAM_COLS = ['task', 'mode', 'combiner', 'dataset', 'model_ids', 'fewshot_k', 'fewshot_epochs', 'batch_size',
                    'regularization']

In [None]:
df['model_ids'] = df['model_ids'].apply(eval).apply(tuple)
df['dataset'] = df['dataset'].apply(lambda x: x.replace('/', '_'))

In [None]:
mean_df = df.groupby(HYPER_PARAM_COLS, dropna=False).test_lp_acc1.mean().reset_index()

#### Prepare data for plotting
Steps:
1. Compute performance gap between combined model (concat or ensemble) and single model for each dataset.
2. Add similarity value for each pair of model

In [None]:
single_performance = mean_df[mean_df['mode'] == 'single_model'].copy().reset_index(drop=True)
concat_performance = mean_df[mean_df['mode'] == 'combined_models'].copy().reset_index(drop=True)
ensemble_performance = mean_df[mean_df['mode'] == 'ensemble'].copy().reset_index(drop=True)
print(f"{single_performance.shape=}, {concat_performance.shape=}, {ensemble_performance.shape=}")

In [None]:
concat_performance['other_model'] = concat_performance['model_ids'].apply(
    lambda x: x[0] if x[1] == anchor_model else x[1])
ensemble_performance['other_model'] = ensemble_performance['model_ids'].apply(
    lambda x: x[0] if x[1] == anchor_model else x[1])

In [None]:
## THESE ARE THE ANCHOR MODEL PERFORMANCES FOR DIFFERENT REGULARIZATIONS
single_performance_pivot = pd.pivot_table(
    single_performance,
    index='dataset',
    columns='regularization',
    values='test_lp_acc1'
)
single_performance_pivot

In [None]:
def get_performance_gap_n_sim_metric(row):
    other_model = row['other_model']
    comb_perf = row['test_lp_acc1']
    sing_perf = single_performance_pivot.loc[row['dataset'], row['regularization']]
    gap = comb_perf - sing_perf
    sim_val = sim_mat.loc[other_model, anchor_model]
    return gap, sim_val

In [None]:
concat_performance = pd.concat([concat_performance,
                                pd.DataFrame(
                                    concat_performance.apply(get_performance_gap_n_sim_metric, axis=1).tolist(),
                                    columns=['gap', 'sim_value'])],
                               axis=1)

ensemble_performance = pd.concat([ensemble_performance,
                                  pd.DataFrame(
                                      ensemble_performance.apply(get_performance_gap_n_sim_metric, axis=1).tolist(),
                                      columns=['gap', 'sim_value'])],
                                 axis=1)

#### Plot scatter plot and add correlation coefficient 

In [None]:
def plot_scatter(df, title):
    g = sns.relplot(
        df,
        x='sim_value',
        y='gap',
        col='regularization',
        row='dataset',
        height=3,
        aspect=1.25,
        facet_kws={'sharey': False, 'sharex': False}
    )
    g.set_titles("{row_name} – {col_name}")

    def annotate_correlation(data, **kwargs):
        r = data['sim_value'].corr(data['gap'])
        ax = plt.gca()
        ax.text(0.05, 0.95, f'r = {r:.2f}', transform=ax.transAxes,
                fontsize=12, verticalalignment='top')
        if max(data['gap']) > 0:
            ax.axhspan(0, max(data['gap']), facecolor='lightgreen', alpha=0.2, zorder=-1)
        if min(data['gap']) < 0:
            ax.axhspan(min(data['gap']), 0, facecolor='lightcoral', alpha=0.2, zorder=-1)

    g.map_dataframe(annotate_correlation)

    g.fig.suptitle(title, y=1)
    g.fig.tight_layout()
    return g.fig

In [None]:
fig = plot_scatter(concat_performance,
                   f"Combined models (Concat) with anchor {anchor_model} and {sim_metric_name_mapping[sim_metric]} similarity values.")
if SAVE:
    fig.savefig(storing_path / 'combined_concat.pdf', bbox_inches='tight')
    plt.close(fig)
    print('stored concat img')
else:
    plt.show(fig)

In [None]:
fig = plot_scatter(ensemble_performance,
                   f"Ensemble with anchor {anchor_model} and {sim_metric_name_mapping[sim_metric]} similarity values.")
if SAVE:
    fig.savefig(storing_path / 'ensemble.pdf', bbox_inches='tight')
    plt.close(fig)
    print('stored ensemble img')
else:
    plt.show(fig)