# RDL Big Paper Plots

*Licensed under the Apache License, Version 2.0.*

To run this in a public Colab, change the GitHub link: replace github.com with [githubtocolab.com](http://githubtocolab.com).

This colab loads raw measurements from disk and analyzes the results.

## Choosing optimal hyperparameters
We automatically detect hyperparameter sweeps by selecting fields that don't correspond to dataset metrics but that have more than one chosen value. We choose the hyperparameters that achieve the best according a given metric (see `dataset_metric`) after averaging over random seeds. For example, if the model is trained on CIFAR-10, we use CIFAR-10's validation loss.

## Plots
All plots report the performance of a given model according to its optimal hyperparameters chosen above. When there are runs with multiple seeds, we show the mean and standard deviation.

In [None]:
from typing import Dict
import itertools
import numpy as np
import pandas as pd
import pickle
import tensorflow as tf
from IPython import display

colab_utils = None

if colab_utils is None:
  !rm -rf uncertainty-baselines
  !git clone https://github.com/google/uncertainty-baselines.git
  !cp uncertainty-baselines/experimental/big_paper/colab_utils.py .
  import colab_utils

## Functions

In [None]:
#@title Choosing optimal hyperparameters

# The finetuning deterministic jobs use a fixed random seed but different
# upstream checkpoints, which themselves correspond to different random seeds.
# In this case, we thus marginalize over upstream checkpoints
# (`config.model_init`) rather than the random seed.

DATASET_METRIC = {
    'cifar10': 'val_loss',
    'cifar100': 'val_loss',
    'imagenet2012': 'val_loss',
    'imagenet21k': 'val_loss',
    'jft/entity:1.0.0': 'val_loss',
    'retina_country': 'in_domain_validation/auroc',
    'retina_severity': 'in_domain_validation/auroc',
}


def get_optimal_results(
    measurements: Dict[str, pd.DataFrame],
    dataset_metric: Dict[str, str] = DATASET_METRIC) -> pd.DataFrame:
  """Returns a dataframe, typically with one result per model type.

  A model type may have multiple results that will be averaged over when
  plotting (e.g., random seeds).

  Args:
    measurements: Dictionary of dataframes to obtain best results for.
    dataset_metric: Each dataset's metric to tune for, in the format
      `{dataset: metric}`.
  """
  results = []
  for k, v in measurements.items():
    marginalization_hparams = (colab_utils.random_seed_col(),)
    if k in ('Det', 'Det I21K', 'DE'):
      marginalization_hparams += ('config.model_init',)
    for ds in v[colab_utils.dataset_col()].unique():
      df = v[v[colab_utils.dataset_col()] == ds]
      try:
        results.append(
            colab_utils.get_tuned_results(
                df,
                tuning_metric=dataset_metric[ds],
                marginalization_hparams=marginalization_hparams))
      except KeyError:
        print(f'Could not get optimal results for {k}, {ds}.')
    print()
  return pd.concat(results)

In [None]:
#@title Obtain reliability score

SPLIT_METRICS = ['loss', 'prec@1', 'ece', 'calib_auc']
IND_METRICS = [f'test_{m}' for m in SPLIT_METRICS]
FEWSHOT_DATASETS = ['imagenet', 'pets', 'birds', 'col_hist', 'cifar100', 'caltech', 'cars', 'dtd', 'uc_merced']
FEWSHOT_METRICS = [
    f'z/{ds}_{f}shot' for (ds, f) in itertools.product(
        FEWSHOT_DATASETS,
        [1, 5, 10, 25])
]
OOD_METRICS = [
    f'{ds}_{m}' for (ds, m) in itertools.product(
        ['cifar_10h', 'imagenet_real'],
        ['loss', 'prec@1', 'ece', 'calib_auc'])
]
OOD_DETECTION_METRICS = [
    f'ood_{ds}_{method}_auroc'
    for (ds, method) in itertools.product(
        ['cifar10', 'cifar100', 'svhn_cropped', 'places365_small'],
        # We use just MSP following Jie's recommendation.
        ['msp'])
        # ['entropy', 'maha', 'msp', 'rmaha'])
]
COMPUTE_METRICS = ['exaflops', 'tpu_days', 'gflops', 'ms_step']
RETINA_METRICS = ['accuracy', 'negative_log_likelihood', 'ece', 'retention_auroc_auc', 'retention_accuracy_auc']
RETINA_METRICS = [f'{prefix}/{metric}' for metric in RETINA_METRICS for prefix in ['in_domain_test', 'ood_test']]
METRICS = IND_METRICS + FEWSHOT_METRICS + OOD_METRICS + OOD_DETECTION_METRICS + COMPUTE_METRICS + RETINA_METRICS
CATEGORIES = {
    'prediction': [
        'test_loss',
        'test_prec@1',
        'cifar_10h_loss',
        'cifar_10h_prec@1',
        'imagenet_real_loss',
        'imagenet_real_prec@1',
        # RETINA
        'in_domain_test/negative_log_likelihood',
        'ood_test/negative_log_likelihood',
        'in_domain_test/accuracy',
        'ood_test/accuracy',
    ],
    'uncertainty': [
        'cifar_10h_calib_auc',
        'cifar_10h_ece',
        'imagenet_real_calib_auc',
        'imagenet_real_ece',
        'ood_cifar100_msp_auroc',
        'ood_cifar10_msp_auroc',
        'ood_places365_small_msp_auroc',
        'ood_svhn_cropped_msp_auroc',
        'test_calib_auc',
        'test_ece',
        # RETINA
        'in_domain_test/ece',
        'ood_test/ece',
        'in_domain_test/retention_auroc_auc',
        'ood_test/retention_auroc_auc',
        'in_domain_test/retention_accuracy_auc',
        'ood_test/retention_accuracy_auc'
    ],
    'adaptation': [
        '10shot_prec@1',
        '25shot_prec@1',
        '5shot_prec@1',
    ],
}
RETINA_NLL_METRICS = [
  metric for metric in RETINA_METRICS if 'negative_log_likelihood' in metric]

DATASET_CLASSES = {
      'cifar10': 10,
      'cifar100': 100,
      'imagenet2012': 1000,
      'imagenet21k': 21841,
      'jft/entity:1.0.0': 18291,
      'retina_country': 2,
      'retina_severity': 2,
}

def preprocess(df,
               split_metrics=SPLIT_METRICS,
               metrics=METRICS,
               compute_metrics=COMPUTE_METRICS,
               fewshot_datasets=FEWSHOT_DATASETS):
  df = df.copy()
  df = df.groupby(['model', 'config.dataset']).agg('mean').reset_index()
  # Set JFT/I21K upstream #s to the test set reporting since we use them that
  # way.
  for m in split_metrics:
    df.loc[df['config.dataset'] == 'jft/entity:1.0.0', f'test_{m}'] = df.loc[
        df['config.dataset'] == 'jft/entity:1.0.0', f'val_{m}']
    df.loc[df['config.dataset'] == 'imagenet21k', f'test_{m}'] = df.loc[
        df['config.dataset'] == 'imagenet21k', f'val_{m}']

  cols = ['model', 'config.dataset'] + metrics
  df = df[cols].copy()
  df = df.pivot(index='model', columns='config.dataset', values=metrics)

  # Drop columns with all NaNs, e.g., ECE for JFT. They aren't measured.
  df = df.dropna(axis=1, how='all')

  # Set few-shot imagenet metrics under a distinct dataset so later, we can
  # aggregate over few-shot metrics while excluding their original
  # config.dataset (the upstream dataset).
  for ds in fewshot_datasets:
    for f in [1, 5, 10, 25]:
      df[f'{f}shot_prec@1', f'few-shot {ds}'] = df[f'z/{ds}_{f}shot'].mean(axis=1)
      del df[f'z/{ds}_{f}shot']
  # Do same for compute and only keep upstream compute metrics.
  for metric in compute_metrics:
    for ds in df[metric]:
      if ds == 'imagenet21k':
        df[metric, 'compute'] = df[metric, ds]
      elif ds == 'jft/entity:1.0.0':
        df[metric, 'compute'] = np.where(df[metric, 'compute'].isnull(), df[metric, ds], df[metric, 'compute'])
      del df[metric, ds]
  return df

def compute_score(df, datasets=None, categories=CATEGORIES):
  """Compute aggregate score across metrics and per-category scores."""
  df = df.copy()

  # Scale all metrics in range [0.0, 1.0], and where higher is better.
  for column in df.columns:
    metric, dataset = column
    if 'ece' in metric:
      df[column] = 1. - df[column]
    if 'dataset' == 'compute':
      del df[column]
  # Remove 1-shot for now as its #s are unreliable due to high variance.
  del df['1shot_prec@1']

  for metric, dataset in df[['test_loss', 'cifar_10h_loss', 'imagenet_real_loss'] + RETINA_NLL_METRICS]:
    # Rescale NLL under its bound [0.0, uniform entropy]. Technically I21K &
    # JFT's uniform entropy should be computed on multiclass sigmoid NLL, but
    # unlike categorical uniform, multiclass sigmoid uniform is so large it's
    # meaningless as a bound.
    num_classes = DATASET_CLASSES[dataset]
    p = 1./num_classes
    max_value = -num_classes * p * np.log(p)
    df.loc[:, (metric, dataset)] = 1. - df[metric][dataset] / max_value

  # Flatten multiindexes.
  df.columns = ['_'.join(col).strip() for col in df.columns.values]
  if datasets is not None:
    metrics = [m for m in df.columns if any(d == m.split('_')[-1] for d in datasets)]
    df = df[metrics]
  # Compute the score only for models that have filled in all metrics.
  subset_df = df.dropna(how='any')
  score = subset_df.mean(axis=1) * 100.0
  df_scores = score.sort_values(ascending=False).to_frame(name='score')
  for key, value in categories.items():
    metrics = [m for m in df.columns if '_'.join(m.split('_')[:-1]) in value]
    subset_df = df[metrics]
    subset_df = subset_df.dropna(how='any')
    score = subset_df.mean(axis=1) * 100.0
    df_scores[f'score_{key}'] = score

  return df_scores

def compute_relative_score_and_ranks(
    df, datasets=None, categories=CATEGORIES, baseline_model='Det'):
  """Compute aggregate score across metrics and per-category scores."""
  df = df.copy()

  # Scale all metrics in range [0.0, 1.0], and where higher is better.
  for column in df.columns:
    metric, dataset = column
    if 'ece' in metric:
      df[column] = 1. - df[column]
    if 'dataset' == 'compute':
      del df[column]
  # Remove 1-shot for now as its #s are unreliable due to high variance.
  del df['1shot_prec@1']

  for metric, dataset in df[['test_loss', 'cifar_10h_loss', 'imagenet_real_loss'] + RETINA_NLL_METRICS]:
    # Rescale NLL under its bound [0.0, uniform entropy]. Technically I21K &
    # JFT's uniform entropy should be computed on multiclass sigmoid NLL, but
    # unlike categorical uniform, multiclass sigmoid uniform is so large it's
    # meaningless as a bound.
    num_classes = DATASET_CLASSES[dataset]
    p = 1./num_classes
    max_value = -num_classes * p * np.log(p)
    df.loc[:, (metric, dataset)] = 1. - df[metric][dataset]# / max_value

  # Flatten multiindexes.
  df.columns = ['_'.join(col).strip() for col in df.columns.values]
  if datasets is not None:
    metrics = [m for m in df.columns if any(d == m.split('_')[-1] for d in datasets)]
    df = df[metrics]
  # Compute the score only for models that have filled in all metrics.
  subset_df = df.dropna(how='any')
  baseline = subset_df.loc[baseline_model, :].to_numpy()[None, :]
  score = subset_df.div(baseline, axis=1).mean(axis=1)
  df_scores = score.sort_values(ascending=False).to_frame(name='rel_score')
  df_ranks = subset_df.rank(axis=0, ascending=False)

  ranks_by_category = {}
  for key, value in categories.items():
    metrics = [m for m in df.columns if '_'.join(m.split('_')[:-1]) in value]
    subset_df = df[metrics]
    subset_df = subset_df.dropna(how='any')
    baseline = subset_df.loc['Det', :].to_numpy()[None, :]
    subset_df = subset_df.div(baseline, axis=1)
    rank_df = subset_df.rank(axis=0, ascending=False)
    ranks_by_category[key] = rank_df
    subset_np = subset_df.to_numpy()
    winners = np.argmax(subset_np, axis=0).astype(np.int64)
    wincount = np.bincount(winners.flatten(),
                           minlength=subset_np.shape[0])
    win_df = pd.DataFrame(data=wincount[:, None],
                          index=subset_df.index)
    score = subset_df.mean(axis=1)
    df_scores[f'rel_score_{key}'] = score
    df_scores[f'#_best_{key}'] = win_df
    df_scores[f'mean_rank_{key}'] = rank_df.mean(axis=1)

  return df_scores, df_ranks, ranks_by_category

def pprint(df, models=None, exclude_models=None):
  """Pretty print dataframe.

  Args:
    df: Dataframe.
    models: Optional list of models to only show. Useful for comparing specific
      models to see which performs better (highlighted cells).
    exclude_models: Optional list of models to exclude.
  """
  def _rename(m):
    m = m.replace('cifar_10h', 'cifar10h')
    m = m.replace('places365_small', 'places365')
    m = m.replace('_', ' ')
    m = m.replace('cropped ', '')
    m = m.replace('ood', '')
    m = m.replace('ece', 'ECE')
    m = m.replace('auc', 'AUC')
    m = m.replace('auroc', 'AUROC')
    m = m.replace('loss', 'NLL')
    m = m.replace('negative log likelih', 'NLL')
    return m
  def _formatter(metric):
    if any(x in metric for x in ['AUROC', 'AUC']):
      return '{:.2f}'.format
    elif any(x in metric for x in ['prec', 'ECE', 'accuracy']):
      return lambda x: '{:.1f}%'.format(x * 100)
    elif any(x in metric for x in ['score', 'exaflops', 'tpu days', 'gflops', 'ms step']):
      return lambda x: '{:.1f}'.format(x)
    elif 'NLL' in metric:
      return '{:.3f}'.format
    else:
      return lambda x: x
  def _highlight(data, color='#90EE90'):
    attr = 'background-color: {}'.format(color)
    data = data.replace('%','', regex=True).astype(float)
    if any(x in data.name[1] for x in ['NLL', 'ECE']):
      is_best = data == data.min()
    elif any(x in data.name[1] for x in ['exaflops', 'tpu days', 'gflops', 'ms step']):
      is_best = data == 'asdf'
    else:
      is_best = data == data.max()
    return [attr if v else '' for v in is_best]

  df = df.copy()
  df = df.rename(columns=_rename)
  for c in df:
    df[c] = df[c].apply(_formatter(c[0]))

  # Swap order of column's multiindex to be dataset first.
  df.columns = df.columns.swaplevel(0, 1)
  df = df.sort_index(axis=1, level=0)

  df = df.T
  if models is not None:
    df = df[[c for c in df.columns if c in models]]
  elif exclude_models is not None:
    df = df[[c for c in df.columns if c not in exclude_models]]

  return display.display(df.style.apply(_highlight, axis=1))

In [None]:
#@title RETINA
REBUILD_RETINA_RESULTS_CACHE = False

if REBUILD_RETINA_RESULTS_CACHE:
  import os
  os.system('pip install wandb')
  import wandb

# TODO(nband): add grid search results (currently random search).
RETINA_SHIFT_AND_UQ_METHOD_TO_WANDB = {
  ('aptos', 'deterministic'): 'vit32-finetune-aptos-deterministic-focused-3',
  ('aptos', 'batchensemble'): 'vit32-finetune-aptos-batchensemble',
  ('severity', 'deterministic'): 'vit32-finetune-severity-deterministic',
  ('severity', 'batchensemble'): 'vit32-finetune-severity-batchensemble-focused-1'
}

RETINA_SHIFTS = ['aptos', 'severity']
RETINA_UQ_METHODS = ['deterministic', 'batchensemble']
RETINA_UQ_METHOD_TO_DF_NAME = {
    'deterministic': 'Det I21K',
    'batchensemble': 'BE L/32 (I21K)'
}

RETINA_SHIFT_TO_METRICS = {
  'aptos': [
    # In-Domain
    'in_domain_test.in_domain_test/accuracy',
    'in_domain_test.in_domain_test/negative_log_likelihood',
    'in_domain_test.in_domain_test/ece',
    'in_domain_test.in_domain_test/retention_auroc_auc',
    # OOD
    'ood_test.ood_test/accuracy',
    'ood_test.ood_test/negative_log_likelihood',
    'ood_test.ood_test/ece',
    'ood_test.ood_test/retention_auroc_auc'
  ],
  'severity': [
    # In-Domain
    'in_domain_test.in_domain_test/accuracy',
    'in_domain_test.in_domain_test/negative_log_likelihood',
    'in_domain_test.in_domain_test/ece',
    'in_domain_test.in_domain_test/retention_auroc_auc',
    # OOD
    'ood_test.ood_test/accuracy',
    'ood_test.ood_test/negative_log_likelihood',
    'ood_test.ood_test/ece',
    'ood_test.ood_test/retention_accuracy_auc'
  ]
}
RETINA_MODEL_SELECTION_METRIC = 'in_domain_validation.in_domain_validation/auroc'

# Split RETINA results into the two distributional shifts: Country Shift and
# Severity Shift.

SHIFT_MAP = {'aptos': 'country', 'severity': 'severity'}


def select_top_model_from_project(project_name):
  api = wandb.Api(timeout=100000000)
  runs = api.runs(project_name)
  print(f'Retrieved run results from Weights & Biases project {project_name}.')
  sweep_history_df = []

  # Get all full histories
  for run in runs:
    run_history_df = pd.DataFrame(run._full_history())

    # Add run name
    run_history_df['run_name'] = run.name
    sweep_history_df.append(run_history_df)

  sweep_history_df = pd.concat(sweep_history_df)
  sweep_history_df.reset_index(inplace=True)

  # Best performing step of the best performing model
  top_idx = sweep_history_df[RETINA_MODEL_SELECTION_METRIC].idxmax()
  return sweep_history_df.iloc[top_idx]


def get_retina_i21k_results_df():
  all_results_df = []
  for shift in RETINA_SHIFTS:
    for uq_method in RETINA_UQ_METHODS:
      print(f'Retrieving results from shift {shift}, '
            f'uncertainty quantification method {uq_method}.')
      wandb_project = RETINA_SHIFT_AND_UQ_METHOD_TO_WANDB[(shift, uq_method)]
      model_results = select_top_model_from_project(wandb_project)
      result_df = model_results.to_frame().T
      result_df['shift'] = shift
      result_df['uq_method'] = uq_method
      all_results_df.append(result_df)

  return pd.concat(all_results_df)


def add_retina_i21k_results(retina_results_df, preprocessed_df, shift_map=SHIFT_MAP):
  for shift in RETINA_SHIFTS:
    for uq_method in RETINA_UQ_METHODS:
      print(f'Adding results from shift {shift}, '
            f'uncertainty quantification method {uq_method}.')
      model_results = retina_results_df[
        (retina_results_df['shift'] == shift) &
        (retina_results_df['uq_method'] == uq_method)]
      n_results = len(model_results)
      assert n_results == 1, f'Found {n_results} model results, expected 1.'
      model_results = model_results.iloc[0]
      metrics = RETINA_SHIFT_TO_METRICS[shift]
      for metric in metrics:
        df_metric_name = metric.split('.')[1]
        per_metric_result = model_results[metric]
        shift_df_name = shift_map[shift]
        metric_shift_series = preprocessed_df[(
          df_metric_name, f'retina_{shift_df_name}')]
        metric_shift_series[
          RETINA_UQ_METHOD_TO_DF_NAME[uq_method]] = per_metric_result
        preprocessed_df[
          (df_metric_name, f'retina_{shift_df_name}')] = metric_shift_series

  return preprocessed_df

if REBUILD_RETINA_RESULTS_CACHE:
  # Retrieve RETINA I21K results from Weights & Biases
  retina_i21k_results_df = get_retina_i21k_results_df()

  # Store RETINA results in gs bucket
  retina_ub_gs_file_path = 'gs://retina-i21k-results-df/retina-i21k-results.tsv'
  with tf.io.gfile.GFile(retina_ub_gs_file_path, 'w') as f:
    retina_i21k_results_df.to_csv(f, sep='\t', index=None)


def add_distribution_shift_to_retina_ds_name(row):
  dataset = str(row['config.dataset'])
  if dataset == 'retina':
    shift = SHIFT_MAP[str(row['config.distribution_shift'])]
    row['config.dataset'] = f'{dataset}_{shift}'

  return row

def split_retina_results_by_shifts(raw_dict):
  for model in raw_dict.keys():
    raw_model_df = raw_dict[model]
    if not len(raw_model_df[raw_model_df['config.dataset'] == 'retina']):
        continue

    print(f'Splitting RETINA results for model {model} by distribution shift.')

    raw_model_df = raw_model_df.apply(
        add_distribution_shift_to_retina_ds_name, axis='columns')
    raw_dict[model] = raw_model_df

  return raw_dict

## Load and preprocess measurements

In [None]:
from google.colab import auth
auth.authenticate_user()

project_id = 'marginalization-external-xgcp'
!gcloud config set project {project_id}

measurements_path = '/tmp/big-paper-raw-measurements.pkl'
!gsutil cp gs://ub-checkpoints/big-paper-raw-measurements.pkl {measurements_path}

retina_path = '/tmp/retina-i21k-results.tsv'
!gsutil cp gs://retina-i21k-results-df/retina-i21k-results.tsv {retina_path}

In [None]:
with tf.io.gfile.GFile(measurements_path, 'rb') as f:
  raw_measurements = pickle.load(f)

with tf.io.gfile.GFile(retina_path, 'r') as f:
  retina_i21k_results_df = pd.read_csv(f, sep='\t')

In [None]:
raw_measurements = split_retina_results_by_shifts(raw_measurements)

measurements = get_optimal_results(raw_measurements)

df = preprocess(measurements)
df = add_retina_i21k_results(retina_results_df=retina_i21k_results_df,
                             preprocessed_df=df)

## Compute reliability score and generate table

In [None]:
datasets = [
    'cifar10',
    'cifar100',
    'imagenet2012',
    'retina_country',
    'retina_severity',
]
datasets += [f'few-shot {d}' for d in FEWSHOT_DATASETS]
scores = compute_score(df, datasets=datasets)
display.display(scores)

In [None]:
df_with_scores = df.copy()
for column in scores.columns:
  df_with_scores[column] = scores[column]

pprint(
    df_with_scores,
    # models=['BE L/32', 'Det'],
    # exclude_models=['DE', 'Det->DE'],
)

In [None]:
# Show a subset of the table's metrics + models
metrics = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation',
           'exaflops', 'test_loss', 'tpu_days']
