##### Copyright 2020 Google LLC.

Licensed under the Apache License, Version 2.0 (the "License");

In [1]:
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

# Meta-Dataset leaderboard

This notebook computes leaderboard tables for different models on Meta-Dataset.

Results for each model in each training setting (ImageNet-only or all datasets) are defined in a different DataFrame. This script aggregates the data in one DataFrame, ranks the models in each setting (using a statistical test for equality), and produces the final tables.

In [78]:
import re
import textwrap

import numpy as np
import pandas as pd
from IPython import display

In [79]:
# Explicit list of evaluation datasets.
# ILSVRC (valid) is included for completeness, but does not have to be reported.
datasets = [
    "ILSVRC (valid)",
    "ILSVRC (test)",
    "Omniglot",
    "Aircraft",
    "Birds",
    "Textures",
    "QuickDraw",
    "Fungi",
    "VGG Flower",
    "Traffic signs",
    "MSCOCO"
]

# Explicit list of articles and references, filled in throughout the notebook.
references = []

## Results from Triantafillou et al. (2020)

In [80]:
ref = ("Triantafillou et al. (2020)",
       "Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, "
       "Utku Evci, Kelvin Xu, Ross Goroshin, Carles Gelada, Kevin Swersky, "
       "Pierre-Antoine Manzagol, Hugo Larochelle; "
       "[_Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few "
       "Examples_](https://arxiv.org/abs/1903.03096); ICLR 2020.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### k-NN (`baseline`)

In [81]:
baseline_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
baseline_imagenet_df['# episodes'] = 600

In [82]:
baseline_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [41.03, 1.01],
    [37.07, 1.15],
    [46.81, 0.89],
    [50.13, 1.00],
    [66.36, 0.75],
    [32.06, 1.08],
    [36.16, 1.02],
    [83.10, 0.68],
    [44.59, 1.19],
    [30.38, 0.99]
]

In [83]:
baseline_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),41.03,1.01,600
Omniglot,37.07,1.15,600
Aircraft,46.81,0.89,600
Birds,50.13,1.0,600
Textures,66.36,0.75,600
QuickDraw,32.06,1.08,600
Fungi,36.16,1.02,600
VGG Flower,83.1,0.68,600
Traffic signs,44.59,1.19,600


In [84]:
baseline_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
baseline_all_df['# episodes'] = 600
baseline_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [38.55, 0.94],
    [74.60, 1.08],
    [64.98, 0.82],
    [66.35, 0.92],
    [63.58, 0.79],
    [44.88, 1.05],
    [37.12, 1.06],
    [83.47, 0.61],
    [40.11, 1.10],
    [29.55, 0.96]
]
baseline_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),38.55,0.94,600
Omniglot,74.6,1.08,600
Aircraft,64.98,0.82,600
Birds,66.35,0.92,600
Textures,63.58,0.79,600
QuickDraw,44.88,1.05,600
Fungi,37.12,1.06,600
VGG Flower,83.47,0.61,600
Traffic signs,40.11,1.1,600


### Finetune (`baselinefinetune`)

In [85]:
baselineft_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
baselineft_imagenet_df['# episodes'] = 600
baselineft_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [45.78, 1.10],
    [60.85, 1.58],
    [68.69, 1.26],
    [57.31, 1.26],
    [69.05, 0.90],
    [42.60, 1.17],
    [38.20, 1.02],
    [85.51, 0.68],
    [66.79, 1.31],
    [34.86, 0.97]
]
baselineft_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),45.78,1.1,600
Omniglot,60.85,1.58,600
Aircraft,68.69,1.26,600
Birds,57.31,1.26,600
Textures,69.05,0.9,600
QuickDraw,42.6,1.17,600
Fungi,38.2,1.02,600
VGG Flower,85.51,0.68,600
Traffic signs,66.79,1.31,600


In [86]:
baselineft_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
baselineft_all_df['# episodes'] = 600
baselineft_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [43.08, 1.08],
    [71.11, 1.37],
    [72.03, 1.07],
    [59.82, 1.15	],
    [69.14, 0.85],
    [47.05, 1.16	],
    [38.16, 1.04],
    [85.28, 0.69],
    [66.74, 1.23],
    [35.17, 1.08]
]
baselineft_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),43.08,1.08,600
Omniglot,71.11,1.37,600
Aircraft,72.03,1.07,600
Birds,59.82,1.15,600
Textures,69.14,0.85,600
QuickDraw,47.05,1.16,600
Fungi,38.16,1.04,600
VGG Flower,85.28,0.69,600
Traffic signs,66.74,1.23,600


### MatchingNet (`matching`)

In [87]:
matching_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
matching_imagenet_df['# episodes'] = 600
matching_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [45.00, 1.10],
    [52.27, 1.28],
    [48.97, 0.93],
    [62.21, 0.95],
    [64.15, 0.85],
    [42.87, 1.09],
    [33.97, 1.00],
    [80.13, 0.71],
    [47.80, 1.14],
    [34.99, 1.00]
]
matching_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),45.0,1.1,600
Omniglot,52.27,1.28,600
Aircraft,48.97,0.93,600
Birds,62.21,0.95,600
Textures,64.15,0.85,600
QuickDraw,42.87,1.09,600
Fungi,33.97,1.0,600
VGG Flower,80.13,0.71,600
Traffic signs,47.8,1.14,600


In [88]:
matching_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
matching_all_df['# episodes'] = 600
matching_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [36.08, 1.00],
    [78.25, 1.01],
    [69.17, 0.96],
    [56.40, 1.00],
    [61.80, 0.74],
    [60.81, 1.03],
    [33.70, 1.04],
    [81.90, 0.72],
    [55.57, 1.08],
    [28.79, 0.96]
]
matching_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),36.08,1.0,600
Omniglot,78.25,1.01,600
Aircraft,69.17,0.96,600
Birds,56.4,1.0,600
Textures,61.8,0.74,600
QuickDraw,60.81,1.03,600
Fungi,33.7,1.04,600
VGG Flower,81.9,0.72,600
Traffic signs,55.57,1.08,600


### ProtoNet (`prototypical`)

In [89]:
prototypical_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
prototypical_imagenet_df['# episodes'] = 600
prototypical_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [50.50, 1.08],
    [59.98, 1.35],
    [53.10, 1.00],
    [68.79, 1.01],
    [66.56, 0.83],
    [48.96, 1.08],
    [39.71, 1.11],
    [85.27, 0.77],
    [47.12, 1.10],
    [41.00, 1.10]
]
prototypical_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),50.5,1.08,600
Omniglot,59.98,1.35,600
Aircraft,53.1,1.0,600
Birds,68.79,1.01,600
Textures,66.56,0.83,600
QuickDraw,48.96,1.08,600
Fungi,39.71,1.11,600
VGG Flower,85.27,0.77,600
Traffic signs,47.12,1.1,600


In [90]:
prototypical_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
prototypical_all_df['# episodes'] = 600
prototypical_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [44.50, 1.05],
    [79.56, 1.12],
    [71.14, 0.86],
    [67.01, 1.02],
    [65.18, 0.84],
    [64.88, 0.89],
    [40.26, 1.13],
    [86.85, 0.71],
    [46.48, 1.00],
    [39.87, 1.06]
]
prototypical_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),44.5,1.05,600
Omniglot,79.56,1.12,600
Aircraft,71.14,0.86,600
Birds,67.01,1.02,600
Textures,65.18,0.84,600
QuickDraw,64.88,0.89,600
Fungi,40.26,1.13,600
VGG Flower,86.85,0.71,600
Traffic signs,46.48,1.0,600


### fo-MAML (`maml`)

In [91]:
maml_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
maml_imagenet_df['# episodes'] = 600
maml_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [45.51, 1.11],
    [55.55, 1.54],
    [56.24, 1.11],
    [63.61, 1.06],
    [68.04, 0.81],
    [43.96, 1.29],
    [32.10, 1.10],
    [81.74, 0.83],
    [50.93, 1.51],
    [35.30, 1.23]
]
maml_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),45.51,1.11,600
Omniglot,55.55,1.54,600
Aircraft,56.24,1.11,600
Birds,63.61,1.06,600
Textures,68.04,0.81,600
QuickDraw,43.96,1.29,600
Fungi,32.1,1.1,600
VGG Flower,81.74,0.83,600
Traffic signs,50.93,1.51,600


In [92]:
maml_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
maml_all_df['# episodes'] = 600
maml_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [37.83, 1.01],
    [83.92, 0.95],
    [76.41, 0.69],
    [62.43, 1.08],
    [64.16, 0.83],
    [59.73, 1.10],
    [33.54, 1.11],
    [79.94, 0.84],
    [42.91, 1.31],
    [29.37, 1.08]
]
maml_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),37.83,1.01,600
Omniglot,83.92,0.95,600
Aircraft,76.41,0.69,600
Birds,62.43,1.08,600
Textures,64.16,0.83,600
QuickDraw,59.73,1.1,600
Fungi,33.54,1.11,600
VGG Flower,79.94,0.84,600
Traffic signs,42.91,1.31,600


### RelationNet (`relationnet`)

In [93]:
relationnet_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
relationnet_imagenet_df['# episodes'] = 600
relationnet_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [34.69, 1.01],
    [45.35, 1.36],
    [40.73, 0.83],
    [49.51, 1.05],
    [52.97, 0.69],
    [43.30, 1.08],
    [30.55, 1.04],
    [68.76, 0.83],
    [33.67, 1.05],
    [29.15, 1.01]
]
relationnet_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),34.69,1.01,600
Omniglot,45.35,1.36,600
Aircraft,40.73,0.83,600
Birds,49.51,1.05,600
Textures,52.97,0.69,600
QuickDraw,43.3,1.08,600
Fungi,30.55,1.04,600
VGG Flower,68.76,0.83,600
Traffic signs,33.67,1.05,600


In [94]:
relationnet_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
relationnet_all_df['# episodes'] = 600
relationnet_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [30.89, 0.93],
    [86.57, 0.79],
    [69.71, 0.83],
    [54.14, 0.99],
    [56.56, 0.73],
    [61.75, 0.97],
    [32.56, 1.08],
    [76.08, 0.76],
    [37.48, 0.93],
    [27.41, 0.89]
]
relationnet_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),30.89,0.93,600
Omniglot,86.57,0.79,600
Aircraft,69.71,0.83,600
Birds,54.14,0.99,600
Textures,56.56,0.73,600
QuickDraw,61.75,0.97,600
Fungi,32.56,1.08,600
VGG Flower,76.08,0.76,600
Traffic signs,37.48,0.93,600


### fo-Proto-MAML (`maml_init_with_proto`)

In [95]:
protomaml_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
protomaml_imagenet_df['# episodes'] = 600
protomaml_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [49.53, 1.05],
    [63.37, 1.33],
    [55.95, 0.99],
    [68.66, 0.96],
    [66.49, 0.83],
    [51.52, 1.00],
    [39.96, 1.14],
    [87.15, 0.69],
    [48.83, 1.09],
    [43.74, 1.12],
]
protomaml_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),49.53,1.05,600
Omniglot,63.37,1.33,600
Aircraft,55.95,0.99,600
Birds,68.66,0.96,600
Textures,66.49,0.83,600
QuickDraw,51.52,1.0,600
Fungi,39.96,1.14,600
VGG Flower,87.15,0.69,600
Traffic signs,48.83,1.09,600


