# Zero-shot Prompt Ensembling for Text-Image Models Results

*Licensed under the Apache License, Version 2.0.*

<a href="https://githubtocolab.com/google/uncertainty-baselines/blob/main/experimental/multimodal/Zero_shot_prompt_ensembling_for_text_image_models_results.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook produces all of the CLIP results for the ICML 2023 paper ["A Simple Zero-shot Prompt Weighting Technique to Improve Prompt Ensembling in Text-Image Models"](https://arxiv.org/abs/2302.06235).

If you find this notebook or our implementations in Uncertainty Baselines useful, please cite:
```
@InProceedings{allingham2023simple,
  title = 	 {A Simple Zero-shot Prompt Weighting Technique to Improve Prompt Ensembling in Text-Image Models},
  author =       {Allingham, James Urquhart and Ren, Jie and Dusenberry, Michael W and Gu, Xiuye and Cui, Yin and Tran, Dustin and Liu, Jeremiah Zhe and Lakshminarayanan, Balaji},
  booktitle = 	 {Proceedings of the 40th International Conference on Machine Learning},
  pages = 	 {547--568},
  year = 	 {2023},
  editor = 	 {Krause, Andreas and Brunskill, Emma and Cho, Kyunghyun and Engelhardt, Barbara and Sabato, Sivan and Scarlett, Jonathan},
  volume = 	 {202},
  series = 	 {Proceedings of Machine Learning Research},
  month = 	 {23--29 Jul},
  publisher =    {PMLR},
}
```

Note that the code uses a lot of memory, and so if the kernel crashes try either running on a machine with more memory, or try manually freeing memory with del statements.

In [None]:
#@title Imports

import jax
print(jax.local_devices())

import tensorflow as tf
tf.config.experimental.set_visible_devices([], "TPU")
print(tf.config.get_visible_devices())

import tensorflow_datasets as tfds
import ml_collections
from importlib import reload

import functools
import itertools
from typing import Sequence
import multiprocessing
import os
from tqdm import tqdm
import pickle
from scipy import stats
import pandas as pd

import jax.numpy as jnp
from jax import random

from absl import app
from absl import flags
from absl import logging
from clu import metric_writers
from clu import parameter_overview
from clu import periodic_actions
from clu import preprocess_spec
import flax
import flax.linen as nn
from flax.training import train_state
import optax
import ml_collections
import ml_collections.config_flags
import numpy as np

import matplotlib
import matplotlib.pyplot as plt
from matplotlib.lines import Line2D
text_width = 6.75133  # From the ICML LaTeX template.
line_width = 3.25063  # From the ICML LaTeX template.
matplotlib.rc('font', size=7)  # Controls default text sizes.
matplotlib.rc('axes', titlesize=7)
matplotlib.rc('axes', labelsize=7)
matplotlib.rc('xtick', labelsize=6)
matplotlib.rc('ytick', labelsize=6)
matplotlib.rc('legend', fontsize=6)
matplotlib.rc('figure', titlesize=8)
matplotlib.rc('font', **{'family':'serif', 'serif': ['Palatino']})
matplotlib.rc('text', usetex=True)


from uncertainty_baselines.models import clip
import robustness_metrics as rm
import uncertainty_baselines as ub
# TODO(jallingham): Fork remaining utils once imports below merged into UB API.
# import train_utils  # local file import from baselines.jft
# NOTE: Usually we do not allow cross-imports between subdirectories. We are
# doing so here because this is an experimental directory and the offending
# utils are soon to have much of their functionality merged into the UB API.
from experimental.multimodal import input_utils
from experimental.multimodal import checkpoint_utils
from experimental.multimodal import multimodal_utils
from experimental.multimodal import preprocess_utils
from experimental.multimodal import simple_tokenizer
from experimental.multimodal.configs import clip_common

preprocess_utils = reload(preprocess_utils)
multimodal_utils = reload(multimodal_utils)

In [None]:
#@title Define config
MAX_TEXT_LENGTH = 77
def get_config(model_name='vit_b16'):
  config = ml_collections.ConfigDict()

  config.data_dir = '/mnt/disks/persist/data/'

  config.model_name = model_name

  config.dataset = 'laion400m'
  config.train_split = 'all[10000:210000]'


  config.batch_size = 5000

  config.tokenizer_max_len = MAX_TEXT_LENGTH

  INPUT_RES = clip_common.IMAGE_RESOLUTION[config.model_name]  # pylint: disable=invalid-name
  train_image_pp = f'decode|inception_crop({INPUT_RES})'
  train_image_pp += f'|value_range(0, 1)|normalize({clip_common.CLIP_IMAGE_MEAN}, {clip_common.CLIP_IMAGE_STD})'
  text_pp = f'|clip_tokenize({config.tokenizer_max_len}, key="caption", key_result="text", bpe_path="uncertainty-baselines/experimental/multimodal/bpe_simple_vocab_16e6.txt.gz")'
  final_pp = '|keep(["image", "text"])'
  config.pp_train = train_image_pp + text_pp + final_pp

  config.shuffle_buffer_size = 250_000  # Per host, so small-ish is ok.

  config.prefetch_to_device = 2
  config.seed = 0

  # Model section.
  config.model_init = {
      'vit_b16': 'ADD_PATH_HERE/clip_vit-b16.npy',
      'vit_b32': 'ADD_PATH_HERE/clip_vit-b32.npy'
  }[config.model_name]

  config.convert_pytorch = True
  config.model = ml_collections.config_dict.create(
      **clip_common.CONFIGS[config.model_name])

  # Optimizer section.
  config.optim_name = 'Adam'
  config.optim = ml_collections.ConfigDict()
  config.grad_clip_norm = 1.0
  config.weight_decay = 1e-5

  config.lr = ml_collections.ConfigDict()
  config.lr.base = 1e-4

  # Zero-shot section.
  def zeroshot_pp(n_classes, resize_method='bicubic'):
    zeroshot_pp = f'decode|resize_small({INPUT_RES}, method="{resize_method}")|central_crop({INPUT_RES})'
    zeroshot_pp += f'|value_range(0, 1)|normalize({clip_common.CLIP_IMAGE_MEAN}, {clip_common.CLIP_IMAGE_STD})'
    zeroshot_pp += f'|onehot({n_classes}, key="label", key_result="label")'
    zeroshot_pp += '|keep(["image", "label"])'
    return zeroshot_pp

  config.zeroshot_eval_datasets = {
      'imagenet': {
          'dataset': 'imagenet2012',
          'split': 'validation',
          'classnames_key': 'imagenet',
          'prompts_key': 'imagenet',
          'pp_spec': zeroshot_pp(1000)
      },
      'imagenet_a': {
          'dataset': 'imagenet_a',
          'split': 'test',
          'classnames_key': 'imagenet_a',
          'prompts_key': 'imagenet',
          'pp_spec': zeroshot_pp(1000)
      },
      'imagenet_r': {
          'dataset': 'imagenet_r',
          'split': 'test',
          'classnames_key': 'imagenet_r',
          'prompts_key': 'imagenet',
          'pp_spec': zeroshot_pp(1000)
      },
      'imagenet_sketch': {
          'dataset': 'imagenet_sketch',
          'split': 'test',
          'classnames_key': 'imagenet',
          'prompts_key': 'imagenet',
          'pp_spec': zeroshot_pp(1000)
      },
      'imagenet_v2': {
          'dataset': 'imagenet_v2',
          'split': 'test',
          'classnames_key': 'imagenet',
          'prompts_key': 'imagenet',
          'pp_spec': zeroshot_pp(1000)
      },
      'caltech101': {
          'dataset': 'caltech101',
          'split': 'test',
          'classnames_key': 'caltech101',
          'prompts_key': 'caltech101',
          'pp_spec': zeroshot_pp(102)
      },
      'cars196': {
          'dataset': 'cars196',
          'split': 'test',
          'classnames_key': 'cars196',
          'prompts_key': 'cars196',
          'pp_spec': zeroshot_pp(196)
      },
      'cifar10': {
          'dataset': 'cifar10',
          'split': 'test',
          'classnames_key': 'cifar10',
          'prompts_key': 'cifar10',
          'pp_spec': zeroshot_pp(10)
      },
      'cifar100': {
          'dataset': 'cifar100',
          'split': 'test',
          'classnames_key': 'cifar100',
          'prompts_key': 'cifar100',
          'pp_spec': zeroshot_pp(100)
      },
      'dtd': {
          'dataset': 'dtd',
          'split': 'test',
          'classnames_key': 'dtd',
          'prompts_key': 'dtd',
          'pp_spec': zeroshot_pp(47)
      },
      'eurosat': {
          'dataset': 'eurosat',
          'split': 'train',
          'classnames_key': 'eurosat',
          'prompts_key': 'eurosat',
          'pp_spec': zeroshot_pp(10)
      },
      'food101': {
          'dataset': 'food101',
          'split': 'validation',
          'classnames_key': 'food101',
          'prompts_key': 'food101',
          'pp_spec': zeroshot_pp(101)
      },
      'oxford_flowers102': {
          'dataset': 'oxford_flowers102',
          'split': 'test',
          'classnames_key': 'oxford_flowers102',
          'prompts_key': 'oxford_flowers102',
          'pp_spec': zeroshot_pp(102)
      },
      'oxford_iiit_pet': {
          'dataset': 'oxford_iiit_pet',
          'split': 'test',
          'classnames_key': 'oxford_iiit_pet',
          'prompts_key': 'oxford_iiit_pet',
          'pp_spec': zeroshot_pp(37)
      },
      'resisc45': {
          'dataset': 'resisc45',
          'split': 'train',
          'classnames_key': 'resisc45',
          'prompts_key': 'resisc45',
          'pp_spec': zeroshot_pp(45)
      },
      'sun397': {
          'dataset': 'sun397',
          'split': 'test',
          'classnames_key': 'sun397',
          'prompts_key': 'sun397',
          'pp_spec': zeroshot_pp(397)
      },
  }

  return config, INPUT_RES

config, image_resolution = get_config('vit_b16')