models = ['BE L/32', 'Det', 'GP', 'Het', 'BE L/32 (I21K)', 'Det I21K',
          'BE->BE+Het']
pprint(df_with_scores.loc[models][metrics].rename(
    columns={'compute': 'z/compute'}))

## Plot reliability score

In [None]:
import colabtools.fileedit
import matplotlib
import matplotlib.pyplot as plt
%matplotlib inline
%config InlineBackend.figure_format = 'retina'
import seaborn as sns
sns.reset_orig()
matplotlib.rcParams['figure.dpi'] = 1000
matplotlib.rcParams['lines.linewidth'] = 1.25
# sns.set_style("whitegrid")
sns.set()

In [None]:
def pareto_plot(df, x, y, ax, filename=None, **kwargs):
  def is_on_pareto_front(p, points, higher_is_better):
    if higher_is_better:
      return len([
          point for point in points if point[0] <= p[0] and point[1] > p[1]
      ]) == 0
    else:
      return len([
          point for point in points if point[0] <= p[0] and point[1] < p[1]
      ]) == 0
  def get_pareto_points(x, y, higher_is_better=True):
    points = list(zip(x, y))
    frontier = [
        p for p in points if is_on_pareto_front(p, points, higher_is_better)
    ]
    return sorted(frontier, key=lambda x: x[0])
  for model, point in df.iterrows():
    ann = ax.annotate(
        '  ' + model,
        xy=(point[x], point[y]),
        ha='left',
        va='bottom',
  )
  sns.scatterplot(x=df[x], y=df[y], ax=ax)
  pareto_frontier = get_pareto_points(df[x], df[y])
  xx, yy = zip(*pareto_frontier)
  sns.lineplot(x=xx, y=yy, linestyle='--', ax=ax)
  ax.set(xscale='log', **kwargs)
  if filename is not None:
    plt.tight_layout()
    plt.savefig(filename)
    colabtools.fileedit.download_file(filename)