In [96]:
protomaml_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
protomaml_all_df['# episodes'] = 600
protomaml_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [46.52, 1.05],
    [82.69, 0.97],
    [75.23, 0.76],
    [69.88, 1.02],
    [68.25, 0.81],
    [66.84, 0.94],
    [41.99, 1.17],
    [88.72, 0.67],
    [52.42, 1.08],
    [41.74, 1.13]
]
protomaml_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),46.52,1.05,600
Omniglot,82.69,0.97,600
Aircraft,75.23,0.76,600
Birds,69.88,1.02,600
Textures,68.25,0.81,600
QuickDraw,66.84,0.94,600
Fungi,41.99,1.17,600
VGG Flower,88.72,0.67,600
Traffic signs,52.42,1.08,600


## Results from Requeima et al. (2019)

In [97]:
ref = ("Requeima et al. (2019)",
       "James Requeima, Jonathan Gordon, John Bronskill, Sebastian Nowozin, "
       "Richard E. Turner; "
       "[_Fast and Flexible Multi-Task Classification Using Conditional Neural "
       "Adaptive Processes_](https://arxiv.org/abs/1906.07697); "
       "NeurIPS 2019.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### CNAPs (`cnaps`)

In [98]:
cnaps_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
cnaps_all_df['# episodes'] = 600
cnaps_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [50.8, 1.1],
    [91.7, 0.5],
    [83.7, 0.6],
    [73.6, 0.9],
    [59.5, 0.7],
    [74.7, 0.8],
    [50.2, 1.1],
    [88.9, 0.5],
    [56.5, 1.1],
    [39.4, 1.0]
]
cnaps_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),50.8,1.1,600
Omniglot,91.7,0.5,600
Aircraft,83.7,0.6,600
Birds,73.6,0.9,600
Textures,59.5,0.7,600
QuickDraw,74.7,0.8,600
Fungi,50.2,1.1,600
VGG Flower,88.9,0.5,600
Traffic signs,56.5,1.1,600


## Results from Baik et al. (2020)

In [99]:
ref = ("Baik et al. (2020)",
       "Sungyong Baik, Myungsub Choi, Janghoon Choi, Heewon Kim, Kyoung Mu Lee; "
       "[_Meta-Learning with Adaptive Hyperparameters_]"
       "(https://papers.nips.cc/paper/2020/hash/ee89223a2b625b5152132ed77abbcc79-Abstract.html); "
       "NeurIPS 2020.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### ALFA + fo-Proto-MAML (`alfa_protomaml`)

In [100]:
alfa_protomaml_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
alfa_protomaml_imagenet_df['# episodes'] = 600
alfa_protomaml_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [52.80, 1.11],
    [61.87, 1.51],
    [63.43, 1.10],
    [69.75, 1.05],
    [70.78, 0.88],
    [59.17, 1.16],
    [41.49, 1.17],
    [85.96, 0.77],
    [60.78, 1.29],
    [48.11, 1.14]
]
alfa_protomaml_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),52.8,1.11,600
Omniglot,61.87,1.51,600
Aircraft,63.43,1.1,600
Birds,69.75,1.05,600
Textures,70.78,0.88,600
QuickDraw,59.17,1.16,600
Fungi,41.49,1.17,600
VGG Flower,85.96,0.77,600
Traffic signs,60.78,1.29,600


### ALFA + fo-MAML (`alfa_maml`)
Not included in the global table as it performs worse than ALFA + fo-Proto-MAML overall, but provided here for reference.

In [101]:
alfa_maml_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
alfa_maml_imagenet_df['# episodes'] = 600
alfa_maml_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [51.09, 1.17],
    [67.89, 1.43],
    [66.34, 1.17],
    [67.67, 1.06],
    [65.34, 0.95],
    [60.53, 1.13],
    [37.41, 1.00],
    [84.28, 0.97],
    [60.86, 1.43],
    [40.05, 1.14]
]
alfa_maml_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),51.09,1.17,600
Omniglot,67.89,1.43,600
Aircraft,66.34,1.17,600
Birds,67.67,1.06,600
Textures,65.34,0.95,600
QuickDraw,60.53,1.13,600
Fungi,37.41,1.0,600
VGG Flower,84.28,0.97,600
Traffic signs,60.86,1.43,600


## Results from Doersch et al. (2020)
Carl Doersch, Ankush Gupta, Andrew Zisserman,
_CrossTransformers: spatially-aware few-shot transfer_,
NeurIPS 2020

In [102]:
ref = ("Doersch et al. (2020)",
       "Carl Doersch, Ankush Gupta, Andrew Zisserman; "
       "[_CrossTransformers: spatially-aware few-shot transfer_]"
       "(https://arxiv.org/abs/2007.11498); "
       "NeurIPS 2020.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### ProtoNet large (`protonet_large`)
Larger-scale prototypical networks, including:
- 224x224 input size
- ResNet-34 backbone
- SimCLR Episodes

In [103]:
protonet_large_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
protonet_large_imagenet_df['# episodes'] = 600
protonet_large_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
     [53.69, 1.07],
     [68.50, 1.27],
     [58.04, 0.96],
     [74.07, 0.92],
     [68.76, 0.77],
     [53.30, 1.06],
     [40.73, 1.15],
     [86.96, 0.73],
     [58.11, 1.05],
     [41.70, 1.08],
]
protonet_large_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),53.69,1.07,600
Omniglot,68.5,1.27,600
Aircraft,58.04,0.96,600
Birds,74.07,0.92,600
Textures,68.76,0.77,600
QuickDraw,53.3,1.06,600
Fungi,40.73,1.15,600
VGG Flower,86.96,0.73,600
Traffic signs,58.11,1.05,600


### CrossTransformers (`ctx`)

CrossTransformers network with:
- 224x224 input size
- ResNet-34 backbone
- SimCLR episodes
- 14x14 feature grid
- BOHB-inspired data augmentation

In [104]:
ctx_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
ctx_imagenet_df['# episodes'] = 600
ctx_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
      [62.76, 0.99],
      [82.21, 1.00],
      [79.49, 0.89],
      [80.63, 0.88],
      [75.57, 0.64],
      [72.68, 0.82],
      [51.58, 1.11],
      [95.34, 0.37],
      [82.65, 0.76],
      [59.90, 1.02],
]
ctx_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),62.76,0.99,600
Omniglot,82.21,1.0,600
Aircraft,79.49,0.89,600
Birds,80.63,0.88,600
Textures,75.57,0.64,600
QuickDraw,72.68,0.82,600
Fungi,51.58,1.11,600
VGG Flower,95.34,0.37,600
Traffic signs,82.65,0.76,600


## Results from Saikia et al. (2020)
Tonmoy Saikia, Thomas Brox, Cordelia Schmid, _Optimized Generic Feature Learning for Few-shot Classification across Domains_, arXiv 2020

In [105]:
ref = ("Saikia et al. (2020)",
       "Tonmoy Saikia, Thomas Brox, Cordelia Schmid; "
       "[_Optimized Generic Feature Learning for Few-shot Classification "
       "across Domains_]"
       "(https://arxiv.org/abs/2001.07926); "
       "arXiv 2020.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### BOHB (`bohb`)
Validated on _S1_ (ImageNet) only, nearest-centroid classifier (NC).

In [106]:
bohb_imagenet_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
bohb_imagenet_df['# episodes'] = 600
bohb_imagenet_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [51.92, 1.05],
    [67.57, 1.21],
    [54.12, 0.90],
    [70.69, 0.90],
    [68.34, 0.76],
    [50.33, 1.04],
    [41.38, 1.12],
    [87.34, 0.59],
    [51.80, 1.04],
    [48.03, 0.99],
]
bohb_imagenet_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),51.92,1.05,600
Omniglot,67.57,1.21,600
Aircraft,54.12,0.9,600
Birds,70.69,0.9,600
Textures,68.34,0.76,600
QuickDraw,50.33,1.04,600
Fungi,41.38,1.12,600
VGG Flower,87.34,0.59,600
Traffic signs,51.8,1.04,600


## Results from Dvornik et al. (2020)


In [107]:
ref = ("Dvornik et al. (2020)",
       "Nikita Dvornik, Cordelia Schmid, Julien Mairal; "
       "[_Selecting Relevant Features from a Multi-domain Representation for "
       "Few-shot Classification_](https://arxiv.org/abs/2003.09338); "
       "ECCV 2020.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### SUR (`sur`)

In [108]:
sur_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
sur_all_df['# episodes'] = 600
sur_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [56.1, 1.1],
    [93.1, 0.5],
    [84.6, 0.7],
    [70.6, 1.0],
    [71.0, 0.8],
    [81.3, 0.6],
    [64.2, 1.1],
    [82.8, 0.8],
    [53.4, 1.0],
    [50.1, 1.0],
]
sur_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),56.1,1.1,600
Omniglot,93.1,0.5,600
Aircraft,84.6,0.7,600
Birds,70.6,1.0,600
Textures,71.0,0.8,600
QuickDraw,81.3,0.6,600
Fungi,64.2,1.1,600
VGG Flower,82.8,0.8,600
Traffic signs,53.4,1.0,600


### SUR-pnf (`sur_pnf`)
SUR with parametric network family, also referred as "SUR-pf".

In [109]:
sur_pnf_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
sur_pnf_all_df['# episodes'] = 600
sur_pnf_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [56.0, 1.1],
    [90.0, 0.6],
    [79.7, 0.8],
    [75.9, 0.9],
    [72.5, 0.7],
    [76.7, 0.7],
    [49.8, 1.1],
    [90.0, 0.6],
    [52.2, 0.8],
    [50.2, 1.1],
]
sur_pnf_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),56.0,1.1,600
Omniglot,90.0,0.6,600
Aircraft,79.7,0.8,600
Birds,75.9,0.9,600
Textures,72.5,0.7,600
QuickDraw,76.7,0.7,600
Fungi,49.8,1.1,600
VGG Flower,90.0,0.6,600
Traffic signs,52.2,0.8,600


## Results from Bateni et al. (2020a)

In [110]:
ref = ("Bateni et al. (2020a)",
       "Peyman Bateni, Raghav Goyal, Vaden Masrani, Frank Wood, Leonid Sigal; "
       "[_Improved Few-Shot Visual Classification_]"
       "(https://openaccess.thecvf.com/content_CVPR_2020/html/Bateni_Improved_Few-Shot_Visual_Classification_CVPR_2020_paper.html); "
       "CVPR 2020.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### Simple CNAPS (`simple_cnaps`)

In [111]:
simple_cnaps_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
simple_cnaps_all_df['# episodes'] = 600
simple_cnaps_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [56.5, 1.1],
    [91.9, 0.6],
    [83.8, 0.6],
    [76.1, 0.9],
    [70.0, 0.8],
    [78.3, 0.7],
    [49.1, 1.2],
    [91.3, 0.6],
    [59.2, 1.0],
    [42.4, 1.1],
]
simple_cnaps_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),56.5,1.1,600
Omniglot,91.9,0.6,600
Aircraft,83.8,0.6,600
Birds,76.1,0.9,600
Textures,70.0,0.8,600
QuickDraw,78.3,0.7,600
Fungi,49.1,1.2,600
VGG Flower,91.3,0.6,600
Traffic signs,59.2,1.0,600