In [None]:
#@title Create model
seed = config.get('seed', 0)
rng = jax.random.PRNGKey(seed)

# create model, initialize model parameters
clip_model = ub.models.clip(**config.model)
@functools.partial(jax.jit, backend='cpu')
def init(rng):
    image_size = (image_resolution, image_resolution, 3)
    text_size = (MAX_TEXT_LENGTH, )
    dummy_image = jnp.zeros((1,) + image_size, jnp.float32)
    dummy_text = jnp.zeros((1,) + text_size, jnp.int32)
    variables = clip_model.init(rng, dummy_image, dummy_text)
    states, params = variables.pop('params')
    params = flax.core.unfreeze(params)
    return params, states

rng, rng_init = jax.random.split(rng)
params_cpu, states_cpu = init(rng_init)
# Load the optimizer from flax. We need to create an optimizer because our
# checkpoint loader assumes that the optimizer is storing the params.
opt_name = config.get('optim_name')
opt_def = getattr(flax.optim, opt_name)(**config.get('optim', {}))
# We jit this, such that the arrays that are created are created on the same
# device as the input is, in this case the CPU. Else they'd be on device[0].
opt_cpu = jax.jit(opt_def.create)(params_cpu)

# Load the checkpoint.
checkpoint_data = checkpoint_utils.maybe_load_checkpoint(
    train_loop_rngs=rng,
    save_checkpoint_path=None,
    init_optimizer=opt_cpu,
    init_params=params_cpu,
    init_fixed_model_states=states_cpu,
    default_reinit_params=[],
    config=config)
loaded_params = checkpoint_data.optimizer.target
loaded_states = checkpoint_data.fixed_model_states
# Sanity check to make sure we loaded params:
print('params_cpu[logit_scale]=%s, loaded_params[logit_scale]=%s' % (params_cpu['logit_scale'], loaded_params['logit_scale']))

del opt_cpu
del params_cpu
del states_cpu

clip_vars = {'params': flax.core.freeze(loaded_params), **loaded_states}

In [None]:
#@title Create tokenizer
bpe_path='uncertainty-baselines/experimental/multimodal/bpe_simple_vocab_16e6.txt.gz'

tokenizer = simple_tokenizer.SimpleTokenizer(bpe_path=bpe_path)
tokenize_fn = simple_tokenizer.make_tokenize_fn(tokenizer, config.tokenizer_max_len)

## Create helper functions

In [None]:
#@title encode_texts & encode_images

# For batches of texts/images.
@jax.jit
def encode_texts(texts):
    return clip_model.apply(
        clip_vars,
        texts,
        normalize=False,
        scale_logits=False,
        method=clip_model.encode_text
    )

@jax.jit
def encode_images(images):
    return clip_model.apply(clip_vars, images, method=clip_model.encode_image)


# For a single text/image.
def encode_text(text):
    return encode_texts(jnp.expand_dims(text, axis=0))

def encode_image(image):
    return encode_images(jnp.expand_dims(image, axis=0))

In [None]:
#@title load_xxx_dataset

def _get_split(dataset, split, pp, rng, data_dir, batch_size=None, drop_remainder=False):

    if isinstance(pp, str):
        pp = preprocess_spec.parse(spec=pp, available_ops=preprocess_utils.all_ops())

    batch_size = BATCH_SIZE if batch_size is None else batch_size

    rng = jax.random.fold_in(rng, jax.process_index())

    val_ds = input_utils.get_data(
        dataset=dataset,
        split=split,
        rng=rng,
        process_batch_size=batch_size,
        preprocess_fn=pp,
        cache=config.get('val_cache', 'batched'),
        num_epochs=1,
        repeat_after_batching=True,
        shuffle=False,
        prefetch_size=config.get('prefetch_to_host', 2),
        drop_remainder=drop_remainder,
        data_dir=data_dir)

    return val_ds

def load_zeroshot_dataset(config, rng, dataset_name, zs_batch_size=5000):
    rng, zeroshot_ds_rng = jax.random.split(rng)
    preprocess_fn = preprocess_spec.parse(
        spec=config.zeroshot_eval_datasets[dataset_name]['pp_spec'],
        available_ops=preprocess_utils.all_ops())

    data_dir = config.get('data_dir')

    zs_split = _get_split(
        config.zeroshot_eval_datasets[dataset_name]['dataset'],
        split=config.zeroshot_eval_datasets[dataset_name]['split'],
        pp=preprocess_fn,
        rng=zeroshot_ds_rng,
        data_dir=data_dir,
        batch_size=zs_batch_size
    )
    return zs_split

def load_train_dataset(config, train_ds_rng):
    batch_size = config.batch_size

    preprocess_fn = preprocess_spec.parse(
        spec=config.pp_train, available_ops=preprocess_utils.all_ops()
    )
    train_split = _get_split(
        config.dataset,
        split=config.train_split,
        pp=preprocess_fn,
        rng=train_ds_rng,
        data_dir=config.get('data_dir'),
        batch_size=batch_size
    )

    return train_split

In [None]:
#@title compute_text_embeddings
def compute_text_embeddings(templates, tokenize_fn, logit_scale, dataset, use_l2_norm=True, classnames=None):
    ztxts = []
    if classnames is None:
        classnames = multimodal_utils._ZEROSHOT_CLASS_NAMES[config.zeroshot_eval_datasets[dataset]['classnames_key']]
    for clsname in tqdm(classnames):
        token_fn = lambda text: tokenize_fn(tf.constant(text, dtype=tf.string))
        texts = jnp.array(
            [token_fn(template.format(clsname)) for template in templates])
        class_embeddings = encode_texts(texts)
        class_embedding = class_embeddings  # [n_prompts, emb_dim]
        if use_l2_norm:
            class_embedding *= jax.lax.rsqrt(jnp.sum(class_embedding**2, axis=-1, keepdims=True))
        class_embedding *= jnp.sqrt(jnp.exp(logit_scale))
        ztxts.append(class_embedding)
    ztxts = np.stack(ztxts, axis=1)  # [n_prompts, n_classes, emb_dim]
    return ztxts

In [None]:
#@title compute_image_embeddings
def compute_image_embeddings(ds_iter, image_resolution):
    zimgs = []
    labels = []
    for batch in tqdm(ds_iter):
        image_embedding = jax.pmap(encode_images)(batch['image'])
        image_embedding = image_embedding.reshape(-1, image_embedding.shape[-1])
        labels_ = batch['label'].reshape(-1, batch['label'].shape[-1]).argmax(-1)

        mask = batch['mask'].reshape(-1).astype(np.int32)
        mask = np.where(mask)
        zimgs.append(image_embedding[mask])
        labels.append(labels_[mask])

    return np.vstack(zimgs), np.hstack(labels)

In [None]:
#@title compute_metrics

def compute_accuracy(logits, labels):
    top_probs, top_labels = jax.lax.top_k(logits, 5)
    top1 = 100 * jnp.mean(top_labels[:, 0] == labels)
    top5 = 100 * jnp.sum(top_labels == labels[:, None]) / labels.shape[0]
    return top1, top5

def compute_metrics(labels, logits, print_out=True):
    probs = jax.nn.softmax(logits)
    preds = np.argmax(logits, axis=-1)
    confidences = np.max(probs, axis=-1)

    acc, acc5 = compute_accuracy(logits, labels)

    if print_out:
        print(f'top1_acc: {acc:5.2f}')

    return acc, acc5

In [None]:
#@title get_logits
def get_logits(
    ztxts,  # [n_prompts, n_classes, emb_dim]
    zimgs   # [n_imgs, emb_dim]
):
    """Calculate the zero-shot classifier's logits."""
    all_logits = jax.vmap(jax.jit(
        lambda x, y: jnp.dot(x, y.T), backend='cpu'),
        in_axes=(None, 0), out_axes=1)(zimgs, ztxts)

    return all_logits


In [None]:
#@title agg_logits
@functools.partial(jax.jit, backend='cpu')
def agg_logits(
    logits, # [n_imgs, n_prompts, n_classes]
    weights=None
):
    """Calculate (optionally weighted) average of logits."""
    _, n_prompts, _ = logits.shape

    if weights is None:
        weights = jnp.ones((n_prompts,))[jnp.newaxis, :, jnp.newaxis]

    logits = (logits * weights).mean(axis=1) # [n_imgs, n_classes]

    return logits

In [None]:
#@title get_weights
@functools.partial(jax.jit, static_argnums=(2, 3, 4), backend='cpu')
def get_weights(
    logits,  # [n_img, n_prompt, n_cls]
    random_logits=None,  # [n_rand, n_prompt, n_cls]
    debias_mode='both',
    img_mean=True,
    frac_test=1.,
):
    n_img = logits.shape[0]

    if debias_mode in ['both', 'pretrain', 'pretrain_star']:
        assert random_logits is not None
        axes = (0,) if debias_mode != 'pretrain_star' else (0, 2)
        img_mean_rand_logits = random_logits.mean(axes, keepdims=True)  # [1, n_prompt, n_cls or 1]

    if debias_mode in ['both', 'test']:
        n_test = round(n_img * frac_test)
        img_mean_ds_logits = logits[:n_test].mean(0, keepdims=True)  # [1, n_prompt, n_cls]

    if debias_mode == 'both':
        normalised = logits - 0.5*(img_mean_rand_logits + img_mean_ds_logits)
    elif debias_mode == 'test':
        normalised = logits - img_mean_ds_logits
    elif debias_mode in ['pretrain_star', 'pretrain']:
        normalised = logits - img_mean_rand_logits
    elif debias_mode == 'none':
        normalised = logits
    else:
        raise RuntimeError(f'Unknown "debias_mode" type {debias_mode}')

    conf_scores = normalised.max(-1, keepdims=True)  # [n_img, n_prompt, 1]

    if img_mean:
        conf_scores = conf_scores.mean(0, keepdims=True)  # [1, n_prompt, 1]

    return conf_scores

