##### Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [None]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Meta-Dataset leaderboard

This notebook computes leaderboard tables for different models on Meta-Dataset.

Results for each model in each training setting (ImageNet-only or all datasets) are defined in a different DataFrame. This script aggregates the data in one DataFrame, ranks the models in each setting (using a statistical test for equality), and produces the final tables.

In [None]:
import numpy as np
import pandas as pd

In [None]:
# Explicit list of evaluation datasets.
# ILSVRC (valid) is included for completeness, but does not have to be reported.
datasets = [
    "ILSVRC (valid)",
    "ILSVRC (test)",
    "Omniglot",
    "Aircraft",
    "Birds",
    "Textures",
    "QuickDraw",
    "Fungi",
    "VGG Flower",
    "Traffic signs",
    "MSCOCO"
]

## Results from Triantafillou et al. (ICLR 2020)

### k-NN (`baseline`)

In [None]:
baseline_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
baseline_imagenet_df['# episodes'] = 600

In [None]:
baseline_imagenet_df.loc[1:, ['mean (%)', '95% CI']] = [
    [41.03, 1.01],
    [37.07, 1.15],
    [46.81, 0.89],
    [50.13, 1.00],
    [66.36, 0.75],
    [32.06, 1.08],
    [36.16, 1.02],
    [83.10, 0.68],
    [44.59, 1.19],
    [30.38, 0.99]
]

In [None]:
baseline_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),41.03,1.01,600
Omniglot,37.07,1.15,600
Aircraft,46.81,0.89,600
Birds,50.13,1.0,600
Textures,66.36,0.75,600
QuickDraw,32.06,1.08,600
Fungi,36.16,1.02,600
VGG Flower,83.1,0.68,600
Traffic signs,44.59,1.19,600


In [None]:
baseline_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
baseline_all_df['# episodes'] = 600
baseline_all_df.loc[1:, ['mean (%)', '95% CI']] = [
    [38.55, 0.94],
    [74.60, 1.08],
    [64.98, 0.82],
    [66.35, 0.92],
    [63.58, 0.79],
    [44.88, 1.05],
    [37.12, 1.06],
    [83.47, 0.61],
    [40.11, 1.10],
    [29.55, 0.96]
]
baseline_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),38.55,0.94,600
Omniglot,74.6,1.08,600
Aircraft,64.98,0.82,600
Birds,66.35,0.92,600
Textures,63.58,0.79,600
QuickDraw,44.88,1.05,600
Fungi,37.12,1.06,600
VGG Flower,83.47,0.61,600
Traffic signs,40.11,1.1,600


### Finetune (`baselinefinetune`)

In [None]:
baselineft_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
baselineft_imagenet_df['# episodes'] = 600
baselineft_imagenet_df.loc[1:, ['mean (%)', '95% CI']] = [
    [45.78, 1.10],
    [60.85, 1.58],
    [68.69, 1.26],
    [57.31, 1.26],
    [69.05, 0.90],
    [42.60, 1.17],
    [38.20, 1.02],
    [85.51, 0.68],
    [66.79, 1.31],
    [34.86, 0.97]
]
baselineft_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),45.78,1.1,600
Omniglot,60.85,1.58,600
Aircraft,68.69,1.26,600
Birds,57.31,1.26,600
Textures,69.05,0.9,600
QuickDraw,42.6,1.17,600
Fungi,38.2,1.02,600
VGG Flower,85.51,0.68,600
Traffic signs,66.79,1.31,600


In [None]:
baselineft_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
baselineft_all_df['# episodes'] = 600
baselineft_all_df.loc[1:, ['mean (%)', '95% CI']] = [
    [43.08, 1.08],
    [71.11, 1.37],
    [72.03, 1.07],
    [59.82, 1.15	],
    [69.14, 0.85],
    [47.05, 1.16	],
    [38.16, 1.04],
    [85.28, 0.69],
    [66.74, 1.23],
    [35.17, 1.08]
]
baselineft_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),43.08,1.08,600
Omniglot,71.11,1.37,600
Aircraft,72.03,1.07,600
Birds,59.82,1.15,600
Textures,69.14,0.85,600
QuickDraw,47.05,1.16,600
Fungi,38.16,1.04,600
VGG Flower,85.28,0.69,600
Traffic signs,66.74,1.23,600


### MatchingNet (`matching`)

In [None]:
matching_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
matching_imagenet_df['# episodes'] = 600
matching_imagenet_df.loc[1:, ['mean (%)', '95% CI']] = [
    [45.00, 1.10],
    [52.27, 1.28],
    [48.97, 0.93],
    [62.21, 0.95],
    [64.15, 0.85],
    [42.87, 1.09],
    [33.97, 1.00],
    [80.13, 0.71],
    [47.80, 1.14],
    [34.99, 1.00]
]
matching_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),45.0,1.1,600
Omniglot,52.27,1.28,600
Aircraft,48.97,0.93,600
Birds,62.21,0.95,600
Textures,64.15,0.85,600
QuickDraw,42.87,1.09,600
Fungi,33.97,1.0,600
VGG Flower,80.13,0.71,600
Traffic signs,47.8,1.14,600


In [None]:
matching_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
matching_all_df['# episodes'] = 600
matching_all_df.loc[1:, ['mean (%)', '95% CI']] = [
    [36.08, 1.00],
    [78.25, 1.01],
    [69.17, 0.96],
    [56.40, 1.00],
    [61.80, 0.74],
    [60.81, 1.03],
    [33.70, 1.04],
    [81.90, 0.72],
    [55.57, 1.08],
    [28.79, 0.96]
]
matching_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),36.08,1.0,600
Omniglot,78.25,1.01,600
Aircraft,69.17,0.96,600
Birds,56.4,1.0,600
Textures,61.8,0.74,600
QuickDraw,60.81,1.03,600
Fungi,33.7,1.04,600
VGG Flower,81.9,0.72,600
Traffic signs,55.57,1.08,600