## Results from Bateni et al. (2020b)

In [112]:
ref = ("Bateni et al. (2020b)",
       "Peyman Bateni, Jarred Barber, Jan-Willem van de Meent, Frank Wood; "
       "[_Enhancing Few-Shot Image Classification with Unlabelled Examples_]"
       "(https://arxiv.org/abs/2006.12245); "
       "arXiv 2020.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### Transductive CNAPS (`transductive_cnaps`)

In [113]:
transductive_cnaps_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
transductive_cnaps_all_df['# episodes'] = 600
transductive_cnaps_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [57.9, 1.1],
    [94.3, 0.4],
    [84.7, 0.5],
    [78.8, 0.7],
    [66.2, 0.8],
    [77.9, 0.6],
    [48.9, 1.2],
    [92.3, 0.4],
    [59.7, 1.1],
    [42.5, 1.1],
]
transductive_cnaps_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),57.9,1.1,600
Omniglot,94.3,0.4,600
Aircraft,84.7,0.5,600
Birds,78.8,0.7,600
Textures,66.2,0.8,600
QuickDraw,77.9,0.6,600
Fungi,48.9,1.2,600
VGG Flower,92.3,0.4,600
Traffic signs,59.7,1.1,600


## Results from Liu et al. (2021)

Lu Liu, William Hamilton, Guodong Long, Jing Jiang, Hugo Larochelle,
_A Universal Representation Transformer Layer for Few-Shot Image Classification_, ICLR 2021



In [114]:
ref = ("Liu et al. (2021a)",
       "Lu Liu, William Hamilton, Guodong Long, Jing Jiang, Hugo Larochelle; "
       "[_Universal Representation Transformer Layer for Few-Shot Image "
       "Classification_](https://arxiv.org/abs/2006.11702); "
       "ICLR 2021.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### URT (`urt`)

In [115]:
urt_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
urt_all_df['# episodes'] = 600
urt_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [55.7, 1.0],
    [94.4, 0.4],
    [85.8, 0.6],
    [76.3, 0.8],
    [71.8, 0.7],
    [82.5, 0.6],
    [63.5, 1.0],
    [88.2, 0.6],
    [51.1, 1.1],
    [52.2, 1.1],
]
urt_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),55.7,1.0,600
Omniglot,94.4,0.4,600
Aircraft,85.8,0.6,600
Birds,76.3,0.8,600
Textures,71.8,0.7,600
QuickDraw,82.5,0.6,600
Fungi,63.5,1.0,600
VGG Flower,88.2,0.6,600
Traffic signs,51.1,1.1,600


### URT-pf (`urt-pf`)

In [116]:
urt_pf_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
urt_pf_all_df['# episodes'] = 600
urt_pf_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [55.5, 1.1],
    [90.2, 0.6],
    [79.8, 0.7],
    [77.5, 0.8],
    [73.5, 0.7],
    [75.8, 0.7],
    [48.1, 0.9],
    [91.9, 0.5],
    [52.0, 1.4],
    [52.1, 1.0],
]
urt_pf_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),55.5,1.1,600
Omniglot,90.2,0.6,600
Aircraft,79.8,0.7,600
Birds,77.5,0.8,600
Textures,73.5,0.7,600
QuickDraw,75.8,0.7,600
Fungi,48.1,0.9,600
VGG Flower,91.9,0.5,600
Traffic signs,52.0,1.4,600


## Results from Triantafillou et al. (2021)

Eleni Triantafillou, Hugo Larochelle, Richard Zemel, Vincent Dumoulin. Learning a Universal Template for Few-shot Dataset Generalization. ICML 2021.

In [117]:
ref = ("Triantafillou et al. (2021)",
       "Eleni Triantafillou, Hugo Larochelle, Richard Zemel, Vincent Dumoulin; "
       "[_Learning a Universal Template for Few-shot Dataset Generalization_]"
       "(https://arxiv.org/abs/2105.07029); "
       "ICML 2021.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

### FLUTE (`FLUTE`)

In [118]:
flute_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
flute_all_df['# episodes'] = 600
flute_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [51.8, 1.1],
    [93.2, 0.5],
    [87.2, 0.5],
    [79.2, 0.8],
    [68.8, 0.8],
    [79.5, 0.7],
    [58.1, 1.1],
    [91.6, 0.6],
    [58.4, 1.1],
    [50.0, 1.0],
]
flute_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),51.8,1.1,600
Omniglot,93.2,0.5,600
Aircraft,87.2,0.5,600
Birds,79.2,0.8,600
Textures,68.8,0.8,600
QuickDraw,79.5,0.7,600
Fungi,58.1,1.1,600
VGG Flower,91.6,0.6,600
Traffic signs,58.4,1.1,600


## Results from Li et al. (2021a)
Wei-Hong Li, Xialei Liu, Hakan Bilen. Universal Representation Learning from Multiple Domains for Few-shot Classification. ICCV 2021.

In [119]:
ref = ("Li et al. (2021a)",
       "Wei-Hong Li, Xialei Liu, Hakan Bilen; "
       "[_Universal Representation Learning from Multiple Domains for Few-shot Classification_]"
       "(https://arxiv.org/pdf/2103.13841.pdf); "
       "ICCV 2021.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

In [120]:
url_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
url_all_df['# episodes'] = 600
url_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [57.51, 1.08],
    [94.51, 0.41],
    [88.59, 0.46],
    [80.54, 0.69],
    [76.17, 0.67],
    [81.94, 0.56],
    [68.75, 0.95],
    [92.11, 0.48],
    [63.34, 1.19],
    [54.03, 0.96],
]
url_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),57.51,1.08,600
Omniglot,94.51,0.41,600
Aircraft,88.59,0.46,600
Birds,80.54,0.69,600
Textures,76.17,0.67,600
QuickDraw,81.94,0.56,600
Fungi,68.75,0.95,600
VGG Flower,92.11,0.48,600
Traffic signs,63.34,1.19,600




```
# This is formatted as code
```

## Results from Li et al. (2021b)
Wei-Hong Li, Xialei Liu, Hakan Bilen. Improving Task Adaptation for Cross-domain Few-shot Learning. arXiv 2021.

In [121]:
ref = ("Li et al. (2021b)",
       "Wei-Hong Li, Xialei Liu, Hakan Bilen; "
       "[_Improving Task Adaptation for Cross-domain Few-shot Learning_]"
       "(https://arxiv.org/pdf/2107.00358.pdf); "
       "arXiv 2021.")
references.append(ref)
# display.display(display.Markdown(ref[1]))

In [122]:
ita_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
ita_all_df['# episodes'] = 600
ita_all_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [
    [57.35, 1.05],
    [94.96, 0.38],
    [89.33, 0.44],
    [81.42, 0.74],
    [76.74, 0.72],
    [82.01, 0.57],
    [67.40, 0.99],
    [92.18, 0.52],
    [83.55, 0.90],
    [55.75, 1.06],
]
ita_all_df

Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),57.35,1.05,600
Omniglot,94.96,0.38,600
Aircraft,89.33,0.44,600
Birds,81.42,0.74,600
Textures,76.74,0.72,600
QuickDraw,82.01,0.57,600
Fungi,67.4,0.99,600
VGG Flower,92.18,0.52,600
Traffic signs,83.55,0.9,600


## Results from Liu et al. (2021)

Yanbin Liu, Juho Lee, Linchao Zhu, Ling Chen, Humphrey Shi and Yi Yang. A Multi-Mode Modulator for Multi-Domain Few-Shot Classification. ICCV 2021.

In [123]:
ref = ("Liu et al. (2021b)",
       "Yanbin Liu, Juho Lee, Linchao Zhu, Ling Chen, Humphrey Shi, Yi Yang; "
       "[_A Multi-Mode Modulator for Multi-Domain Few-Shot Classification_]"
       "(https://openaccess.thecvf.com/content/ICCV2021/papers/Liu_A_Multi-Mode_Modulator_for_Multi-Domain_Few-Shot_Classification_ICCV_2021_paper.pdf); "
       "ICCV 2021.")
references.append(ref)

In [124]:
triM_all_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
triM_all_df['# episodes'] = 600
triM_all_df.loc[1:, ['mean (%)', '95% CI']] = [
    [58.6, 1.0],
    [92.0, 0.6],
    [82.8, 0.7],
    [75.3, 0.8],
    [71.2, 0.8],
    [77.3, 0.7],
    [48.5, 1.0],
    [90.5, 0.5],
    [63.0, 1.0],
    [52.8, 1.1]
]
triM_all_df

  app.launch_new_instance()


Unnamed: 0,mean (%),95% CI,# episodes
ILSVRC (valid),,,600
ILSVRC (test),58.6,1.0,600
Omniglot,92.0,0.6,600
Aircraft,82.8,0.7,600
Birds,75.3,0.8,600
Textures,71.2,0.8,600
QuickDraw,77.3,0.7,600
Fungi,48.5,1.0,600
VGG Flower,90.5,0.5,600
Traffic signs,63.0,1.0,600


## Template to add a new paper

```
ref = ("Author et al. (year)",
       "First Author, Second Author, Last Author; "
       "[_Title of Paper_](https://paper.url/); "
       "Venue year.")
references.append(ref)
# display.display(display.Markdown(ref[1]))
```

### Template to add a new model

```
<model_name>_<train_source>_df = pd.DataFrame(
    columns=['mean (%)', '95% CI', '# episodes'],
    index=datasets
)
<model_name>_<train_source>_df['# episodes'] = ...
<model_name>_<train_source>_df.loc[datasets[1:], ['mean (%)', '95% CI']] = [...]
```

## Aggregate in table

In [125]:
imagenet_dfs = {
    'k-NN': baseline_imagenet_df,
    'Finetune': baselineft_imagenet_df,
    'MatchingNet': matching_imagenet_df,
    'ProtoNet': prototypical_imagenet_df,
    'fo-MAML': maml_imagenet_df,
    'RelationNet': relationnet_imagenet_df,
    'fo-Proto-MAML': protomaml_imagenet_df,
    'ALFA+fo-Proto-MAML': alfa_protomaml_imagenet_df,
    'ProtoNet (large)': protonet_large_imagenet_df,
    'CTX': ctx_imagenet_df,
    'BOHB': bohb_imagenet_df,
}

In [126]:
imagenet_df = pd.concat(
    imagenet_dfs.values(),
    axis=1,
    keys=imagenet_dfs.keys())
imagenet_df

