In [1]:
from pathlib import Path

import mml.interactive

mml.interactive.init(Path('~/.config/mml.env').expanduser())

import pandas as pd
from typing import List
from rich.progress import track
from itertools import chain
import plotly.express as px

from mml_tf.representations import FullFeatureRepresentations, AveragedFeatureRepresentations, \
    MeanAndCovarianceRepresentations, BinnedFeatureRepresentations
from mml_tf.distances import EMDDistances, KLDDistances, FIDDistances, MMDDistances, LoadCachedDistances, map_dist2printable
from mml_tf.evaluation import get_evaluations
from mml_tf.aggregate import AggregateStrategy
from mml_tf.visualization import init_colors, get_dist_measure_color
from mml_tf.experiments import EXPERIMENTS
from mml_tf.ranking import BootstrapRanking
from mml_tf.paths import FIG_PATH

 _____ ______   _____ ______   ___
|\   _ \  _   \|\   _ \  _   \|\  \
\ \  \\\__\ \  \ \  \\\__\ \  \ \  \
 \ \  \\|__| \  \ \  \\|__| \  \ \  \
  \ \  \    \ \  \ \  \    \ \  \ \  \____
   \ \__\    \ \__\ \__\    \ \__\ \_______\
    \|__|     \|__|\|__|     \|__|\|_______|
         ____  _  _    __  _  _  ____  _  _
        (  _ \( \/ )  (  )( \/ )/ ___)( \/ )
         ) _ ( )  /    )( / \/ \\___ \ )  /
        (____/(__/    (__)\_)(_/(____/(__/
Interactive MML API initialized.


In [3]:
other_probes = ['CLIP', 'MAE', 'DINO']

## Step 1: compute distances for other probe networks

(step may be skipped if using already cached distances)

In [25]:
# load individual representations for all three extraction networks
clip_rep = FullFeatureRepresentations(folder='clip_features', probe_network='CLIP')
mae_rep = FullFeatureRepresentations(folder='mae_features', probe_network='MAE')
dino_rep = FullFeatureRepresentations(folder='dino_features', probe_network='DINO')
full_reps = [clip_rep, mae_rep, dino_rep]
for rep in full_reps:
    rep.load_representations()
avg_reps = [AveragedFeatureRepresentations(full_features=rep) for rep in full_reps]
few_bins_reps = [BinnedFeatureRepresentations(full_features=rep, n_bins=100) for rep in full_reps]
lot_bins_reps = [BinnedFeatureRepresentations(full_features=rep, n_bins=1000) for rep in full_reps]
mean_cov_reps = [MeanAndCovarianceRepresentations(full_features=rep) for rep in full_reps]
for rep in track(avg_reps + few_bins_reps + lot_bins_reps + mean_cov_reps):
    rep.load_representations()

Output()

In [26]:
# actual computation
for probe_idx, probe_name in enumerate(other_probes):
    kld_small_target = KLDDistances(representations=few_bins_reps[probe_idx], source_pp='soft', target_pp='norm',
                                    invert=False,
                                    weighing_by='target', weights_rep=avg_reps[probe_idx], weights_pp='soft',
                                    clip=False)
    kld_large_source = KLDDistances(representations=lot_bins_reps[probe_idx], source_pp='soft', target_pp='norm',
                                    invert=False,
                                    weighing_by='source', weights_rep=avg_reps[probe_idx], weights_pp='norm',
                                    clip=False)
    kld_large_unweighted = KLDDistances(representations=lot_bins_reps[probe_idx], source_pp='soft', target_pp='norm',
                                        invert=False,
                                        weights_rep=None, clip=False)
    vdna = EMDDistances(representations=lot_bins_reps[probe_idx], soft_features=False)
    p2l = KLDDistances(representations=avg_reps[probe_idx], source_pp='norm', target_pp='norm')
    fid = FIDDistances(representations=mean_cov_reps[probe_idx])
    mmd = MMDDistances(representations=full_reps[probe_idx], kernel='geo-sinkhorn', blur=0.01)
    # probe_variants[probe_name] = [kld_small_target, kld_large_source, kld_large_unweighted, vdna, p2l, fid, mmd]

## Step 2: Conduct evaluation

In [4]:
probe_variants = {}
for probe_name in other_probes:
    probe_variants[probe_name] = [LoadCachedDistances(name + '-' + probe_name) for name in
                                  ['KLD-PP:NS-W:TS-100-BINS',
                                   # 'KLD-PP:NS-W:SN-1000-BINS',
                                   'KLD-PP:NS-1000-BINS',
                                   'MMD-geo-sinkhorn-0.01', 'FID', 'KLD-PP:NN', 'VDNA-PP:NN-1000-BINS']]

In [5]:
all_distances = list(chain(*[probe_variants[probe_name] for probe_name in other_probes])) + [
    LoadCachedDistances('SEMANTIC')]
base_evals = get_evaluations(all_distances=all_distances,
                             aggregates=[AggregateStrategy.FIRST, AggregateStrategy.SECOND, AggregateStrategy.THIRD],
                             top_meta_metrics=['regret', 'rank', 'delta'], top_mode='avg', top_k=3)

Calculating...: 100%|██████████| 19608/19608 [01:05<00:00, 299.61it/s]


## Step 3: Perform bootstrapping

In [6]:
# settings, these are identical to the ones in Figure 5
map_probe_dist_to_orig = {d.name: d.name.replace('-MAE', '').replace('-CLIP', '').replace('-DINO', '') for d in
                          all_distances}
plot_order = [map_dist2printable[d] for d in
              ['SEMANTIC', 'VDNA-PP:NN-1000-BINS', 'FED', 'FID', 'KLD-PP:NN', 'MMD-geo-sinkhorn-0.01',
               'KLD-PP:NS-W:TS-100-BINS',
               # 'KLD-PP:NS-W:SN-1000-BINS',
               'KLD-PP:NS-1000-BINS']]
init_colors(exp=EXPERIMENTS, distance_measures=plot_order)
color_map = {dist: get_dist_measure_color(dist) for dist in plot_order}
symbol_map = {d: 'circle' for d in plot_order}
# symbol_map[map_dist2printable['KLD-PP:NS-W:SN-1000-BINS']] = 'hexagon'
symbol_map[map_dist2printable['KLD-PP:NS-1000-BINS']] = 'hexagram'
dash_map = {d: 'solid' for d in plot_order}
# dash_map[map_dist2printable['KLD-PP:NS-W:SN-1000-BINS']] = 'dash'
dash_map[map_dist2printable['KLD-PP:NS-1000-BINS']] = 'dot'

In [7]:
def get_probe_exp_meta_statistics(
        distances: List[str],
        meta_metric: str,
        exp: str,
) -> pd.DataFrame:
    rr_df = pd.DataFrame()
    sub_evals = base_evals[base_evals['distances'].isin(distances) & (base_evals['meta metric'] == meta_metric) & (
            base_evals['exp'] == exp)]
    rr_df['case'] = sub_evals['seed'] + sub_evals['metric'] + sub_evals['target'] + sub_evals['exp']
    rr_df['task'] = 'dummy'
    rr_df['algorithm'] = sub_evals['distances']
    rr_df['value'] = sub_evals['score']
    rr_df = rr_df.replace(map_probe_dist_to_orig)
    rr_df = rr_df.replace(map_dist2printable)
    bsr = BootstrapRanking(data=rr_df, use_median=False)
    return bsr.statistics

In [8]:
# we use the meta metrics explained in the paper (except for gain, as it is a less granular version of improve)
_meta_metrics = ['regret', 'rank', 'delta', 'weightedtau']
# map the internal names to the ones of the paper
mm_display_map = {'regret': 'Regret', 'rank': 'Percentile', 'delta': 'Improvement', 'weightedtau': 'Weightedtau'}
statistics_collector = []
for exp in EXPERIMENTS:
    for meta_metric in _meta_metrics:
        for probe in probe_variants:
            comp_distances = [d.name for d in probe_variants[probe]] + ['SEMANTIC']
            stats = get_probe_exp_meta_statistics(distances=comp_distances, meta_metric=meta_metric, exp=exp)
            statistics_collector.append(stats)

## Step 4: Merge backbones and plot

In [9]:
all_rank_series = [frame.set_index('algorithm')['mean_rank'] for frame in statistics_collector]
stats_df = pd.DataFrame(all_rank_series)
mean_stats_series = stats_df.mean()
mean_stats_series.name = 'mean'
std_stats_series = stats_df.std()
std_stats_series.name = 'std'
plot_df = pd.DataFrame([mean_stats_series, std_stats_series])
plot_df = plot_df.T.sort_values(by='mean', ascending=False)
plot_df.reset_index(inplace=True)

In [11]:
fig = px.scatter(plot_df, x='algorithm', y='mean', error_y='std', color_discrete_map=color_map, color='algorithm',
                 symbol='algorithm', symbol_map=symbol_map, size=[7] * len(plot_df),
                 labels={'algorithm': 'Task selector', 'mean': 'Mean rank'}
                 )
fig.update_layout(template='plotly', font_size=20, width=1200, height=500)
fig.for_each_yaxis(lambda a: a.update(tickvals=list(range(1, len(comp_distances) + 1)), autorange='reversed'))
fig.update_layout(showlegend=True)  # omit legend
fig.for_each_trace(lambda trace: trace.update(legendrank=1000 + plot_order.index(trace.name)) if trace.showlegend else 1000)
fig.update_xaxes({'tickvals': [''] * len(comp_distances)})  # remove labels
fig.update_layout(margin={'l': 150})
fig.add_annotation(
    showarrow=False,
    xanchor='center',
    xref='paper', yref='paper', text='a', x=-0.14, y=1.15, font=dict(size=40))
fig.write_image(FIG_PATH / 'fig_all_probes.png')
fig.write_image(FIG_PATH / 'fig_all_probes.pdf')
fig.show()

## Step 5: Merge generalization plots

(requires `fig_msd.png` from notebook `13_msd.ipynb`)

In [12]:
from PIL import Image
OVERLAY = 50
img_bottom = Image.open(FIG_PATH / 'fig_msd.png')
img_top = Image.open(FIG_PATH / 'fig_all_probes.png')
merged = Image.new('RGB', size=(img_top.width, img_top.height + img_bottom.height - OVERLAY))
merged.paste(im=img_bottom, box=(0, img_top.height - OVERLAY))
merged.paste(im=img_top, box=(0, 0))
merged.save(FIG_PATH / 'fig_7.png')
merged.save(FIG_PATH / 'fig_7.pdf')