fig, ax = plt.subplots(figsize=(10.0, 5.0))
pareto_plot(
    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],
    ax=ax,
    y='score',
    x=('tpu_days', 'compute'),
    xlabel='Compute (TPUv3 core days)',
    ylabel='Reliability Score',
    filename='reliability.png',
)

fig, axes = plt.subplots(1, 3, figsize=(3.5 * 3, 3.5))
pareto_plot(
    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],
    ax=axes[0],
    y='score_prediction',
    x=('tpu_days', 'compute'),
    xlabel=None,
    ylabel=None,
    title='Reliability Score (Prediction)',
)
pareto_plot(
    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],
    ax=axes[1],
    y='score_uncertainty',
    x=('tpu_days', 'compute'),
    xlabel=None,
    ylabel=None,
    title='Reliability Score (Uncertainty)',
)
pareto_plot(
    df_with_scores[[x.startswith('BE') for x in df_with_scores.index.values]],
    ax=axes[2],
    y='score_adaptation',
    x=('tpu_days', 'compute'),
    xlabel=None,
    ylabel=None,
    title='Reliability Score (Adaptation)',
)
filename = 'reliability_components.png'
plt.tight_layout()
plt.savefig(filename)
colabtools.fileedit.download_file(filename)

## Analyze correlation of metrics

