In [None]:
import sys
from itertools import combinations
from pathlib import Path

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns
import json

from constants import (
    anchor_name_mapping, 
    available_data, 
    exclude_models, 
    exclude_models_w_mae, 
    cat_name_mapping, 
    ds_info_file, 
    model_config_file,
    fontsizes
)
from helper import load_model_configs_and_allowed_models, load_similarity_matrices, save_or_show, get_fmt_name, load_ds_info

sys.path.append('..')
from scripts.helper import parse_datasets
from constants import sim_metric_name_mapping

In [None]:
base_path_similarity_matrices = Path('/home/space/diverse_priors/model_similarities')
sim_metrics = [
    'cka_kernel_rbf_unbiased_sigma_0.4',
    'cka_kernel_linear_unbiased',
]
sim_metrics_mapped = [sim_metric_name_mapping[k] for k in sim_metrics]

ds_list = parse_datasets('../scripts/webdatasets_w_insub10k.txt')
ds_list = list(map(lambda x: x.replace('/', '_'), ds_list))

ds_info = load_ds_info(ds_info_file)

cm = 0.393701

anchors = [
    'OpenCLIP_RN50_openai',
    'simclr-rn50',
    'resnet50',
    'OpenCLIP_ViT-L-14_openai',
    'dinov2-vit-large-p14',
    'vit_large_patch16_224',
]

curr_data = available_data[3]
curr_data_wo_ext = curr_data.split('.')[0]
agg_data_path = Path(f'/home/space/diverse_priors/results/aggregated/r_coeff_dist/{curr_data}')
tmp = pd.read_csv(agg_data_path)
print(curr_data, tmp.shape)

SAVE = True
storing_path = Path(
    f'/home/space/diverse_priors/results/plots/scatter_similarity_v2/{curr_data_wo_ext}'
)
if SAVE:
    storing_path.mkdir(parents=True, exist_ok=True)

In [None]:
curr_excl_models = exclude_models_w_mae if 'mae' in curr_data else exclude_models

model_configs, allowed_models = load_model_configs_and_allowed_models(
    path=model_config_file,
    exclude_models=curr_excl_models,
    exclude_alignment=True,
)

In [None]:
sim_mats = load_similarity_matrices(
    path=base_path_similarity_matrices,
    ds_list=ds_list,
    sim_metrics=sim_metrics,
    allowed_models=allowed_models,
)
sim_mats = {sim_metric_name_mapping[k]: v for k, v in sim_mats.items()}

In [None]:
ds_lists = dict(
    ds_row_1_v2=['imagenet-subset-10k', 'wds_vtab_flowers', 'wds_vtab_pcam'],
)
curr_ds_list = ds_lists['ds_row_1_v2']
curr_ds_list

In [None]:
anchor_named_2_orig = {v: k for k, v in anchor_name_mapping.items()}
anchor_nm_val_list = list(anchor_name_mapping.values())
anchor_nm_val_list_v2 = [anchor_name_mapping[mid] for mid in anchors]

anchor_col = 'Anchor Model'
sim_metric_col = 'Similarity metric'
comp_cat_col = 'Comparison category'
comp_cat_orig_col = 'Comparison category (orig. name)'
comp_val_col = 'Comparison values'


In [None]:
from matplotlib.lines import Line2D

