In [3]:
# Imports 

import sys
import pickle 
import seaborn as sns
import matplotlib.pyplot as plt
import pandas as pd
import dassl
import copy
from torchvision import transforms
import torch
import csv
import random
import numpy as np
from munch import DefaultMunch 
from PIL import Image 
from tqdm import tqdm

sys.path.append(r'/home/afroehli/coding/OOD-X-Benchmarks')
from datasets import imagenet_x as i_x, imagenet_r_x as r_x, imagenet_c_x as c_x, imagenet_v2_x as v2_x

sys.path.append(r'/home/afroehli/coding/util_scripts')
from utils_dataloading.dassl_datum_mod import DatumWithWnid 
from utils_dataloading.imagenet_tree import ImagenetSemanticInfo, ImagenetSemanticSubtree

In [None]:
# Create dataset instances if necessary 

cfg = dict()
cfg["DATASET"] = dict()
cfg["DATASET"]["ROOT"] = r"/home/afroehli/coding/OOD-X-Benchmarks/data"
cfg["DATASET"]["SUBSAMPLE_CLASSES"] = "custom"
cfg["DATASET"]["ID_CLASSES_FILE"] = r"/home/afroehli/coding/OOD-X-Benchmarks/data/class_splits/imagenet/x/id_data.txt"
cfg["DATASET"]["OOD_CLASSES_FILE"] = r"/home/afroehli/coding/OOD-X-Benchmarks/data/class_splits/imagenet/x/ood_data.txt"
cfg["DATASET"]["NUM_SHOTS"] = 0

cfg_obj = DefaultMunch.fromDict(cfg)

# create and store Imagenet-C instance 
print('Start loading Imagenet-C')
inet_c_x = c_x.ImageNetCX(cfg_obj)
with open('/home/afroehli/coding/pickle_data/imagenet_c.pkl', 'wb') as imagenet_c_pickle:
    pickle.dump(inet_c_x, imagenet_c_pickle, pickle.HIGHEST_PROTOCOL)

# create and store Imagenet-V2 instance 
print('Start loading Imagenet-V2')
inet_v2_x = v2_x.ImageNetV2X(cfg_obj)
with open('/home/afroehli/coding/pickle_data/imagenet_v2.pkl', 'wb') as imagenet_v2_pickle:
    pickle.dump(inet_v2_x, imagenet_v2_pickle, pickle.HIGHEST_PROTOCOL)

# create and store Imagenet-R instance 
print('Start loading Imagenet-R')
inet_r_x = r_x.ImageNetRX(cfg_obj)
with open('/home/afroehli/coding/pickle_data/imagenet_r.pkl', 'wb') as imagenet_r_pickle:
    pickle.dump(inet_r_x, imagenet_r_pickle, pickle.HIGHEST_PROTOCOL)
 

# create and store Imagenet-X instance
print('Start loading Imagenet-X')
inet_x = i_x.ImageNetX(cfg_obj) 
with open('/home/afroehli/coding/pickle_data/imagenet_x.pkl', 'wb') as imagenet_x_pickle:
    pickle.dump(inet_x, imagenet_x_pickle, pickle.HIGHEST_PROTOCOL)

