In [16]:
import torch
import json
from imagenetv2_pytorch import ImageNetV2Dataset
from data.ImageNetV2.superclassing_dataset import SuperclassingImageNetDataset
from torch.utils.data import DataLoader
import matplotlib.pyplot as plt
from collections import Counter

IMAGENET_CLASS_INDEX_PATH = "./data/ImageNetV2/imagenet_class_index.json"

SUPERCLASS_DATASET_PATH="./data/ImageNetV2/raw/"

mode = {
    "tick": 17,
    "fig": [5, 4],
    "legend": 13,
    "label": 14
}

marker_list = [
    ".",
    "o",
    "v",
    "8",
    "s",
    "p",
    "P",
    "*",
    "x",
    4,
    5,
    6,
    7,
    8,
    None,
    None
]

## load ImageNet V2

In [19]:
with open(IMAGENET_CLASS_INDEX_PATH, 'r') as f:
    imagenet_class_index = json.load(f)

# Map label indices to class descriptions, replacing spaces with underscores
idx_to_class_desc = {}
for key, value in imagenet_class_index.items():
    label = int(key)
    class_desc = value[1]
    if " " in class_desc:
        class_desc = class_desc.replace(" ", "_")
    idx_to_class_desc[label] = class_desc

superclass_names = [
    "Bird",
    "Boat",
    "Car",
    "Cat",
    "Dog",
    "Fruit",
    "Fungus",
    "Insect",
    "Monkey"
]

superclass_to_idx = {name: idx for idx, name in enumerate(superclass_names)}

label_to_superclass = {}