def get_scatter_for_pp_r_df(r_df, sim_metric, with_reg_line=False):
    nrows = r_df[comp_cat_col].nunique()
    ncols = len(curr_ds_list)

    fig, axes = plt.subplots(nrows, ncols, figsize=(ncols * 10 * cm, nrows * 9 * cm), sharex=True, sharey=True)

    def pp_simmat(sim_mat, anch):
        new = sim_mat.loc[anch].copy()
        new = new.drop(index=anch)
        return new

    markers = {
        'OpenCLIP RN50': 'o',  # Circle
        'OpenCLIP ViT-L': 's',  # Square
        'ResNet-50': '^',  # Triangle Up
        'ViT-L': 'v',  # Triangle Down
        'SimCLR RN50': 'D',  # Diamond
        'DINO ViT-B': '*',  # Star
        'DINOv2 ViT-L': 'p',  # Pentagon
        'MAE ViT-L': 'H'  # Hexagon
    }

    sim_mat_metric = sim_mats[sim_metric]

    for i, (comp_cat, data) in enumerate(r_df.groupby(comp_cat_col)):
        
        for j, (ds1, ds2) in enumerate(combinations(curr_ds_list, 2)):
            ax = axes[i, j]
            handles, labels = [], []
            subset_data = data[data[['DS 1', 'DS 2']].apply(lambda x: ds1 in x.tolist() and ds2 in x.tolist(), axis=1)]
            nanchors = subset_data[anchor_col].nunique()
            orig_cal_col = subset_data[comp_cat_orig_col].unique()[0]

            cat_vals = subset_data[comp_val_col].unique()
            colors = sns.color_palette('tab10', len(cat_vals)).as_hex()
            colors = colors[-1:] + colors[:-1]
            colors = {cat:colors[i]for i, cat in enumerate(cat_vals)}
            
            for anchor, anch_data in subset_data.groupby(anchor_col):
                pp_ds1 = pp_simmat(sim_mat_metric[ds1], anchor_named_2_orig[anchor])
                pp_ds2 = pp_simmat(sim_mat_metric[ds2], anchor_named_2_orig[anchor])

                cat_values = np.array([model_configs.loc[c, orig_cal_col] for c in pp_ds1.index])

                if with_reg_line:
                    for val in np.unique(cat_values):
                        idxs = cat_values == val
                        sns.regplot(
                            x=pp_ds1[idxs],
                            y=pp_ds2[idxs],
                            marker=markers[anchor],
                            ax=ax,
                            color=colors[val],
                            # line_kws=dict(alpha=0.75, ls=':', lw=0.75),
                            line_kws=dict(alpha=0.75, ls='--', lw=2),
                            scatter_kws=dict(alpha=0.5),
                            ci=None
                        )
                else:
                    curr_colors = [colors[c] for c in cat_values]
                    ax.scatter(x=pp_ds1, y=pp_ds2, c=curr_colors, marker=markers[anchor], alpha=0.5)

                    ax.scatter(x=pp_ds1, y=pp_ds2, marker=markers[anchor], alpha=0.5)

                # handles.append(Line2D([0], [0], linestyle='None', marker=None))
                # labels.append(anchor)
                for idx, row in anch_data[[comp_val_col, 'r coeff']].iterrows():
                    cur_marker = None if row[comp_val_col] == 'All' else markers[anchor]
                    handles.append(
                        Line2D([0], [0], 
                               color=colors[row[comp_val_col]], 
                               marker=cur_marker, linestyle='None',
                               markersize=7, alpha=0.5))
                    labels.append(f"r {cat_name_mapping[row[comp_val_col]]}: {row['r coeff']:.2f}")

            ax.legend(handles=handles, labels=labels, fontsize=fontsizes['ticks'], framealpha=0.5, frameon=True, title='',
                      loc='lower right', ncols=2 if nanchors > 1 else 1)
            
            ax.tick_params(axis='both',
                           which='major',
                           labelsize=fontsizes['ticks'])

            ax.set_xlabel(get_fmt_name(ds1, ds_info), fontsize=fontsizes['label'])
            if j == 0:
                ax.set_ylabel(f"{comp_cat}\n{get_fmt_name(ds2, ds_info)}", fontsize=fontsizes['label'])
            else:
                ax.set_ylabel(get_fmt_name(ds2, ds_info), fontsize=fontsizes['label'])

    fig.suptitle(f"{r_df[anchor_col].unique()[0]} – {sim_metric}", fontsize=fontsizes['title'])
    fig.tight_layout()
    return fig

In [None]:
# anchor_combinations = list(combinations(anchors, 1)) + list(combinations(anchors, 2))
anchor_combinations = list(combinations(anchors, 1))

for sim_metric in ['CKA RBF 0.4', 'CKA linear']:
    for curr_anchors in anchor_combinations:
        for draw_reg in [True]:
            # for draw_reg in [True, False]:
            r_df = pd.read_csv(agg_data_path)
            r_df = r_df[r_df[anchor_col].isin(curr_anchors)].copy().reset_index(drop=True)
            r_df[anchor_col] = r_df[anchor_col].map(anchor_name_mapping)
            r_df = r_df[r_df['DS 1'].isin(curr_ds_list) & r_df['DS 2'].isin(curr_ds_list)]
    
            r_df = r_df[r_df['Similarity metric'] == sim_metric]
            
            # Filter for two cats
            r_df = r_df[r_df[comp_cat_col].isin(['Objective', 'Dataset diversity'])]
            
            fig = get_scatter_for_pp_r_df(r_df, sim_metric, with_reg_line=draw_reg)

            sm_name = sim_metric.replace(" ", "_").lower()
            reg_suf = "_wih_reg" if draw_reg else ""
            model_suf = "_".join(curr_anchors)
            ds_suf ="_".join(curr_ds_list)
            save_or_show(fig,
                         storing_path / f'scatter{reg_suf}_SM{sm_name}_M{model_suf}_DS{ds_suf}.pdf',
                         SAVE)