Unnamed: 0_level_0,k-NN,k-NN,k-NN,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,fo-MAML,fo-MAML,fo-MAML,RelationNet,RelationNet,RelationNet,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,ProtoNet (large),ProtoNet (large),ProtoNet (large),CTX,CTX,CTX,BOHB,BOHB,BOHB
Unnamed: 0_level_1,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes
ILSVRC (valid),,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600
ILSVRC (test),41.03,1.01,600,45.78,1.1,600,45.0,1.1,600,50.5,1.08,600,45.51,1.11,600,34.69,1.01,600,49.53,1.05,600,52.8,1.11,600,53.69,1.07,600,62.76,0.99,600,51.92,1.05,600
Omniglot,37.07,1.15,600,60.85,1.58,600,52.27,1.28,600,59.98,1.35,600,55.55,1.54,600,45.35,1.36,600,63.37,1.33,600,61.87,1.51,600,68.5,1.27,600,82.21,1.0,600,67.57,1.21,600
Aircraft,46.81,0.89,600,68.69,1.26,600,48.97,0.93,600,53.1,1.0,600,56.24,1.11,600,40.73,0.83,600,55.95,0.99,600,63.43,1.1,600,58.04,0.96,600,79.49,0.89,600,54.12,0.9,600
Birds,50.13,1.0,600,57.31,1.26,600,62.21,0.95,600,68.79,1.01,600,63.61,1.06,600,49.51,1.05,600,68.66,0.96,600,69.75,1.05,600,74.07,0.92,600,80.63,0.88,600,70.69,0.9,600
Textures,66.36,0.75,600,69.05,0.9,600,64.15,0.85,600,66.56,0.83,600,68.04,0.81,600,52.97,0.69,600,66.49,0.83,600,70.78,0.88,600,68.76,0.77,600,75.57,0.64,600,68.34,0.76,600
QuickDraw,32.06,1.08,600,42.6,1.17,600,42.87,1.09,600,48.96,1.08,600,43.96,1.29,600,43.3,1.08,600,51.52,1.0,600,59.17,1.16,600,53.3,1.06,600,72.68,0.82,600,50.33,1.04,600
Fungi,36.16,1.02,600,38.2,1.02,600,33.97,1.0,600,39.71,1.11,600,32.1,1.1,600,30.55,1.04,600,39.96,1.14,600,41.49,1.17,600,40.73,1.15,600,51.58,1.11,600,41.38,1.12,600
VGG Flower,83.1,0.68,600,85.51,0.68,600,80.13,0.71,600,85.27,0.77,600,81.74,0.83,600,68.76,0.83,600,87.15,0.69,600,85.96,0.77,600,86.96,0.73,600,95.34,0.37,600,87.34,0.59,600
Traffic signs,44.59,1.19,600,66.79,1.31,600,47.8,1.14,600,47.12,1.1,600,50.93,1.51,600,33.67,1.05,600,48.83,1.09,600,60.78,1.29,600,58.11,1.05,600,82.65,0.76,600,51.8,1.04,600


In [127]:
all_dfs = {
    'k-NN': baseline_all_df,
    'Finetune': baselineft_all_df,
    'MatchingNet': matching_all_df,
    'ProtoNet': prototypical_all_df,
    'fo-MAML': maml_all_df,
    'RelationNet': relationnet_all_df,
    'fo-Proto-MAML': protomaml_all_df,
    'CNAPs': cnaps_all_df,
    'SUR': sur_all_df,
    'SUR-pnf': sur_pnf_all_df,
    'SimpleCNAPS': simple_cnaps_all_df,
    'TransductiveCNAPS': transductive_cnaps_all_df,
    'URT': urt_all_df,
    'URT-pf': urt_pf_all_df,
    'FLUTE': flute_all_df,
    'URL': url_all_df,
    'ITA': ita_all_df,
    'TriM': triM_all_df,
}
all_df = pd.concat(
    all_dfs.values(),
    axis=1,
    keys=all_dfs.keys())
all_df

Unnamed: 0_level_0,k-NN,k-NN,k-NN,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,fo-MAML,fo-MAML,fo-MAML,RelationNet,RelationNet,RelationNet,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,CNAPs,CNAPs,CNAPs,SUR,SUR,SUR,SUR-pnf,SUR-pnf,SUR-pnf,SimpleCNAPS,SimpleCNAPS,SimpleCNAPS,TransductiveCNAPS,TransductiveCNAPS,TransductiveCNAPS,URT,URT,URT,URT-pf,URT-pf,URT-pf,FLUTE,FLUTE,FLUTE,URL,URL,URL,ITA,ITA,ITA,TriM,TriM,TriM
Unnamed: 0_level_1,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes,mean (%),95% CI,# episodes
ILSVRC (valid),,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600,,,600
ILSVRC (test),38.55,0.94,600,43.08,1.08,600,36.08,1.0,600,44.5,1.05,600,37.83,1.01,600,30.89,0.93,600,46.52,1.05,600,50.8,1.1,600,56.1,1.1,600,56.0,1.1,600,56.5,1.1,600,57.9,1.1,600,55.7,1.0,600,55.5,1.1,600,51.8,1.1,600,57.51,1.08,600,57.35,1.05,600,58.6,1.0,600
Omniglot,74.6,1.08,600,71.11,1.37,600,78.25,1.01,600,79.56,1.12,600,83.92,0.95,600,86.57,0.79,600,82.69,0.97,600,91.7,0.5,600,93.1,0.5,600,90.0,0.6,600,91.9,0.6,600,94.3,0.4,600,94.4,0.4,600,90.2,0.6,600,93.2,0.5,600,94.51,0.41,600,94.96,0.38,600,92.0,0.6,600
Aircraft,64.98,0.82,600,72.03,1.07,600,69.17,0.96,600,71.14,0.86,600,76.41,0.69,600,69.71,0.83,600,75.23,0.76,600,83.7,0.6,600,84.6,0.7,600,79.7,0.8,600,83.8,0.6,600,84.7,0.5,600,85.8,0.6,600,79.8,0.7,600,87.2,0.5,600,88.59,0.46,600,89.33,0.44,600,82.8,0.7,600
Birds,66.35,0.92,600,59.82,1.15,600,56.4,1.0,600,67.01,1.02,600,62.43,1.08,600,54.14,0.99,600,69.88,1.02,600,73.6,0.9,600,70.6,1.0,600,75.9,0.9,600,76.1,0.9,600,78.8,0.7,600,76.3,0.8,600,77.5,0.8,600,79.2,0.8,600,80.54,0.69,600,81.42,0.74,600,75.3,0.8,600
Textures,63.58,0.79,600,69.14,0.85,600,61.8,0.74,600,65.18,0.84,600,64.16,0.83,600,56.56,0.73,600,68.25,0.81,600,59.5,0.7,600,71.0,0.8,600,72.5,0.7,600,70.0,0.8,600,66.2,0.8,600,71.8,0.7,600,73.5,0.7,600,68.8,0.8,600,76.17,0.67,600,76.74,0.72,600,71.2,0.8,600
QuickDraw,44.88,1.05,600,47.05,1.16,600,60.81,1.03,600,64.88,0.89,600,59.73,1.1,600,61.75,0.97,600,66.84,0.94,600,74.7,0.8,600,81.3,0.6,600,76.7,0.7,600,78.3,0.7,600,77.9,0.6,600,82.5,0.6,600,75.8,0.7,600,79.5,0.7,600,81.94,0.56,600,82.01,0.57,600,77.3,0.7,600
Fungi,37.12,1.06,600,38.16,1.04,600,33.7,1.04,600,40.26,1.13,600,33.54,1.11,600,32.56,1.08,600,41.99,1.17,600,50.2,1.1,600,64.2,1.1,600,49.8,1.1,600,49.1,1.2,600,48.9,1.2,600,63.5,1.0,600,48.1,0.9,600,58.1,1.1,600,68.75,0.95,600,67.4,0.99,600,48.5,1.0,600
VGG Flower,83.47,0.61,600,85.28,0.69,600,81.9,0.72,600,86.85,0.71,600,79.94,0.84,600,76.08,0.76,600,88.72,0.67,600,88.9,0.5,600,82.8,0.8,600,90.0,0.6,600,91.3,0.6,600,92.3,0.4,600,88.2,0.6,600,91.9,0.5,600,91.6,0.6,600,92.11,0.48,600,92.18,0.52,600,90.5,0.5,600
Traffic signs,40.11,1.1,600,66.74,1.23,600,55.57,1.08,600,46.48,1.0,600,42.91,1.31,600,37.48,0.93,600,52.42,1.08,600,56.5,1.1,600,53.4,1.0,600,52.2,0.8,600,59.2,1.0,600,59.7,1.1,600,51.1,1.1,600,52.0,1.4,600,58.4,1.1,600,63.34,1.19,600,83.55,0.9,600,63.0,1.0,600


In [128]:
models_df = pd.DataFrame.from_dict(
    orient='index',
    columns=["ref"],
    data={
        'k-NN': 'Triantafillou et al. (2020)',
        'Finetune': 'Triantafillou et al. (2020)',
        'MatchingNet': 'Triantafillou et al. (2020)',
        'ProtoNet': 'Triantafillou et al. (2020)',
        'fo-MAML': 'Triantafillou et al. (2020)',
        'RelationNet': 'Triantafillou et al. (2020)',
        'fo-Proto-MAML': 'Triantafillou et al. (2020)',
        'CNAPs': 'Requeima et al. (2019)',
        'ALFA+fo-Proto-MAML': 'Baik et al. (2020)',
        'ProtoNet (large)': 'Doersch et al. (2020)',
        'CTX': 'Doersch et al. (2020)',
        'BOHB': 'Saikia et al. (2020)',
        'SUR': 'Dvornik et al. (2020)',
        'SUR-pnf': 'Dvornik et al. (2020)',
        'SimpleCNAPS': 'Bateni et al. (2020a)',
        'TransductiveCNAPS': 'Bateni et al. (2020b)',
        'URT': 'Liu et al. (2021a)',
        'URT-pf': 'Liu et al. (2021a)',
        'FLUTE': 'Triantafillou et al. (2021)',
        'URL': 'Li et al. (2021a)',
        'ITA': 'Li et al. (2021b)',
        'TriM': 'Liu et al. (2021b)',
        })
models_df

Unnamed: 0,ref
k-NN,Triantafillou et al. (2020)
Finetune,Triantafillou et al. (2020)
MatchingNet,Triantafillou et al. (2020)
ProtoNet,Triantafillou et al. (2020)
fo-MAML,Triantafillou et al. (2020)
RelationNet,Triantafillou et al. (2020)
fo-Proto-MAML,Triantafillou et al. (2020)
CNAPs,Requeima et al. (2019)
ALFA+fo-Proto-MAML,Baik et al. (2020)
ProtoNet (large),Doersch et al. (2020)


### Add stddev

In [129]:
def add_stddev(df):
  # Extract original order of labels
  datasets = df.index
  models = df.columns.levels[0]
  # Have only one result (mean, CI, ...) per row
  stacked_df = df.stack(0)
  # Add 'stddev' as column
  stacked_df['stddev'] = stacked_df['95% CI'] * np.sqrt(stacked_df['# episodes']) / 1.96
  # Reshape and put back in original order
  new_df = stacked_df.unstack().swaplevel(0, 1, axis=1)
  new_df = new_df.loc[datasets][models]
  return new_df

In [130]:
imagenet_df = add_stddev(imagenet_df)
imagenet_df

