In [2]:
import open_clip
import clip
from torchvision.datasets import CIFAR10, CIFAR100, ImageNet
from open_clip import tokenizer
import torch
from PIL import Image
import matplotlib.pyplot as plt
import numpy as np
import os
from sentence_transformers import SentenceTransformer
import requests
import json
import math
import random
from tqdm import tqdm
import nltk
nltk.download("punkt")
from nltk.tokenize import WordPunctTokenizer
os.environ["CUDA_VISIBLE_DEVICES"] = "0"

In [3]:
LABEL_FILES = {"cifar10":"data/cifar10_classes.txt",
               "cifar100": "data/cifar100_classes.txt",
               "cub": "data/cub_classes.txt",
               "imagenet": "data/imagenet_classes.txt"}
LIMIT = 200
RELATIONS = ["HasA", "IsA", "PartOf", "HasProperty", "MadeOf", "AtLocation"]

CLASS_SIM_CUTOFF = 0.85
OTHER_SIM_CUTOFF = 0.9
MAX_LEN = 30
PRINT_PROB = 0.2

dataset = ["cifar10", "cifar100", "cub", "imagenet"]
save_name = ["data/conceptnet_{}_filtered_new.txt".format(data) for data in dataset]

In [8]:
def get_init_conceptnet(classes, limit=200, relations=["HasA", "IsA", "PartOf", "HasProperty", "MadeOf", "AtLocation"]):
    concepts = set()

    for cls in tqdm(classes):
        words = cls.replace(',', '').split(' ')
        for word in words:
            obj = requests.get('http://api.conceptnet.io/c/en/{}?limit={}'.format(word, limit)).json()
            obj.keys()
            for dicti in obj['edges']:
                rel = dicti['rel']['label']
                try:
                    if dicti['start']['language'] != 'en' or dicti['end']['language'] != 'en':
                        continue
                except(KeyError):
                    continue

                if rel in relations:
                    if rel in ["IsA"]:
                        concepts.add(dicti['end']['label'])
                    else:
                        concepts.add(dicti['start']['label'])
                        concepts.add(dicti['end']['label'])
    return concepts

In [9]:
def _clip_dot_prods(list1, list2, device="cuda", clip_name="ViT-B/16", batch_size=500):
    "Returns: numpy array with dot products"
    clip_model, _ = clip.load(clip_name, device=device)
    text1 = clip.tokenize(list1).to(device)
    text2 = clip.tokenize(list2).to(device)

    features1 = []
    with torch.no_grad():
        for i in range(math.ceil(len(text1)/batch_size)):
            features1.append(clip_model.encode_text(text1[batch_size*i:batch_size*(i+1)]))
        features1 = torch.cat(features1, dim=0)
        features1 /= features1.norm(dim=1, keepdim=True)

    features2 = []
    with torch.no_grad():
        for i in range(math.ceil(len(text2)/batch_size)):
            features2.append(clip_model.encode_text(text2[batch_size*i:batch_size*(i+1)]))
        features2 = torch.cat(features2, dim=0)
        features2 /= features2.norm(dim=1, keepdim=True)

    dot_prods = features1 @ features2.T
    return dot_prods.cpu().numpy()