### ProtoNet (`prototypical`)

In [None]:
prototypical_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
prototypical_imagenet_df['# episodes'] = 600
prototypical_imagenet_df.loc[1:, ['mean (%)', '95% CI']] = [
    [50.50, 1.08],
    [59.98, 1.35],
    [53.10, 1.00],
    [68.79, 1.01],
    [66.56, 0.83],
    [48.96, 1.08],
    [39.71, 1.11],
    [85.27, 0.77],
    [47.12, 1.10],
    [41.00, 1.10]
]
prototypical_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),50.5,1.08,600
Omniglot,59.98,1.35,600
Aircraft,53.1,1.0,600
Birds,68.79,1.01,600
Textures,66.56,0.83,600
QuickDraw,48.96,1.08,600
Fungi,39.71,1.11,600
VGG Flower,85.27,0.77,600
Traffic signs,47.12,1.1,600


In [None]:
prototypical_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
prototypical_all_df['# episodes'] = 600
prototypical_all_df.loc[1:, ['mean (%)', '95% CI']] = [
    [44.50, 1.05],
    [79.56, 1.12],
    [71.14, 0.86],
    [67.01, 1.02],
    [65.18, 0.84],
    [64.88, 0.89],
    [40.26, 1.13],
    [86.85, 0.71],
    [46.48, 1.00],
    [39.87, 1.06]
]
prototypical_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),44.5,1.05,600
Omniglot,79.56,1.12,600
Aircraft,71.14,0.86,600
Birds,67.01,1.02,600
Textures,65.18,0.84,600
QuickDraw,64.88,0.89,600
Fungi,40.26,1.13,600
VGG Flower,86.85,0.71,600
Traffic signs,46.48,1.0,600


### fo-MAML (`maml`)

In [None]:
maml_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
maml_imagenet_df['# episodes'] = 600
maml_imagenet_df.loc[1:, ['mean (%)', '95% CI']] = [
    [45.51, 1.11],
    [55.55, 1.54],
    [56.24, 1.11],
    [63.61, 1.06],
    [68.04, 0.81],
    [43.96, 1.29],
    [32.10, 1.10],
    [81.74, 0.83],
    [50.93, 1.51],
    [35.30, 1.23]
]
maml_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),45.51,1.11,600
Omniglot,55.55,1.54,600
Aircraft,56.24,1.11,600
Birds,63.61,1.06,600
Textures,68.04,0.81,600
QuickDraw,43.96,1.29,600
Fungi,32.1,1.1,600
VGG Flower,81.74,0.83,600
Traffic signs,50.93,1.51,600


In [None]:
maml_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
maml_all_df['# episodes'] = 600
maml_all_df.loc[1:, ['mean (%)', '95% CI']] = [
    [37.83, 1.01],
    [83.92, 0.95],
    [76.41, 0.69],
    [62.43, 1.08],
    [64.16, 0.83],
    [59.73, 1.10],
    [33.54, 1.11],
    [79.94, 0.84],
    [42.91, 1.31],
    [29.37, 1.08]
]
maml_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),37.83,1.01,600
Omniglot,83.92,0.95,600
Aircraft,76.41,0.69,600
Birds,62.43,1.08,600
Textures,64.16,0.83,600
QuickDraw,59.73,1.1,600
Fungi,33.54,1.11,600
VGG Flower,79.94,0.84,600
Traffic signs,42.91,1.31,600


### RelationNet (`relationnet`)

In [None]:
relationnet_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
relationnet_imagenet_df['# episodes'] = 600
relationnet_imagenet_df.loc[1:, ['mean (%)', '95% CI']] = [
    [34.69, 1.01],
    [45.35, 1.36],
    [40.73, 0.83],
    [49.51, 1.05],
    [52.97, 0.69],
    [43.30, 1.08],
    [30.55, 1.04],
    [68.76, 0.83],
    [33.67, 1.05],
    [29.15, 1.01]
]
relationnet_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),34.69,1.01,600
Omniglot,45.35,1.36,600
Aircraft,40.73,0.83,600
Birds,49.51,1.05,600
Textures,52.97,0.69,600
QuickDraw,43.3,1.08,600
Fungi,30.55,1.04,600
VGG Flower,68.76,0.83,600
Traffic signs,33.67,1.05,600


In [None]:
relationnet_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
relationnet_all_df['# episodes'] = 600
relationnet_all_df.loc[1:, ['mean (%)', '95% CI']] = [
    [30.89, 0.93],
    [86.57, 0.79],
    [69.71, 0.83],
    [54.14, 0.99],
    [56.56, 0.73],
    [61.75, 0.97],
    [32.56, 1.08],
    [76.08, 0.76],
    [37.48, 0.93],
    [27.41, 0.89]
]
relationnet_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),30.89,0.93,600
Omniglot,86.57,0.79,600
Aircraft,69.71,0.83,600
Birds,54.14,0.99,600
Textures,56.56,0.73,600
QuickDraw,61.75,0.97,600
Fungi,32.56,1.08,600
VGG Flower,76.08,0.76,600
Traffic signs,37.48,0.93,600


### fo-Proto-MAML (`maml_init_with_proto`)

