# Matching to ImageNet

Given that MetaShift is a flexible framework to generate a large number of real-world distribution shifts that are well annotated and controlled, we can use it to construct a new dataset of specific classes and subpopulations. MetaShift can be readily used to evaluate any ImageNet pre-trained vision model, as we have matched MetaShift with ImageNet hierarchy.

## Generate Wordnet ID for MetaShift and ImageNet

We use [wordnet](https://www.nltk.org/howto/wordnet.html) to do the matching.

In [5]:
! pip install nltk
import nltk
nltk.download('wordnet')
nltk.download('omw-1.4')

Examples of using wordnet

In [54]:
from nltk.corpus import wordnet as wn
print(wn.synsets('dog'))
print(wn.synsets('dog', pos=wn.VERB))

[Synset('dog.n.01'), Synset('frump.n.01'), Synset('dog.n.03'), Synset('cad.n.01'), Synset('frank.n.02'), Synset('pawl.n.01'), Synset('andiron.n.01'), Synset('chase.v.01')]
[Synset('chase.v.01')]


### Generate wordnet id for ImageNet

We use [imagenet](https://observablehq.com/@mbostock/imagenet-hierarchy) for ImageNet-1k class hierarchy information. 

For each meta-data tag of the classes and the subsets of the context as well as the attributes, we search in the ImageNet-1k hierarchy to find if it has the label with the same wordnet id. The meta-data tag in MetaShift may represent a greater domain than the leaf nodes of the ImageNet hierarchy, for example, MetaShift has only one general "cat" class, while the ImageNet "domestic cat" and "wildcat" under the "cat" hierarchy, and each kind of cat also has several different breeds. In the matching procedure, all breeds under "cat" hierarchy will be matched to "cat" class in MetaShift.

In [5]:
import json
imagenet_file = "imagenet1k_node_names.json"

In [22]:
def collect_children(dt, ls=[]):
    for k in dt:
        if k == 'children':
            if isinstance(dt[k], list):
                for i in dt[k]:
                    collect_children(i, ls)
            else:
                collect_children(dt[k], ls)
        elif k == 'id':
            ls.append(dt[k])
    return ls

def generate_imagenet_wn(filename):
    data = json.load(open(filename))
    imagenet_id = collect_children(data)
    imagenet_id.pop(0) # remove the root node: ImageNet 2011 Fall Release
    imagenet_wn = []
    for i in imagenet_id:
        wn_id = wn.synset_from_pos_and_offset('n', int(i[1:]))
        imagenet_wn.append(wn_id)
    return imagenet_wn

In [23]:
imagenet_wn = generate_imagenet_wn(imagenet_file)
print('imagenet-1k node num:', len(imagenet_wn))

imagenet-1k node num: 2153


### Generate wordnet id for MetaShift

In [42]:
def generate_metashift_wn(filename):
    metashift = json.load(open(filename))
    metashift_wn = {}
    for k, v in metashift.items():
        k_ = k.replace(' ', '_') # wordnet does not support space
        wn_id = wn.synsets(k_, pos=wn.NOUN)
        metashift_wn[k] = wn_id
        for i in v:
            i_ = i.replace(' ', '_')
            wn_id = wn.synsets(i_, pos=wn.NOUN)
            metashift_wn[i] = wn_id
    return metashift_wn

In [43]:
metashift_file = "../meta_data/class_hierarchy.json"

metashift_wn = generate_metashift_wn(metashift_file)
print('metashift len:', len(metashift_wn))

metashift len: 1262


### Matching

Match wordnet of MetaShift and ImageNet to see how many labels in metashift can be matched to labels in ImageNet-1k

In [10]:
def match(imagenet_wn, metashift_wn):
    match_list = {}
    for k, v in metashift_wn.items():
        if type(v) == list:
            for j in v:
                if j in imagenet_wn:
                    match_list[k] = j
                    break
        else:
            if v in imagenet_wn:
                match_list[k] = v
    return match_list

In [11]:
match_list = match(imagenet_wn, metashift_wn)
print('match len:', len(match_list))

match len: 427


## Generate Selected Class from matching labels

After matching the labels, we can select the classes that we want to use in our dataset, which are the labels in ImageNet-1k.

In [55]:
import pickle
full_subsets = '../meta_data/full-candidate-subsets.pkl'

f = open(full_subsets, 'rb')
info = pickle.load(f)
keysinfo = []
selected = {}

classes_full = pickle.load(open(full_subsets, 'rb'))

for i in classes_full.keys():
    idx1 = i.find('(')
    idx2 = i.find(')')
    cls1 = i[:idx1]
    cls2 = i[idx1+1:idx2]
    if not (cls1 in match_list and cls2 in match_list):
        continue
    selected[i] = classes_full[i]
print('selected subsets len:', len(selected))

selected_class = []
for i in selected.keys():
    idx1 = i.find('(')
    cls = i[:idx1]
    selected_class.append(cls)
selected_class = list(set(selected_class))
print('selected classes len:', len(selected_class))

pkl = open('selected-candidate-subsets.pkl', 'wb')
pickle.dump(selected, pkl)
pkl.close()

selected subsets len: 5040
selected classes len: 261


## Check Coverage

We check the coverage of the selected classes in ImageNet-1k: for each meta-data of the matched version of MetaShift, we locate the tags in the ImageNet hierarchy. If it is a non-leaf node, then mark all of its leaf nodes, otherwise mark the leaf node itself.

In [59]:
leaf_nodes = []

# count the number of children under a node
def count_children(dt, cnt=0):
    if isinstance(dt, list):
        for i in dt:
            cnt = count_children(i, cnt)
    elif isinstance(dt, dict):
        if 'children' not in dt:
            if dt['id'] in leaf_nodes:
                pass
            if dt['id'] not in leaf_nodes:
                leaf_nodes.append(dt['id'])
                cnt += 1
        else:
            cnt = count_children(dt['children'], cnt)
    return cnt

def match_wn_id(synset, id):
    wn_id = wn.synset_from_pos_and_offset('n', int(id[1:]))
    if type(synset) == list:
        for i in synset:
            if i == wn_id:
                return True
    else:
        if synset == wn_id:
            return True
    return False

# find the matched node
def find_node(name, dt, cnt = 0):
    if isinstance(dt, list):
        for i in dt:
            cnt = find_node(name, i, cnt)
    elif isinstance(dt, dict):
        if match_wn_id(name, dt['id']):
            cnt =  count_children(dt, cnt)
        elif 'children' in dt:
            cnt = find_node(name, dt['children'], cnt)
    return cnt


In [60]:
imagenet_dict = json.load(open(imagenet_file))
imagenet_dict = imagenet_dict['children']
metashift_selected_wn = {}

# generate the wordnet id of selected classes
for k in selected_class:
    k_ = k.replace(' ', '_')
    wn_id = wn.synsets(k_, pos=wn.NOUN)
    metashift_selected_wn[k] = wn_id

# count the number of matched nodes
cnt = 0
for k, v in metashift_selected_wn.items():
    nodes = find_node(v, imagenet_dict)
    cnt += nodes
print('matched nodes num:', cnt)

matched nodes num: 867
