## 1. Reproducing CLIP zero-shot ImageNet classification performance

Replicating the latest OpenAI [Colab](https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb).


In [1]:
%pylab inline
import tensorflow_datasets as tfds
import jax
import jax.numpy as jnp
from tqdm import tqdm
import tensorflow as tf
import random
import os
import json
from scipy.special import softmax
from PIL import Image
import pandas as pd
from scenic.projects.baselines.clip import model as clip
from scenic.projects.baselines.clip import tokenizer as clip_tokenizer

%pylab is deprecated, use %matplotlib inline and import the required libraries.
Populating the interactive namespace from numpy and matplotlib


In [2]:
#@title ImageNet classNames
# https://github.com/openai/CLIP/blob/main/notebooks/Prompt_Engineering_for_ImageNet.ipynb
clip.IMAGENET_CLASSES = ['apple',
 'aquarium_fish',
 'baby',
 'bear',
 'beaver',
 'bed',
 'bee',
 'beetle',
 'bicycle',
 'bottle',
 'bowl',
 'boy',
 'bridge',
 'bus',
 'butterfly',
 'camel',
 'can',
 'castle',
 'caterpillar',
 'cattle',
 'chair',
 'chimpanzee',
 'clock',
 'cloud',
 'cockroach',
 'couch',
 'crab',
 'crocodile',
 'cup',
 'dinosaur',
 'dolphin',
 'elephant',
 'flatfish',
 'forest',
 'fox',
 'girl',
 'hamster',
 'house',
 'kangaroo',
 'keyboard',
 'lamp',
 'lawn_mower',
 'leopard',
 'lion',
 'lizard',
 'lobster',
 'man',
 'maple_tree',
 'motorcycle',
 'mountain',
 'mouse',
 'mushroom',
 'oak_tree',
 'orange',
 'orchid',
 'otter',
 'palm_tree',
 'pear',
 'pickup_truck',
 'pine_tree',
 'plain',
 'plate',
 'poppy',
 'porcupine',
 'possum',
 'rabbit',
 'raccoon',
 'ray',
 'road',
 'rocket',
 'rose',
 'sea',
 'seal',
 'shark',
 'shrew',
 'skunk',
 'skyscraper',
 'snail',
 'snake',
 'spider',
 'squirrel',
 'streetcar',
 'sunflower',
 'sweet_pepper',
 'table',
 'tank',
 'telephone',
 'television',
 'tiger',
 'tractor',
 'train',
 'trout',
 'tulip',
 'turtle',
 'wardrobe',
 'whale',
 'willow_tree',
 'wolf',
 'woman',
 'worm']

In [3]:
model_name = 'vit_b16' # we could change different backbone

model = clip.MODELS[model_name]()
vars = clip.load_model_vars(model_name)

encode_text = jax.jit(lambda texts: model.apply(vars, texts, method=model.encode_text))
encode_image = jax.jit(lambda x: model.apply(vars, x, method=model.encode_image))

tokenize_fn = clip_tokenizer.build_tokenizer()

100%|█████████████████████████████████████| 1.29M/1.29M [00:00<00:00, 4.75MiB/s]


In [4]:
def permute_words(text):
  words = text.split(' ')
  random.shuffle(words)
  return ' '.join(words)

def zeroshot_classifier(classnames, templates, permute=False):
  zeroshot_weights = []
  permute_fn = permute_words if permute else lambda x: x
  for classname in tqdm(classnames):
    texts = [permute_fn(template.format(classname)) for template in templates]
    class_embeddings = encode_text(tokenize_fn(texts))
    class_embedding = class_embeddings.mean(0)
    class_embedding /= jnp.linalg.norm(class_embedding)
    zeroshot_weights.append(class_embedding)
  return jnp.stack(zeroshot_weights, axis=1)


In [5]:
# Readout weights with prompt engineering
weights_prompteng = zeroshot_classifier(clip.IMAGENET_CLASSES, clip.PROMPTS)

# Readout weights with modified ImageNet class names only
weights_name = zeroshot_classifier(clip.IMAGENET_CLASSES, ['{}'])


100%|██████████| 100/100 [01:08<00:00,  1.46it/s]
100%|██████████| 100/100 [00:09<00:00, 10.27it/s]


In [6]:
def preprocess(batch, size=224):
  batch = tf.image.convert_image_dtype(batch, dtype=tf.float32)
  return central_crop(resize_small(batch, size), (size, size))