In [None]:
protomaml_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
protomaml_imagenet_df['# episodes'] = 600
protomaml_imagenet_df.loc[1:, ['mean (%)', '95% CI']] = [
    [49.53, 1.05],
    [63.37, 1.33],
    [55.95, 0.99],
    [68.66, 0.96],
    [66.49, 0.83],
    [51.52, 1.00],
    [39.96, 1.14],
    [87.15, 0.69],
    [48.83, 1.09],
    [43.74, 1.12],
]
protomaml_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),49.53,1.05,600
Omniglot,63.37,1.33,600
Aircraft,55.95,0.99,600
Birds,68.66,0.96,600
Textures,66.49,0.83,600
QuickDraw,51.52,1.0,600
Fungi,39.96,1.14,600
VGG Flower,87.15,0.69,600
Traffic signs,48.83,1.09,600


In [None]:
protomaml_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
protomaml_all_df['# episodes'] = 600
protomaml_all_df.loc[1:, ['mean (%)', '95% CI']] = [
    [46.52, 1.05],
    [82.69, 0.97],
    [75.23, 0.76],
    [69.88, 1.02],
    [68.25, 0.81],
    [66.84, 0.94],
    [41.99, 1.17],
    [88.72, 0.67],
    [52.42, 1.08],
    [41.74, 1.13]
]
protomaml_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),46.52,1.05,600
Omniglot,82.69,0.97,600
Aircraft,75.23,0.76,600
Birds,69.88,1.02,600
Textures,68.25,0.81,600
QuickDraw,66.84,0.94,600
Fungi,41.99,1.17,600
VGG Flower,88.72,0.67,600
Traffic signs,52.42,1.08,600


## Template to add a new model

In [None]:
<model_name>_<train_source>_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=<eval_datasets>
)
<model_name>_<train_source>_df['# episodes'] = ...
<model_name>_<train_source>_df.loc[..., ['mean (%)', '95% CI']] = [...]

## Aggregate in table

In [None]:
imagenet_dfs = {
    'k-NN': baseline_imagenet_df,
    'Finetune': baselineft_imagenet_df,
    'MatchingNet': matching_imagenet_df,
    'ProtoNet': prototypical_imagenet_df,
    'fo-MAML': maml_imagenet_df,
    'RelationNet': relationnet_imagenet_df,
    'fo-Proto-MAML': protomaml_imagenet_df
}

In [None]:
imagenet_df = pd.concat(
    imagenet_dfs.values(),
    axis=1,
    keys=imagenet_dfs.keys())
imagenet_df

Unnamed: 0_level_0,k-NN,k-NN,k-NN,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,fo-MAML,fo-MAML,fo-MAML,RelationNet,RelationNet,RelationNet,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML
Unnamed: 0_level_1,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes
ILSVRC (valid),,,600,,,600,,,600,,,600,,,600,,,600,,,600
ILSVRC (test),41.03,1.01,600,45.78,1.1,600,45.0,1.1,600,50.5,1.08,600,45.51,1.11,600,34.69,1.01,600,49.53,1.05,600
Omniglot,37.07,1.15,600,60.85,1.58,600,52.27,1.28,600,59.98,1.35,600,55.55,1.54,600,45.35,1.36,600,63.37,1.33,600
Aircraft,46.81,0.89,600,68.69,1.26,600,48.97,0.93,600,53.1,1.0,600,56.24,1.11,600,40.73,0.83,600,55.95,0.99,600
Birds,50.13,1.0,600,57.31,1.26,600,62.21,0.95,600,68.79,1.01,600,63.61,1.06,600,49.51,1.05,600,68.66,0.96,600
Textures,66.36,0.75,600,69.05,0.9,600,64.15,0.85,600,66.56,0.83,600,68.04,0.81,600,52.97,0.69,600,66.49,0.83,600
QuickDraw,32.06,1.08,600,42.6,1.17,600,42.87,1.09,600,48.96,1.08,600,43.96,1.29,600,43.3,1.08,600,51.52,1.0,600
Fungi,36.16,1.02,600,38.2,1.02,600,33.97,1.0,600,39.71,1.11,600,32.1,1.1,600,30.55,1.04,600,39.96,1.14,600
VGG Flower,83.1,0.68,600,85.51,0.68,600,80.13,0.71,600,85.27,0.77,600,81.74,0.83,600,68.76,0.83,600,87.15,0.69,600
Traffic signs,44.59,1.19,600,66.79,1.31,600,47.8,1.14,600,47.12,1.1,600,50.93,1.51,600,33.67,1.05,600,48.83,1.09,600


In [None]:
all_dfs = {
    'k-NN': baseline_all_df,
    'Finetune': baselineft_all_df,
    'MatchingNet': matching_all_df,
    'ProtoNet': prototypical_all_df,
    'fo-MAML': maml_all_df,
    'RelationNet': relationnet_all_df,
    'fo-Proto-MAML': protomaml_all_df
}
all_df = pd.concat(
    all_dfs.values(),
    axis=1,
    keys=all_dfs.keys())
all_df