Unnamed: 0_level_0,k-NN,k-NN,k-NN,k-NN,Finetune,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,fo-MAML,fo-MAML,fo-MAML,fo-MAML,RelationNet,RelationNet,RelationNet,RelationNet,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,ProtoNet (large),ProtoNet (large),ProtoNet (large),ProtoNet (large),CTX,CTX,CTX,CTX,BOHB,BOHB,BOHB,BOHB
Unnamed: 0_level_1,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev
ILSVRC (valid),600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,
ILSVRC (test),600,1.01,41.03,12.6224,600,1.1,45.78,13.7471,600,1.1,45.0,13.7471,600,1.08,50.5,13.4972,600,1.11,45.51,13.8721,600,1.01,34.69,12.6224,600,1.05,49.53,13.1223,600,1.11,52.8,13.8721,600,1.07,53.69,13.3722,600,0.99,62.76,12.3724,600,1.05,51.92,13.1223
Omniglot,600,1.15,37.07,14.372,600,1.58,60.85,19.7459,600,1.28,52.27,15.9967,600,1.35,59.98,16.8715,600,1.54,55.55,19.246,600,1.36,45.35,16.9965,600,1.33,63.37,16.6215,600,1.51,61.87,18.8711,600,1.27,68.5,15.8717,600,1.0,82.21,12.4974,600,1.21,67.57,15.1218
Aircraft,600,0.89,46.81,11.1227,600,1.26,68.69,15.7467,600,0.93,48.97,11.6226,600,1.0,53.1,12.4974,600,1.11,56.24,13.8721,600,0.83,40.73,10.3728,600,0.99,55.95,12.3724,600,1.1,63.43,13.7471,600,0.96,58.04,11.9975,600,0.89,79.49,11.1227,600,0.9,54.12,11.2477
Birds,600,1.0,50.13,12.4974,600,1.26,57.31,15.7467,600,0.95,62.21,11.8725,600,1.01,68.79,12.6224,600,1.06,63.61,13.2472,600,1.05,49.51,13.1223,600,0.96,68.66,11.9975,600,1.05,69.75,13.1223,600,0.92,74.07,11.4976,600,0.88,80.63,10.9977,600,0.9,70.69,11.2477
Textures,600,0.75,66.36,9.37305,600,0.9,69.05,11.2477,600,0.85,64.15,10.6228,600,0.83,66.56,10.3728,600,0.81,68.04,10.1229,600,0.69,52.97,8.6232,600,0.83,66.49,10.3728,600,0.88,70.78,10.9977,600,0.77,68.76,9.623,600,0.64,75.57,7.99833,600,0.76,68.34,9.49802
QuickDraw,600,1.08,32.06,13.4972,600,1.17,42.6,14.622,600,1.09,42.87,13.6222,600,1.08,48.96,13.4972,600,1.29,43.96,16.1216,600,1.08,43.3,13.4972,600,1.0,51.52,12.4974,600,1.16,59.17,14.497,600,1.06,53.3,13.2472,600,0.82,72.68,10.2479,600,1.04,50.33,12.9973
Fungi,600,1.02,36.16,12.7473,600,1.02,38.2,12.7473,600,1.0,33.97,12.4974,600,1.11,39.71,13.8721,600,1.1,32.1,13.7471,600,1.04,30.55,12.9973,600,1.14,39.96,14.247,600,1.17,41.49,14.622,600,1.15,40.73,14.372,600,1.11,51.58,13.8721,600,1.12,41.38,13.9971
VGG Flower,600,0.68,83.1,8.49823,600,0.68,85.51,8.49823,600,0.71,80.13,8.87315,600,0.77,85.27,9.623,600,0.83,81.74,10.3728,600,0.83,68.76,10.3728,600,0.69,87.15,8.6232,600,0.77,85.96,9.623,600,0.73,86.96,9.1231,600,0.37,95.34,4.62404,600,0.59,87.34,7.37346
Traffic signs,600,1.19,44.59,14.8719,600,1.31,66.79,16.3716,600,1.14,47.8,14.247,600,1.1,47.12,13.7471,600,1.51,50.93,18.8711,600,1.05,33.67,13.1223,600,1.09,48.83,13.6222,600,1.29,60.78,16.1216,600,1.05,58.11,13.1223,600,0.76,82.65,9.49802,600,1.04,51.8,12.9973


In [131]:
all_df = add_stddev(all_df)
all_df

Unnamed: 0_level_0,k-NN,k-NN,k-NN,k-NN,Finetune,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,fo-MAML,fo-MAML,fo-MAML,fo-MAML,RelationNet,RelationNet,RelationNet,RelationNet,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,CNAPs,CNAPs,CNAPs,CNAPs,SUR,SUR,SUR,SUR,SUR-pnf,SUR-pnf,SUR-pnf,SUR-pnf,SimpleCNAPS,SimpleCNAPS,SimpleCNAPS,SimpleCNAPS,TransductiveCNAPS,TransductiveCNAPS,TransductiveCNAPS,TransductiveCNAPS,URT,URT,URT,URT,URT-pf,URT-pf,URT-pf,URT-pf,FLUTE,FLUTE,FLUTE,FLUTE,URL,URL,URL,URL,ITA,ITA,ITA,ITA,TriM,TriM,TriM,TriM
Unnamed: 0_level_1,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev,# episodes,95% CI,mean (%),stddev
ILSVRC (valid),600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,,600,,,
ILSVRC (test),600,0.94,38.55,11.7476,600,1.08,43.08,13.4972,600,1.0,36.08,12.4974,600,1.05,44.5,13.1223,600,1.01,37.83,12.6224,600,0.93,30.89,11.6226,600,1.05,46.52,13.1223,600,1.1,50.8,13.7471,600,1.1,56.1,13.7471,600,1.1,56.0,13.7471,600,1.1,56.5,13.7471,600,1.1,57.9,13.7471,600,1.0,55.7,12.4974,600,1.1,55.5,13.7471,600,1.1,51.8,13.7471,600,1.08,57.51,13.4972,600,1.05,57.35,13.1223,600,1.0,58.6,12.4974
Omniglot,600,1.08,74.6,13.4972,600,1.37,71.11,17.1214,600,1.01,78.25,12.6224,600,1.12,79.56,13.9971,600,0.95,83.92,11.8725,600,0.79,86.57,9.87294,600,0.97,82.69,12.1225,600,0.5,91.7,6.2487,600,0.5,93.1,6.2487,600,0.6,90.0,7.49844,600,0.6,91.9,7.49844,600,0.4,94.3,4.99896,600,0.4,94.4,4.99896,600,0.6,90.2,7.49844,600,0.5,93.2,6.2487,600,0.41,94.51,5.12393,600,0.38,94.96,4.74901,600,0.6,92.0,7.49844
Aircraft,600,0.82,64.98,10.2479,600,1.07,72.03,13.3722,600,0.96,69.17,11.9975,600,0.86,71.14,10.7478,600,0.69,76.41,8.6232,600,0.83,69.71,10.3728,600,0.76,75.23,9.49802,600,0.6,83.7,7.49844,600,0.7,84.6,8.74818,600,0.8,79.7,9.99792,600,0.6,83.8,7.49844,600,0.5,84.7,6.2487,600,0.6,85.8,7.49844,600,0.7,79.8,8.74818,600,0.5,87.2,6.2487,600,0.46,88.59,5.7488,600,0.44,89.33,5.49885,600,0.7,82.8,8.74818
Birds,600,0.92,66.35,11.4976,600,1.15,59.82,14.372,600,1.0,56.4,12.4974,600,1.02,67.01,12.7473,600,1.08,62.43,13.4972,600,0.99,54.14,12.3724,600,1.02,69.88,12.7473,600,0.9,73.6,11.2477,600,1.0,70.6,12.4974,600,0.9,75.9,11.2477,600,0.9,76.1,11.2477,600,0.7,78.8,8.74818,600,0.8,76.3,9.99792,600,0.8,77.5,9.99792,600,0.8,79.2,9.99792,600,0.69,80.54,8.6232,600,0.74,81.42,9.24807,600,0.8,75.3,9.99792
Textures,600,0.79,63.58,9.87294,600,0.85,69.14,10.6228,600,0.74,61.8,9.24807,600,0.84,65.18,10.4978,600,0.83,64.16,10.3728,600,0.73,56.56,9.1231,600,0.81,68.25,10.1229,600,0.7,59.5,8.74818,600,0.8,71.0,9.99792,600,0.7,72.5,8.74818,600,0.8,70.0,9.99792,600,0.8,66.2,9.99792,600,0.7,71.8,8.74818,600,0.7,73.5,8.74818,600,0.8,68.8,9.99792,600,0.67,76.17,8.37326,600,0.72,76.74,8.99813,600,0.8,71.2,9.99792
QuickDraw,600,1.05,44.88,13.1223,600,1.16,47.05,14.497,600,1.03,60.81,12.8723,600,0.89,64.88,11.1227,600,1.1,59.73,13.7471,600,0.97,61.75,12.1225,600,0.94,66.84,11.7476,600,0.8,74.7,9.99792,600,0.6,81.3,7.49844,600,0.7,76.7,8.74818,600,0.7,78.3,8.74818,600,0.6,77.9,7.49844,600,0.6,82.5,7.49844,600,0.7,75.8,8.74818,600,0.7,79.5,8.74818,600,0.56,81.94,6.99854,600,0.57,82.01,7.12352,600,0.7,77.3,8.74818
Fungi,600,1.06,37.12,13.2472,600,1.04,38.16,12.9973,600,1.04,33.7,12.9973,600,1.13,40.26,14.1221,600,1.11,33.54,13.8721,600,1.08,32.56,13.4972,600,1.17,41.99,14.622,600,1.1,50.2,13.7471,600,1.1,64.2,13.7471,600,1.1,49.8,13.7471,600,1.2,49.1,14.9969,600,1.2,48.9,14.9969,600,1.0,63.5,12.4974,600,0.9,48.1,11.2477,600,1.1,58.1,13.7471,600,0.95,68.75,11.8725,600,0.99,67.4,12.3724,600,1.0,48.5,12.4974
VGG Flower,600,0.61,83.47,7.62341,600,0.69,85.28,8.6232,600,0.72,81.9,8.99813,600,0.71,86.85,8.87315,600,0.84,79.94,10.4978,600,0.76,76.08,9.49802,600,0.67,88.72,8.37326,600,0.5,88.9,6.2487,600,0.8,82.8,9.99792,600,0.6,90.0,7.49844,600,0.6,91.3,7.49844,600,0.4,92.3,4.99896,600,0.6,88.2,7.49844,600,0.5,91.9,6.2487,600,0.6,91.6,7.49844,600,0.48,92.11,5.99875,600,0.52,92.18,6.49865,600,0.5,90.5,6.2487
Traffic signs,600,1.1,40.11,13.7471,600,1.23,66.74,15.3718,600,1.08,55.57,13.4972,600,1.0,46.48,12.4974,600,1.31,42.91,16.3716,600,0.93,37.48,11.6226,600,1.08,52.42,13.4972,600,1.1,56.5,13.7471,600,1.0,53.4,12.4974,600,0.8,52.2,9.99792,600,1.0,59.2,12.4974,600,1.1,59.7,13.7471,600,1.1,51.1,13.7471,600,1.4,52.0,17.4964,600,1.1,58.4,13.7471,600,1.19,63.34,14.8719,600,0.9,83.55,11.2477,600,1.0,63.0,12.4974


### Add rankings