def central_crop(image, crop_size):
    '''
    image
    crop_size: (h, w)
    '''
    h, w = crop_size[0], crop_size[1]
    dy = (tf.shape(image)[0] - h) // 2
    dx = (tf.shape(image)[1] - w) // 2
    return tf.image.crop_to_bounding_box(image, dy, dx, h, w)

def resize_small(image, smaller_size, method="area", antialias=True):
    '''
    image
    smaller_size: an integer, that represents a new size of the smaller side of
      an input image.
    '''
    h, w = tf.shape(image)[0], tf.shape(image)[1]

    # Figure out the necessary h/w.
    ratio = (
        tf.cast(smaller_size, tf.float32) /
        tf.cast(tf.minimum(h, w), tf.float32))
    h = tf.cast(tf.round(tf.cast(h, tf.float32) * ratio), tf.int32)
    w = tf.cast(tf.round(tf.cast(w, tf.float32) * ratio), tf.int32)

    dtype = image.dtype
    image = tf.image.resize(image, (h, w), method, antialias)
    return tf.cast(image, dtype)


def normalize(img):
  return (img - clip.IMAGE_MEAN) / clip.IMAGE_STD

def unnormalize(x):
  return x * clip.IMAGE_STD + clip.IMAGE_MEAN

In [7]:
def load_dataset(dataset='cifar100', split='train', batch_size=1024):
  ds = tfds.load(dataset, split=split)
  def _preprocess(d):
    d['image'] = normalize(preprocess(d['image']))
    return d
  def _prepare(d):
    return jax.tree_map(lambda x: x._numpy(), d)
  batched_dataset = ds.map(_preprocess).batch(batch_size)
  batched_dataset = map(_prepare, batched_dataset)
  return batched_dataset

def load_dataset_info(dataset='cifar100', split='train', batch_size=1024):
  ds, info = tfds.load(dataset, split=split, with_info="true")
  def _preprocess(d):
    d['image'] = normalize(preprocess(d['image']))
    return d
  def _prepare(d):
    return jax.tree_map(lambda x: x._numpy(), d)
  batched_dataset = ds.map(_preprocess).batch(batch_size)
  batched_dataset = map(_prepare, batched_dataset)
  return batched_dataset, info

# def load_dataset_from(data_dir='YOUR/LOCAL/PATH/imagenet2012', dataset='cifar100', split='train', batch_size=1024):
#   # ds = tfds.load(dataset, split=split, data_dir=data_dir)
#   ds = tfds.builder_from_directory(data_dir)
#   ds = ds.as_dataset(split='validation')
#   def _preprocess(d):
#     d['image'] = normalize(preprocess(d['image']))
#     return d
#   def _prepare(d):
#     return jax.tree_map(lambda x: x._numpy(), d)
#   batched_dataset = ds.map(_preprocess).batch(batch_size)
#   batched_dataset = map(_prepare, batched_dataset)
#   return batched_dataset

In [8]:
def compute_image_embeddings(dset, norm_image=True):
  embeddings = []
  labels = []
  for batch in tqdm(dset):
    image_embedding = encode_image(batch['image'])
    if norm_image:
      image_embedding /= jnp.linalg.norm(image_embedding)
    embeddings.append(image_embedding)
    labels.append(batch['label'])
  return jnp.vstack(embeddings), jnp.hstack(labels)

def compute_accuracy(logits, labels):
  top_probs, top_labels = jax.lax.top_k(logits, 5)
  top1 = 100 * jnp.mean(top_labels[:, 0] == labels)
  top5 = 100 * jnp.sum(top_labels == labels[:, None]) / labels.shape[0]
  return top1, top5

In [9]:
dset = load_dataset('cifar100')
# You could also first download ImageNet and then process them with tensorflow_datasets and load them with function:
# dset = load_dataset_from(data_dir='YOUR/LOCAL/PATH/imagenet2012', dataset='imagenet2012', split='validation')
embeddings, labels = compute_image_embeddings(dset)

2023-12-30 12:08:32.426101: W tensorflow/core/common_runtime/gpu/gpu_device.cc:1960] Cannot dlopen some GPU libraries. Please make sure the missing libraries mentioned above are installed properly if you would like to use GPU. Follow the guide at https://www.tensorflow.org/install/gpu for how to download and setup the required libraries for your platform.
Skipping registering GPU devices...
49it [15:37, 19.14s/it]