In [None]:
#@title mad_method
def mad_method(data, threshold=3):
    """Median Absolute Deviation outlier detection."""
    if threshold == 'NA':
        return np.arange(len(data))

    med = np.median(data)
    mad = np.abs(stats.median_abs_deviation(data))
    outliers = []
    for i, v in enumerate(data):
        t = (v - med)/mad
        if t > threshold:
            outliers.append(i)
        else:
            continue
    return np.array(outliers, dtype=np.int32)

## Collect tabular results

In [None]:
rng, train_ds_rng = jax.random.split(rng, 2)
train_ds = load_train_dataset(config, train_ds_rng)
train_iter = input_utils.start_input_pipeline(
      train_ds, config.get('prefetch_to_device', 1))

zimgs_laion = []

for i, batch in enumerate(tqdm(train_iter)):
  zimg = jax.pmap(encode_images)(batch['image'])
  zimg = np.array(zimg).reshape(-1, *zimg.shape[2:])
  zimgs_laion.append(zimg)

zimgs_laion = np.concatenate(zimgs_laion)

zimgs_laion.shape

In [None]:
#@title Extra templates generated by GPT
EXTRA_TEMPLATES = {
    'A photo of a {}, a type of insect.',
    'A photo of a {}, a type of fish.',
    'A photo of a {}, a type of tree.',
    'A photo of a {}, a type of fruit.',
    'A photo of a {}, a type of car.',
    'A photo of a {}, a type of dog.',
    'A photo of a {}, a type of mammal.',
    'A photo of a {}, a type of reptile.',
    'A photo of a {}, a type of food.',
    'A photo of a {}, a type of vegetable.',
    'A photo of a {}, a type of landscape.',
    'A photo of a {}, a type of cityscape.',
    'A photo of a {}, a type of seascape.',
    'A photo of a {}, a type of architecture.',
    'A photo of a {}, a type of monument.',
    'A photo of a {}, a type of painting.',
    'A photo of a {}, a type of sculpture.',
    'A photo of a {}, a type of musical instrument.',
    'A photo of a {}, a type of weapon.',
    'A photo of a {}, a type of clothing.',
    'A photo of a {}, a type of jewelry.',
    'A photo of a {}, a type of household item.',
    'A photo of a {}, a type of electronic device.',
    'A photo of a {}, a type of tool.',
    'A photo of a {}, a type of transportation.',
    'A photo of a {}, a type of recreational activity.',
    'A photo of a {}, a type of game.',
    'A photo of a {}, a type of sport.',
    'A photo of a {}, a type of musical genre.',
    'A photo of a {}, a type of movie genre.',
    'A photo of a {}, a type of book genre.',
    'A photo of a {}, a type of historical event.',
    'A photo of a {}, a type of mythological creature.',
    'A photo of a {}, a type of fantasy creature.',
    'A photo of a {}, a type of planet.',
    'A photo of a {}, a type of constellation.',
    'A photo of a {}, a type of comet.',
    'A photo of a {}, a type of galaxy.',
    'A photo of a {}, a type of meteor.',
    'A photo of a {}, a type of asteroid.',
    'A photo of a {}, a type of planet.',
    'A photo of a {}, a type of star.',
    'A photo of a {}, a type of black hole.',
    'A photo of a {}, a type of neutron star.',
    'A photo of a {}, a type of quasar.',
    'A photo of a {}, a type of pulsar.',
    'A photo of a {}, a type of supernova.',
    'A photo of a {}, a type of brown dwarf.',
    'A photo of a {}, a type of white dwarf.',
    'A photo of a {}, a type of red giant.',
    'A photo of a {}, a type of butterfly.',
    'A photo of a {}, a type of amphibian.',
    'A photo of a {}, a type of berry.',
    'A photo of a {}, a type of motorcycle.',
    'A photo of a {}, a type of cat.',
    'A photo of a {}, a type of rodent.',
    'A photo of a {}, a type of fish.',
    'A photo of a {}, a type of dinosaur.',
    'A photo of a {}, a type of pasta.',
    'A photo of a {}, a type of grain.',
    'A photo of a {}, a type of mountain range.',
    'A photo of a {}, a type of waterfall.',
    'A photo of a {}, a type of lake.',
    'A photo of a {}, a type of bridge.',
    'A photo of a {}, a type of lighthouse.',
    'A photo of a {}, a type of pottery.',
    'A photo of a {}, a type of tapestry.',
    'A photo of a {}, a type of drum.',
    'A photo of a {}, a type of sword.',
    'A photo of a {}, a type of hat.',
    'A photo of a {}, a type of watch.',
    'A photo of a {}, a type of kitchen appliance.',
    'A photo of a {}, a type of camera.',
    'A photo of a {}, a type of power tool.',
    'A photo of a {}, a type of boat.',
    'A photo of a {}, a type of adventure sport.',
    'A photo of a {}, a type of board game.',
    'A photo of a {}, a type of ball sport.',
    'A photo of a {}, a type of folk music.',
    'A photo of a {}, a type of action movie.',
    'A photo of a {}, a type of mystery novel.',
    'A photo of a {}, a type of war.',
    'A photo of a {}, a type of mythical king.',
    'A photo of a {}, a type of fantasy race.',
    'A photo of a {}, a type of planet.',
    'A photo of a {}, a type of constellation.',
    'A photo of a {}, a type of comet.',
    'A photo of a {}, a type of galaxy.',
    'A photo of a {}, a type of meteor.',
    'A photo of a {}, a type of asteroid.',
    'A photo of a {}, a type of planet.',
    'A photo of a {}, a type of star.',
    'A photo of a {}, a type of black hole.',
    'A photo of a {}, a type of neutron star.',
    'A photo of a {}, a type of quasar.',
    'A photo of a {}, a type of pulsar.',
    'A photo of a {}, a type of supernova.',
    'A photo of a {}, a type of brown dwarf.',
    'A photo of a {}, a type of white dwarf.',
    'A photo of a {}, a type of red giant.',
    'A panoramic photo of a {}.',
    'A close-up photo of a {}.',
    'A wide-angle photo of a {}.',
    'A high-resolution photo of a {}.',
    'A low-light photo of a {}.',
    'A time-lapse photo of a {}.',
    'A long-exposure photo of a {}.',
    'A night photo of a {}.',
    'A sunset photo of a {}.',
    'A sunrise photo of a {}.',
    'A silhouette photo of a {}.',
    'A sepia-toned photo of a {}.',
    'A colored photo of a {}.',
    'A watercolor photo of a {}.',
    'A sketch photo of a {}.',
    'A hyperlapse photo of a {}.',
    'A tilt-shift photo of a {}.',
    'A motion-blurred photo of a {}.',
    'A double-exposure photo of a {}.',
    'A HDR photo of a {}.',
    'A 360-degree photo of a {}.',
    'A black-and-white negative photo of a {}.',
    'A split-tone photo of a {}.',
    'A film-grain photo of a {}.',
    'A thermal photo of a {}.',
    'A infrared photo of a {}.',
    'A ultraviolet photo of a {}.',
    'A x-ray photo of a {}.',
    'A 3D photo of a {}.',
    'A stop-motion photo of a {}.',
    'A bokeh photo of a {}.',
    'A miniature photo of a {}.',
    'A light-painted photo of a {}.',
    'A composite photo of a {}.',
    'A polarized photo of a {}.',
    'A photomontage photo of a {}.',
    'A digital-art photo of a {}.',
    'A abstract photo of a {}.',
    'A selective-focus photo of a {}.',
    'A black-and-white film photo of a {}.',
    'A cross-processed photo of a {}.',
    'A cyanotype photo of a {}.',
    'A lomography photo of a {}.',
    'A pinhole photo of a {}.',
    'A cyanotype photo of a {}.',
    'A high-dynamic-range photo of a {}.',
    'A low-dynamic-range photo of a {}.',
    'A multiexposure photo of a {}.',
    'A high-speed photo of a {}.',
    'A underwater photo of a {}.',
    'A sculpture of a {}.',
    'A print of a {}.',
    'A sketch of a {}.',
    'A engraving of a {}.',
    'A etching of a {}.',
    'A lithograph of a {}.',
    'A watercolor of a {}.',
    'A pastel of a {}.',
    'A charcoal of a {}.',
    'A oil painting of a {}.',
    'A acrylic painting of a {}.',
    'A digital painting of a {}.',
    'A fresco of a {}.',
    'A mosaic of a {}.',
    'A collage of a {}.',
    'A graffiti of a {}.',
    'A stained glass of a {}.',
    'A quilt of a {}.',
    'A tapestry of a {}.',
    'A batik of a {}.',
    'A calligraphy of a {}.',
    'A wood carving of a {}.',
    'A metal sculpture of a {}.',
    'A glass sculpture of a {}.',
    'A clay sculpture of a {}.',
    'A ice sculpture of a {}.',
    'A sand sculpture of a {}.',
    'A paper mache of a {}.',
    'A sculptural installation of a {}.',
    'A mural of a {}.',
    'A fresco of a {}.',
    'A graffiti of a {}.',
    'A street art of a {}.',
    'A digital art of a {}.',
    'A film of a {}.',
    'A animation of a {}.',
    'A stop motion animation of a {}.',
    'A motion graphics of a {}.',
    'A 3D animation of a {}.',
    'A VR of a {}.',
    'A AR of a {}.',
    'A hologram of a {}.',
    'A laser show of a {}.',
    'A light show of a {}.',
    'A pyrotechnics of a {}.',
    'A performance of a {}.',
    'A sound sculpture of a {}.',
    'A kinetic sculpture of a {}.',
    'A land art of a {}.',
    'A environmental art of a {}.',
}
len(EXTRA_TEMPLATES)



In [None]:
#@title Make shared ztxt

pool_templates = list(set(sum(multimodal_utils._ZEROSHOT_TEMPLATES.values(), [])))
pool_n_prompts = len(pool_templates)
print("pool_n_prompts", pool_n_prompts)

USE_EXTRA_TEMPLATE = True #@param
if USE_EXTRA_TEMPLATE:
  extras_templates = EXTRA_TEMPLATES - set(pool_templates)
  all_templates = pool_templates + list(extras_templates)
else:
  all_templates = pool_templates

print("len(all_templates)", len(all_templates))