In [132]:
def is_difference_significant(best_stats, candidate_stats):
  # compute a 95% confidence for the difference of means.
  ci = 1.96 * np.sqrt((best_stats['stddev'] ** 2) / best_stats['# episodes'] +
                      (candidate_stats['stddev'] ** 2) / candidate_stats['# episodes'])
  diff_of_means = best_stats['mean (%)'] - candidate_stats['mean (%)']
  return np.abs(diff_of_means) > ci

In [133]:
def compute_ranks(dataset_series):
  dataset_df = dataset_series.unstack()
  n_models = len(dataset_df.index)
  remaining_models = list(dataset_df.index)
  next_available_rank = 1
  ranks = {}
  # Iteratively pick the best models, then all the ones statistically equivalent
  while remaining_models:
    accuracies = dataset_df.loc[remaining_models]['mean (%)'].astype('d')
    best_model = accuracies.idxmax(axis=1)
    best_stats = dataset_df.loc[best_model]
    tied_models = [best_model]
    potential_tied_models = [model for model in remaining_models
                             if model != best_model]
    for candidate in potential_tied_models:
      candidate_stats = dataset_df.loc[candidate]
      if not is_difference_significant(best_stats, candidate_stats):
        tied_models.append(candidate)

    n_ties = len(tied_models)
    # All tied models share the same rank, which is the average of the next
    # `n_ties` available ranks (the ranks they would have without the ties), or
    # next_available_rank + (1 + ... + (n_ties - 1)) / n_ties, which gives:
    shared_rank = next_available_rank + (n_ties - 1) / 2
    next_available_rank += n_ties
    for model in tied_models:
      ranks[model] = shared_rank

    # Remove picked models for next iteration
    remaining_models = [model for model in remaining_models
                        if model not in tied_models]
  return pd.Series(ranks, name='rank')

In [134]:
def add_ranks(df):
  # Get ranks as a data frame (ignore "ILSVRC (valid)")
  ranks = df[1:].apply(compute_ranks, axis=1)
  # Set the columns as (model, 'rank') Multi-index
  ranks = pd.concat([ranks], axis=1, keys=['rank']).swaplevel(0, 1, axis=1)
  # Concatenate with the original dataframe and defrag columns
  new_df = pd.concat([df, ranks], axis=1)[df.columns.levels[0]]
  return new_df

In [135]:
imagenet_df = add_ranks(imagenet_df)
imagenet_df

Unnamed: 0_level_0,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,ALFA+fo-Proto-MAML,BOHB,BOHB,BOHB,BOHB,BOHB,CTX,CTX,CTX,CTX,CTX,Finetune,Finetune,Finetune,Finetune,Finetune,MatchingNet,MatchingNet,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet (large),ProtoNet (large),ProtoNet (large),ProtoNet (large),ProtoNet (large),RelationNet,RelationNet,RelationNet,RelationNet,RelationNet,fo-MAML,fo-MAML,fo-MAML,fo-MAML,fo-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,k-NN,k-NN,k-NN,k-NN,k-NN
Unnamed: 0_level_1,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank
ILSVRC (valid),600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,
ILSVRC (test),600,1.11,52.8,13.8721,2.5,600,1.05,51.92,13.1223,4.5,600,0.99,62.76,12.3724,1.0,600,1.1,45.78,13.7471,8.0,600,1.1,45.0,13.7471,8.0,600,1.08,50.5,13.4972,4.5,600,1.07,53.69,13.3722,2.5,600,1.01,34.69,12.6224,11.0,600,1.11,45.51,13.8721,8.0,600,1.05,49.53,13.1223,6.0,600,1.01,41.03,12.6224,10.0
Omniglot,600,1.51,61.87,18.8711,4.5,600,1.21,67.57,15.1218,2.5,600,1.0,82.21,12.4974,1.0,600,1.58,60.85,19.7459,6.5,600,1.28,52.27,15.9967,9.0,600,1.35,59.98,16.8715,6.5,600,1.27,68.5,15.8717,2.5,600,1.36,45.35,16.9965,10.0,600,1.54,55.55,19.246,8.0,600,1.33,63.37,16.6215,4.5,600,1.15,37.07,14.372,11.0
Aircraft,600,1.1,63.43,13.7471,3.0,600,0.9,54.12,11.2477,7.5,600,0.89,79.49,11.1227,1.0,600,1.26,68.69,15.7467,2.0,600,0.93,48.97,11.6226,9.0,600,1.0,53.1,12.4974,7.5,600,0.96,58.04,11.9975,4.0,600,0.83,40.73,10.3728,11.0,600,1.11,56.24,13.8721,5.5,600,0.99,55.95,12.3724,5.5,600,0.89,46.81,11.1227,10.0
Birds,600,1.05,69.75,13.1223,3.5,600,0.9,70.69,11.2477,3.5,600,0.88,80.63,10.9977,1.0,600,1.26,57.31,15.7467,9.0,600,0.95,62.21,11.8725,7.5,600,1.01,68.79,12.6224,5.5,600,0.92,74.07,11.4976,2.0,600,1.05,49.51,13.1223,10.5,600,1.06,63.61,13.2472,7.5,600,0.96,68.66,11.9975,5.5,600,1.0,50.13,12.4974,10.5
Textures,600,0.88,70.78,10.9977,2.0,600,0.76,68.34,9.49802,4.5,600,0.64,75.57,7.99833,1.0,600,0.9,69.05,11.2477,4.5,600,0.85,64.15,10.6228,10.0,600,0.83,66.56,10.3728,8.0,600,0.77,68.76,9.623,4.5,600,0.69,52.97,8.6232,11.0,600,0.81,68.04,10.1229,4.5,600,0.83,66.49,10.3728,8.0,600,0.75,66.36,9.37305,8.0
QuickDraw,600,1.16,59.17,14.497,2.0,600,1.04,50.33,12.9973,4.5,600,0.82,72.68,10.2479,1.0,600,1.17,42.6,14.622,8.5,600,1.09,42.87,13.6222,8.5,600,1.08,48.96,13.4972,6.0,600,1.06,53.3,13.2472,3.0,600,1.08,43.3,13.4972,8.5,600,1.29,43.96,16.1216,8.5,600,1.0,51.52,12.4974,4.5,600,1.08,32.06,13.4972,11.0
Fungi,600,1.17,41.49,14.622,3.5,600,1.12,41.38,13.9971,3.5,600,1.11,51.58,13.8721,1.0,600,1.02,38.2,12.7473,7.0,600,1.0,33.97,12.4974,9.0,600,1.11,39.71,13.8721,6.0,600,1.15,40.73,14.372,3.5,600,1.04,30.55,12.9973,11.0,600,1.1,32.1,13.7471,10.0,600,1.14,39.96,14.247,3.5,600,1.02,36.16,12.7473,8.0
VGG Flower,600,0.77,85.96,9.623,6.0,600,0.59,87.34,7.37346,3.0,600,0.37,95.34,4.62404,1.0,600,0.68,85.51,8.49823,6.0,600,0.71,80.13,8.87315,10.0,600,0.77,85.27,9.623,6.0,600,0.73,86.96,9.1231,3.0,600,0.83,68.76,10.3728,11.0,600,0.83,81.74,10.3728,9.0,600,0.69,87.15,8.6232,3.0,600,0.68,83.1,8.49823,8.0
Traffic signs,600,1.29,60.78,16.1216,3.0,600,1.04,51.8,12.9973,5.5,600,0.76,82.65,9.49802,1.0,600,1.31,66.79,16.3716,2.0,600,1.14,47.8,14.247,7.5,600,1.1,47.12,13.7471,9.0,600,1.05,58.11,13.1223,4.0,600,1.05,33.67,13.1223,11.0,600,1.51,50.93,18.8711,5.5,600,1.09,48.83,13.6222,7.5,600,1.19,44.59,14.8719,10.0


In [136]:
all_df = add_ranks(all_df)
all_df