In [10]:
def filter_too_similar_to_cls(concepts, classes, sim_cutoff, device="cuda", print_prob=0):
    #first check simple text matches
    print(len(concepts))
    concepts = list(concepts)
    concepts = sorted(concepts)

    for cls in classes:
        for prefix in ["", "a ", "A ", "an ", "An ", "the ", "The "]:
            try:
                concepts.remove(prefix+cls)
                if random.random()<print_prob:
                    print("Class:{} - Deleting {}".format(cls, prefix+cls))
            except(ValueError):
                pass
        try:
            concepts.remove(cls.upper())
        except(ValueError):
            pass
        try:
            concepts.remove(cls[0].upper()+cls[1:])
        except(ValueError):
            pass
    print(len(concepts))

    mpnet_model = SentenceTransformer('all-mpnet-base-v2')
    class_features_m = mpnet_model.encode(classes)
    concept_features_m = mpnet_model.encode(concepts)
    dot_prods_m = class_features_m @ concept_features_m.T
    dot_prods_c = _clip_dot_prods(classes, concepts)
    #weighted since mpnet has highger variance
    dot_prods = (dot_prods_m + 3*dot_prods_c)/4

    to_delete = []
    for i in range(len(classes)):
        for j in range(len(concepts)):
            prod = dot_prods[i,j]
            if prod >= sim_cutoff and i!=j:
                if j not in to_delete:
                    to_delete.append(j)
                    if random.random()<print_prob:
                        print("Class:{} - Concept:{}, sim:{:.3f} - Deleting {}".format(classes[i], concepts[j], dot_prods[i,j], concepts[j]))
                        print("".format(concepts[j]))

    to_delete = sorted(to_delete)[::-1]

    for item in to_delete:
        concepts.pop(item)
    print(len(concepts))
    return concepts

In [11]:
def filter_too_similar(concepts, sim_cutoff, device="cuda", print_prob=0):

    mpnet_model = SentenceTransformer('all-mpnet-base-v2')
    concept_features = mpnet_model.encode(concepts)

    dot_prods_m = concept_features @ concept_features.T
    dot_prods_c = _clip_dot_prods(concepts, concepts)

    dot_prods = (dot_prods_m + 3*dot_prods_c)/4

    to_delete = []
    for i in range(len(concepts)):
        for j in range(len(concepts)):
            prod = dot_prods[i,j]
            if prod >= sim_cutoff and i!=j:
                if i not in to_delete and j not in to_delete:
                    to_print = random.random() < print_prob
                    #Deletes the concept with lower average similarity to other concepts - idea is to keep more general concepts
                    if np.sum(dot_prods[i]) < np.sum(dot_prods[j]):
                        to_delete.append(i)
                        if to_print:
                            print("{} - {} , sim:{:.4f} - Deleting {}".format(concepts[i], concepts[j], dot_prods[i,j], concepts[i]))
                    else:
                        to_delete.append(j)
                        if to_print:
                            print("{} - {} , sim:{:.4f} - Deleting {}".format(concepts[i], concepts[j], dot_prods[i,j], concepts[j]))

    to_delete = sorted(to_delete)[::-1]
    for item in to_delete:
        concepts.pop(item)
    print(len(concepts))
    return concepts

In [12]:
def remove_too_long(concepts, max_len, print_prob=0):
    new_concepts = []
    for concept in concepts:
        if len(concept) <= max_len:
            new_concepts.append(concept)
        else:
            if random.random()<print_prob:
                print(len(concept), concept)
    print(len(concepts), len(new_concepts))
    return new_concepts

In [13]:
cls_cifar10 = "data/cifar10_classes.txt"
cls_cifar100 = "data/cifar100_classes.txt"
cls_imagenet = "data/imagenet_classes.txt"
cls_cub = "data/cub_classes.txt"

with open(cls_cifar10, 'r') as f:
    classes_cifar10 = f.read().split('\n')

with open(cls_cifar100, 'r') as f:
    classes_cifar100 = f.read().split('\n')

with open(cls_imagenet, 'r') as f:
    classes_imagenet = f.read().split('\n')

with open(cls_cub, 'r') as f:
    classes_cub = f.read().split('\n')

classes = classes_cifar10 + classes_cifar100 + classes_imagenet + classes_cub

unique = []
for item in classes:
  if item not in unique:
    unique.append(item)

classes = unique

classes = [item.lower() for item in classes]

save_classes = "data/all_classes.txt"
with open(save_classes, "w") as f:
  for item in classes:
    f.write(item + "\n")

