In [1]:
cd ..

/Users/yanndubois/SSL-Risk-Decomposition


In [2]:
%config InlineBackend.figure_format = 'retina'
%matplotlib inline

import os
os.environ['KMP_DUPLICATE_LIB_OK']='True'

# Collect all results

In [3]:
import torch
import pandas as pd
from utils.plotting import *
from utils.collect_results import *
from utils.causal import *
import numpy as np
import warnings
import logging
import hubconf

 No module named 'clip'


In [None]:
is_subset=False
is_read_files=True
if is_subset:
    DATA = "imagenet-S0.01"
    subset = 0.01
    base_pred = 'sk_logistic_hypopt'
    threshold_kwargs=dict(threshold_bad_ifnew=70,
                           threshold_delta=100 # don't use
                          )
else:
    DATA = "imagenet"
    subset = None
    base_pred = 'torch_linear_hypopt'
    threshold_kwargs=dict()
    

metadata_df = hubconf.metadata_df(is_multiindex=False)

if is_read_files:
    results = load_all_results(pattern=f"**/data_{DATA}/**/*hypopt*/**/results_all.csv",
                               skip_ifneq=dict(data=DATA.lower(), pred=base_pred))
    check_missing(results, metadata_df)
    results = format_approx_results(results, metadata_df, f_replace_arch=f_replace_arch, subset=subset)
    results = make_risk_decomposition(results, traverse_path=["down","right","down"], is_print=True, subset=subset)
    results, metadata_df = clean_results(results, metadata_df, predictor=base_pred)
    validate_results(results, metadata_df, **threshold_kwargs)

    # filter out values that are suspiciously bad
    to_del = ["selav2_rn50_ep400_2x160_4x96"]
    to_keep = [i for i in results.index.get_level_values("enc") if i not in to_del]
    results = results.loc[to_keep]

    results.to_csv(f"notebooks/saved/results_{DATA}_{base_pred}.csv")
else:
    results=pd.read_csv(f"notebooks/saved/results_{DATA}_{base_pred}.csv",index_col=0)
    
    
metadata_df = metadata_df.loc[to_keep]
df = pd.concat([results,metadata_df], axis=1)
df = df.replace(dict(ssl_mode={"hierarchical contrastive": "hierarchical"}))

# coarsens many different parameters for cleaner plots
df = preprocess_features(df, round_dict={}, pow_dict={})

melted = melt(df)

In [None]:
#non_core_params = ['date_published','top1acc_in1k_official','projection_nparameters']
core_params = ['objective','version','architecture','patch_size','z_dim','z_layer',
               "epochs","batch_size", "optimizer", "learning_rate", "weight_decay", "scheduler",
               "pretraining_data",'img_size','year',"nviews",'finetuning_data','projection_arch']
minimal_params = ['objective','version','architecture','patch_size','z_dim','z_layer',"pretraining_data",'finetuning_data']
df_core = df.set_index(core_params)[results.columns]
df_minimal = df.set_index(minimal_params)[results.columns]
melted_onlyComp = melt(df, components=COMPONENTS_ONLY)


### Examples of models
Let's plot a few models as comparison. Only plot the ones that do not have any nans

! once you have all the results you might want to replace resnet -> rn and mmselfsup -> mmss to gain space

In [None]:
from IPython.display import IFrame,display

is_plot_radar = False
save_path="figures/all_radar.pdf"

if is_plot_radar:
    isna = df[COMPONENTS].isna().any(axis=1)
    plot_radar_grid(df, ncols=7, components=["agg_risk"]+COMPONENTS_ONLY, models=~isna, save_path=save_path, 
                    config_kwargs=dict(font_scale=1, is_despine=False))
else:
    out = IFrame(Path("..") / save_path, width=1000, height=500)
    display(out)

### Effect over time
Let's see how each components changed over time. We will aggregate per year as a rough estimate.

In [None]:
with plot_config(font_scale=1):
    sns.lineplot(data=melted, x="year", y="value", hue="component")

In [None]:
with plot_config(font_scale=1):
    curr_df = df.copy()[COMPONENTS_ONLY + ["year"]]
    curr_df.groupby("year").mean().plot.area()

We see that most of the gains over time have been hapening in usability, while the oters haven't changed much. We are now at the point were usability ~= probe gen ~= approx error. We also see that the generalization error is really not a problem at this scale. Note that it might be surprising to see that the probe doesn't generalize so well despite having hundred of thousands of training points and being only linear. 