Unnamed: 0_level_0,CNAPs,CNAPs,CNAPs,CNAPs,CNAPs,FLUTE,FLUTE,FLUTE,FLUTE,FLUTE,Finetune,Finetune,Finetune,Finetune,Finetune,ITA,ITA,ITA,ITA,ITA,MatchingNet,MatchingNet,MatchingNet,MatchingNet,MatchingNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,ProtoNet,RelationNet,RelationNet,RelationNet,RelationNet,RelationNet,SUR,SUR,SUR,SUR,SUR,...,TransductiveCNAPS,TransductiveCNAPS,TransductiveCNAPS,TransductiveCNAPS,TransductiveCNAPS,TriM,TriM,TriM,TriM,TriM,URL,URL,URL,URL,URL,URT,URT,URT,URT,URT,URT-pf,URT-pf,URT-pf,URT-pf,URT-pf,fo-MAML,fo-MAML,fo-MAML,fo-MAML,fo-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,fo-Proto-MAML,k-NN,k-NN,k-NN,k-NN,k-NN
Unnamed: 0_level_1,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,...,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank,# episodes,95% CI,mean (%),stddev,rank
ILSVRC (valid),600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,...,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,,600,,,,
ILSVRC (test),600,1.1,50.8,13.7471,10.5,600,1.1,51.8,13.7471,10.5,600,1.08,43.08,13.4972,13.5,600,1.05,57.35,13.1223,2.5,600,1.0,36.08,12.4974,17.0,600,1.05,44.5,13.1223,13.5,600,0.93,30.89,11.6226,18.0,600,1.1,56.1,13.7471,7.0,...,600,1.1,57.9,13.7471,2.5,600,1.0,58.6,12.4974,2.5,600,1.08,57.51,13.4972,2.5,600,1.0,55.7,12.4974,7.0,600,1.1,55.5,13.7471,7.0,600,1.01,37.83,12.6224,15.5,600,1.05,46.52,13.1223,12.0,600,0.94,38.55,11.7476,15.5
Omniglot,600,0.5,91.7,6.2487,8.0,600,0.5,93.2,6.2487,5.5,600,1.37,71.11,17.1214,18.0,600,0.38,94.96,4.74901,1.5,600,1.01,78.25,12.6224,15.5,600,1.12,79.56,13.9971,15.5,600,0.79,86.57,9.87294,12.0,600,0.5,93.1,6.2487,5.5,...,600,0.4,94.3,4.99896,3.5,600,0.6,92.0,7.49844,8.0,600,0.41,94.51,5.12393,1.5,600,0.4,94.4,4.99896,3.5,600,0.6,90.2,7.49844,10.5,600,0.95,83.92,11.8725,13.5,600,0.97,82.69,12.1225,13.5,600,1.08,74.6,13.4972,17.0
Aircraft,600,0.6,83.7,7.49844,7.5,600,0.5,87.2,6.2487,3.0,600,1.07,72.03,13.3722,14.5,600,0.44,89.33,5.49885,1.0,600,0.96,69.17,11.9975,16.5,600,0.86,71.14,10.7478,14.5,600,0.83,69.71,10.3728,16.5,600,0.7,84.6,8.74818,5.5,...,600,0.5,84.7,6.2487,5.5,600,0.7,82.8,8.74818,9.0,600,0.46,88.59,5.7488,2.0,600,0.6,85.8,7.49844,4.0,600,0.7,79.8,8.74818,10.5,600,0.69,76.41,8.6232,12.0,600,0.76,75.23,9.49802,13.0,600,0.82,64.98,10.2479,18.0
Birds,600,0.9,73.6,11.2477,10.0,600,0.8,79.2,9.99792,3.5,600,1.15,59.82,14.372,16.0,600,0.74,81.42,9.24807,1.5,600,1.0,56.4,12.4974,17.0,600,1.02,67.01,12.7473,13.5,600,0.99,54.14,12.3724,18.0,600,1.0,70.6,12.4974,11.5,...,600,0.7,78.8,8.74818,3.5,600,0.8,75.3,9.99792,7.5,600,0.69,80.54,8.6232,1.5,600,0.8,76.3,9.99792,7.5,600,0.8,77.5,9.99792,5.0,600,1.08,62.43,13.4972,15.0,600,1.02,69.88,12.7473,11.5,600,0.92,66.35,11.4976,13.5
Textures,600,0.7,59.5,8.74818,17.0,600,0.8,68.8,9.99792,10.5,600,0.85,69.14,10.6228,8.5,600,0.72,76.74,8.99813,1.5,600,0.74,61.8,9.24807,16.0,600,0.84,65.18,10.4978,12.5,600,0.73,56.56,9.1231,18.0,600,0.8,71.0,9.99792,6.5,...,600,0.8,66.2,9.99792,12.5,600,0.8,71.2,9.99792,6.5,600,0.67,76.17,8.37326,1.5,600,0.7,71.8,8.74818,4.5,600,0.7,73.5,8.74818,3.0,600,0.83,64.16,10.3728,14.5,600,0.81,68.25,10.1229,10.5,600,0.79,63.58,9.87294,14.5
QuickDraw,600,0.8,74.7,9.99792,11.0,600,0.7,79.5,8.74818,5.0,600,1.16,47.05,14.497,17.0,600,0.57,82.01,7.12352,2.0,600,1.03,60.81,12.8723,14.5,600,0.89,64.88,11.1227,13.0,600,0.97,61.75,12.1225,14.5,600,0.6,81.3,7.49844,4.0,...,600,0.6,77.9,7.49844,6.5,600,0.7,77.3,8.74818,8.5,600,0.56,81.94,6.99854,2.0,600,0.6,82.5,7.49844,2.0,600,0.7,75.8,8.74818,10.0,600,1.1,59.73,13.7471,16.0,600,0.94,66.84,11.7476,12.0,600,1.05,44.88,13.1223,18.0
Fungi,600,1.1,50.2,13.7471,7.5,600,1.1,58.1,13.7471,5.0,600,1.04,38.16,12.9973,14.5,600,0.99,67.4,12.3724,1.5,600,1.04,33.7,12.9973,17.0,600,1.13,40.26,14.1221,13.0,600,1.08,32.56,13.4972,17.0,600,1.1,64.2,13.7471,3.5,...,600,1.2,48.9,14.9969,7.5,600,1.0,48.5,12.4974,10.5,600,0.95,68.75,11.8725,1.5,600,1.0,63.5,12.4974,3.5,600,0.9,48.1,11.2477,10.5,600,1.11,33.54,13.8721,17.0,600,1.17,41.99,14.622,12.0,600,1.06,37.12,13.2472,14.5
VGG Flower,600,0.5,88.9,6.2487,10.0,600,0.6,91.6,7.49844,3.0,600,0.69,85.28,8.6232,13.0,600,0.52,92.18,6.49865,3.0,600,0.72,81.9,8.99813,16.0,600,0.71,86.85,8.87315,12.0,600,0.76,76.08,9.49802,18.0,600,0.8,82.8,9.99792,14.5,...,600,0.4,92.3,4.99896,3.0,600,0.5,90.5,6.2487,7.5,600,0.48,92.11,5.99875,3.0,600,0.6,88.2,7.49844,10.0,600,0.5,91.9,6.2487,3.0,600,0.84,79.94,10.4978,17.0,600,0.67,88.72,8.37326,10.0,600,0.61,83.47,7.62341,14.5
Traffic signs,600,1.1,56.5,13.7471,8.5,600,1.1,58.4,13.7471,6.0,600,1.23,66.74,15.3718,2.0,600,0.9,83.55,11.2477,1.0,600,1.08,55.57,13.4972,8.5,600,1.0,46.48,12.4974,15.0,600,0.93,37.48,11.6226,18.0,600,1.0,53.4,12.4974,11.5,...,600,1.1,59.7,13.7471,6.0,600,1.0,63.0,12.4974,3.5,600,1.19,63.34,14.8719,3.5,600,1.1,51.1,13.7471,14.0,600,1.4,52.0,17.4964,11.5,600,1.31,42.91,16.3716,16.0,600,1.08,52.42,13.4972,11.5,600,1.1,40.11,13.7471,17.0


In [137]:
imagenet_df.xs('rank', axis=1, level=1).mean()

ALFA+fo-Proto-MAML     3.25
BOHB                   4.15
CTX                    1.00
Finetune               6.15
MatchingNet            8.65
ProtoNet               6.45
ProtoNet (large)       3.45
RelationNet           10.55
fo-MAML                7.45
fo-Proto-MAML          5.20
k-NN                   9.70
dtype: float64

In [138]:
all_df.xs('rank', axis=1, level=1).mean()

CNAPs                10.25
FLUTE                 5.90
Finetune             13.10
ITA                   1.65
MatchingNet          15.40
ProtoNet             13.50
RelationNet          16.80
SUR                   7.65
SUR-pnf               8.20
SimpleCNAPS           7.45
TransductiveCNAPS     6.05
TriM                  6.60
URL                   2.15
URT                   6.05
URT-pf                7.55
fo-MAML              15.25
fo-Proto-MAML        11.60
k-NN                 15.85
dtype: float64

### Display in HTML
This section uses the DataFrame's "styler" object, which renders nicely within the notebook.

Unfortunately, the HTML it outputs is not compatible with GitHub's markdown (as it relies on the `<style>` tag).

In [139]:
def str_summary(series):
  # Summarize each (episode, model) by a single cell
  # Non-breaking space to keep things on the same line
  nbsp = '\u00A0'
  string = '%(acc)s±%(ci)s%(nbsp)s(%(rank)g)' % {
      'acc': series['mean (%)'],
      'ci': series['95% CI'],
      'rank': series['rank'],
      'nbsp': nbsp
  }
  return string

In [140]:
def display_table(df, models=None):
  accuracies_df = df.stack(0).apply(str_summary, axis=1).unstack(0)[df.index[1:]]
  rank_df = df.xs('rank', axis=1, level=1).loc[df.index[1:]]
  avg_rank_df = pd.DataFrame(rank_df.mean(), columns=['Avg rank'])
  display_df = pd.concat([avg_rank_df, accuracies_df], axis=1)
  if models:
    # Try and force a particular order of models
    display_df = display_df.loc[models]
 
  # Bold cells corresponding to the best rank
  best_acc_mask = rank_df.T == rank_df.min(axis=1)
  best_avg_mask = avg_rank_df == avg_rank_df.min()
  best_mask = pd.concat([best_avg_mask, best_acc_mask], axis=1)
  if models:
    best_mask = best_mask.loc[models]
  bold_mask = best_mask.applymap(lambda v: 'font-weight: bold' if v else '')

  display_style = display_df.style.apply(lambda f: bold_mask, axis=None)
  display_style = display_style.format({'Avg rank': '{:g}'})
  return display_style

In [141]:
imagenet_display = display_table(imagenet_df, models=imagenet_dfs.keys())
imagenet_display

Unnamed: 0,Avg rank,ILSVRC (test),Omniglot,Aircraft,Birds,Textures,QuickDraw,Fungi,VGG Flower,Traffic signs,MSCOCO
k-NN,9.7,41.03±1.01 (10),37.07±1.15 (11),46.81±0.89 (10),50.13±1.0 (10.5),66.36±0.75 (8),32.06±1.08 (11),36.16±1.02 (8),83.1±0.68 (8),44.59±1.19 (10),30.38±0.99 (10.5)
Finetune,6.15,45.78±1.1 (8),60.85±1.58 (6.5),68.69±1.26 (2),57.31±1.26 (9),69.05±0.9 (4.5),42.6±1.17 (8.5),38.2±1.02 (7),85.51±0.68 (6),66.79±1.31 (2),34.86±0.97 (8)
MatchingNet,8.65,45.0±1.1 (8),52.27±1.28 (9),48.97±0.93 (9),62.21±0.95 (7.5),64.15±0.85 (10),42.87±1.09 (8.5),33.97±1.0 (9),80.13±0.71 (10),47.8±1.14 (7.5),34.99±1.0 (8)
ProtoNet,6.45,50.5±1.08 (4.5),59.98±1.35 (6.5),53.1±1.0 (7.5),68.79±1.01 (5.5),66.56±0.83 (8),48.96±1.08 (6),39.71±1.11 (6),85.27±0.77 (6),47.12±1.1 (9),41.0±1.1 (5.5)
fo-MAML,7.45,45.51±1.11 (8),55.55±1.54 (8),56.24±1.11 (5.5),63.61±1.06 (7.5),68.04±0.81 (4.5),43.96±1.29 (8.5),32.1±1.1 (10),81.74±0.83 (9),50.93±1.51 (5.5),35.3±1.23 (8)
RelationNet,10.55,34.69±1.01 (11),45.35±1.36 (10),40.73±0.83 (11),49.51±1.05 (10.5),52.97±0.69 (11),43.3±1.08 (8.5),30.55±1.04 (11),68.76±0.83 (11),33.67±1.05 (11),29.15±1.01 (10.5)
fo-Proto-MAML,5.2,49.53±1.05 (6),63.37±1.33 (4.5),55.95±0.99 (5.5),68.66±0.96 (5.5),66.49±0.83 (8),51.52±1.0 (4.5),39.96±1.14 (3.5),87.15±0.69 (3),48.83±1.09 (7.5),43.74±1.12 (4)
ALFA+fo-Proto-MAML,3.25,52.8±1.11 (2.5),61.87±1.51 (4.5),63.43±1.1 (3),69.75±1.05 (3.5),70.78±0.88 (2),59.17±1.16 (2),41.49±1.17 (3.5),85.96±0.77 (6),60.78±1.29 (3),48.11±1.14 (2.5)
ProtoNet (large),3.45,53.69±1.07 (2.5),68.5±1.27 (2.5),58.04±0.96 (4),74.07±0.92 (2),68.76±0.77 (4.5),53.3±1.06 (3),40.73±1.15 (3.5),86.96±0.73 (3),58.11±1.05 (4),41.7±1.08 (5.5)
CTX,1.0,62.76±0.99 (1),82.21±1.0 (1),79.49±0.89 (1),80.63±0.88 (1),75.57±0.64 (1),72.68±0.82 (1),51.58±1.11 (1),95.34±0.37 (1),82.65±0.76 (1),59.9±1.02 (1)


In [142]:
# print(imagenet_display.render())

In [143]:
all_display = display_table(all_df, models=all_dfs.keys())
all_display