In [None]:
temp_df = preprocess(
    measurements,
    metrics=METRICS + ['training_loss', 'training_prec@1'])
datasets = [
    'cifar10',
    'cifar100',
    'imagenet2012',
]
datasets += [f'few-shot {d}' for d in FEWSHOT_DATASETS]
temp_scores = compute_score(temp_df, datasets=datasets)
for column in temp_scores.columns:
  temp_df[column] = temp_scores[column]

# scores correlation matrix
columns = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation']
corr_matrix = temp_df[columns]
corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]
corr_matrix = corr_matrix.corr()
display.display(corr_matrix)

# upstream test metrics
metrics = ['score', 'score_prediction', 'score_uncertainty', 'score_adaptation']
corr_matrix = temp_df.corr()[['test_loss', 'test_prec@1']].T.xs(
    'jft/entity:1.0.0', level='config.dataset')
corr_matrix = corr_matrix[metrics]
corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]
display.display(corr_matrix)

# imagenet 10-shot. It doesn't correlate well with reliability, mostly due to
# it not correlating well surprisingly on other few-shot tasks.
corr_matrix = temp_df.corr()[['10shot_prec@1']].T.xs(
    'few-shot imagenet', level='config.dataset')
corr_matrix = corr_matrix[metrics]
corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]
display.display(corr_matrix)