Pass 0 items
items: []
id_labels: [0, 2, 4, 5, 7, 10, 11, 14, 16, 17, 20, 22, 24, 26, 28, 30, 31, 33, 35, 38, 39, 42, 44, 46, 48, 50, 52, 53, 54, 55, 56, 62, 64, 66, 67, 69, 70, 72, 73, 74, 80, 81, 84, 85, 87, 90, 92, 94, 95, 98, 100, 102, 104, 106, 108, 110, 112, 113, 116, 118, 119, 122, 124, 126, 127, 129, 130, 131, 132, 137, 140, 142, 144, 146, 148, 150, 151, 152, 153, 157, 160, 161, 162, 163, 164, 168, 165, 166, 169, 171, 179, 181, 182, 183, 184, 185, 186, 187, 191, 192, 189, 196, 197, 205, 206, 210, 212, 213, 215, 216, 217, 222, 224, 226, 227, 228, 229, 230, 236, 237, 238, 239, 242, 243, 244, 249, 251, 253, 254, 255, 258, 259, 263, 265, 266, 269, 270, 273, 276, 277, 278, 281, 282, 286, 288, 289, 290, 294, 295, 298, 300, 301, 302, 305, 308, 309, 311, 314, 316, 318, 319, 322, 325, 324, 327, 330, 331, 333, 334, 338, 340, 341, 344, 345, 347, 349, 351, 352, 356, 357, 358, 359, 364, 365, 368, 370, 371, 372, 373, 377, 378, 379, 383, 385, 387, 390, 392, 394, 396, 398, 399, 400, 401, 402, 

In [4]:
# Load dataset instances if possible 

def get_unique_classnames(dataset: list[DatumWithWnid]) -> list[str]:
    classnames_in_dset = []
    for dset_instance in dataset: 
        classnames_in_dset.append((dset_instance.classname, dset_instance.wnid))
    classnames_in_dset = list(set(classnames_in_dset))
    return classnames_in_dset 

with open('/home/afroehli/coding/pickle_data/imagenet_c.pkl', 'rb') as imagenet_c_pickle:
    inet_c_x = pickle.load(imagenet_c_pickle)
with open('/home/afroehli/coding/pickle_data/imagenet_v2.pkl', 'rb') as imagenet_v2_pickle:
    inet_v2_x = pickle.load(imagenet_v2_pickle)
with open('/home/afroehli/coding/pickle_data/imagenet_r.pkl', 'rb') as imagenet_r_pickle: 
    inet_r_x = pickle.load(imagenet_r_pickle)
with open('/home/afroehli/coding/pickle_data/imagenet_x.pkl', 'rb') as imagenet_x_pickle:
    inet_x = pickle.load(imagenet_x_pickle)


classnames_dataset = {'Imagenet-C': {'train_x': [], 'val': [], 'test': []}, 
                      'Imagenet-V2': {'train_x': [], 'val': [], 'test': []}, 
                      'Imagenet-R': {'train_x': [], 'val': [], 'test': []},
                      'Imagenet-X': {'train_x': [], 'val': [], 'test': []}}

print('\tDataset\t\t|\tTrain\t|\tVal\t|\tTest\t|')
for name, dataset in [('Imagenet-C', inet_c_x), ('Imagenet-V2', inet_v2_x), ('Imagenet-R', inet_r_x), ('Imagenet-X', inet_x)]:
    print(f'\t{name}\t|\t{len(dataset.train_x)}\t|\t{len(dataset.val)}\t|\t{len(dataset.test)}\t|')
    classnames_dataset[name]['train_x'] = get_unique_classnames(dataset.train_x)
    classnames_dataset[name]['val'] = get_unique_classnames(dataset.val)
    classnames_dataset[name]['test'] = get_unique_classnames(dataset.test)
    print(f'\tN-Classnames\t|\t{len(classnames_dataset[name]['train_x'])}'
          f'\t{len([cn for cn in classnames_dataset[name]['train_x'] if cn in classnames_dataset[name]['val']])}'
          f'\t{len(classnames_dataset[name]['val'])}'
          f'\t{len([cn for cn in classnames_dataset[name]['val'] if cn in classnames_dataset[name]['test']])}'
          f'\t{len(classnames_dataset[name]['test'])}\t|')
    
intersections = dict()

for split in ['train_x', 'val', 'test']:
    intersections[split] = set(classnames_dataset['Imagenet-C'][split]) & set(classnames_dataset['Imagenet-V2'][split]) & set(classnames_dataset['Imagenet-R'][split])  & set(classnames_dataset['Imagenet-X'][split])

print(f'\tIntersections\t|\t{len(intersections['train_x'])}\t|\t{len(intersections['val'])}\t|'
      f'\t{len(intersections['test'])}\t|')

not_in_inet_x = {cn for cn in set(classnames_dataset['Imagenet-R']['val']) if not cn in set(classnames_dataset['Imagenet-C']['val'])}
print(f'Not in Imagenet-X: {len(not_in_inet_x)} {not_in_inet_x}')

	Dataset		|	Train	|	Val	|	Test	|
	Imagenet-C	|	5380	|	5380	|	4620	|
	N-Classnames	|	296	296	296	0	225	|
	Imagenet-V2	|	5000	|	5000	|	5000	|
	N-Classnames	|	500	500	500	0	500	|
	Imagenet-R	|	14755	|	14755	|	15245	|
	N-Classnames	|	100	100	100	0	100	|
	Imagenet-X	|	1281167	|	25000	|	25000	|
	N-Classnames	|	1000	500	500	0	500	|
	Intersections	|	72	|	72	|	63	|
Not in Imagenet-X: 28 {('cabbage', 'n07714571'), ('banana', 'n07753592'), ('spider web', 'n04275548'), ('hammer', 'n03481172'), ('hot dog', 'n07697537'), ('school bus', 'n04146614'), ('fire truck', 'n03345487'), ('Granny Smith apple', 'n07742313'), ('ice cream', 'n07614500'), ('broccoli', 'n07714990'), ('lipstick', 'n03676483'), ('flute', 'n03372029'), ('pickup truck', 'n03930630'), ('jeep', 'n03594945'), ('revolver', 'n04086273'), ('grand piano', 'n03452741'), ('pirate ship', 'n03947888'), ('pizza', 'n07873807'), ('hatchet', 'n03498962'), ('acorn', 'n12267677'), ('submarine', 'n04347754'), ('pretzel', 'n07695742'), ('cucumber', 'n07

In [6]:
# Prepare loaded datasets: select classes 

# sort Imagenet-R instance by wnid
imagenet_r_sorted = dict()
for data_split in [inet_r_x.val, inet_r_x.test]:
    for dassl_item in data_split: 
        if dassl_item.wnid in imagenet_r_sorted:
            imagenet_r_sorted[dassl_item.wnid].add(dassl_item)
        else:
            imagenet_r_sorted[dassl_item.wnid] = set()
            imagenet_r_sorted[dassl_item.wnid].add(dassl_item)

for wnid in imagenet_r_sorted.keys():
    imagenet_r_sorted[wnid] = list(imagenet_r_sorted[wnid])

# sort Imagenet-X instance by wnid
imagenet_x_sorted = dict()
for data_split in [inet_x.val, inet_x.test]:
    for dassl_item in data_split: 
        if dassl_item.wnid in imagenet_x_sorted:
            imagenet_x_sorted[dassl_item.wnid].add(dassl_item)
        else:
            imagenet_x_sorted[dassl_item.wnid] = set() 
            imagenet_x_sorted[dassl_item.wnid].add(dassl_item)

for wnid in imagenet_x_sorted.keys():
    imagenet_x_sorted[wnid] = list(imagenet_x_sorted[wnid])


# sort train part of Imagenet-X instance by wnid
imagenet_x_train_sorted = dict()
for dassl_item in inet_x.train_x: 
    if dassl_item.wnid in imagenet_x_train_sorted: 
        imagenet_x_train_sorted[dassl_item.wnid].add(dassl_item)
    else: 
        imagenet_x_train_sorted[dassl_item.wnid] = set()
        imagenet_x_train_sorted[dassl_item.wnid].add(dassl_item)

for wnid in imagenet_x_train_sorted.keys():
    imagenet_x_train_sorted[wnid] = list(imagenet_x_train_sorted[wnid])


# sort Imagenet-V2 instances by wnid 
imagenet_v2_sorted = dict()
for data_split in [inet_v2_x.val, inet_v2_x.test]:
    for dassl_item in data_split:
        if dassl_item.wnid in imagenet_v2_sorted.keys():
            imagenet_v2_sorted[dassl_item.wnid].add(dassl_item)
        else:
            imagenet_v2_sorted[dassl_item.wnid] = set()
            imagenet_v2_sorted[dassl_item.wnid].add(dassl_item)

for wnid in imagenet_v2_sorted.keys():
    imagenet_v2_sorted[wnid] = list(imagenet_v2_sorted[wnid])

# sort Imagenet-C instances by wnid 
imagenet_c_sorted = dict() 
for data_split in [inet_c_x.val, inet_c_x.test]:
    for dassl_item in data_split: 
        if dassl_item.wnid in imagenet_c_sorted: 
            imagenet_c_sorted[dassl_item.wnid].add(dassl_item)
        else:
            imagenet_c_sorted[dassl_item.wnid] = set()
            imagenet_c_sorted[dassl_item.wnid].add(dassl_item)

for wnid in imagenet_c_sorted.keys():
    imagenet_c_sorted[wnid] = list(imagenet_c_sorted[wnid])

# restrict Imagenet-X to classes that are also in Imagenet-R in new class Imagenet-Small-X 
imagenet_r_wnids = list(imagenet_r_sorted.keys())
imagenet_x_small_sorted = dict()

for wnid in imagenet_r_wnids:
    imagenet_x_small_sorted[wnid] = imagenet_x_sorted[wnid]

print(f'N-Classes in Imagenet-R-Sorted: {len(list(imagenet_r_sorted.keys()))}')
print(f'N-Classes in Imagenet-X-Sorted: {len(list(imagenet_x_sorted.keys()))}')
print(f'N-Classes in Imagenet-X-Train-Sorted: {len(list(imagenet_x_train_sorted.keys()))}')
print(f'N-Classes in Imagenet-V2-Sorted: {len(list(imagenet_v2_sorted.keys()))}')
print(f'N-Classes in Imagenet-C-Sorted: {len(list(imagenet_c_sorted.keys()))}')
print(f'N-Classes in Imagenet-X-Small-Sorted: {len(list(imagenet_x_small_sorted.keys()))}')

N-Classes in Imagenet-R-Sorted: 200
N-Classes in Imagenet-X-Sorted: 1000
N-Classes in Imagenet-X-Train-Sorted: 1000
N-Classes in Imagenet-V2-Sorted: 1000
N-Classes in Imagenet-C-Sorted: 521
N-Classes in Imagenet-X-Small-Sorted: 200


In [7]:
# store datasets sorted by wnid 

with open('/home/afroehli/coding/pickle_data/dataset_objects/imagenet_r_sorted.pkl', 'wb') as pickle_file:
    pickle.dump(imagenet_r_sorted, pickle_file, pickle.HIGHEST_PROTOCOL)

with open('/home/afroehli/coding/pickle_data/dataset_objects/imagenet_x_sorted.pkl', 'wb') as pickle_file:
    pickle.dump(imagenet_x_sorted, pickle_file, pickle.HIGHEST_PROTOCOL)

with open('/home/afroehli/coding/pickle_data/dataset_objects/imagenet_x_train_sorted.pkl', 'wb') as pickle_file:
    pickle.dump(imagenet_x_train_sorted, pickle_file, pickle.HIGHEST_PROTOCOL)

with open('/home/afroehli/coding/pickle_data/dataset_objects/imagenet_v2_sorted.pkl', 'wb') as pickle_file:
    pickle.dump(imagenet_v2_sorted, pickle_file, pickle.HIGHEST_PROTOCOL)

with open('/home/afroehli/coding/pickle_data/dataset_objects/imagenet_c_sorted.pkl', 'wb') as pickle_file:
    pickle.dump(imagenet_c_sorted, pickle_file, pickle.HIGHEST_PROTOCOL)

with open('/home/afroehli/coding/pickle_data/dataset_objects/imagenet_x_small_sorted.pkl', 'wb') as pickle_file:
    pickle.dump(imagenet_x_small_sorted, pickle_file, pickle.HIGHEST_PROTOCOL)