In [15]:
concepts = get_init_conceptnet(classes, LIMIT, RELATIONS)
concepts = remove_too_long(concepts, MAX_LEN, PRINT_PROB)
concepts = filter_too_similar_to_cls(concepts, classes, CLASS_SIM_CUTOFF, print_prob=PRINT_PROB)
concepts = filter_too_similar(concepts, OTHER_SIM_CUTOFF, print_prob=PRINT_PROB)

100%|██████████| 1000/1000 [15:55<00:00,  1.05it/s]


38 One word that is frequently misspelled
33 the accent on the second syllable
34 a wilderness area away from people
33 a fast pacedand hard hitting game
32 dirt tracked in from the outside
38 naturally occurring tangible substance
31 a Discovery Channel documentary
31 single reed woodwind instrument
31 less cramped than an automobile
31 cooked to make them more edible
6780 6712
6712
Class:goldfish - Deleting goldfish
Class:bald eagle - Deleting a bald eagle
Class:harvestman - Deleting harvestman
Class:scorpion - Deleting scorpion
Class:centipede - Deleting centipede
Class:tusker - Deleting tusker
Class:koala - Deleting a koala
Class:flatworm - Deleting flatworm
Class:spoonbill - Deleting spoonbill
Class:albatross - Deleting albatross
Class:lion - Deleting lion
Class:meerkat - Deleting meerkat
Class:hamster - Deleting hamster
Class:hamster - Deleting a hamster
Class:beaver - Deleting beaver
Class:beaver - Deleting a beaver
Class:gorilla - Deleting gorilla
Class:baboon - Deleting baboon

Downloading (…)a8e1d/.gitattributes:   0%|          | 0.00/1.18k [00:00<?, ?B/s]

Downloading (…)_Pooling/config.json:   0%|          | 0.00/190 [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/README.md:   0%|          | 0.00/10.6k [00:00<?, ?B/s]

Downloading (…)0bca8e1d/config.json:   0%|          | 0.00/571 [00:00<?, ?B/s]

Downloading (…)ce_transformers.json:   0%|          | 0.00/116 [00:00<?, ?B/s]

Downloading (…)e1d/data_config.json:   0%|          | 0.00/39.3k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/438M [00:00<?, ?B/s]

Downloading (…)nce_bert_config.json:   0%|          | 0.00/53.0 [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/239 [00:00<?, ?B/s]

Downloading (…)a8e1d/tokenizer.json:   0%|          | 0.00/466k [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/363 [00:00<?, ?B/s]

Downloading (…)8e1d/train_script.py:   0%|          | 0.00/13.1k [00:00<?, ?B/s]

Downloading (…)b20bca8e1d/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)bca8e1d/modules.json:   0%|          | 0.00/349 [00:00<?, ?B/s]

100%|████████████████████████████████████████| 335M/335M [00:02<00:00, 161MiB/s]


Class:hammerhead shark - Concept:hammerhead, sim:0.887 - Deleting hammerhead

Class:kite (bird of prey) - Concept:kite, sim:0.880 - Deleting kite

Class:bald eagle - Concept:eagle, sim:0.880 - Deleting eagle

Class:tree frog - Concept:A frog, sim:0.881 - Deleting A frog

Class:tree frog - Concept:a frog, sim:0.881 - Deleting a frog

Class:boa constrictor - Concept:constrictor, sim:0.880 - Deleting constrictor

Class:Indian cobra - Concept:cobra, sim:0.875 - Deleting cobra

Class:prairie grouse - Concept:grouse, sim:0.863 - Deleting grouse

Class:duck - Concept:sea duck, sim:0.888 - Deleting sea duck

Class:sea anemone - Concept:anemone, sim:0.869 - Deleting anemone

Class:great egret - Concept:egret, sim:0.861 - Deleting egret

Class:crane bird - Concept:a bird, sim:0.855 - Deleting a bird

Class:ruddy turnstone - Concept:turnstone, sim:0.890 - Deleting turnstone

Class:red fox - Concept:The fox, sim:0.862 - Deleting The fox

Class:cricket insect - Concept:An insect, sim:0.871 - Deleti