all_classnames = list(set(sum(multimodal_utils._ZEROSHOT_CLASS_NAMES.values(), [])))
print("len(all_classnames)", len(all_classnames))

ztxts_all_prompts_all_class = compute_text_embeddings(
    all_templates, tokenize_fn, loaded_params['logit_scale'], '',
    classnames=all_classnames
)

ztxts_all_prompts_all_class.shape

In [None]:
OVERWRITE_RESULTS = False #@param
LOAD_RESULTS = False #@param

base_path = ''
df_name = 'final_results_dataframe.pkl'
df_path = os.path.join(base_path, df_name)

results_df = pd.DataFrame(columns = [
    'dataset_name', 'debias_mode', 'img_mean', 'prompt_set', 'weighting',
    'select_threshold', 'num_pretrain', 'frac_test', 'top1_acc', 'top5_acc',
])

if OVERWRITE_RESULTS:
    with tf.io.gfile.GFile(df_path, 'w') as f:
        f.write(pickle.dumps(results_df, protocol=4))

if LOAD_RESULTS:
    with tf.io.gfile.GFile(df_path, 'rb') as f:
      results_df = pickle.load(f)

results_df

In [None]:
#@title Specify datasets and batch sizes
ds_inet = {
    'imagenet': 5000, 'imagenet_a': 5000, 'imagenet_r': 5000, 'imagenet_sketch': 5000, 'imagenet_v2': 5000,
}
ds_fine = {
    'caltech101': 5000, 'cars196': 5000, 'cifar10': 5000, 'cifar100': 5000, 'dtd': 1880, 'eurosat': 5000,
    'food101': 5000, 'oxford_iiit_pet': 3669, 'oxford_flowers102': 5000, 'resisc45': 5000, 'sun397': 5000,
}
ds_list = ds_inet | ds_fine

In [None]:
#@title Main loop

RUN_ABLATIONS = False #@param
# ^ Whether to run all of the ablations, or the key results only.
RUN_FULL_THRESHOLD_SWEEP = False #@param
# ^ Whether or not to run the full sweep over threshold values, or just the values we found to be best.
IMG_MEANS = [True] # [True, False] #@param
# ^ Whether or not to get per-example (False) or per-dataset (True) prompt scores.


In [None]:
for dataset_name, batch_size in ds_list.items():
    print("dataset_name", dataset_name)

    # Collect prompts.
    # 'ds' stands for 'dataset specific'. I.e., the 'hand-crafted' prompts.
    ds_templates = multimodal_utils._ZEROSHOT_TEMPLATES[config.zeroshot_eval_datasets[dataset_name]['prompts_key']]
    ds_n_prompts = len(ds_templates)

    pool_idxs = np.array([all_templates.index(p) for p in pool_templates])
    ds_idxs = np.array([all_templates.index(p) for p in ds_templates])
    inet_idxs = np.array([all_templates.index(p) for p in multimodal_utils._ZEROSHOT_TEMPLATES['imagenet']])
    none_idx = np.array([all_templates.index(p) for p in multimodal_utils._ZEROSHOT_TEMPLATES['none']])
    photo_idx = np.array([all_templates.index(p) for p in multimodal_utils._ZEROSHOT_TEMPLATES['photo']])

    # Get prompt embeddings.
    classnames = multimodal_utils._ZEROSHOT_CLASS_NAMES[config.zeroshot_eval_datasets[dataset_name]['classnames_key']]
    classname_idxs = np.array([all_classnames.index(classname) for classname in classnames])
    ztxts_all_prompts = ztxts_all_prompts_all_class[:, classname_idxs, :]

    ztxts_classname = compute_text_embeddings(multimodal_utils._ZEROSHOT_TEMPLATES['none'],
                                                tokenize_fn, loaded_params['logit_scale'], dataset_name)

    ztxts_photo = compute_text_embeddings(multimodal_utils._ZEROSHOT_TEMPLATES['photo'],
                                                tokenize_fn, loaded_params['logit_scale'], dataset_name)

    # Get image embeddings.
    zs_split = load_zeroshot_dataset(config, rng, dataset_name, zs_batch_size=batch_size)
    ds_iter = input_utils.start_input_pipeline(zs_split, config.get('prefetch_to_device', 1))
    zimgs, labels = compute_image_embeddings(ds_iter, image_resolution)

    if dataset_name == 'imagenet_r':
        labels = np.array([multimodal_utils._IMAGENET_R_LABELSET.index(l) for l in labels])
    elif dataset_name == 'imagenet_a':
        labels = np.array([multimodal_utils._IMAGENET_A_LABELSET.index(l) for l in labels])

    # Get logits.
    classname_logits = get_logits(ztxts_classname, zimgs)
    photo_logits = get_logits(ztxts_photo, zimgs)
    all_prompts_logits = get_logits(ztxts_all_prompts, zimgs)
    pool_logits = all_prompts_logits[:, pool_idxs, :]
    ds_logits = all_prompts_logits[:, ds_idxs, :]
    inet_logits = all_prompts_logits[:, inet_idxs, :]
    del ztxts_classname, ztxts_photo

    def add_row_(
        df, dataset_name, debias_mode, img_mean, prompt_set, weighting, select_threshold,
        num_pretrain, frac_test, top1_acc, top5_acc
    ):
        return pd.concat([df, pd.DataFrame.from_dict({
            'dataset_name': [dataset_name], 'debias_mode': [debias_mode], 'img_mean': [img_mean],
            'prompt_set': [prompt_set], 'weighting': [weighting], 'select_threshold': [select_threshold],
            'num_pretrain': [num_pretrain],  'frac_test': [frac_test],
            'top1_acc': [top1_acc], 'top5_acc': [top5_acc],
        })], ignore_index=True)

    # Class name.
    top1_acc, top5_acc = compute_metrics(labels, agg_logits(classname_logits));
    results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'classname', 'NA', 'NA',
                          'NA', 'NA', top1_acc, top5_acc)

    # 'A photo of {}'.
    top1_acc, top5_acc = compute_metrics(labels, agg_logits(photo_logits));
    results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'photo', 'NA', 'NA',
                          'NA', 'NA', top1_acc, top5_acc)

    # Dataset specific prompts with equal weighting.
    top1_acc, top5_acc = compute_metrics(labels, agg_logits(ds_logits));
    results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'dataset', 'equal', 'NA',
                          'NA', 'NA', top1_acc, top5_acc)

    # Pool prompts with equal weighting.
    top1_acc, top5_acc = compute_metrics(labels, agg_logits(pool_logits));
    results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'pool', 'equal', 'NA',
                          'NA', 'NA', top1_acc, top5_acc)

    # INet prompts with equal weighting.
    if RUN_ABLATIONS:
        top1_acc, top5_acc = compute_metrics(labels, agg_logits(inet_logits));
        results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'inet', 'equal', 'NA',
                              'NA', 'NA', top1_acc, top5_acc)

        # All prompts with equal weighting.
        top1_acc, top5_acc = compute_metrics(labels, agg_logits(all_prompts_logits));
        results_df = add_row_(results_df, dataset_name, 'NA', 'NA', 'all', 'equal', 'NA',
                              'NA', 'NA', top1_acc, top5_acc)

    for img_mean in IMG_MEANS:
        debias_modes = ['both', 'test', 'pretrain_star', 'pretrain', 'none'] if RUN_ABLATIONS else ['both']
        for debias_mode in debias_modes:
            print(debias_mode)
            if debias_mode == 'both':
                if RUN_ABLATIONS:
                    num_pretrains_frac_test = [(5_000, 1.), (10_000, 1.), (20_000, 1.), (20_000, .5), (20_000, .2), (20_000, .1)]
                else:
                    num_pretrains_frac_test = [(20_000, 1.)]
            elif debias_mode in ['pretrain_star', 'pretrain']:
                num_pretrains_frac_test = [(20_000, 'NA')]
            else:
                num_pretrains_frac_test = [('NA', 1.)]

            for num_pretrain, frac_test in num_pretrains_frac_test:
                if debias_mode == 'both':
                    random_logits = get_logits(ztxts_all_prompts, zimgs_laion[:num_pretrain])  # [n_pretrain, n_prompts, n_classes_ds]
                elif debias_mode == 'test':
                    random_logits = None
                elif debias_mode == 'pretrain_star':
                    random_logits = get_logits(ztxts_all_prompts_all_class, zimgs_laion[:num_pretrain])  # [n_pretrain, n_prompts, n_classes_all]
                elif debias_mode == 'pretrain':
                    random_logits = get_logits(ztxts_all_prompts, zimgs_laion[:num_pretrain])  # [n_pretrain, n_prompts, n_classes_ds]
                else:
                    random_logits = None
                del ztxts_all_prompts

                all_weights = get_weights(all_prompts_logits, random_logits, debias_mode=debias_mode, img_mean=img_mean, frac_test=frac_test)

                # ds prompts with score weighting.
                top1_acc, top5_acc = compute_metrics(labels, agg_logits(ds_logits, all_weights[:, ds_idxs, :]));
                results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'dataset', 'scores', 'NA',
                                      num_pretrain, frac_test, top1_acc, top5_acc)

                top1_acc, top5_acc = compute_metrics(labels, agg_logits(ds_logits, jax.nn.softmax(all_weights, axis=1)[:, ds_idxs, :]));
                results_df = add_row_(results_df, dataset_name, debias_mode, img_mean,'dataset', 'softmax_scores', 'NA',
                                      num_pretrain, frac_test, top1_acc, top5_acc)

                if img_mean:
                    if RUN_FULL_THRESHOLD_SWEEP:
                        thresholds = ['NA', 0.1, 0.2, 0.3, 0.4, 0.5, 1.0, 1.5, 1.8, 2.0, 2.5]
                    else:
                        thresholds = ['NA', 0.5, 2.0]
                else:
                    thresholds = ['NA']

                for select_threshold in thresholds:

                    # Pool prompts with score weighting / thresholding.
                    selected_prompt_idxs = mad_method(all_weights[0, pool_idxs, 0], select_threshold)
                    # Note: ^ select_threshold == 'NA' is a special case for mad_method equivalent to select_threshold == 0.

                    top1_acc, top5_acc = compute_metrics(
                        labels, agg_logits(pool_logits[:, selected_prompt_idxs, :], all_weights[:, pool_idxs[selected_prompt_idxs], :])
                    ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                    results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'pool', 'scores', select_threshold,
                                          num_pretrain, frac_test, top1_acc, top5_acc)

                    top1_acc, top5_acc = compute_metrics(
                        labels, agg_logits(pool_logits[:, selected_prompt_idxs, :], jax.nn.softmax(all_weights, axis=1)[:, pool_idxs[selected_prompt_idxs], :])
                    ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                    results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'pool', 'softmax_scores', select_threshold,
                                          num_pretrain, frac_test, top1_acc, top5_acc)

                    if RUN_ABLATIONS and debias_mode == 'both' and num_pretrain == 20_000 and frac_test == 1.:
                        top1_acc, top5_acc = compute_metrics(
                            labels, agg_logits(pool_logits[:, selected_prompt_idxs, :], (all_weights**10)[:, pool_idxs[selected_prompt_idxs], :])
                        ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'pool', 'scores^10', select_threshold,
                                              num_pretrain, frac_test, top1_acc, top5_acc)

                    if select_threshold != 'NA':
                        top1_acc, top5_acc = compute_metrics(
                            labels, agg_logits(pool_logits[:, selected_prompt_idxs, :])
                        ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'pool', 'equal', select_threshold,
                                              num_pretrain, frac_test, top1_acc, top5_acc)

                    if RUN_ABLATIONS and debias_mode == 'both' and num_pretrain == 20_000 and frac_test == 1.:
                        # INet prompts with score weighting / thresholding.
                        selected_prompt_idxs = mad_method(all_weights[0, inet_idxs, 0], select_threshold)

                        top1_acc, top5_acc = compute_metrics(
                            labels, agg_logits(inet_logits[:, selected_prompt_idxs, :], all_weights[:, inet_idxs[selected_prompt_idxs], :])
                        ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'inet', 'scores', select_threshold,
                                              num_pretrain, frac_test, top1_acc, top5_acc)

                        top1_acc, top5_acc = compute_metrics(
                            labels, agg_logits(inet_logits[:, selected_prompt_idxs, :], jax.nn.softmax(all_weights, axis=1)[:, inet_idxs[selected_prompt_idxs], :])
                        ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'inet', 'softmax_scores', select_threshold,
                                              num_pretrain, frac_test, top1_acc, top5_acc)

                        if select_threshold != 'NA':
                            top1_acc, top5_acc = compute_metrics(
                                labels, agg_logits(inet_logits[:, selected_prompt_idxs, :])
                            ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                            results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'inet', 'equal', select_threshold,
                                                  num_pretrain, frac_test, top1_acc, top5_acc)

                        # All prompts with score weighting / thresholding.
                        selected_prompt_idxs = mad_method(all_weights[0, :, 0], select_threshold)

                        top1_acc, top5_acc = compute_metrics(
                            labels, agg_logits(all_prompts_logits[:, selected_prompt_idxs, :], all_weights[:, selected_prompt_idxs, :])
                        ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'all', 'scores', select_threshold,
                                            num_pretrain, frac_test, top1_acc, top5_acc)

                        top1_acc, top5_acc = compute_metrics(
                            labels, agg_logits(all_prompts_logits[:, selected_prompt_idxs, :], jax.nn.softmax(all_weights, axis=1)[:, selected_prompt_idxs, :])
                        ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                        results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'all', 'softmax_scores', select_threshold,
                                            num_pretrain, frac_test, top1_acc, top5_acc)

                        if select_threshold != 'NA':
                            top1_acc, top5_acc = compute_metrics(
                                labels, agg_logits(all_prompts_logits[:, selected_prompt_idxs, :])
                            ) if len(selected_prompt_idxs) > 0 else (0, 0, 0, 0)
                            results_df = add_row_(results_df, dataset_name, debias_mode, img_mean, 'all', 'equal', select_threshold,
                                                num_pretrain, frac_test, top1_acc, top5_acc)

                del all_weights
                del random_logits

                with tf.io.gfile.GFile(df_path, 'w') as f:
                    f.write(pickle.dumps(results_df, protocol=4))