# downstream training loss. The correlation is not nearly as tight as on
# upstream.
corr_matrix = temp_df.corr()[['training_loss']].T
corr_matrix = corr_matrix[metrics + ['test_loss']]
corr_matrix = corr_matrix.drop(index=('training_loss', 'retina_country'))
corr_matrix = corr_matrix.drop(index=('training_loss', 'retina_severity'))
corr_matrix = corr_matrix.drop(index=('training_loss', 'imagenet21k'))
corr_matrix = corr_matrix.drop(columns=('test_loss', 'imagenet21k'))
# Display test loss only for training loss' same downstream dataset. Looking at
# cifar10's train loss correlation with I1K's test loss isn't meaningful.
test_loss = pd.Series(np.diag(corr_matrix['test_loss']),
                      index=corr_matrix['test_loss'].index)
corr_matrix = corr_matrix.drop(columns='test_loss')
corr_matrix['test_loss'] = test_loss
corr_matrix.columns = [''.join(col) for col in corr_matrix.columns.values]
display.display(corr_matrix)

# Similar to old plot in go/rdl-big-meeting, even generalization gap decreases.
# And downstream is not very indicative, but upstream is.
temp_df2 = temp_df.copy()
for d in temp_df2['test_loss'].columns:
  temp_df2['reg_loss', d] = temp_df2['test_loss', d] - temp_df2['training_loss', d]