Unnamed: 0_level_0,k-NN,k-NN,k-NN,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,fo-MAML,fo-MAML,fo-MAML,RelationNet,RelationNet,RelationNet,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML
Unnamed: 0_level_1,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes
ILSVRC (valid),,,600,,,600,,,600,,,600,,,600,,,600,,,600
ILSVRC (test),38.55,0.94,600,43.08,1.08,600,36.08,1.0,600,44.5,1.05,600,37.83,1.01,600,30.89,0.93,600,46.52,1.05,600
Omniglot,74.6,1.08,600,71.11,1.37,600,78.25,1.01,600,79.56,1.12,600,83.92,0.95,600,86.57,0.79,600,82.69,0.97,600
Aircraft,64.98,0.82,600,72.03,1.07,600,69.17,0.96,600,71.14,0.86,600,76.41,0.69,600,69.71,0.83,600,75.23,0.76,600
Birds,66.35,0.92,600,59.82,1.15,600,56.4,1.0,600,67.01,1.02,600,62.43,1.08,600,54.14,0.99,600,69.88,1.02,600
Textures,63.58,0.79,600,69.14,0.85,600,61.8,0.74,600,65.18,0.84,600,64.16,0.83,600,56.56,0.73,600,68.25,0.81,600
QuickDraw,44.88,1.05,600,47.05,1.16,600,60.81,1.03,600,64.88,0.89,600,59.73,1.1,600,61.75,0.97,600,66.84,0.94,600
Fungi,37.12,1.06,600,38.16,1.04,600,33.7,1.04,600,40.26,1.13,600,33.54,1.11,600,32.56,1.08,600,41.99,1.17,600
VGG Flower,83.47,0.61,600,85.28,0.69,600,81.9,0.72,600,86.85,0.71,600,79.94,0.84,600,76.08,0.76,600,88.72,0.67,600
Traffic signs,40.11,1.1,600,66.74,1.23,600,55.57,1.08,600,46.48,1.0,600,42.91,1.31,600,37.48,0.93,600,52.42,1.08,600


### Add stddev

In [None]:
def add_stddev(df):
  # Extract original order of labels
  datasets = df.index
  models = df.columns.levels[0]
  # Have only one result (mean, CI, ...) per row
  stacked_df = df.stack(0)
  # Add 'stddev' as column
  stacked_df['stddev'] = stacked_df['95% CI'] * np.sqrt(stacked_df['# episodes']) / 1.96
  # Reshape and put back in original order
  new_df = stacked_df.unstack().swaplevel(0, 1, axis=1)
  new_df = new_df.loc[datasets][models]
  return new_df

In [None]:
imagenet_df = add_stddev(imagenet_df)
imagenet_df

Unnamed: 0_level_0,k-NN,k-NN,k-NN,k-NN,Finetune,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,fo-MAML,fo-MAML,fo-MAML,fo-MAML,RelationNet,RelationNet,RelationNet,RelationNet,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML
Unnamed: 0_level_1,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev
ILSVRC (valid),600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,
ILSVRC (test),600,1.01,41.03,12.6224,600,1.1,45.78,13.7471,600,1.1,45.0,13.7471,600,1.08,50.5,13.4972,600,1.11,45.51,13.8721,600,1.01,34.69,12.6224,600,1.05,49.53,13.1223
Omniglot,600,1.15,37.07,14.372,600,1.58,60.85,19.7459,600,1.28,52.27,15.9967,600,1.35,59.98,16.8715,600,1.54,55.55,19.246,600,1.36,45.35,16.9965,600,1.33,63.37,16.6215
Aircraft,600,0.89,46.81,11.1227,600,1.26,68.69,15.7467,600,0.93,48.97,11.6226,600,1.0,53.1,12.4974,600,1.11,56.24,13.8721,600,0.83,40.73,10.3728,600,0.99,55.95,12.3724
Birds,600,1.0,50.13,12.4974,600,1.26,57.31,15.7467,600,0.95,62.21,11.8725,600,1.01,68.79,12.6224,600,1.06,63.61,13.2472,600,1.05,49.51,13.1223,600,0.96,68.66,11.9975
Textures,600,0.75,66.36,9.37305,600,0.9,69.05,11.2477,600,0.85,64.15,10.6228,600,0.83,66.56,10.3728,600,0.81,68.04,10.1229,600,0.69,52.97,8.6232,600,0.83,66.49,10.3728
QuickDraw,600,1.08,32.06,13.4972,600,1.17,42.6,14.622,600,1.09,42.87,13.6222,600,1.08,48.96,13.4972,600,1.29,43.96,16.1216,600,1.08,43.3,13.4972,600,1.0,51.52,12.4974
Fungi,600,1.02,36.16,12.7473,600,1.02,38.2,12.7473,600,1.0,33.97,12.4974,600,1.11,39.71,13.8721,600,1.1,32.1,13.7471,600,1.04,30.55,12.9973,600,1.14,39.96,14.247
VGG Flower,600,0.68,83.1,8.49823,600,0.68,85.51,8.49823,600,0.71,80.13,8.87315,600,0.77,85.27,9.623,600,0.83,81.74,10.3728,600,0.83,68.76,10.3728,600,0.69,87.15,8.6232
Traffic signs,600,1.19,44.59,14.8719,600,1.31,66.79,16.3716,600,1.14,47.8,14.247,600,1.1,47.12,13.7471,600,1.51,50.93,18.8711,600,1.05,33.67,13.1223,600,1.09,48.83,13.6222


In [None]:
all_df = add_stddev(all_df)
all_df