In [None]:
results_df

## Make tables

In [None]:
with tf.io.gfile.GFile("final_results_dataframe.pkl", 'rb') as f:
  results_df = pickle.load(f)

In [None]:
results_df

In [None]:
def df_to_latex(df, apply_formatting=True):
    df_rounded = df.round(2)

    # Convert dataframe to LaTeX with formatting for largest and second largest values.
    latex_string = df_rounded.to_latex(escape=False, header=True)

    # Iterate over the floating-point columns.
    if apply_formatting:
        for column in df_rounded.select_dtypes(include=['float64']).columns:
            # Find the largest and second largest values in the column.
            largest = df_rounded[column].max()
            second_largest = df_rounded[column].nlargest(2).min()

            # Apply formatting to the largest value (bold) and second largest value (underline).
            latex_string = latex_string.replace(f'{largest:.2f}', r'\textbf{' + f'{largest:.2f}' + '}')
            latex_string = latex_string.replace(f'{second_largest:.2f}', r'\ul{' + f'{second_largest:.2f}' + '}')

    print(latex_string)

In [None]:
#@title Table 1

# Take only ImageNet datasets.
table1_df = results_df[results_df.dataset_name.isin(ds_inet.keys())]

# Remove most of the ablation rows.
table1_df = table1_df[
    table1_df.num_pretrain.isin(['NA', 20_000]) &
    table1_df.frac_test.isin(['NA', 1.]) &
    table1_df.img_mean.isin(['NA', True]) &
    table1_df.debias_mode.isin(['NA', 'both', 'none'])
]

# Construct table rows.
table1_df = pd.concat([
    # 'class name'
    table1_df[
        table1_df.prompt_set == 'classname'
    ].assign(Name='class name'),
    # 'A photo of {}'
    table1_df[
        table1_df.prompt_set == 'photo'
    ].assign(Name="`\emph{A photo of \{\}.}'"),
    # hand-crafted, equal average
    table1_df[
        (table1_df.prompt_set == 'dataset') & (table1_df.weighting == 'equal')
    ].assign(Name='hand-crafted, equal average'),
    # pool set, equal average
    table1_df[
        (table1_df.prompt_set == 'pool') & (table1_df.weighting == 'equal') & (table1_df.select_threshold == 'NA')
    ].assign(Name='pool set, equal average'),
    # max-logit scoring
    table1_df[
        (table1_df.prompt_set == 'pool') & (table1_df.weighting == 'scores') & (table1_df.debias_mode == 'none') & (table1_df.select_threshold == 'NA')
    ].assign(Name='max-logit scoring'),
    # ZPE (weighted average)
    table1_df[
        (table1_df.prompt_set == 'pool') & (table1_df.weighting == 'softmax_scores') & (table1_df.debias_mode == 'both') & (table1_df.select_threshold == 'NA')
    ].assign(Name='ZPE (weighted average)'),
    # ZPE (prompt selection, ours)
    table1_df[
        (table1_df.prompt_set == 'pool') & (table1_df.weighting == 'softmax_scores') & (table1_df.debias_mode == 'both') & (table1_df.select_threshold == .5)
    ].assign(Name='ZPE (prompt selection, ours)'),
])

# Drop columns.
table1_df = table1_df[['Name', 'dataset_name', 'top1_acc']]

# Pivot the table.
table1_df = table1_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table1_df.columns.name = None
table1_df.index.name = None

# Sort columns.
table1_df = table1_df.sort_index(axis=1)

# Add the averaged column.
table1_df['avg'] = table1_df.mean(axis=1)

table1_df

In [None]:
df_to_latex(table1_df)

In [None]:
#@title Table 2

# Take only fine datasets.
table2_df = results_df[results_df.dataset_name.isin(ds_fine.keys())]

# Remove most of the ablation rows.
table2_df = table2_df[
    table2_df.num_pretrain.isin(['NA', 20_000]) &
    table2_df.frac_test.isin(['NA', 1.]) &
    table2_df.img_mean.isin(['NA', True]) &
    table2_df.debias_mode.isin(['NA', 'both', 'none'])
]

# Construct table rows.
table2_df = pd.concat([
    # 'class name'
    table2_df[
        table2_df.prompt_set == 'classname'
    ].assign(Name='class name'),
    # 'A photo of {}'
    table2_df[
        table2_df.prompt_set == 'photo'
    ].assign(Name="`\emph{A photo of \{\}.}'"),
    # hand-crafted, equal average
    table2_df[
        (table2_df.prompt_set == 'dataset') & (table2_df.weighting == 'equal')
    ].assign(Name='hand-crafted, equal average'),
    # pool set, equal average
    table2_df[
        (table2_df.prompt_set == 'pool') & (table2_df.weighting == 'equal') & (table2_df.select_threshold == 'NA')
    ].assign(Name='pool set, equal average'),
    # max-logit scoring
    table2_df[
        (table2_df.prompt_set == 'pool') & (table2_df.weighting == 'scores') & (table2_df.debias_mode == 'none') & (table2_df.select_threshold == 'NA')
    ].assign(Name='max-logit scoring'),
    # ZPE (weighted average)
    table2_df[
        (table2_df.prompt_set == 'pool') & (table2_df.weighting == 'softmax_scores') & (table2_df.debias_mode == 'both') & (table2_df.select_threshold == 'NA')
    ].assign(Name='ZPE (weighted average)'),
    # ZPE (prompt selection, ours)
    table2_df[
        (table2_df.prompt_set == 'pool') & (table2_df.weighting == 'softmax_scores') & (table2_df.debias_mode == 'both') & (table2_df.select_threshold == 2.)
    ].assign(Name='ZPE (prompt selection, ours)'),
])

# Drop columns.
table2_df = table2_df[['Name', 'dataset_name', 'top1_acc']]

# Pivot the table.
table2_df = table2_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table2_df.columns.name = None
table2_df.index.name = None

# Sort columns.
table2_df = table2_df.sort_index(axis=1)

# Add the averaged column.
table2_df['avg'] = table2_df.mean(axis=1)