Unnamed: 0,Avg rank,ILSVRC (test),Omniglot,Aircraft,Birds,Textures,QuickDraw,Fungi,VGG Flower,Traffic signs,MSCOCO
k-NN,15.85,38.55±0.94 (15.5),74.6±1.08 (17),64.98±0.82 (18),66.35±0.92 (13.5),63.58±0.79 (14.5),44.88±1.05 (18),37.12±1.06 (14.5),83.47±0.61 (14.5),40.11±1.1 (17),29.55±0.96 (16)
Finetune,13.1,43.08±1.08 (13.5),71.11±1.37 (18),72.03±1.07 (14.5),59.82±1.15 (16),69.14±0.85 (8.5),47.05±1.16 (17),38.16±1.04 (14.5),85.28±0.69 (13),66.74±1.23 (2),35.17±1.08 (14)
MatchingNet,15.4,36.08±1.0 (17),78.25±1.01 (15.5),69.17±0.96 (16.5),56.4±1.0 (17),61.8±0.74 (16),60.81±1.03 (14.5),33.7±1.04 (17),81.9±0.72 (16),55.57±1.08 (8.5),28.79±0.96 (16)
ProtoNet,13.5,44.5±1.05 (13.5),79.56±1.12 (15.5),71.14±0.86 (14.5),67.01±1.02 (13.5),65.18±0.84 (12.5),64.88±0.89 (13),40.26±1.13 (13),86.85±0.71 (12),46.48±1.0 (15),39.87±1.06 (12.5)
fo-MAML,15.25,37.83±1.01 (15.5),83.92±0.95 (13.5),76.41±0.69 (12),62.43±1.08 (15),64.16±0.83 (14.5),59.73±1.1 (16),33.54±1.11 (17),79.94±0.84 (17),42.91±1.31 (16),29.37±1.08 (16)
RelationNet,16.8,30.89±0.93 (18),86.57±0.79 (12),69.71±0.83 (16.5),54.14±0.99 (18),56.56±0.73 (18),61.75±0.97 (14.5),32.56±1.08 (17),76.08±0.76 (18),37.48±0.93 (18),27.41±0.89 (18)
fo-Proto-MAML,11.6,46.52±1.05 (12),82.69±0.97 (13.5),75.23±0.76 (13),69.88±1.02 (11.5),68.25±0.81 (10.5),66.84±0.94 (12),41.99±1.17 (12),88.72±0.67 (10),52.42±1.08 (11.5),41.74±1.13 (10)
CNAPs,10.25,50.8±1.1 (10.5),91.7±0.5 (8),83.7±0.6 (7.5),73.6±0.9 (10),59.5±0.7 (17),74.7±0.8 (11),50.2±1.1 (7.5),88.9±0.5 (10),56.5±1.1 (8.5),39.4±1.0 (12.5)
SUR,7.65,56.1±1.1 (7),93.1±0.5 (5.5),84.6±0.7 (5.5),70.6±1.0 (11.5),71.0±0.8 (6.5),81.3±0.6 (4),64.2±1.1 (3.5),82.8±0.8 (14.5),53.4±1.0 (11.5),50.1±1.0 (7)
SUR-pnf,8.2,56.0±1.1 (7),90.0±0.6 (10.5),79.7±0.8 (10.5),75.9±0.9 (7.5),72.5±0.7 (4.5),76.7±0.7 (8.5),49.8±1.1 (7.5),90.0±0.6 (7.5),52.2±0.8 (11.5),50.2±1.1 (7)


In [144]:
# print(all_display.render())

### Display in MarkDown
At least, in GitHub-flavored MarkDown.

In [145]:
def md_render(series):
  # Summarize each (episode, model) by a single cell containing MarkDown
  nbsp = '&nbsp;'
  md_string = '%(bold)s%(acc)5.2f%(bold)s±%(ci)4.2f%(nbsp)s(%(rank)g)' % {
      'acc': series['mean (%)'],
      'ci': series['95% CI'],
      'rank': series['rank'],
      'bold': '**' if series['best_rank'] else '',
      'nbsp': nbsp
  }
  return md_string

In [146]:
def md_table(df, models=None):
  # Whether a model has the best rank on a given dataset
  rank_df = df.xs('rank', axis=1, level=1).loc[df.index[1:]]
  best_rank = pd.concat([rank_df.T == rank_df.min(axis=1)], axis=1,
                        keys=['best_rank']).swaplevel(0, 1, axis=1)
  accuracies_df = df[1:].T.unstack(1)
  accuracies_df = pd.concat([accuracies_df, best_rank], axis=1)
  accuracies_md = accuracies_df.stack(0).apply(md_render, axis=1).unstack(1)

  # Average rank (and whether it's the best)
  avg_rank_df = pd.DataFrame(rank_df.mean(), columns=['Avg rank'])
  best_avg_rank = (avg_rank_df == avg_rank_df.min()).rename(
      columns={'Avg rank': 'best_rank'})
  avg_rank_md = pd.concat([avg_rank_df, best_avg_rank], axis=1).apply(
      lambda s: '%(bold)s%(avg_rank)g%(bold)s' % {
          'avg_rank': s['Avg rank'],
          'bold': '**' if s['best_rank'] else ''
      },
      axis=1).rename('Avg rank')

  # Display method name with a pointer to the reference, defined later.
  ref_to_link = {ref[0]: "[[%i]]" % i for i, ref in enumerate(references, 1)}
  method_md = models_df.apply(lambda r: ref_to_link[r['ref']], axis='columns')

  display_md = pd.concat([avg_rank_md, accuracies_md[df.index[1:]]], axis=1)
  if models:
    # Try and force a particular order of models
    display_md = display_md.loc[list(models)]

  # Pad all cells so they align well, 27 chars should be enough
  header_str = '|'.join(['%-27s' % c
                         for c in ['Method'] + list(display_md.columns)])
  sep_str = '|'.join(['-' * 27 for _ in [''] + list(display_md.columns)])  
  rows = [
      '|'.join(['%-27s' % c for c in ([' '.join((i, method_md.loc[i]))] +
                                      list(display_md.loc[i]))])
      for i in display_md.index
  ]
  return '\n'.join([header_str, sep_str] + rows)

In [147]:
print(md_table(imagenet_df, models=imagenet_dfs.keys()))

Method                     |Avg rank                   |ILSVRC (test)              |Omniglot                   |Aircraft                   |Birds                      |Textures                   |QuickDraw                  |Fungi                      |VGG Flower                 |Traffic signs              |MSCOCO                     
---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------
k-NN [[1]]                 |9.7                        |41.03±1.01&nbsp;(10)       |37.07±1.15&nbsp;(11)       |46.81±0.89&nbsp;(10)       |50.13±1.00&nbsp;(10.5)     |66.36±0.75&nbsp;(8)        |32.06±1.08&nbsp;(11)       |36.16±1.02&nbsp;(8)        |83.10±0.68&nbsp;(8)        |44.59±1.19&nbsp;(10)       |30.38±0.99&nbsp;(10.

In [148]:
print(md_table(all_df, models=all_dfs.keys()))

Method                     |Avg rank                   |ILSVRC (test)              |Omniglot                   |Aircraft                   |Birds                      |Textures                   |QuickDraw                  |Fungi                      |VGG Flower                 |Traffic signs              |MSCOCO                     
---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------
k-NN [[1]]                 |15.85                      |38.55±0.94&nbsp;(15.5)     |74.60±1.08&nbsp;(17)       |64.98±0.82&nbsp;(18)       |66.35±0.92&nbsp;(13.5)     |63.58±0.79&nbsp;(14.5)     |44.88±1.05&nbsp;(18)       |37.12±1.06&nbsp;(14.5)     |83.47±0.61&nbsp;(14.5)     |40.11±1.10&nbsp;(17)       |29.55±0.96&nbsp;(16)

## Export to MarkDown

### Reference list

In [149]:
def sanitize_anchor(string):
  # Try to mimic the MarkDown function that transforms a section title into an
  # html link anchor, that is:
  # - put it in lower case
  # - remove everything that is not a text character ("\w", which includes "_"),
  #   a space ("\s") or dash ("-")
  # - replace spaces and "_" by "-" (and deduplicate)
  anchor = string.lower()
  anchor = re.sub('[^\w\s-]', '', anchor)
  anchor = re.sub('[\s_-]+', '-', anchor)
  return anchor

In [150]:
def ref_list():
  # Define links from [i] to the reference section
  links = []
  for i, ref in enumerate(references, 1):
    links.append('[%(i)i]: #%(i)i-%(r)s' % dict(
        i=i,
        r=sanitize_anchor(ref[0])))

  references_md = []
  # Content of the reference section
  for i, ref in enumerate(references, 1):
    references_md.append(textwrap.dedent(r'''
      ###### \[%(i)i\] %(shortref)s

      %(fullref)s
    ''') % dict(i=i, shortref=ref[0], fullref=ref[1]))

  return '\n'.join(links + references_md)

In [151]:
print(ref_list())

[1]: #1-triantafillou-et-al-2020
[2]: #2-requeima-et-al-2019
[3]: #3-baik-et-al-2020
[4]: #4-doersch-et-al-2020
[5]: #5-saikia-et-al-2020
[6]: #6-dvornik-et-al-2020
[7]: #7-bateni-et-al-2020a
[8]: #8-bateni-et-al-2020b
[9]: #9-liu-et-al-2021a
[10]: #10-triantafillou-et-al-2021
[11]: #11-li-et-al-2021a
[12]: #12-li-et-al-2021b
[13]: #13-liu-et-al-2021b

###### \[1\] Triantafillou et al. (2020)

Eleni Triantafillou, Tyler Zhu, Vincent Dumoulin, Pascal Lamblin, Utku Evci, Kelvin Xu, Ross Goroshin, Carles Gelada, Kevin Swersky, Pierre-Antoine Manzagol, Hugo Larochelle; [_Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples_](https://arxiv.org/abs/1903.03096); ICLR 2020.


###### \[2\] Requeima et al. (2019)

James Requeima, Jonathan Gordon, John Bronskill, Sebastian Nowozin, Richard E. Turner; [_Fast and Flexible Multi-Task Classification Using Conditional Neural Adaptive Processes_](https://arxiv.org/abs/1906.07697); NeurIPS 2019.


###### \[3\] Baik et al. (2020)



### Full section

In [152]:
def export_md():
  begin_line = '<!-- Beginning of content generated by `Leaderboard.ipynb` -->'
  end_line = '<!-- End of content generated by `Leaderboard.ipynb` -->'

  parts = [
      begin_line,
      '## Training on ImageNet only',
      md_table(imagenet_df, models=imagenet_dfs.keys()),
      '## Training on all datasets',
      md_table(all_df, models=all_dfs.keys()),
      '## References',
      ref_list(),
      end_line
  ]
  return '\n\n'.join(parts)

In [153]:
print(export_md())

<!-- Beginning of content generated by `Leaderboard.ipynb` -->

## Training on ImageNet only

Method                     |Avg rank                   |ILSVRC (test)              |Omniglot                   |Aircraft                   |Birds                      |Textures                   |QuickDraw                  |Fungi                      |VGG Flower                 |Traffic signs              |MSCOCO                     
---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------|---------------------------
k-NN [[1]]                 |9.7                        |41.03±1.01&nbsp;(10)       |37.07±1.15&nbsp;(11)       |46.81±0.89&nbsp;(10)       |50.13±1.00&nbsp;(10.5)     |66.36±0.75&nbsp;(8)        |32.06±1.08&nbsp;(11)       |36.16±1.02