corr_matrix = temp_df2.corr()[['reg_loss']].T
corr_matrix = corr_matrix[metrics + ['training_loss']]
corr_matrix = corr_matrix.drop(index=('reg_loss', 'imagenet21k'))
display.display(corr_matrix)

In [None]:
corr_matrix = temp_df.corr()[['test_loss', 'test_prec@1', 'training_loss']].T.xs('jft/entity:1.0.0', level='config.dataset')

# Rename certain task metrics to be under their generic metric name. This way,
# we can average values across that metric.
corr_matrix.columns = corr_matrix.columns.values
corr_matrix.columns = pd.MultiIndex.from_tuples(corr_matrix.rename(columns={
    ('imagenet_real_calib_auc', 'imagenet2012'): ('test_calib_auc', 'imagenet_real'),
    ('imagenet_real_ece', 'imagenet2012'): ('test_ece', 'imagenet_real'),
    ('imagenet_real_loss', 'imagenet2012'): ('test_loss', 'imagenet_real'),
    ('imagenet_real_prec@1', 'imagenet2012'): ('test_prec@1', 'imagenet_real'),
    ('cifar_10h_calib_auc', 'cifar10'): ('test_calib_auc', 'cifar_10h'),
    ('cifar_10h_ece', 'cifar10'): ('test_ece', 'cifar_10h'),
    ('cifar_10h_loss', 'cifar10'): ('test_loss', 'cifar_10h'),
    ('cifar_10h_prec@1', 'cifar10'): ('test_prec@1', 'cifar_10h'),
    ('ood_cifar100_msp_auroc', 'cifar10'): ('msp_auroc', 'cifar10->cifar100'),
    ('ood_cifar10_msp_auroc', 'cifar100'): ('msp_auroc', 'cifar100->cifar10'),
    ('ood_places365_small_msp_auroc', 'imagenet2012'): ('msp_auroc', 'imagenet2012->places365'),
    ('ood_svhn_cropped_msp_auroc', 'cifar10'): ('msp_auroc', 'cifar10->svhn'),
    ('ood_svhn_cropped_msp_auroc', 'cifar100'): ('msp_auroc', 'cifar100->svhn'),
}))