table2_df

In [None]:
df_to_latex(table2_df)

In [None]:
#@title Table 3

table3_df = results_df

# Remove most of the ablation rows.
table3_df = table3_df[
    table3_df.num_pretrain.isin([20_000, 'NA']) &
    table3_df.frac_test.isin([1., 'NA']) &
    table3_df.img_mean.isin([True]) &
    table3_df.weighting.isin(['softmax_scores']) &
    (
        (table3_df.dataset_name.isin(ds_inet.keys()) & (table3_df.select_threshold.isin(['NA', .5]))) |
        (table3_df.dataset_name.isin(ds_fine.keys()) & (table3_df.select_threshold.isin(['NA', 2.])))
    ) &
    table3_df.prompt_set.isin(['pool'])
]

# Construct table rows.
table3_df = pd.concat([
    # none (weighted average)
    table3_df[
        (table3_df.debias_mode == 'none') & (table3_df.select_threshold == 'NA')
    ].assign(Name='none (weighted average)'),
    # pretrain (weighted average)
    table3_df[
        (table3_df.debias_mode == 'pretrain') & (table3_df.select_threshold == 'NA')
    ].assign(Name='pretrain (weighted average)'),
    # pretrain_star (weighted average)
    table3_df[
        (table3_df.debias_mode == 'pretrain_star') & (table3_df.select_threshold == 'NA')
    ].assign(Name='pretrain_star (weighted average)'),
    # test (weighted average)
    table3_df[
        (table3_df.debias_mode == 'test') & (table3_df.select_threshold == 'NA')
    ].assign(Name='test (weighted average)'),
    # both (weighted average)
    table3_df[
        (table3_df.debias_mode == 'both') & (table3_df.select_threshold == 'NA')
    ].assign(Name='both (weighted average)'),
    # none (prompt selection, ours)
    table3_df[
        (table3_df.debias_mode == 'none') & (table3_df.select_threshold != 'NA')
    ].assign(Name='none (prompt selection, ours)'),
    # pretrain (prompt selection)
    table3_df[
        (table3_df.debias_mode == 'pretrain') & (table3_df.select_threshold != 'NA')
    ].assign(Name='pretrain (prompt selection)'),
    # pretrain_star (prompt selection)
    table3_df[
        (table3_df.debias_mode == 'pretrain_star') & (table3_df.select_threshold != 'NA')
    ].assign(Name='pretrain_star (prompt selection)'),
    # test (prompt selection)
    table3_df[
        (table3_df.debias_mode == 'test') & (table3_df.select_threshold != 'NA')
    ].assign(Name='test (prompt selection)'),
    # both (prompt selection)
    table3_df[
        (table3_df.debias_mode == 'both') & (table3_df.select_threshold != 'NA')
    ].assign(Name='both (prompt selection)'),
])

# # Order the rows.
# desired_order = ['none', 'pretrain', 'pretrain_star', 'test', 'both']
# table3_df = table3_df.sort_values(by='debias_mode', key=lambda x: pd.Categorical(x, categories=desired_order, ordered=True))

# Pivot the table.
table3_df = table3_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table3_df.columns.name = None
table3_df.index.name = None

# Create the new columns.
table3_df['variants'] = table3_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)
table3_df['fine'] = table3_df[ds_fine.keys()].mean(axis=1)
table3_df['all'] = table3_df.mean(axis=1)


# Drop columns.
table3_df = table3_df[['imagenet', 'variants', 'fine', 'all']]

table3_df

In [None]:
df_to_latex(table3_df[:5])
df_to_latex(table3_df[5:])

In [None]:
#@title Table 4

table4_df = results_df

# Remove most of the ablation rows.
table4_df = table4_df[
    table4_df.num_pretrain.isin([20_000]) &
    table4_df.frac_test.isin([1.]) &
    table4_df.img_mean.isin([True]) &
    table4_df.debias_mode.isin(['both']) &
    table4_df.prompt_set.isin(['pool']) &
    (
        (table4_df.dataset_name.isin(ds_inet.keys()) & (table4_df.select_threshold.isin(['NA', .5]))) |
        (table4_df.dataset_name.isin(ds_fine.keys()) & (table4_df.select_threshold.isin(['NA', 2.])))
    )
]

# Construct table rows.
table4_df = pd.concat([
    # scores (weighted average)
    table4_df[
        (table4_df.weighting == 'scores') & (table4_df.select_threshold == 'NA')
    ].assign(Name='scores (weighted average)'),
    # scores^10 (weighted average)
    table4_df[
        (table4_df.weighting == 'scores^10') & (table4_df.select_threshold == 'NA')
    ].assign(Name='scores^10 (weighted average)'),
    # softmax_scores (weighted average)
    table4_df[
        (table4_df.weighting == 'softmax_scores') & (table4_df.select_threshold == 'NA')
    ].assign(Name='softmax_scores (weighted average)'),
    # scores (prompt selection)
    table4_df[
        (table4_df.weighting == 'scores') & (table4_df.select_threshold != 'NA')
    ].assign(Name='scores (prompt selection)'),
    # scores^10 (prompt selection)
    table4_df[
        (table4_df.weighting == 'scores^10') & (table4_df.select_threshold != 'NA')
    ].assign(Name='scores^10 (prompt selection)'),
    # softmax_scores (prompt selection)
    table4_df[
        (table4_df.weighting == 'softmax_scores') & (table4_df.select_threshold != 'NA')
    ].assign(Name='softmax_scores (prompt selection)'),
])

# Drop columns.
table4_df = table4_df[['Name', 'dataset_name', 'top1_acc']]

# Pivot the table.
table4_df = table4_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table4_df.columns.name = None
table4_df.index.name = None

# Create the new columns.
table4_df['variants'] = table4_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)
table4_df['fine'] = table4_df[ds_fine.keys()].mean(axis=1)
table4_df['all'] = table4_df.mean(axis=1)


# Drop columns.
table4_df = table4_df[['imagenet', 'variants', 'fine', 'all']]

table4_df

In [None]:
df_to_latex(table4_df[:3], apply_formatting=True)
df_to_latex(table4_df[3:], apply_formatting=True)

In [None]:
#@title Table 5

table5_df = results_df

# Remove most of the ablation rows.
table5_df = table5_df[
    table5_df.num_pretrain.isin(['NA', 20_000]) &
    table5_df.frac_test.isin(['NA', 1.]) &
    table5_df.img_mean.isin(['NA', True]) &
    table5_df.debias_mode.isin(['NA', 'both'])  &
    (
        (table5_df.dataset_name.isin(ds_inet.keys()) & (table5_df.select_threshold.isin(['NA', .5]))) |
        (table5_df.dataset_name.isin(ds_fine.keys()) & (table5_df.select_threshold.isin(['NA', 2.])))
    )
]

# Construct table rows.
table5_df = pd.concat([
    # hand-crafted, equal average
    table5_df[
        (table5_df.prompt_set == 'dataset') & (table5_df.weighting == 'equal')
    ].assign(Name='hand-crafted, equal average'),
    # pool set, equal average
    table5_df[
        (table5_df.prompt_set == 'pool') & (table5_df.weighting == 'equal') & (table5_df.select_threshold == 'NA')
    ].assign(Name='pool set, equal average'),
    # ZPE (weighted average)
    table5_df[
        (table5_df.prompt_set == 'pool') & (table5_df.weighting == 'softmax_scores') & (table5_df.select_threshold == 'NA')
    ].assign(Name='ZPE (weighted average)'),
    # ZPE (prompt selection, ours)
    table5_df[
        (table5_df.prompt_set == 'pool') & (table5_df.weighting == 'softmax_scores') & (table5_df.select_threshold != 'NA')
    ].assign(Name='ZPE (prompt selection, ours)'),
])

# Drop columns.
table5_df = table5_df[['Name', 'dataset_name', 'top1_acc']]

# Pivot the table.
table5_df = table5_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table5_df.columns.name = None
table5_df.index.name = None

# Create the new columns.
table5_df['variants'] = table5_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)
table5_df['fine'] = table5_df[ds_fine.keys()].mean(axis=1)
table5_df['all'] = table5_df.mean(axis=1)


# Drop columns.
table5_df = table5_df[['imagenet', 'variants', 'fine', 'all']]

table5_df

In [None]:
df_to_latex(table5_df)

In [None]:
#@title Table 6

table6_df = results_df

# Remove most of the ablation rows.
table6_df = table6_df[
    table6_df.num_pretrain.isin(['NA', 20_000]) &
    table6_df.frac_test.isin(['NA', 1.]) &
    table6_df.img_mean.isin(['NA', True]) &
    table6_df.debias_mode.isin(['NA', 'both']) &
    (
        (table6_df.dataset_name.isin(ds_inet.keys()) & (table6_df.select_threshold.isin(['NA', .5]))) |
        (table6_df.dataset_name.isin(ds_fine.keys()) & (table6_df.select_threshold.isin(['NA', 2.])))
    )
]

# Construct table rows.
table6_df = pd.concat([
    # hand-crafted, equal average
    table6_df[
        (table6_df.prompt_set == 'dataset') & (table6_df.weighting == 'equal')
    ].assign(Name='hand-crafted, equal average'),
    # hand-crafted, ZPE weights
    table6_df[
        (table6_df.prompt_set == 'dataset') & (table6_df.weighting == 'softmax_scores')
    ].assign(Name='hand-crafted, ZPE weights'),
    # ZPE (weighted average, 80 prompts)
    table6_df[
        (table6_df.prompt_set == 'inet') & (table6_df.weighting == 'softmax_scores') & (table6_df.select_threshold == 'NA')
    ].assign(Name='ZPE (weighted average, 80 prompts)'),
    # ZPE (weighted average, 247 prompts)
    table6_df[
        (table6_df.prompt_set == 'pool') & (table6_df.weighting == 'softmax_scores') & (table6_df.select_threshold == 'NA')
    ].assign(Name='ZPE (weighted average, 247 prompts)'),
    # ZPE (prompt selection, 80 prompts)
    table6_df[
        (table6_df.prompt_set == 'inet') & (table6_df.weighting == 'softmax_scores') & (table6_df.select_threshold != 'NA')
    ].assign(Name='ZPE (prompt selection, 80 prompts)'),
    # ZPE (prompt selection, 247 prompts)
    table6_df[
        (table6_df.prompt_set == 'pool') & (table6_df.weighting == 'softmax_scores') & (table6_df.select_threshold != 'NA')
    ].assign(Name='ZPE (prompt selection, 247 prompts)'),
])