for label, description in idx_to_class_desc.items():
    description_lower = description.lower()
    if "bird" in description_lower or any(
        bird in description_lower
        for bird in [
            "cock",
            "hen",
            "duck",
            "goose",
            "owl",
            "parrot",
            "swan",
            "flamingo",
            "penguin",
            "eagle",
            "vulture",
            "peacock",
            "crane",
            "heron",
            "kingfisher",
            "woodpecker",
            "hummingbird",
            "sparrow",
            "finch",
            "swallow",
            "gull",
            "kite",
            "robin",
            "magpie",
            "chickadee",
            "jay",
            "turkey",
            "pigeon",
            "ostrich",
            "quail",
            "ptarmigan",
        ]
    ):
        assigned_superclass = "Bird"
    elif "boat" in description_lower or any(
        boat in description_lower
        for boat in [
            "boat",
            "ship",
            "canoe",
            "gondola",
            "yawl",
            "liner",
            "schooner",
            "trimaran",
            "barge",
            "lifeboat",
            "submarine",
            "raft",
            "kayak",
            "pirate",
            "aircraft_carrier",
            "speedboat",
            "bobsled",
            "catamaran",
            "sailboat",
            "dinghy",
            "paddlewheel",
            "dock",
        ]
    ):
        assigned_superclass = "Boat"
    elif "car" in description_lower or any(
        car in description_lower
        for car in [
            "car",
            "vehicle",
            "taxi",
            "cab",
            "limousine",
            "jeep",
            "minivan",
            "convertible",
            "wagon",
            "bus",
            "truck",
            "ambulance",
            "pickup",
            "trailer",
            "van",
            "moped",
            "motor_scooter",
            "snowmobile",
            "trolleybus",
            "fire_engine",
            "school_bus",
            "garbage_truck",
            "police_van",
            "racing_car",
            "sports_car",
            "go-kart",
            "golfcart",
            "forklift",
            "bicycle",
            "motorcycle",
            "bicycle-built-for-two",
            "mountain_bike",
            "streetcar",
        ]
    ):
        assigned_superclass = "Car"
    elif "cat" in description_lower or any(
        cat in description_lower
        for cat in [
            "cat",
            "kitten",
            "lion",
            "tiger",
            "leopard",
            "jaguar",
            "cheetah",
            "cougar",
            "panther",
            "lynx",
            "bobcat",
            "ocelot",
            "caracal",
            "wildcat",
            "tiger_cat",
            "persian_cat",
            "siamese_cat",
            "egyptian_cat",
            "tabby",
        ]
    ):
        assigned_superclass = "Cat"
    elif "dog" in description_lower or any(
        dog in description_lower
        for dog in [
            "dog",
            "puppy",
            "wolf",
            "fox",
            "coyote",
            "hound",
            "dingo",
            "dhole",
            "jackal",
            "hyena",
            "poodle",
            "terrier",
            "retriever",
            "bulldog",
            "beagle",
            "spaniel",
            "sheepdog",
            "collie",
            "pinscher",
            "dalmatian",
            "husky",
            "greyhound",
            "chihuahua",
            "labrador",
            "boxer",
            "great_dane",
            "newfoundland",
            "samoyed",
            "pomeranian",
            "keeshond",
            "malamute",
            "shih-tzu",
            "papillon",
            "basenji",
            "pug",
            "leonberg",
            "eskimo_dog",
            "rottweiler",
            "doberman",
            "bloodhound",
            "schipperke",
            "griffon",
        ]
    ):
        assigned_superclass = "Dog"
    elif "fruit" in description_lower or any(
        fruit in description_lower
        for fruit in [
            "fruit",
            "apple",
            "banana",
            "orange",
            "lemon",
            "pineapple",
            "strawberry",
            "mango",
            "melon",
            "grape",
            "pear",
            "peach",
            "plum",
            "cherry",
            "fig",
            "pomegranate",
            "custard_apple",
            "jackfruit",
            "papaya",
            "guava",
            "kiwi",
            "apricot",
            "berry",
            "raspberry",
            "blackberry",
            "blueberry",
            "pineapple",
            "bell_pepper",
            "cucumber",
            "zucchini",
            "artichoke",
            "cauliflower",
            "broccoli",
            "mushroom",
            "potato",
            "squash",
            "corn",
            "pumpkin",
            "eggplant",
            "tomato",
            "olive",
            "avocado",
        ]
    ):
        assigned_superclass = "Fruit"
    elif "fungus" in description_lower or any(
        fungus in description_lower
        for fungus in [
            "fungus",
            "mushroom",
            "morel",
            "truffle",
            "bolete",
            "hen-of-the-woods",
            "earthstar",
            "gyromitra",
            "stinkhorn",
            "agaric",
            "polypore",
            "coral_fungus",
            "edible_mushroom",
            "toadstool",
            "bracket_fungus",
        ]
    ):
        assigned_superclass = "Fungus"
    elif "insect" in description_lower or any(
        insect in description_lower
        for insect in [
            "insect",
            "bug",
            "bee",
            "ant",
            "fly",
            "beetle",
            "butterfly",
            "grasshopper",
            "cockroach",
            "mosquito",
            "dragonfly",
            "mantis",
            "wasp",
            "cricket",
            "ladybug",
            "cicada",
            "locust",
            "termite",
            "firefly",
            "stick_insect",
            "lacewing",
            "damselfly",
            "weevil",
            "centipede",
            "spider",
            "scorpion",
            "tick",
            "tarantula",
            "silkworm",
            "flea",
            "mite",
            "aphid",
            "leafhopper",
            "praying_mantis",
            "leaf_beetle",
            "earwig",
            "cockchafer",
            "mantid",
            "stonefly",
            "dung_beetle",
            "black_widow",
        ]
    ):
        assigned_superclass = "Insect"
    elif "monkey" in description_lower or any(
        monkey in description_lower
        for monkey in [
            "monkey",
            "ape",
            "chimpanzee",
            "baboon",
            "gorilla",
            "orangutan",
            "gibbon",
            "lemur",
            "macaque",
            "mandrill",
            "capuchin",
            "howler",
            "titi",
            "squirrel_monkey",
            "colobus",
            "guenon",
            "proboscis_monkey",
            "langur",
            "spider_monkey",
            "siamang",
            "indri",
            "patas",
        ]
    ):
        assigned_superclass = "Monkey"
    else:
        continue
    label_to_superclass[label] = superclass_to_idx[assigned_superclass]