corr_matrix = corr_matrix.sort_index(axis=1)
corr_matrix = corr_matrix.mean(level=0, axis='columns')
corr_matrix = abs(corr_matrix)
corr_matrix = corr_matrix.reindex(
    corr_matrix.mean().sort_values().index, axis=1)
for metric in corr_matrix.columns:
  if metric in COMPUTE_METRICS or metric.startswith('score'):
    del corr_matrix[metric]
corr_matrix = corr_matrix.T.reset_index()

fig, ax = plt.subplots(figsize=(20.0, 5.0))
sns.barplot(x='index', y='test_loss', data=corr_matrix)
ax.set(xlabel=None)
ax.set(ylabel=r'$\rho(\cdot,$ test_loss)')

filename = 'correlation.png'
plt.tight_layout()
plt.savefig(filename)
colabtools.fileedit.download_file(filename)

## Plot Relative Score and Rankings

In [None]:
datasets = [
    'cifar10',
    'cifar100',
    'imagenet2012',
]
datasets += [f'few-shot {d}' for d in FEWSHOT_DATASETS]
rel_scores, ranks, ranks_by_category = compute_relative_score_and_ranks(df, datasets=datasets)
print("Average relative score and ranks across categories")
display.display(rel_scores)

# Plot rank distribution
ax = sns.violinplot(data=ranks.T)
ax.set_xticklabels(ax.get_xticklabels(),rotation = 45)
ax.set_ylabel('Ranking')
print("==" * 50)
print("Rankings")
display.display(ranks)

for key in CATEGORIES:
  plt.figure()
  ax = sns.violinplot(data=ranks_by_category[key].T)
  ax.set_xticklabels(ax.get_xticklabels(),rotation = 45)
  ax.set_ylabel('Ranking - %s' % key)

# Plotting helpers

In [None]:
#@title Bar plots
def plot_in_distribution(df, train_dataset, split):
  df = df[df['config.dataset'] == train_dataset].copy()
  metrics = [f'{split}_{m}' for m in ['loss', 'prec@1', 'ece', 'calib_auc']]
  df = df[['model'] + metrics].melt(
      id_vars='model', var_name='metric', value_name='value')
  sns.catplot(
      col='metric', data=df, y='value', kind='bar', sharey=False, x='model')

def plot_ood(df, train_dataset):
  df = df[df['config.dataset'] == train_dataset].copy()
  if train_dataset == 'imagenet2012':
    datasets = {'places365_small'}
    metrics = ['msp', 'entropy', 'mlogit']
  else:
    datasets = set(['svhn_cropped', 'cifar100', 'cifar10']) - {train_dataset}
    metrics = ['msp', 'entropy', 'mlogit', 'maha', 'rmaha']
  cols = [
      f'ood_{ds}_{m}_auroc' for (ds, m) in itertools.product(datasets, metrics)
  ]
  cols = list(set(cols).intersection(df.columns))
  df = df[['model'] + cols]
  df = df.melt(id_vars='model', var_name='metric', value_name='AUROC')
  df['dataset'] = df['metric'].apply(lambda x: x.split('_')[1])
  df['metric'] = df['metric'].apply(lambda x: x.split('_')[-2])

  sns.catplot(
      data=df, x='metric', y='AUROC', hue='model', kind='bar', col='dataset')
  plt.ylim((0.5, 1))