Unnamed: 0_level_0,k-NN,k-NN,k-NN,k-NN,Finetune,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,fo-MAML,fo-MAML,fo-MAML,fo-MAML,RelationNet,RelationNet,RelationNet,RelationNet,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML
Unnamed: 0_level_1,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev
ILSVRC (valid),600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,
ILSVRC (test),600,0.94,38.55,11.7476,600,1.08,43.08,13.4972,600,1.0,36.08,12.4974,600,1.05,44.5,13.1223,600,1.01,37.83,12.6224,600,0.93,30.89,11.6226,600,1.05,46.52,13.1223
Omniglot,600,1.08,74.6,13.4972,600,1.37,71.11,17.1214,600,1.01,78.25,12.6224,600,1.12,79.56,13.9971,600,0.95,83.92,11.8725,600,0.79,86.57,9.87294,600,0.97,82.69,12.1225
Aircraft,600,0.82,64.98,10.2479,600,1.07,72.03,13.3722,600,0.96,69.17,11.9975,600,0.86,71.14,10.7478,600,0.69,76.41,8.6232,600,0.83,69.71,10.3728,600,0.76,75.23,9.49802
Birds,600,0.92,66.35,11.4976,600,1.15,59.82,14.372,600,1.0,56.4,12.4974,600,1.02,67.01,12.7473,600,1.08,62.43,13.4972,600,0.99,54.14,12.3724,600,1.02,69.88,12.7473
Textures,600,0.79,63.58,9.87294,600,0.85,69.14,10.6228,600,0.74,61.8,9.24807,600,0.84,65.18,10.4978,600,0.83,64.16,10.3728,600,0.73,56.56,9.1231,600,0.81,68.25,10.1229
QuickDraw,600,1.05,44.88,13.1223,600,1.16,47.05,14.497,600,1.03,60.81,12.8723,600,0.89,64.88,11.1227,600,1.1,59.73,13.7471,600,0.97,61.75,12.1225,600,0.94,66.84,11.7476
Fungi,600,1.06,37.12,13.2472,600,1.04,38.16,12.9973,600,1.04,33.7,12.9973,600,1.13,40.26,14.1221,600,1.11,33.54,13.8721,600,1.08,32.56,13.4972,600,1.17,41.99,14.622
VGG Flower,600,0.61,83.47,7.62341,600,0.69,85.28,8.6232,600,0.72,81.9,8.99813,600,0.71,86.85,8.87315,600,0.84,79.94,10.4978,600,0.76,76.08,9.49802,600,0.67,88.72,8.37326
Traffic signs,600,1.1,40.11,13.7471,600,1.23,66.74,15.3718,600,1.08,55.57,13.4972,600,1.0,46.48,12.4974,600,1.31,42.91,16.3716,600,0.93,37.48,11.6226,600,1.08,52.42,13.4972


### Add rankings

In [None]:
def is_difference_significant(best_stats, candidate_stats):
  # compute a 95% confidence for the difference of means.
  ci = 1.96 * np.sqrt((best_stats['stddev'] ** 2) / best_stats['# episodes'] +
                      (candidate_stats['stddev'] ** 2) / candidate_stats['# episodes'])
  diff_of_means = best_stats['mean (%)'] - candidate_stats['mean (%)']
  return np.abs(diff_of_means) > ci

In [None]:
def compute_ranks(dataset_series):
  dataset_df = dataset_series.unstack()
  n_models = len(dataset_df.index)
  remaining_models = list(dataset_df.index)
  next_available_rank = 1
  ranks = {}
  # Iteratively pick the best models, then all the ones statistically equivalent
  while remaining_models:
    accuracies = dataset_df.loc[remaining_models]['mean (%)'].astype('d')
    best_model = accuracies.idxmax(axis=1)
    best_stats = dataset_df.loc[best_model]
    tied_models = [best_model]
    potential_tied_models = [model for model in remaining_models
                             if model != best_model]
    for candidate in potential_tied_models:
      candidate_stats = dataset_df.loc[candidate]
      if not is_difference_significant(best_stats, candidate_stats):
        tied_models.append(candidate)

    n_ties = len(tied_models)
    # All tied models share the same rank, which is the average of the next
    # `n_ties` available ranks (the ranks they would have without the ties), or
    # next_available_rank + (1 + ... + (n_ties - 1)) / n_ties, which gives:
    shared_rank = next_available_rank + (n_ties - 1) / 2
    next_available_rank += n_ties
    for model in tied_models:
      ranks[model] = shared_rank

    # Remove picked models for next iteration
    remaining_models = [model for model in remaining_models
                        if model not in tied_models]
  return pd.Series(ranks, name='rank')

In [None]:
def add_ranks(df):
  # Get ranks as a data frame (ignore "ILSVRC (valid)")
  ranks = df[1:].apply(compute_ranks, axis=1)
  # Set the columns as (model, 'rank') Multi-index
  ranks = pd.concat([ranks], axis=1, keys=['rank']).swaplevel(0, 1, axis=1)
  # Concatenate with the original dataframe and defrag columns
  new_df = pd.concat([df, ranks], axis=1)[df.columns.levels[0]]
  return new_df

In [None]:
imagenet_df = add_ranks(imagenet_df)
imagenet_df