In [10]:
# More accurate and consistent result
logits_prompteng = np.matmul(embeddings, weights_prompteng)
logits_name = np.matmul(embeddings, weights_name)
top1_prompt, top5_prompt = compute_accuracy(logits_prompteng, labels)
top1_name, top5_name = compute_accuracy(logits_name, labels)
print(f'Prompt Engineering: top1={top1_prompt:.2f}%, top5={top5_prompt:.2f}%')
print(f'Class Names: top1={top1_name:.2f}%, top5={top5_name:.2f}%')

Prompt Engineering: top1=38.55%, top5=67.42%
Class Names: top1=32.03%, top5=59.90%


## 2. Build hierarchy from WordNet

The wordnet hierarchy is based on Github repo: https://github.com/niharikajainn/imagenet-ancestors-descendants


### WordNet parse

In [11]:
root_path ='/data/lab/Hierarchy-CLIP/' # VM yunhao

In [12]:
words_map = {}
child_map = {}
parent_map = {}
gloss_map = {} # description
# BEGIN GOOGLE-INTERNAL
words_path = os.path.join(root_path, "imagenet-ancestors-descendants", "words.txt")
gloss_path = os.path.join(root_path, "imagenet-ancestors-descendants", "gloss.txt")
child_map_path = os.path.join(root_path, "imagenet-ancestors-descendants", "is_a_new.txt")
imagenet_label_to_wordnet_file = os.path.join(root_path, "imagenet-ancestors-descendants", "imagenet_label_to_wordnet_synset.txt")
# END GOOGLE-INTERNAL

blank = ' '
comma_blank = ', '

In [13]:
#obtain wordnet_id mappings for all words
with tf.io.gfile.GFile(words_path, mode='r') as f:
	for line in f:
		line_split = line.split() # use blank " " to split
		wnid = line_split[0] # e.g., 'n03200357'
		words = line_split[1:] # e.g., ['electric,', 'electric', 'automobile,', 'electric', 'car'],
		words_map[wnid] = words
f.close()

In [14]:
#obtain wordnet_id mappings for all word description
with tf.io.gfile.GFile(gloss_path, mode='r') as f:
	for line in f:
		line_split = line.split()
		wnid = line_split[0]
		gloss = blank.join(line_split[1:])
		gloss_map[wnid] = gloss
f.close()

In [15]:
#obtain wordnet_id mappings for all parents-children
with tf.io.gfile.GFile(child_map_path, mode='r') as f:
	for line in f:
		parent, child = line.split()
		parent_map[child] = parent
		if parent not in child_map:
			child_map[parent] = [child]
		else:
			child_map[parent].append(child)
f.close()

In [16]:
print(child_map)