len(label_to_superclass)

309

In [21]:
custom_dataset = SuperclassingImageNetDataset(label_to_superclass, variant="matched-frequency", location=SUPERCLASS_DATASET_PATH)

Dataset matched-frequency not found on disk, downloading....


100%|██████████| 1.26G/1.26G [00:58<00:00, 21.5MiB/s]


Extracting....


In [22]:
dataloader = DataLoader(custom_dataset, batch_size=32, shuffle=True)

In [24]:
import matplotlib.pyplot as plt
from torchvision.transforms import ToPILImage

# Function to display images
def display_images(images, labels, class_names, superclass_name):
    fig, axes = plt.subplots(1, len(images), figsize=(15, 5))
    fig.suptitle(superclass_name, fontsize=16)
    for img, label, ax in zip(images, labels, axes):
        ax.imshow(ToPILImage()(img))
        ax.set_title(class_names[label])
        ax.axis('off')
    plt.show()

# Iterate over the dataset to find images for each superclass
for superclass_name, (count, class_names) in superclass_to_classes_with_count.items():
    images_to_display = []
    labels_to_display = []
    class_indices = [idx for idx, name in enumerate(class_names)]
    
    for images, labels in dataloader:
        for img, label in zip(images, labels):
            if label in class_indices:
                images_to_display.append(img)
                labels_to_display.append(label)
                if len(images_to_display) == 3:
                    break
        if len(images_to_display) == 3:
            break
    
    if images_to_display:
        display_images(images_to_display, labels_to_display, class_names, superclass_name)

RuntimeError: stack expects each tensor to be equal size, but got [3, 400, 400] at entry 0 and [3, 500, 500] at entry 1

In [11]:
superclass_to_classes = {name: [] for name in superclass_names}

for label, superclass_idx in label_to_superclass.items():
    class_desc = idx_to_class_desc[label]
    superclass_name = superclass_names[superclass_idx]
    superclass_to_classes[superclass_name].append(class_desc)

superclass_to_classes_with_count = {name: (len(classes), classes) for name, classes in superclass_to_classes.items()}

superclass_to_classes_with_count


{'Bird': (35,
  ['cock',
   'hen',
   'ostrich',
   'goldfinch',
   'house_finch',
   'robin',
   'jay',
   'magpie',
   'chickadee',
   'kite',
   'bald_eagle',
   'vulture',
   'great_grey_owl',
   'ptarmigan',
   'peacock',
   'quail',
   'sulphur-crested_cockatoo',
   'hummingbird',
   'goose',
   'black_swan',
   'flamingo',
   'little_blue_heron',
   'crane',
   'king_penguin',
   'beagle',
   'cocker_spaniel',
   'mongoose',
   'cockroach',
   'howler_monkey',
   'birdhouse',
   'cocktail_shaker',
   'crane',
   'mixing_bowl',
   'soup_bowl',
   'hen-of-the-woods']),
 'Boat': (20,
  ['aircraft_carrier',
   'airliner',
   'airship',
   'boathouse',
   'bobsled',
   'canoe',
   'catamaran',
   'container_ship',
   'dock',
   'fireboat',
   'gondola',
   'lifeboat',
   'liner',
   'paddlewheel',
   'pirate',
   'schooner',
   'speedboat',
   'submarine',
   'trimaran',
   'yawl']),
 'Car': (50,
  ['bustard',
   'Cardigan',
   'cabbage_butterfly',
   'colobus',
   'Madagascar_cat',


## class distribution analysis

In [None]:
# Class Distribution Analysis
class_distribution = analyze_class_distribution(dataset)