Unnamed: 0_level_0,Finetune,Finetune,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,RelationNet,RelationNet,RelationNet,RelationNet,RelationNet,fo-MAML,fo-MAML,fo-MAML,fo-MAML,fo-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,k-NN,k-NN,k-NN,k-NN,k-NN
Unnamed: 0_level_1,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank
ILSVRC (valid),600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,
ILSVRC (test),600,1.1,45.78,13.7471,4.0,600,1.1,45.0,13.7471,4.0,600,1.08,50.5,13.4972,1.5,600,1.01,34.69,12.6224,7.0,600,1.11,45.51,13.8721,4.0,600,1.05,49.53,13.1223,1.5,600,1.01,41.03,12.6224,6.0
Omniglot,600,1.58,60.85,19.7459,2.5,600,1.28,52.27,15.9967,5.0,600,1.35,59.98,16.8715,2.5,600,1.36,45.35,16.9965,6.0,600,1.54,55.55,19.246,4.0,600,1.33,63.37,16.6215,1.0,600,1.15,37.07,14.372,7.0
Aircraft,600,1.26,68.69,15.7467,1.0,600,0.93,48.97,11.6226,5.0,600,1.0,53.1,12.4974,4.0,600,0.83,40.73,10.3728,7.0,600,1.11,56.24,13.8721,2.5,600,0.99,55.95,12.3724,2.5,600,0.89,46.81,11.1227,6.0
Birds,600,1.26,57.31,15.7467,5.0,600,0.95,62.21,11.8725,3.5,600,1.01,68.79,12.6224,1.5,600,1.05,49.51,13.1223,6.5,600,1.06,63.61,13.2472,3.5,600,0.96,68.66,11.9975,1.5,600,1.0,50.13,12.4974,6.5
Textures,600,0.9,69.05,11.2477,1.5,600,0.85,64.15,10.6228,6.0,600,0.83,66.56,10.3728,4.0,600,0.69,52.97,8.6232,7.0,600,0.81,68.04,10.1229,1.5,600,0.83,66.49,10.3728,4.0,600,0.75,66.36,9.37305,4.0
QuickDraw,600,1.17,42.6,14.622,4.5,600,1.09,42.87,13.6222,4.5,600,1.08,48.96,13.4972,2.0,600,1.08,43.3,13.4972,4.5,600,1.29,43.96,16.1216,4.5,600,1.0,51.52,12.4974,1.0,600,1.08,32.06,13.4972,7.0
Fungi,600,1.02,38.2,12.7473,3.0,600,1.0,33.97,12.4974,5.0,600,1.11,39.71,13.8721,1.5,600,1.04,30.55,12.9973,7.0,600,1.1,32.1,13.7471,6.0,600,1.14,39.96,14.247,1.5,600,1.02,36.16,12.7473,4.0
VGG Flower,600,0.68,85.51,8.49823,2.5,600,0.71,80.13,8.87315,6.0,600,0.77,85.27,9.623,2.5,600,0.83,68.76,10.3728,7.0,600,0.83,81.74,10.3728,5.0,600,0.69,87.15,8.6232,1.0,600,0.68,83.1,8.49823,4.0
Traffic signs,600,1.31,66.79,16.3716,1.0,600,1.14,47.8,14.247,3.5,600,1.1,47.12,13.7471,5.0,600,1.05,33.67,13.1223,7.0,600,1.51,50.93,18.8711,2.0,600,1.09,48.83,13.6222,3.5,600,1.19,44.59,14.8719,6.0


In [None]:
all_df = add_ranks(all_df)
all_df

Unnamed: 0_level_0,Finetune,Finetune,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,RelationNet,RelationNet,RelationNet,RelationNet,RelationNet,fo-MAML,fo-MAML,fo-MAML,fo-MAML,fo-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,k-NN,k-NN,k-NN,k-NN,k-NN
Unnamed: 0_level_1,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank
ILSVRC (valid),600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,
ILSVRC (test),600,1.08,43.08,13.4972,2.5,600,1.0,36.08,12.4974,6.0,600,1.05,44.5,13.1223,2.5,600,0.93,30.89,11.6226,7.0,600,1.01,37.83,12.6224,4.5,600,1.05,46.52,13.1223,1.0,600,0.94,38.55,11.7476,4.5
Omniglot,600,1.37,71.11,17.1214,7.0,600,1.01,78.25,12.6224,4.5,600,1.12,79.56,13.9971,4.5,600,0.79,86.57,9.87294,1.0,600,0.95,83.92,11.8725,2.5,600,0.97,82.69,12.1225,2.5,600,1.08,74.6,13.4972,6.0
Aircraft,600,1.07,72.03,13.3722,3.5,600,0.96,69.17,11.9975,5.5,600,0.86,71.14,10.7478,3.5,600,0.83,69.71,10.3728,5.5,600,0.69,76.41,8.6232,1.0,600,0.76,75.23,9.49802,2.0,600,0.82,64.98,10.2479,7.0
Birds,600,1.15,59.82,14.372,5.0,600,1.0,56.4,12.4974,6.0,600,1.02,67.01,12.7473,2.5,600,0.99,54.14,12.3724,7.0,600,1.08,62.43,13.4972,4.0,600,1.02,69.88,12.7473,1.0,600,0.92,66.35,11.4976,2.5
Textures,600,0.85,69.14,10.6228,1.5,600,0.74,61.8,9.24807,6.0,600,0.84,65.18,10.4978,3.5,600,0.73,56.56,9.1231,7.0,600,0.83,64.16,10.3728,3.5,600,0.81,68.25,10.1229,1.5,600,0.79,63.58,9.87294,5.0
QuickDraw,600,1.16,47.05,14.497,6.0,600,1.03,60.81,12.8723,3.5,600,0.89,64.88,11.1227,2.0,600,0.97,61.75,12.1225,3.5,600,1.1,59.73,13.7471,5.0,600,0.94,66.84,11.7476,1.0,600,1.05,44.88,13.1223,7.0
Fungi,600,1.04,38.16,12.9973,3.5,600,1.04,33.7,12.9973,6.0,600,1.13,40.26,14.1221,2.0,600,1.08,32.56,13.4972,6.0,600,1.11,33.54,13.8721,6.0,600,1.17,41.99,14.622,1.0,600,1.06,37.12,13.2472,3.5
VGG Flower,600,0.69,85.28,8.6232,3.0,600,0.72,81.9,8.99813,5.0,600,0.71,86.85,8.87315,2.0,600,0.76,76.08,9.49802,7.0,600,0.84,79.94,10.4978,6.0,600,0.67,88.72,8.37326,1.0,600,0.61,83.47,7.62341,4.0
Traffic signs,600,1.23,66.74,15.3718,1.0,600,1.08,55.57,13.4972,2.0,600,1.0,46.48,12.4974,4.0,600,0.93,37.48,11.6226,7.0,600,1.31,42.91,16.3716,5.0,600,1.08,52.42,13.4972,3.0,600,1.1,40.11,13.7471,6.0


In [None]:
imagenet_df.xs('rank', axis=1, level=1).mean()