{'n02329401': ['n02363005', 'n02346627', 'n02342885', 'n02330245'], 'n02066707': ['n02068974'], 'n14764061': ['n14765785', 'n14766189'], 'n06855207': ['n06855985'], 'n02062430': ['n02062744'], 'n04388743': ['n02732072'], 'n02552171': ['n02657368'], 'n01900488': ['n01900719'], 'n01482071': ['n01482330'], 'n02534559': ['n02537085'], 'n11669921': ['n12041446', 'n11900569', 'n04971313', 'n11978233'], 'n12425281': ['n12454159'], 'n13756125': ['n13765396', 'n13765531', 'n13765990', 'n13766733', 'n13768850', 'n13770529'], 'n07705931': ['n07739125', 'n07767847'], 'n07707451': ['n07734744'], 'n07747055': ['n07747607'], 'n12900462': ['n12901264'], 'n04437953': ['n03046257'], 'n03163973': ['n03085013'], 'n04263760': ['n03636248'], 'n03278248': ['n04401088'], 'n06276697': ['n06277280'], 'n02821943': ['n02818832'], 'n04161981': ['n03001627', 'n04256520'], 'n03405725': ['n04379243', 'n04550184'], 'n02206270': ['n02206856'], 'n02159955': ['n02164464'], 'n02274024': ['n02274259'], 'n02311060': ['n0230

In [17]:
print(parent_map)

{'n02363005': 'n02329401', 'n02068974': 'n02066707', 'n14765785': 'n14764061', 'n06855985': 'n06855207', 'n02062744': 'n02062430', 'n02732072': 'n04388743', 'n02657368': 'n02552171', 'n01900719': 'n01900488', 'n01482330': 'n01482071', 'n02537085': 'n02534559', 'n12041446': 'n11669921', 'n11900569': 'n11669921', 'n04971313': 'n11669921', 'n11978233': 'n11669921', 'n12454159': 'n12425281', 'n13765396': 'n13756125', 'n13765531': 'n13756125', 'n13765990': 'n13756125', 'n13766733': 'n13756125', 'n13768850': 'n13756125', 'n07739125': 'n07705931', 'n07734744': 'n07707451', 'n07747607': 'n07747055', 'n07767847': 'n07705931', 'n12901264': 'n12900462', 'n03046257': 'n04437953', 'n03085013': 'n03163973', 'n03636248': 'n04263760', 'n04401088': 'n03278248', 'n06277280': 'n06276697', 'n02818832': 'n02821943', 'n03001627': 'n04161981', 'n04256520': 'n04161981', 'n04379243': 'n03405725', 'n04550184': 'n03405725', 'n02206856': 'n02206270', 'n02164464': 'n02159955', 'n02274259': 'n02274024', 'n02309337'

In [18]:
# find details given wordnet_id
category = 'n02363005' #'n02084071' is dog
descendants = []
ancestors = []
# find Descendants and Ancestors
print(words_map[category])
print(gloss_map[category]+"\n")
#list all children
# print("Descendants:\n")
# print(category)
# if category in child_map:
#   search = [child for child in child_map[category]]
# while search: # go over all children (BFS)
#   node = search.pop()
#   print("\t"+ blank.join(words_map[node])+"\n")
#   descendants.append(blank.join(words_map[node])) # keep all descendant
#   if node in child_map: #has children
#     [search.append(child) for child in child_map[node]]

#list all parents
print("Ancestors:\n")
if category in parent_map:
  node = parent_map[category] # only one parent class
else:
  node = category
while node in parent_map: # one way go up
  print("\t"+ blank.join(words_map[node])+"\n")
  ancestors.append(blank.join(words_map[node])) # keep all ancestor
  node = parent_map[node]
print("finish")

['beaver']
large semiaquatic rodent with webbed hind feet and a broad flat tail; construct complex dams and underwater lodges

Ancestors:

finish


In [19]:
# get imagenet_id: wordnet_id mapping. e.g., '0': 'n01440764'
index_wdid = {}
with tf.io.gfile.GFile(imagenet_label_to_wordnet_file, mode='r') as f:
  for line in f:
    if "{'id'" in line:
      index_wdid[line.split(": {'")[0].split("{")[-1].split(" ")[-1]] = 'n' + line.split("-n")[0].split("'")[-1]      # {0: {'id': '01440764-n',

In [20]:
print(index_wdid)

{'0': 'n07739125', '1': 'n02732072', '2': 'n09827683', '3': 'n02131653', '4': 'n02363005', '5': 'n02818832', '6': 'n02206856', '7': 'n02164464', '8': 'n02834778', '9': 'n13765396', '10': 'n13765531', '11': 'n09870926', '12': 'n02898711', '13': 'n02924116', '14': 'n02274259', '15': 'n02437136', '16': 'n13765990', '17': 'n03878066', '18': 'n02309337', '19': 'n02402425', '20': 'n03001627', '21': 'n02481823', '22': 'n03046257', '23': 'n11439690', '24': 'n02233338', '25': 'n04256520', '26': 'n01976957', '27': 'n01697178', '28': 'n13766733', '29': 'n01699831', '30': 'n02581957', '31': 'n02503517', '32': 'n07790400', '33': 'n08438533', '34': 'n02118333', '35': 'n10129825', '36': 'n02342885', '37': 'n03544360', '38': 'n01877134', '39': 'n03085013', '40': 'n03636248', '41': 'n03649909', '42': 'n02128385', '43': 'n02129165', '44': 'n01674464', '45': 'n01982650', '46': 'n10287213', '47': 'n12752205', '48': 'n03790512', '49': 'n09359803', '50': 'n02330245', '51': 'n07734744', '52': 'n12268918', '5

build hierarchy

In [21]:
# find ancestor, descendants and children for each class, may takes minites
an_des_dict_json = {}
for index in range(len(clip.IMAGENET_CLASSES)): # for 1000 classes
  category = index_wdid[str(index)]
  an_des_dict_json[str(index)] = {}
  descendants = []
  ancestors = []
  children = []
  # find Descendants and Ancestors
  #list all children
#   if category in child_map:
#     search = [child for child in child_map[category]] # here is only the child
#     children = [blank.join(words_map[ele]) for ele in search]
#   while search: # go over all children BFS priority queue
#     node = search.pop()
#     descendants.append(blank.join(words_map[node])) # keep all descendant
#     if node in child_map: #has children
#       [search.append(child) for child in child_map[node]]

  #list all parents
  node = parent_map[category] if category in parent_map else category
  while node in parent_map: # one way go up
    ancestors.append(blank.join(words_map[node])) # keep all ancestor
    node = parent_map[node]

  # save
  an_des_dict_json[str(index)]["wdid"] = category
  an_des_dict_json[str(index)]["words_map"] = blank.join(words_map[category])
#   an_des_dict_json[str(index)]["clip_words_map"] = clip.IMAGENET_CLASSES[index]
  an_des_dict_json[str(index)]["ancestors"] = ancestors
  an_des_dict_json[str(index)]["descendants"] = descendants
  an_des_dict_json[str(index)]["children"] = children

### Utility functions for finding descendants and ancestors

In [22]:
# TODO(yunhaoge) Add examples for using the functions

def bottom_up_hierarchy(max_depth_ancestor=1, max_depth_descendant=0, use_children=False):
  """Obtain the class list by considering hierarchy
    here consisder only descendent because max_depth_ancestor=0
  """
  # Consider ancedescendant into clip zero-shot
  class_an_des_mapping = [] # start index of each imagenet class, used for following aggregation
  imagenet_classes_an_des = []
  class_an_des_mapping_list = [] # index details of each imagenet class
  i = 0
  for imagenet_idx, classname in enumerate(clip.IMAGENET_CLASSES):
    class_an_des_mapping.append(i)
    imagenet_classes_an_des.append(classname) # add original class name
    class_an_des_list = [] # relevant id for specific class
    i += 1 # add index first
    class_an_des_list.append(i)
    node_ancestors = an_des_dict_json[str(imagenet_idx)]["ancestors"]
    if use_children: # consider only children
      node_descendants = an_des_dict_json[str(imagenet_idx)]["children"]
    else:
      node_descendants = an_des_dict_json[str(imagenet_idx)]["descendants"]

    for an_idx, ancestor in enumerate(node_ancestors): # select ancestors
      if an_idx < max_depth_ancestor:
        if comma_blank in ancestor: # contains synonym, more than one, keep only one
          imagenet_classes_an_des.append(ancestor.split(comma_blank)[0])
          i += 1
          class_an_des_list.append(i)
        else:
          imagenet_classes_an_des.append(ancestor)
          i += 1
          class_an_des_list.append(i)
    for des_idx, descentant in enumerate(node_descendants): # select descendants
      if des_idx < max_depth_descendant:
        if comma_blank in descentant: # contains synonym, more than one, keep only one
          imagenet_classes_an_des.append(descentant.split(comma_blank)[0])
          i += 1
          class_an_des_list.append(i)
        else:
          imagenet_classes_an_des.append(descentant)
          i += 1
          class_an_des_list.append(i)
    class_an_des_mapping_list.append(class_an_des_list)
  return imagenet_classes_an_des, class_an_des_mapping, class_an_des_mapping_list

In [23]:
def top_down_hierarchy(interest_case, clip_templete, use_LCA = True, use_ancestor = True):
  """Obtain the word list by adding LCA."""
  # interest_case
  logits_ori_all = []
  for idx, interest_case_ele in enumerate(tqdm(interest_case)):
    # original word embedding
    word_list_ori = np.array(clip.IMAGENET_CLASSES)[np.array(interest_case_ele)].tolist()
    if use_LCA: # add LCA/A
      if use_ancestor: # use Ancestor
        word_list =  [word + blank + A_all[idx] for word in word_list_ori]
      else: # use LCA
        word_list =  [word + blank + LCA_all[idx] for word in word_list_ori]
    else:
      word_list =  [word for word in word_list_ori]
  return word_list

In [24]:
# fine the LCA for each image top 5
def find_LCA(top5_ancestor): # general
  """Obtain the Lowest comman ancestor of top 5 classes.

  Args:
      top5_ancestor: List of the ancestors for each candidate class
  Returns:
      LCA: Str, lowest comman ancestor
  """
  LCA = 'physical entity'
  # while (1) still have ancestor
  while min([len(ele) for ele in top5_ancestor]) > 0 :
    # find the highest for each class in topk
    current_roots = [topk_ancestor.pop() for topk_ancestor in top5_ancestor]
    current_roots_freq = Counter(current_roots)
    current_roots_freq = sorted(current_roots_freq.items(), key=lambda x: x[1], reverse=True) # become a list e.g.,[('ee', 2), ('ww', 1), ('cc', 1)]
    majority, majority_freq = current_roots_freq[0]
    if majority_freq == 5:
      LCA = majority if comma_blank not in majority else majority.split(comma_blank)[0]
  return LCA

## Uncertainty Estimation

In [25]:
# with children
# hierarchy ImageNet class work on same h for each class
max_depth_ancestor = 0 # here only consider descendent
max_depth_children = 10 # could change
imagenet_bottom_up_first_index_list = []
imagenet_bottom_up_all_classes = []
imagenet_bottom_up_mapping_index_list = [] # add list for analysis
i = 0
for imagenet_idx, classname in enumerate(clip.IMAGENET_CLASSES):
  imagenet_bottom_up_first_index_list.append(i) # index of each imagenet class
  imagenet_bottom_up_all_classes.append(classname) # add original class name
  class_an_des_list = [] # relevant id for specific class
  i += 1 # add index first
  class_an_des_list.append(i)
  node_ancestors = an_des_dict_json[str(imagenet_idx)]["ancestors"]
  node_children = an_des_dict_json[str(imagenet_idx)]["children"]

  for an_idx, ancestor in enumerate(node_ancestors): # select ancestors
    if an_idx < max_depth_ancestor:
      if ", " in ancestor: # contains synonym, more than one, keep only one
        imagenet_bottom_up_all_classes.append(ancestor.split(", ")[0])
        i += 1
        class_an_des_list.append(i)
      else:
        imagenet_bottom_up_all_classes.append(ancestor)
        i += 1
        class_an_des_list.append(i)
  for des_idx, descentant in enumerate(node_children): # select descendants
    if des_idx < max_depth_children:
      if ", " in descentant: # contains synonym, more than one, keep only one
        imagenet_bottom_up_all_classes.append(descentant.split(", ")[0])
        i += 1
        class_an_des_list.append(i)
      else:
        imagenet_bottom_up_all_classes.append(descentant)
        i += 1
        class_an_des_list.append(i)
  imagenet_bottom_up_mapping_index_list.append(class_an_des_list)

len(imagenet_bottom_up_mapping_index_list) = 1000, each element is the source node and its children

In [26]:
imagenet_classes_an_des, class_an_des_mapping, class_an_des_mapping_list = bottom_up_hierarchy(max_depth_descendant=0)
# len(imagenet_classes_an_des) = 1971. word
# len(class_an_des_mapping) = 1000. starting index of each class among 1971
# len(class_an_des_mapping_list) = 1000.

In [27]:
def compute_accuracy_details(logits, labels):
  top_probs, top_labels = jax.lax.top_k(logits, 5)
  top1 = 100 * jnp.mean(top_labels[:, 0] == labels)
  top5 = 100 * jnp.sum(top_labels == labels[:, None]) / labels.shape[0]
  return top1, top5, top_probs, top_labels

In [28]:
top1_prompt, top5_prompt, top5_prompt_probs, top5_prompt_labels = compute_accuracy_details(logits_prompteng, labels)
top1_name, top5_name, top5_name_probs, top5_name_labels = compute_accuracy_details(logits_name, labels)
print(f'Prompt Engineering: top1={top1_prompt:.2f}%, top5={top5_prompt:.2f}%')
print(f'Class Names: top1={top1_name:.2f}%, top5={top5_name:.2f}%')

Prompt Engineering: top1=38.55%, top5=67.42%
Class Names: top1=32.03%, top5=59.90%


In [29]:
# continuous uncertatinty value calculation
def continuous_uncertainty_estimation(prompts=clip.PROMPTS, name_prediction=top5_name_labels[:, 0]):
  """obtain continuous confidence score for each image"""

  # prompts decision
  prompts_slice_log = []
  for sample_id in range(len(prompts)): # for each different prompts
    prompteng_slice = zeroshot_classifier(clip.IMAGENET_CLASSES, [prompts[sample_id]]) #select one each time
    logits_slice = np.matmul(embeddings, prompteng_slice) # compute the logits
    probs_slice = np.array(jax.nn.softmax(logits_slice, axis=-1))
    preds_slice = jnp.argmax(probs_slice, axis=-1)
    prompts_slice_log.append(preds_slice)

  prompts_slice_log_array = np.array(prompts_slice_log) # (len(prompts), 50000)

  # compare with no-prompts and use consistency to get the confidence score
  consistency_with_nonpromp = prompts_slice_log_array == name_prediction
  consistency_with_nonpromp_log = sum(consistency_with_nonpromp, axis=0)
  confidence_sort_index = numpy.argsort(consistency_with_nonpromp_log) # form small to large

  return prompts_slice_log_array, consistency_with_nonpromp_log, confidence_sort_index

In [30]:
name_prediction=top5_name_labels[:, 0]
prompt_prediction = top5_prompt_labels[:, 0]

In [31]:
prompts_slice_log_array, consistency_with_nonpromp_log, confidence_sort_index = continuous_uncertainty_estimation()

# Calculate accuracy of the low confidence set
low_confident_set_size = 10000
low_confident_set_index = confidence_sort_index[:low_confident_set_size]
# Computed accuracy for the low confidence set
acc = np.mean(name_prediction[np.array(low_confident_set_index)] == labels[low_confident_set_index])
acc

100%|██████████| 100/100 [00:03<00:00, 25.36it/s]
100%|██████████| 100/100 [00:03<00:00, 29.74it/s]
100%|██████████| 100/100 [00:03<00:00, 29.62it/s]
100%|██████████| 100/100 [00:03<00:00, 30.38it/s]
100%|██████████| 100/100 [00:03<00:00, 31.18it/s]
100%|██████████| 100/100 [00:03<00:00, 30.13it/s]
100%|██████████| 100/100 [00:03<00:00, 30.46it/s]
100%|██████████| 100/100 [00:03<00:00, 31.58it/s]
100%|██████████| 100/100 [00:03<00:00, 28.70it/s]
100%|██████████| 100/100 [00:03<00:00, 29.18it/s]
100%|██████████| 100/100 [00:03<00:00, 27.72it/s]
100%|██████████| 100/100 [00:03<00:00, 29.78it/s]
100%|██████████| 100/100 [00:03<00:00, 27.47it/s]
100%|██████████| 100/100 [00:03<00:00, 30.67it/s]
100%|██████████| 100/100 [00:03<00:00, 30.07it/s]
100%|██████████| 100/100 [00:03<00:00, 27.01it/s]
100%|██████████| 100/100 [00:03<00:00, 25.46it/s]
100%|██████████| 100/100 [00:03<00:00, 28.76it/s]
100%|██████████| 100/100 [00:03<00:00, 29.44it/s]
100%|██████████| 100/100 [00:03<00:00, 30.24it/s]


Array(0.0473, dtype=float32)

## Hierarchy-CLIP

### Continuous Solution

In [32]:
an_des_dict = an_des_dict_json

In [33]:
def zeroshot_classifier_hierarchy(classnames, templates, permute=False):
  zeroshot_weights = []
  permute_fn = permute_words if permute else lambda x: x
  for classname in classnames:
    texts = [permute_fn(template.format(classname.split('|')[0]) + template.format(classname.split('|')[1])) for template in templates]
    class_embeddings = encode_text(tokenize_fn(texts))
    class_embedding = class_embeddings.mean(0)
    class_embedding /= jnp.linalg.norm(class_embedding)
    zeroshot_weights.append(class_embedding)
  return jnp.stack(zeroshot_weights, axis=1)

In [34]:
def compute_accuracy_rerank(logits, ori_labels, labels):
  """based on logits, reorder the ori_top5_labels and compare with true labels"""

  top_probs, top_labels_index = jax.lax.top_k(logits, 1)
  top_labels = np.array([ele[top_labels_index[idx]] for idx, ele in enumerate(tqdm(ori_labels))])
  top1 = 100 * jnp.mean(top_labels[:, 0] == labels)
  return top1, top_probs, top_labels

In [35]:
# Bottom-up and top-down augmented inference

from collections import Counter

an_des_dict = an_des_dict_json # simplified wordnet
fill_empty_parent = ''
unstable_num = 1000
# stable set and unstable set
unstable_index = confidence_sort_index[:unstable_num] # number of rejected/unstable images
stable_index = np.array(list(set(range(len(labels))) - set(unstable_index.tolist()))) # stable, np.array

# 1. accuracy of the stable samples:
# Computed accuracy
acc_stable = np.mean(name_prediction[stable_index] == labels[stable_index])

acc_unstable_ori = np.mean(name_prediction[unstable_index] == labels[unstable_index])


# 2. accuracy of the unstable samples: use our hierarchy
# analysis top 5 in all different cases of baseline
baseline_top5 = top5_name_labels[unstable_index]
baseline_top5_probs = top5_name_probs[unstable_index]
diff_image_embeddings = embeddings[unstable_index]
unstable_id_gt_labels = labels[unstable_index]

unstable_case = top5_name_labels[unstable_index]   # original top-5 prediction
unstable_case_label = labels[unstable_index] # top-5 GT labels

# [Top-5 Ancestor collection] calculate Ancestor from simple wordnet for unstable cases
A_all = []
for example in unstable_case: # for each unstable image
  # ancestor of top 5
  top5_ancestor = []
  # top5_ancestor_raw = []
  for ind, top_k in enumerate(example):
    if an_des_dict[str(example[ind])]['ancestors']==[]: # no ancestor
      top5_ancestor.append(fill_empty_parent)
    else:
      top5_ancestor.append(an_des_dict[str(example[ind])]['ancestors'].copy()[0]) # avoid influence the root
  A_all.append(top5_ancestor)


# add ancestor to all candidate classes
logits_ori_all = []
for idx, unstable_case_ele in tqdm(enumerate(unstable_case)): # for each unstable image
  # original word embedding
  word_list_ori = np.array(clip.IMAGENET_CLASSES)[np.array(unstable_case_ele)].tolist()
  word_list_hierarchy = [] # raw words with only [bottom up]
  word_list_hierarchy_with_prompt = [] # words add prompt and ancestor [bottom up + top down]
  wordid_list_hierarchy_list = [] # details of index [[1,2], [3,4,5], [6], [7,8], [9]]
  wordid_list_hierarchy_mapping_list = [] # used for reduceat (index of first element for each top-5) [1,3,6,7,9]
  i = 0 # tracking each class/child class
  for word_id, word in enumerate(word_list_ori): # for each class in top 5
    tempid_list = [] # for specific word in top5
    wordid_list_hierarchy_mapping_list.append(i)
    for child_id in class_an_des_mapping_list[unstable_case_ele[word_id]]: # for each children or element [Bottom-up]
      raw_word = imagenet_classes_an_des[child_id-1]
      word_list_hierarchy.append(raw_word) # child_id start from 1 not 0
      word_list_hierarchy_with_prompt.append(raw_word + ' which is a kind of|' + A_all[idx][word_id]) # [Top-down]
      # word_list_hierarchy_with_prompt.append(raw_word + ' which is a kind of' + A_all[idx][word_id]) # [Top-down]
      tempid_list.append(i)
      i+=1
    wordid_list_hierarchy_list.append(tempid_list)
  word_list = word_list_hierarchy_with_prompt # could change

  # inference
  word_weights_ori = zeroshot_classifier_hierarchy(word_list, ['{}']) # (512, 5) # do not use prompt ensemble
  # word_weights_ori = zeroshot_classifier_hierarchy(word_list, clip.PROMPTS) # (512, 5)
  img_embedding = diff_image_embeddings[idx] # (1, 512) single image
  logits_ori_raw = np.matmul(img_embedding, word_weights_ori)
  logits_ori_raw = logits_ori_raw[np.newaxis, :]
  logits_ori = np.maximum.reduceat(logits_ori_raw, indices=wordid_list_hierarchy_mapping_list, axis=1)
  logits_ori_all.append(logits_ori)
logits_ori_all_array = np.array(logits_ori_all)[:, 0, :]

# Manually computed accuracy
unstable_top1_acc, unstable_top_probs, unstable_top_labels = compute_accuracy_rerank(logits_ori_all_array, unstable_case, unstable_case_label)


# compute overall accuracy
overall_acc = (acc_stable*100*len(stable_index) + unstable_top1_acc*len(unstable_index))/len(labels)

print("acc_stable:", acc_stable)
print("acc_unstable_ori:", acc_unstable_ori)
print("acc_unstable_with_Hierarchy-CLIP:", unstable_top1_acc)
print("overall_acc_with_Hierarchy-CLIP:", overall_acc)

1000it [02:53,  5.78it/s]
100%|██████████| 1000/1000 [00:00<00:00, 1538.04it/s]

acc_stable: 0.32616326
acc_unstable_ori: 0.034
acc_unstable_with_Hierarchy-CLIP: 17.0
overall_acc_with_Hierarchy-CLIP: 32.304



