# 1. INIT - Import packages

In [1]:
import torch
import os
import sys
from pathlib import Path

file_dir = Path().absolute()
workspace_dir = os.path.dirname(file_dir)
sys.path.append(workspace_dir)

import nltk
from nltk.corpus import wordnet as wn
from collections import Counter
from classes import IMAGENET2012_CLASSES

os.environ['CUDA_VISIBLE_DEVICES'] = '0'
print('Pytorch version :', torch.__version__)
print('CUDA version\t:', torch.version.cuda)
print('GPU\t\t:',torch.cuda.get_device_name())

Pytorch version : 2.0.1
CUDA version	: 11.7
GPU		: NVIDIA A40


# 2. INIT - Downloading the wordnet corpora

In [2]:
nltk.download('wordnet', download_dir='../Deps/')

[nltk_data] Downloading package wordnet to ../Deps/...


True

# 3. EXECUTIONS - wordnet analysis for hypercategories

In [3]:
class_labels = []
for class_code in IMAGENET2012_CLASSES.keys():
    class_labels.append(class_code)

# imagenet1k_hypercategory_v2
top_categories = {
    'mammal': 'mammal.n.01',
    'others_animal': 'animal.n.01',
    'instrumentality': 'instrumentality.n.03',
    'others_artifact': 'artifact.n.01',
}

def get_ordered_hypernyms(synset):
    hypernyms = []
    current = synset
    while current.hypernyms():
        current = current.hypernyms()[0]
        hypernyms.append(current)
    return hypernyms

class_categories = {}
all_synset_names = []

for synset_id in class_labels:
    synset = wn.synset_from_pos_and_offset('n', int(synset_id[1:]))
    ordered_hypernyms = get_ordered_hypernyms(synset)
    ordered_hypernym_names = [hypernym.name() for hypernym in ordered_hypernyms]
    # print(f"Synset: {synset.name()}, hypernyms: {ordered_hypernym_names}")
    
    assigned_category = 'others_entity'
    for category, top_synset_name in top_categories.items():
        if assigned_category != 'others_entity': break
        for synset_name in ordered_hypernym_names:
            if synset_name == top_synset_name:
                assigned_category = category
                break
    
    class_categories[synset_id] = assigned_category
    all_synset_names += ordered_hypernym_names

value_counts = Counter(class_categories.values())
sorted_value_counts = sorted(value_counts.items())
for value, count in sorted_value_counts:
    print(f"{value}: {count}")

torch.save(class_categories, "../Results/hypercategory/imagenet1k_hypercategory_v2.pt")

instrumentality: 350
mammal: 218
others_animal: 180
others_artifact: 172
others_entity: 80