Finetune         2.90
MatchingNet      4.65
ProtoNet         2.65
RelationNet      6.55
fo-MAML          3.70
fo-Proto-MAML    1.85
k-NN             5.70
dtype: float64

In [None]:
all_df.xs('rank', axis=1, level=1).mean()

Finetune         3.60
MatchingNet      4.95
ProtoNet         2.85
RelationNet      5.80
fo-MAML          4.25
fo-Proto-MAML    1.50
k-NN             5.05
dtype: float64

### Display in HTML
This section uses the DataFrame's "styler" object, which renders nicely within the notebook.

Unfortunately, the HTML it outputs is not compatible with GitHub's markdown (as it relies on the `<style>` tag).

In [None]:
def str_summary(series):
  # Summarize each (episode, model) by a single cell
  # Non-breaking space to keep things on the same line
  nbsp = '\u00A0'
  string = '%(acc)s±%(ci)s%(nbsp)s(%(rank)g)' % {
      'acc': series['mean (%)'],
      'ci': series['95% CI'],
      'rank': series['rank'],
      'nbsp': nbsp
  }
  return string

In [None]:
def display_table(df, models=None):
  accuracies_df = df.stack(0).apply(str_summary, axis=1).unstack(0)[df.index[1:]]
  rank_df = df.xs('rank', axis=1, level=1).loc[df.index[1:]]
  avg_rank_df = pd.DataFrame(rank_df.mean(), columns=['Avg rank'])
  display_df = pd.concat([avg_rank_df, accuracies_df], axis=1)
  if models:
    # Try and force a particular order of models
    display_df = display_df.loc[models]
 
  # Bold cells corresponding to the best rank
  best_acc_mask = rank_df.T == rank_df.min(axis=1)
  best_avg_mask = avg_rank_df == avg_rank_df.min()
  best_mask = pd.concat([best_avg_mask, best_acc_mask], axis=1)
  if models:
    best_mask = best_mask.loc[models]
  bold_mask = best_mask.applymap(lambda v: 'font-weight: bold' if v else '')

  display_style = display_df.style.apply(lambda f: bold_mask, axis=None)
  display_style = display_style.format({'Avg rank': '{:g}'})
  return display_style

In [None]:
imagenet_display = display_table(imagenet_df, models=imagenet_dfs.keys())
imagenet_display

Unnamed: 0,Avg rank,ILSVRC (test),Omniglot,Aircraft,Birds,Textures,QuickDraw,Fungi,VGG Flower,Traffic signs,MSCOCO
k-NN,5.7,41.03±1.01 (6),37.07±1.15 (7),46.81±0.89 (6),50.13±1.0 (6.5),66.36±0.75 (4),32.06±1.08 (7),36.16±1.02 (4),83.1±0.68 (4),44.59±1.19 (6),30.38±0.99 (6.5)
Finetune,2.9,45.78±1.1 (4),60.85±1.58 (2.5),68.69±1.26 (1),57.31±1.26 (5),69.05±0.9 (1.5),42.6±1.17 (4.5),38.2±1.02 (3),85.51±0.68 (2.5),66.79±1.31 (1),34.86±0.97 (4)
MatchingNet,4.65,45.0±1.1 (4),52.27±1.28 (5),48.97±0.93 (5),62.21±0.95 (3.5),64.15±0.85 (6),42.87±1.09 (4.5),33.97±1.0 (5),80.13±0.71 (6),47.8±1.14 (3.5),34.99±1.0 (4)
ProtoNet,2.65,50.5±1.08 (1.5),59.98±1.35 (2.5),53.1±1.0 (4),68.79±1.01 (1.5),66.56±0.83 (4),48.96±1.08 (2),39.71±1.11 (1.5),85.27±0.77 (2.5),47.12±1.1 (5),41.0±1.1 (2)
fo-MAML,3.7,45.51±1.11 (4),55.55±1.54 (4),56.24±1.11 (2.5),63.61±1.06 (3.5),68.04±0.81 (1.5),43.96±1.29 (4.5),32.1±1.1 (6),81.74±0.83 (5),50.93±1.51 (2),35.3±1.23 (4)
RelationNet,6.55,34.69±1.01 (7),45.35±1.36 (6),40.73±0.83 (7),49.51±1.05 (6.5),52.97±0.69 (7),43.3±1.08 (4.5),30.55±1.04 (7),68.76±0.83 (7),33.67±1.05 (7),29.15±1.01 (6.5)
fo-Proto-MAML,1.85,49.53±1.05 (1.5),63.37±1.33 (1),55.95±0.99 (2.5),68.66±0.96 (1.5),66.49±0.83 (4),51.52±1.0 (1),39.96±1.14 (1.5),87.15±0.69 (1),48.83±1.09 (3.5),43.74±1.12 (1)


In [None]:
print(imagenet_display.render())

In [None]:
all_display = display_table(all_df, models=all_dfs.keys())
all_display

