In [1]:
# Import required modules and methods

import os
import hashlib

import pandas as pd

from sentence_transformers import SentenceTransformer, util

from transformers import ViTImageProcessor, ViTForImageClassification
from PIL import Image

In [2]:
keyword = "your_keyword"

## Finding Related Classes from ImageNet Classes

In [3]:
# Build an object from pre-trained model
model = SentenceTransformer('sentence-transformers/all-mpnet-base-v2')

def semantic_similarity(source_word, words_to_compare):

    # Compute embedding for both lists
    embeddings1 = model.encode(source_word, convert_to_tensor=True)
    embeddings2 = model.encode(words_to_compare, convert_to_tensor=True)

    # Compute cosine-similarities
    cosine_scores = util.cos_sim(embeddings1, embeddings2)

    # Convert type of the cosine_scores from tensor to a simple list
    output = list(cosine_scores.numpy()[0])

    # Return results
    return output

In [4]:
data = pd.read_csv("ImageNet ClassNames.tsv", sep="\t", index_col=0).reset_index(drop=True)
data

Unnamed: 0,Class Name
0,"tench, Tinca tinca"
1,"goldfish, Carassius auratus"
2,"great white shark, white shark, man-eater, man..."
3,"tiger shark, Galeocerdo cuvieri"
4,"hammerhead, hammerhead shark"
...,...
995,earthstar
996,"hen-of-the-woods, hen of the woods, Polyporus ..."
997,bolete
998,"ear, spike, capitulum"


In [5]:
data["Similarity"] = semantic_similarity(f"{keyword}", data["Class Name"])
related_classes = data[data["Similarity"] >= 0.5]["Class Name"].values

## Data Cleaning

In [6]:
processor = ViTImageProcessor.from_pretrained('google/vit-base-patch16-224')
model = ViTForImageClassification.from_pretrained('google/vit-base-patch16-224')

# Assign directory
directory = keyword

### Removing Irrelevant Images

In [7]:
for filename in os.listdir(directory):

    # Path of the image
    path = directory + '/' + filename
    # Open the image
    image = Image.open(path).convert("RGB")

    inputs = processor(images=image, return_tensors="pt")
    outputs = model(**inputs)
    logits = outputs.logits
    # Predict class
    predicted_class_idx = logits.argmax(-1).item()
    # Remove irrelevant images
    if model.config.id2label[predicted_class_idx] not in related_classes:
        os.remove(path)

### Removing Duplicate Images

In [8]:
hashes = set()

for filename in os.listdir(directory):
    # Path of the image
    path = directory + '/' + filename
    # Remove duplicate images
    digest = hashlib.sha1(open(path,'rb').read()).digest()
    if digest not in hashes:
        hashes.add(digest)
    else:
        os.remove(path)