# Drop columns.
table6_df = table6_df[['Name', 'dataset_name', 'top1_acc']]

# Pivot the table.
table6_df = table6_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table6_df.columns.name = None
table6_df.index.name = None

# Create the new columns.
table6_df['variants'] = table6_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)
table6_df['fine'] = table6_df[ds_fine.keys()].mean(axis=1)
table6_df['all'] = table6_df.mean(axis=1)


# Drop columns.
table6_df = table6_df[['imagenet', 'variants', 'fine', 'all']]

table6_df

In [None]:
df_to_latex(table6_df[:2], apply_formatting=False)
df_to_latex(table6_df[2:5], apply_formatting=True)
df_to_latex(table6_df[5:], apply_formatting=True)

In [None]:
#@title Table 7

table7_df = results_df

# Remove most of the ablation rows.
table7_df = table7_df[
    table7_df.num_pretrain.isin([20_000, 10_000, 5_000]) &
    table7_df.frac_test.isin([1.]) &
    table7_df.img_mean.isin([True]) &
    table7_df.weighting.isin(['softmax_scores']) &
    table7_df.debias_mode.isin(['both']) &
    table7_df.prompt_set.isin(['pool']) &
    (
        (table7_df.dataset_name.isin(ds_inet.keys()) & (table7_df.select_threshold.isin(['NA', .5]))) |
        (table7_df.dataset_name.isin(ds_fine.keys()) & (table7_df.select_threshold.isin(['NA', 2.])))
    )
]

# Construct table rows.
table7_df = pd.concat([
    # 5k (weighted average)
    table7_df[
        (table7_df.num_pretrain == 5_000) & (table7_df.select_threshold == 'NA')
    ].assign(Name='5k (weighted average)'),
    # 10k (weighted average)
    table7_df[
        (table7_df.num_pretrain == 10_000) & (table7_df.select_threshold == 'NA')
    ].assign(Name='10k (weighted average)'),
    # 20k (weighted average)
    table7_df[
        (table7_df.num_pretrain == 20_000) & (table7_df.select_threshold == 'NA')
    ].assign(Name='20k (weighted average)'),
    # 5k (prompt selection)
    table7_df[
        (table7_df.num_pretrain == 5_000) & (table7_df.select_threshold != 'NA')
    ].assign(Name='5k (prompt selection)'),
    # 10k (prompt selection)
    table7_df[
        (table7_df.num_pretrain == 10_000) & (table7_df.select_threshold != 'NA')
    ].assign(Name='10k (prompt selection)'),
    # 20k (prompt selection)
    table7_df[
        (table7_df.num_pretrain == 20_000) & (table7_df.select_threshold != 'NA')
    ].assign(Name='20k (prompt selection)'),
])

# Drop columns.
table7_df = table7_df[['Name', 'dataset_name', 'top1_acc']]

# Pivot the table.
table7_df = table7_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table7_df.columns.name = None
table7_df.index.name = None

# Create the new columns.
table7_df['variants'] = table7_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)
table7_df['fine'] = table7_df[ds_fine.keys()].mean(axis=1)
table7_df['all'] = table7_df.mean(axis=1)


# Drop columns.
table7_df = table7_df[['imagenet', 'variants', 'fine', 'all']]

table7_df

In [None]:
df_to_latex(table7_df, apply_formatting=False)

In [None]:
#@title Table 8

table8_df = results_df

# Remove most of the ablation rows.
table8_df = table8_df[
    table8_df.num_pretrain.isin([20_000]) &
    table8_df.frac_test.isin([1., .5, .2, .1]) &
    table8_df.img_mean.isin([True]) &
    table8_df.weighting.isin(['softmax_scores']) &
    table8_df.debias_mode.isin(['both']) &
    table8_df.prompt_set.isin(['pool']) &
    (
        (table8_df.dataset_name.isin(ds_inet.keys()) & (table8_df.select_threshold.isin(['NA', .5]))) |
        (table8_df.dataset_name.isin(ds_fine.keys()) & (table8_df.select_threshold.isin(['NA', 2.])))
    )
]

# Construct table rows.
table8_df = pd.concat([
    # 10% (weighted average)
    table8_df[
        (table8_df.frac_test == .1) & (table8_df.select_threshold == 'NA')
    ].assign(Name='10% (weighted average)'),
    # 20% (weighted average)
    table8_df[
        (table8_df.frac_test == .2) & (table8_df.select_threshold == 'NA')
    ].assign(Name='20% (weighted average)'),
    # 50% (weighted average)
    table8_df[
        (table8_df.frac_test == .5) & (table8_df.select_threshold == 'NA')
    ].assign(Name='50% (weighted average)'),
    # 100% (weighted average)
    table8_df[
        (table8_df.frac_test == 1.) & (table8_df.select_threshold == 'NA')
    ].assign(Name='100% (weighted average)'),
    # 10% (prompt selection)
    table8_df[
        (table8_df.frac_test == .1) & (table8_df.select_threshold != 'NA')
    ].assign(Name='10% (prompt selection)'),
    # 20% (prompt selection)
    table8_df[
        (table8_df.frac_test == .2) & (table8_df.select_threshold != 'NA')
    ].assign(Name='20% (prompt selection)'),
    # 50% (prompt selection)
    table8_df[
        (table8_df.frac_test == .5) & (table8_df.select_threshold != 'NA')
    ].assign(Name='50% (prompt selection)'),
    # 100% (prompt selection)
    table8_df[
        (table8_df.frac_test == 1.) & (table8_df.select_threshold != 'NA')
    ].assign(Name='100% (prompt selection)'),
])

# Drop columns.
table8_df = table8_df[['Name', 'dataset_name', 'top1_acc']]

# Pivot the table.
table8_df = table8_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table8_df.columns.name = None
table8_df.index.name = None

# Create the new columns.
table8_df['variants'] = table8_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)
table8_df['fine'] = table8_df[ds_fine.keys()].mean(axis=1)
table8_df['all'] = table8_df.mean(axis=1)


# Drop columns.
table8_df = table8_df[['imagenet', 'variants', 'fine', 'all']]

table8_df

In [None]:
df_to_latex(table8_df, apply_formatting=False)

In [None]:
#@title Table 9

table9_df = results_df

# Remove most of the ablation rows.
table9_df = table9_df[
    table9_df.num_pretrain.isin([20_000, 'NA']) &
    table9_df.frac_test.isin([1., 'NA']) &
    table9_df.debias_mode.isin(['both', 'none', 'NA']) &
    table9_df.prompt_set.isin(['dataset', 'pool']) &
    table9_df.weighting.isin(['softmax_scores', 'scores', 'equal']) &
    table9_df.select_threshold.isin(['NA'])
]

# Construct table rows.
table9_df = pd.concat([
    # hand crafted, equal average
    table9_df[
        (table9_df.prompt_set == 'dataset') & (table9_df.weighting == 'equal')
    ].assign(Name='hand crafted, equal average'),
    # hand crafted, ZPE weights, per-dataset
    table9_df[
        (table9_df.prompt_set == 'dataset') & (table9_df.weighting == 'softmax_scores') & (table9_df.debias_mode == 'both') & (table9_df.img_mean == True)
    ].assign(Name='hand crafted, ZPE weights, per-dataset'),
    # hand crafted, ZPE weights, per-example
    table9_df[
        (table9_df.prompt_set == 'dataset') & (table9_df.weighting == 'softmax_scores') & (table9_df.debias_mode == 'both') & (table9_df.img_mean == False)
    ].assign(Name='hand crafted, ZPE weights, per-example'),

    # pool set, equal average
    table9_df[
        (table9_df.prompt_set == 'pool') & (table9_df.weighting == 'equal')
    ].assign(Name='pool set, equal average'),
    # pool set, ZPE weights, per-dataset
    table9_df[
        (table9_df.prompt_set == 'pool') & (table9_df.weighting == 'softmax_scores') & (table9_df.debias_mode == 'both') & (table9_df.img_mean == True)
    ].assign(Name='pool set, ZPE weights, per-dataset'),
    # pool set, ZPE weights, per-example
    table9_df[
        (table9_df.prompt_set == 'pool') & (table9_df.weighting == 'softmax_scores') & (table9_df.debias_mode == 'both') & (table9_df.img_mean == False)
    ].assign(Name='pool set, ZPE weights, per-example'),

    # pool set, ZPE weights, per-dataset, no softmax
    table9_df[
        (table9_df.prompt_set == 'pool') & (table9_df.weighting == 'scores') & (table9_df.debias_mode == 'both') & (table9_df.img_mean == True)
    ].assign(Name='pool set, ZPE weights, per-dataset, no softmax'),
    # pool set, ZPE weights, per-example, no softmax
    table9_df[
        (table9_df.prompt_set == 'pool') & (table9_df.weighting == 'scores') & (table9_df.debias_mode == 'both') & (table9_df.img_mean == False)
    ].assign(Name='pool set, ZPE weights, per-example, no softmax'),

    # pool set, ZPE weights, per-dataset, no norm
    table9_df[
        (table9_df.prompt_set == 'pool') & (table9_df.weighting == 'softmax_scores') & (table9_df.debias_mode == 'none') & (table9_df.img_mean == True)
    ].assign(Name='pool set, ZPE weights, per-dataset, no norm'),
    # pool set, ZPE weights, per-example, no norm
    table9_df[
        (table9_df.prompt_set == 'pool') & (table9_df.weighting == 'softmax_scores') & (table9_df.debias_mode == 'none') & (table9_df.img_mean == False)
    ].assign(Name='pool set, ZPE weights, per-example, no norm'),

])