def plot_corrupted(df, train_dataset):
  df = df[df['config.dataset'] == train_dataset].copy()
  ds = 'imagenet_real' if train_dataset == 'imagenet2012' else 'cifar_10h'
  metrics = [f'{ds}_{m}' for m in ['loss', 'prec@1', 'ece', 'calib_auc']]
  df = df[['model'] + metrics].melt(
      id_vars='model', var_name='metric', value_name='value')
  sns.catplot(
      col='metric', data=df, y='value', kind='bar', sharey=False, x='model')

In [None]:
#@title Pareto plots

def is_on_pareto_front(p, points, higher_is_better):
  if higher_is_better:
    return len([
        point for point in points if point[0] <= p[0] and point[1] > p[1]
    ]) == 0
  else:
    return len([
        point for point in points if point[0] <= p[0] and point[1] < p[1]
    ]) == 0


def get_pareto_points(x, y, higher_is_better):
  points = list(zip(x, y))
  frontier = [
      p for p in points if is_on_pareto_front(p, points, higher_is_better)
  ]
  return sorted(frontier, key=lambda x: x[0])


def plot_fn(data, x, y, **kws):
  ax = plt.gca()
  sns.scatterplot(data=data, x=x, y=y, hue='model')
  for _, point in data.iterrows():
    ann = ax.annotate(
        '  ' + point['model'],
        xy=(point[x], point[y]),
        ha='left',
        va='bottom',
  )

  metric = data['metric'].iloc[0]
  higher_is_better = 'prec' in metric or 'auc' in metric
  pareto_frontier = get_pareto_points(
      data[x], data[y], higher_is_better=higher_is_better)
  xx, yy = zip(*pareto_frontier)
  sns.lineplot(x=xx, y=yy, linestyle='--')

def pareto_plot(df, metrics, train_dataset=None,
                xmetric='num_params', xlabel='Log # Params'):
  df = df[df['config.dataset'] == train_dataset].copy()
  df = df.groupby(['model', 'config.dataset', xmetric]
                  )[metrics].apply(np.mean).reset_index()
  df = df.melt(
      id_vars=['model', 'config.dataset', xmetric],
      var_name='metric',
      value_name='value')

  g = sns.FacetGrid(data=df, col='metric', sharey=False, size=5)
  g.map_dataframe(plot_fn, x=xmetric, y='value')
  g.set_xlabels(xlabel)
  g.set(xscale='log')

# Results

In [None]:
#@title Upstream JFT
df = measurements.copy()
df = df[df['config.dataset'] == 'jft/entity:1.0.0']
df = df[['model', 'val_loss', 'val_prec@1', 'a/imagenet_10shot']].melt(
    id_vars='model', var_name='metric', value_name='value')
sns.catplot(
    col='metric', data=df, x='model', y='value', kind='bar', sharey=False)
g = pareto_plot(
    measurements,
    train_dataset='jft/entity:1.0.0',
    metrics=['val_loss', 'val_prec@1', 'a/imagenet_10shot'],
)
g = pareto_plot(
    measurements,
    train_dataset='jft/entity:1.0.0',
    metrics=['val_loss', 'val_prec@1', 'a/imagenet_10shot'],
    xmetric='tpu_days',
    xlabel='Compute (TPUv3 core days)',
)

## Cifar 10

In [None]:
#@title In-distribution
plot_in_distribution(measurements, train_dataset='cifar10', split='test')
g = pareto_plot(
    measurements,
    train_dataset='cifar10',
    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])

In [None]:
#@title Cifar10h
plot_corrupted(measurements, train_dataset='cifar10')
g = pareto_plot(
    measurements,
    train_dataset='cifar10',
    metrics=['cifar_10h_loss', 'cifar_10h_prec@1', 'cifar_10h_ece', 'cifar_10h_calib_auc'])

In [None]:
#@title OOD
plot_ood(measurements, train_dataset='cifar10')

## Cifar100

In [None]:
#@title In-distribution
plot_in_distribution(measurements, train_dataset='cifar100', split='test')
g = pareto_plot(
    measurements,
    train_dataset='cifar100',
    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])

In [None]:
#@title OOD
plot_ood(measurements, train_dataset='cifar100')

## Imagenet

In [None]:
#@title In-distribution
plot_in_distribution(measurements, train_dataset='imagenet2012', split='test')
g = pareto_plot(
    measurements,
    train_dataset='imagenet2012',
    metrics=['test_loss', 'test_prec@1', 'test_ece', 'test_calib_auc'])

In [None]:
#@title Imagenet Real
plot_corrupted(measurements, train_dataset='imagenet2012')
g = pareto_plot(
    measurements,
    train_dataset='imagenet2012',
    metrics=[
        'imagenet_real_loss', 'imagenet_real_prec@1', 'imagenet_real_ece',
        'imagenet_real_calib_auc'
    ])