Now let's consider models that have the same objectives but have changed over time (hparam tuning) as different version. 

In [None]:
with plot_config():
    df_version = melt(get_only_vary(df, ["version"], ["objective","version","architecture"]))

    sns.relplot(data=df_version, 
                x="version", 
                y="value", 
                hue="objective",
                style="objective",
               col="component",
               kind="line",
                facet_kws={'sharey': False}, 
                #err_style="bars"
               )

We see again that in the three cases usability has been the main driver of progress.


### Effect of type of method

Now let's consider the effect of the type of method

**Hypothesis** ISSL showed that contrastive, distillation and clustering all optimize nearly the right objective and can be made correct with minor differences, so we hypothesize that usability for them must be good.

In [None]:
with plot_config(font_scale=1):
    sns.barplot(data=melted_onlyComp, x="ssl_mode", y="value", hue="component")
    plt.tight_layout()

It seems that the ssl mode mostly has an impact on the usability component with transform and generative being very bad. 

## Naive plots
The following are naive plots where we plot everything without thinking about confounders. As a result they are likely meaningless for interpretaition but should rather be used as ways to understand our data.

### Effect of epochs

In [None]:
with plot_config(font_scale=1):
    curr_df = copy.deepcopy(df[COMPONENTS_ONLY + ["epochs"]])
    curr_df.groupby("epochs").mean().plot.area()

Doesn't show much

In [None]:
with plot_config():
    sns.relplot(data=melted, x="epochs", y="value", col="component", facet_kws={'sharey': False})

### Effect of batch size

In [None]:
with plot_config(font_scale=1):
    curr_df = copy.deepcopy(df[COMPONENTS_ONLY + ["batch_size"]])
    curr_df.groupby("batch_size").mean().plot.area(logx=True)

In [None]:
with plot_config():
    g=sns.relplot(data=melted, x="batch_size", y="value", col="component", facet_kws={'sharey': False})
    g.set(xscale="log")

### Effect of multicrops

In [None]:
with plot_config(font_scale=1):
    curr_df = copy.deepcopy(df[COMPONENTS_ONLY + ["nviews"]])
    curr_df.groupby("nviews").mean().plot.area()

We see that nviews seems to have a huge effect on usability but detrimental effect on probing generalization. That being said there is probably a big confounder as few models (and usually more recent) use multicrops.

### Effect of dimensionality

In [None]:
with plot_config(font_scale=1):
    curr_df = copy.deepcopy(df[COMPONENTS_ONLY + ["z_dim"]])
    curr_df.groupby("z_dim").mean().plot.area(logx=True)

In [None]:
with plot_config():
    g=sns.relplot(data=melted, x="z_dim", y="value", col="component", facet_kws={'sharey': False})
    g.set(xscale="log")

### Effect of architecture

#### N parameters

In [None]:
with plot_config(font_scale=1):
    curr_df = copy.deepcopy(df[COMPONENTS_ONLY + ["n_parameters"]])
    curr_df.groupby("n_parameters").mean().plot.area(logx=True)

In [None]:
with plot_config():
    g=sns.relplot(data=melted, x="n_parameters", y="value", col="component", facet_kws={'sharey': False})
    g.set(xscale="log")

#### ViT vs ResNet

In [None]:
with plot_config():
    g=sns.catplot(data=melted, x="family", y="value", col="component", facet_kws={'sharey': False}, kind="box")
#g.set(xscale="log")

Appproximation gap seems slightly better.

#### Effect of patch size

In [None]:
with plot_config():
    g=sns.catplot(data=melted[melted.family=="vit"], 
              x="patch_size", y="value", col="component", facet_kws={'sharey': False}, kind="box")
#g.set(xscale="log")

Larger patch size is worst approx

### Effect of projection 

In [None]:
with plot_config():
    g=sns.catplot(data=melted, x="projection_arch", y="value", col="component", facet_kws={'sharey': False}, kind="box")
#g.set(xscale="log")

### Effect of pretraining data

In [None]:
with plot_config():
    g=sns.catplot(data=melted, x="pretraining_data", y="value", col="component", facet_kws={'sharey': False}, kind="box")
    g.set_xticklabels(rotation=30)

### Effect of finetuning

In [None]:
curr_df = df.copy()
curr_df["is_finetuned"] = curr_df.finetuning_data.isna()

In [None]:
with plot_config():
    g=sns.catplot(data=melt(curr_df), x="is_finetuned", y="value", col="component", facet_kws={'sharey': False}, kind="box")
#g.set(xscale="log")