# Drop columns.
table9_df = table9_df[['Name', 'dataset_name', 'top1_acc']]

# Pivot the table.
table9_df = table9_df.pivot_table(index='Name', columns='dataset_name', values='top1_acc', sort=False)

# Drop extra levels.
table9_df.columns.name = None
table9_df.index.name = None

# Create the new columns.
table9_df['variants'] = table9_df[list(ds_inet.keys() - {'imagenet'})].mean(axis=1)
table9_df['fine'] = table9_df[ds_fine.keys()].mean(axis=1)
table9_df['all'] = table9_df.mean(axis=1)


# Drop columns.
table9_df = table9_df[['imagenet', 'variants', 'fine', 'all']]

table9_df

In [None]:
df_to_latex(table9_df[:3], apply_formatting=True)
df_to_latex(table9_df[3:6], apply_formatting=True)
df_to_latex(table9_df[6:8], apply_formatting=True)
df_to_latex(table9_df[8:], apply_formatting=True)

## Collect per-prompt per-dataset scores (Appendix C)

In [None]:
OVERWRITE_RESULTS = False
LOAD_RESULTS = False

per_dataset_per_prompt_scores_df_name = 'per_dataset_per_prompt_scores.pkl'
per_dataset_per_prompt_scores_df_path = os.path.join(base_path, per_dataset_per_prompt_scores_df_name)

per_dataset_per_prompt_scores_df = pd.DataFrame(columns = [
    'dataset_name', 'prompt', 'score'
])

if OVERWRITE_RESULTS:
    with tf.io.gfile.GFile(per_dataset_per_prompt_scores_df_path, 'w') as f:
        f.write(pickle.dumps(per_dataset_per_prompt_scores_df, protocol=4))

if LOAD_RESULTS:
    with tf.io.gfile.GFile(per_dataset_per_prompt_scores_df_path, 'rb') as f:
        per_dataset_per_prompt_scores_df = pickle.load(f)

In [None]:
for dataset_name, batch_size in ds_list.items():
    print("dataset_name", dataset_name)

    # Collect prompts.
    pool_idxs = np.array([all_templates.index(p) for p in pool_templates])

    # Get prompt embeddings.
    classnames = multimodal_utils._ZEROSHOT_CLASS_NAMES[config.zeroshot_eval_datasets[dataset_name]['classnames_key']]
    classname_idxs = np.array([all_classnames.index(classname) for classname in classnames])
    del classnames
    ztxts_all_prompts = ztxts_all_prompts_all_class[:, classname_idxs, :]
    del ztxts_all_prompts_all_class
    del classname_idxs
    ztxts_pool = ztxts_all_prompts[pool_idxs, :, :]
    del ztxts_all_prompts
    del pool_idxs

    # Get image embeddings.
    zs_split = load_zeroshot_dataset(config, rng, dataset_name, zs_batch_size=batch_size)
    ds_iter = input_utils.start_input_pipeline(zs_split, config.get('prefetch_to_device', 1))
    zimgs, _ = compute_image_embeddings(ds_iter, image_resolution)

    # Get logits.
    pool_logits = get_logits(ztxts_pool, zimgs)
    random_logits = get_logits(ztxts_pool, zimgs_laion)  # [n_pretrain, n_prompts, n_classes_ds]
    del zimgs
    del ztxts_pool

    pool_weights = get_weights(pool_logits, random_logits, debias_mode='both', img_mean=True, frac_test=1.)
    del pool_logits
    del random_logits
    pool_weights = jax.nn.softmax(pool_weights, axis=1)[0, :, 0]


    idx = np.argsort(np.abs(pool_weights))
    prompts_ordered = np.array(pool_templates)[idx]
    weights_list = pool_weights[idx][::-1]
    prompts_list = prompts_ordered[::-1]

    triplets = zip(itertools.repeat(dataset_name), prompts_list, weights_list)
    per_dataset_per_prompt_scores_df = pd.concat([
        per_dataset_per_prompt_scores_df,
        pd.DataFrame(triplets, columns=per_dataset_per_prompt_scores_df.columns)
    ], ignore_index=True)

    with tf.io.gfile.GFile(per_dataset_per_prompt_scores_df_path, 'w') as f:
        f.write(pickle.dumps(per_dataset_per_prompt_scores_df, protocol=4))

In [None]:
per_dataset_per_prompt_scores_df

In [None]:
per_dataset_per_prompt_scores_df[per_dataset_per_prompt_scores_df.dataset_name == 'caltech101']['prompt'].to_csv("pool_set.csv", index=False)

In [None]:
for dataset_name, group in per_dataset_per_prompt_scores_df.groupby("dataset_name"):
    print(dataset_name)

    for i, row in group.reset_index().head(10).iterrows():
        prompt = row["prompt"].replace("{", "\{").replace("}", "\}")
        print(f'{i+1} & `\\emph⁍{prompt}⁌\' & {row["score"]:0.4f} \\\\'.replace("⁍", "{").replace("⁌", "}"))
    print('\\multicolumn{3}{c}{\\vdots} \\\\')
    for i, row in group.reset_index().tail(10).iterrows():
        prompt = row["prompt"].replace("{", "\{").replace("}", "\}")
        print(f'{i+1} & `\\emph⁍{prompt}⁌\' & {row["score"]:0.4f} \\\\'.replace("⁍", "{").replace("⁌", "}"))

    print("")

## Make figure 3

In [None]:
dataset_name = 'imagenet'

# Collect prompts.
pool_idxs = np.array([all_templates.index(p) for p in pool_templates])

# Get prompt embeddings.
classnames = multimodal_utils._ZEROSHOT_CLASS_NAMES[config.zeroshot_eval_datasets[dataset_name]['classnames_key']]
classname_idxs = np.array([all_classnames.index(classname) for classname in classnames])
del classnames
ztxts_all_prompts = ztxts_all_prompts_all_class[:, classname_idxs, :]
del ztxts_all_prompts_all_class
del classname_idxs
ztxts_pool = ztxts_all_prompts[pool_idxs, :, :]
del ztxts_all_prompts
del pool_idxs

# Get image embeddings.
zs_split = load_zeroshot_dataset(config, rng, dataset_name, zs_batch_size=5000)
ds_iter = input_utils.start_input_pipeline(zs_split, config.get('prefetch_to_device', 1))
zimgs, labels = compute_image_embeddings(ds_iter, image_resolution)

# Get logits.
pool_logits = get_logits(ztxts_pool, zimgs)
random_logits = get_logits(ztxts_pool, zimgs_laion)  # [n_pretrain, n_prompts, n_classes_ds]
del zimgs
del ztxts_pool

pool_weights = get_weights(pool_logits, random_logits, debias_mode='both', img_mean=True, frac_test=1.)
del random_logits

pool_weights_softmax = jax.nn.softmax(pool_weights, axis=1)[0, :, 0]
pool_weights = pool_weights[0, :, 0]

In [None]:
pcts = 1 - np.arange(0.02, 1.01, 0.02)
n_prompts = pool_weights.shape[0]
n_prompts_selected = set(np.floor((1 - pcts) * n_prompts).astype(np.int32)) | set([1])
n_prompts_selected = np.sort(np.array(list(n_prompts_selected)))

In [None]:
accs_no_softmax = []
for n in n_prompts_selected:
  idxs = np.argsort(1/np.abs(pool_weights))[:n]
  mask = np.zeros((1, pool_n_prompts, 1))
  mask[0, idxs, 0] = 1
  masked_weights = mask * pool_weights[jnp.newaxis, :, jnp.newaxis]
  logits_pool_weighted = agg_logits(pool_logits, weights=masked_weights)
  acc, _, _, _ = compute_metrics(labels, logits_pool_weighted);
  accs_no_softmax.append(acc)

In [None]:
accs_softmax = []
for n in n_prompts_selected:
  idxs = np.argsort(1/np.abs(pool_weights_softmax))[:n]
  mask = np.zeros((1, pool_n_prompts, 1))
  mask[0, idxs, 0] = 1
  masked_weights = mask * pool_weights_softmax[jnp.newaxis, :, jnp.newaxis]
  logits_pool_weighted = agg_logits(pool_logits, weights=masked_weights)
  acc, _, _, _ = compute_metrics(labels, logits_pool_weighted);
  accs_softmax.append(acc)

In [None]:
fig, (ax1, ax2) = plt.subplots(1, 2, dpi=400, figsize=(line_width, text_width/4.5), tight_layout=True, sharey=True)


ax1.plot(n_prompts_selected, accs_no_softmax, '-', alpha=0.8, color='C0', lw=0.8)
idx = np.argmax(accs_no_softmax)
ax1.plot(n_prompts_selected[idx], accs_no_softmax[idx], marker='*', color='C0', lw=0.8, ms=3)

ax12 = ax1.twinx()
idx = np.argsort(1/np.abs(pool_weights))
ax12.plot(np.abs(pool_weights)[idx], alpha=0.8, lw=0.8, color='C2')


ax2.plot(n_prompts_selected, accs_softmax, '-', alpha=0.8, color='C0', lw=0.8)
idx = np.argmax(accs_softmax)
ax2.plot(n_prompts_selected[idx], accs_softmax[idx], marker='*', color='C0', lw=0.8, ms=3)

ax22 = ax2.twinx()
idx = np.argsort(1/np.abs(pool_weights_softmax))
ax22.plot(np.abs(pool_weights_softmax)[idx], alpha=0.8, lw=0.8, color='C2')


ax1.set_ylabel('acc')
ax22.set_ylabel('score')
fig.text(0.5, 0.0, 'prompt index', ha='center')
ax1.set_title('no softmax')
ax2.set_title('softmax')

legend_elements = [
    Line2D([0], [0], alpha=0.8, color='C0', lw=0.8, label='acc'),
    Line2D([0], [0], alpha=0.8, color='C2', lw=0.8, label='score'),
]
ax2.legend(handles=legend_elements, loc='center right')

ax1.grid(alpha=0.3)
ax2.grid(alpha=0.3)

plt.savefig("acc_score_curves.pdf", dpi=400, format="pdf", bbox_inches='tight', pad_inches=0.01)
plt.show()