Unnamed: 0,Avg rank,ILSVRC (test),Omniglot,Aircraft,Birds,Textures,QuickDraw,Fungi,VGG Flower,Traffic signs,MSCOCO
k-NN,5.05,38.55±0.94 (4.5),74.6±1.08 (6),64.98±0.82 (7),66.35±0.92 (2.5),63.58±0.79 (5),44.88±1.05 (7),37.12±1.06 (3.5),83.47±0.61 (4),40.11±1.1 (6),29.55±0.96 (5)
Finetune,3.6,43.08±1.08 (2.5),71.11±1.37 (7),72.03±1.07 (3.5),59.82±1.15 (5),69.14±0.85 (1.5),47.05±1.16 (6),38.16±1.04 (3.5),85.28±0.69 (3),66.74±1.23 (1),35.17±1.08 (3)
MatchingNet,4.95,36.08±1.0 (6),78.25±1.01 (4.5),69.17±0.96 (5.5),56.4±1.0 (6),61.8±0.74 (6),60.81±1.03 (3.5),33.7±1.04 (6),81.9±0.72 (5),55.57±1.08 (2),28.79±0.96 (5)
ProtoNet,2.85,44.5±1.05 (2.5),79.56±1.12 (4.5),71.14±0.86 (3.5),67.01±1.02 (2.5),65.18±0.84 (3.5),64.88±0.89 (2),40.26±1.13 (2),86.85±0.71 (2),46.48±1.0 (4),39.87±1.06 (2)
fo-MAML,4.25,37.83±1.01 (4.5),83.92±0.95 (2.5),76.41±0.69 (1),62.43±1.08 (4),64.16±0.83 (3.5),59.73±1.1 (5),33.54±1.11 (6),79.94±0.84 (6),42.91±1.31 (5),29.37±1.08 (5)
RelationNet,5.8,30.89±0.93 (7),86.57±0.79 (1),69.71±0.83 (5.5),54.14±0.99 (7),56.56±0.73 (7),61.75±0.97 (3.5),32.56±1.08 (6),76.08±0.76 (7),37.48±0.93 (7),27.41±0.89 (7)
fo-Proto-MAML,1.5,46.52±1.05 (1),82.69±0.97 (2.5),75.23±0.76 (2),69.88±1.02 (1),68.25±0.81 (1.5),66.84±0.94 (1),41.99±1.17 (1),88.72±0.67 (1),52.42±1.08 (3),41.74±1.13 (1)


In [None]:
print(all_display.render())

### Display in MarkDown
At least, in GitHub-flavored MarkDown.

In [None]:
def md_render(series):
  # Summarize each (episode, model) by a single cell containing MarkDown
  nbsp = '&nbsp;'
  md_string = '%(bold)s%(acc)5.2f%(bold)s±%(ci)4.2f%(nbsp)s(%(rank)g)' % {
      'acc': series['mean (%)'],
      'ci': series['95% CI'],
      'rank': series['rank'],
      'bold': '**' if series['best_rank'] else '',
      'nbsp': nbsp
  }
  return md_string

In [None]:
def md_table(df, models=None):
  # Whether a model has the best rank on a given dataset
  rank_df = df.xs('rank', axis=1, level=1).loc[df.index[1:]]
  best_rank = pd.concat([rank_df.T == rank_df.min(axis=1)], axis=1,
                        keys=['best_rank']).swaplevel(0, 1, axis=1)
  accuracies_df = df[1:].T.unstack(1)
  accuracies_df = pd.concat([accuracies_df, best_rank], axis=1)
  accuracies_md = accuracies_df.stack(0).apply(md_render, axis=1).unstack(1)

  # Average rank (and whether it's the best)
  avg_rank_df = pd.DataFrame(rank_df.mean(), columns=['Avg rank'])
  best_avg_rank = (avg_rank_df == avg_rank_df.min()).rename(
      columns={'Avg rank': 'best_rank'})
  avg_rank_md = pd.concat([avg_rank_df, best_avg_rank], axis=1).apply(
      lambda s: '%(bold)s%(avg_rank)g%(bold)s' % {
          'avg_rank': s['Avg rank'],
          'bold': '**' if s['best_rank'] else ''
      },
      axis=1).rename('Avg rank')

  display_md = pd.concat([avg_rank_md, accuracies_md[df.index[1:]]], axis=1)
  if models:
    # Try and force a particular order of models
    display_md = display_md.loc[list(models)]

  # Pad all cells so they align well, 27 chars should be enough
  header_str = '|'.join(['%-27s' % c
                         for c in ['Method'] + list(display_md.columns)])
  sep_str = '|'.join(['-' * 27 for _ in [''] + list(display_md.columns)])  
  rows = [
      '|'.join(['%-27s' % c for c in [i] + list(display_md.loc[i])])
      for i in display_md.index
  ]
  return '\n'.join([header_str, sep_str] + rows)

In [None]:
print(md_table(imagenet_df, models=imagenet_dfs.keys()))

Method                     |Avg rank                   |ILSVRC (test)              |Omniglot                   |Aircraft                   |Birds                      |Textures                   |QuickDraw                  |Fungi                      |VGG Flower                 |Traffic signs              |MSCOCO                     
---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------
k-NN                       |5.7                        |41.03±1.01&nbsp;(6)        |37.07±1.15&nbsp;(7)        |46.81±0.89&nbsp;(6)        |50.13±1.00&nbsp;(6.5)      |66.36±0.75&nbsp;(4)        |32.06±1.08&nbsp;(7)        |36.16±1.02&nbsp;(4)        |83.10±0.68&nbsp;(4)        |44.59±1.19&nbsp;(6)        |30.38±0.99&nbsp;(6.5

In [None]:
print(md_table(all_df, models=all_dfs.keys()))

Method                     |Avg rank                   |ILSVRC (test)              |Omniglot                   |Aircraft                   |Birds                      |Textures                   |QuickDraw                  |Fungi                      |VGG Flower                 |Traffic signs              |MSCOCO                     
---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------
k-NN                       |5.05                       |38.55±0.94&nbsp;(4.5)      |74.60±1.08&nbsp;(6)        |64.98±0.82&nbsp;(7)        |66.35±0.92&nbsp;(2.5)      |63.58±0.79&nbsp;(5)        |44.88±1.05&nbsp;(7)        |37.12±1.06&nbsp;(3.5)      |83.47±0.61&nbsp;(4)        |40.11±1.10&nbsp;(6)        |29.55±0.96&nbsp;(5) 