# Choose settings

##### Choose your settings here

In [14]:
# choose dataset
DATASET_NAME = 'cats-vs-dogs-large'  # needs to match folder name in FM/datasets
LOAD_AND_EMBED_DATASET_IN_BATCHES = True  # True for large datasets, False for small ones
USE_CACHED_EMBEDDINGS = 'CREATE__mislabeled_instances_cats-vs-dogs.pkl'  # '' for loading the dataset normally, 'CREATE__{x}.pkl' for creating the cache file {x}.pkl, '{x}.pkl' for loading the cache file {x}.pkl
MISLABELED_INSTANCES = [
    'cat/92.jpg',
    'cat/835.jpg',
    'cat/3216.jpg',
    'cat/3672.jpg',
    'cat/4085.jpg',
    'cat/4338.jpg',
    'cat/4688.jpg',
    'cat/5351.jpg',
    'cat/5418.jpg',
    'cat/5673.jpg',
    'cat/6987.jpg',
    'cat/7377.jpg',
    'cat/7564.jpg',
    'cat/8100.jpg',
    'cat/10029.jpg',
    'cat/10539.jpg',
    'cat/10712.jpg',
    'cat/10827.jpg',
    'cat/11184.jpg',
    'cat/11565.jpg',
    'cat/12272.jpg',
    'cat/8100.jpg',
    'cat/8456.jpg',
    'cat/9171.jpg',
    'cat/9626.jpg',
    'cat/9770.jpg',
    'dog/1043.jpg',
    'dog/1194.jpg',
    'dog/1773.jpg',
    'dog/2614.jpg',
    'dog/2877.jpg',
    'dog/4334.jpg',
    'dog/4367.jpg',
    'dog/5490.jpg',
    'dog/5604.jpg',
    'dog/6475.jpg',
    'dog/7413.jpg',
    'dog/10161.jpg',
    'dog/10237.jpg',
    'dog/10401.jpg',
    'dog/10747.jpg',
    'dog/10797.jpg',
    'dog/10801.jpg',
    'dog/11299.jpg',
    'dog/11702.jpg',
    'dog/11731.jpg',
    'dog/12223.jpg',
    'dog/12376.jpg',
    'dog/8736.jpg',
    'dog/9090.jpg',
    'dog/9517.jpg'
]

##### This part is calculated automatically

In [15]:
datasets_path = '/content/drive/My Drive/FM/datasets/'
dataset_path = datasets_path + DATASET_NAME + '/'

if DATASET_NAME == 'cats-vs-dogs-large' or DATASET_NAME == 'train-small':
  LABELS = ['cat', 'dog']
elif DATASET_NAME == 'jellyfish-classification':
  LABELS = ['barrel jellyfish', 'compass jellyfish', 'lions mane jellyfish', 'moon jellyfish']
elif DATASET_NAME == 'traffic-signs':
  LABELS = ['30 kilometers per hour speed limit traffic sign', '80 kilometers per hour speed limit traffic sign', '100 kilometers per hour speed limit traffic sign', 'give way traffic sign', 'no entry traffic sign', 'no overtaking traffic sign', 'priority over oncoming traffic sign', 'stop sign']
else:
  raise ValueError('Invalid dataset selected or labels not set!')

# Load libraries

In [16]:
# hacky way when hitting "run all" that libraries are not reloaded
try:
  torch.tensor([[0]])
  libraries_already_loaded = True
except:
  libraries_already_loaded = False

In [17]:
if not libraries_already_loaded:
  import torch
  from torchvision import transforms
  import numpy as np
  from matplotlib import pyplot as plt
  from PIL import Image
  from pkg_resources import packaging
  import os
  from google.colab import drive
  from sklearn.metrics.pairwise import cosine_similarity
  from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix
  import glob
  import pickle
  from scipy.spatial.distance import cosine

print("Torch version:", torch.__version__)

Torch version: 2.1.0+cu121


# Mounting storage

In [18]:
drive.mount('/content/drive')
!ls "{datasets_path}"

Drive already mounted at /content/drive; to attempt to forcibly remount, call drive.mount("/content/drive", force_remount=True).
AlexNet_cats-vs-dogs-large.pkl	     imagenet_one-of-each-class-except-cats-and-dogs
AlexNet_imagenet-subset.pkl	     indizes_clean_Anni.txt
AlexNet_traffic-signs.pkl	     jellyfish-classification
ALIGN_cats-vs-dogs-large.pkl	     mislabeled_instances_cats-vs-dogs_CATS_ONLY.pkl
ALIGN_imagenet-subset.pkl	     mislabeled_instances_cats-vs-dogs.pkl
ALIGN_traffic-signs.pkl		     note.txt
cats-dogs-big_ids.pkl		     text_dog_embeddings.pkl
cats-dogs-big.pkl		     text_random_embeddings.pkl
cats-vs-dogs-large		     traffic-signs
CLIP_cats-vs-dogs-large.pkl	     train-small
CLIP_imagenet-subset.pkl	     ViT-CLS_cats-vs-dogs-large.pkl
CLIP_traffic-signs.pkl		     ViT-CLS_imagenet-subset.pkl
dog_wrong_2_12.txt		     ViT-CLS_traffic-signs.pkl
dog_wrong.txt			     ViT-pooling_cats-vs-dogs-large.pkl
image_embeddings__cats-vs-dogs.pkl   ViT-pooling_imagenet-subset.pkl
imag

# Define dataset loader

In [19]:
def get_embeddings_dict_batchwise(folder_path, labels, batch_size=64):

  # Checking if the provided folder path exists
  if not os.path.exists(folder_path):
    raise ValueError("Folder path does not exist.")

  image_embeddings = {}
  for label in labels:
    image_embeddings[label] = []

  images = {}
  for label in labels:
    images[label] = []

  # Looping through all files in the folder
  all_files = glob.glob(folder_path + '**/*', recursive=True)
  for n_instances_processed, filename in enumerate(all_files):

    try:
      img = Image.open(filename).convert('RGB')
    except:
      continue

    # Find label of the image
    label_found = False
    for label in labels:
      if label in '/'.join(filename.split('/')[-2:]):
        if label_found:
          raise ValueError(f"Label of {filename} is ambiguous.")
        label_found = True
        is_mislabeled = '/'.join(filename.split('/')[-2:]) in MISLABELED_INSTANCES
        images[label].append(is_mislabeled)
    if not label_found:
      raise ValueError(f"No label for {filename} found.")

    # Get embeddings if already a batch is full
    if n_instances_processed % batch_size == 0 and n_instances_processed > 0 or n_instances_processed == len(all_files) - 1:
      for label in labels:
        if len(images[label]) == 0:
          continue
        image_embeddings[label].extend(images[label])
      images = {}
      for label in labels:
        images[label] = []
      print(n_instances_processed, 'loaded and encoded')

  return image_embeddings

# Load and embed dataset

In [20]:
# load and embed images in batches (to save GPU memory and especially RAM)
image_embeddings = get_embeddings_dict_batchwise(dataset_path, LABELS, batch_size=512)

# save embeddings in pickle file if desired (enables to reload them later on)
pickle_filename = '__'.join(USE_CACHED_EMBEDDINGS.split('__')[1:])  # remove prefix 'CREATE__'
pickle_file = datasets_path + pickle_filename
with open(pickle_file, 'wb') as f:
  pickle.dump(image_embeddings, f)
print('Embeddings stored in', pickle_file)

512 loaded and encoded
1024 loaded and encoded
1536 loaded and encoded




2048 loaded and encoded
2560 loaded and encoded
3072 loaded and encoded
3584 loaded and encoded
4096 loaded and encoded
4608 loaded and encoded
5120 loaded and encoded
5632 loaded and encoded
6144 loaded and encoded
6656 loaded and encoded
7168 loaded and encoded
7680 loaded and encoded
8192 loaded and encoded
8704 loaded and encoded
9216 loaded and encoded
9728 loaded and encoded
10240 loaded and encoded
10752 loaded and encoded
11264 loaded and encoded
11776 loaded and encoded
12288 loaded and encoded
12800 loaded and encoded
13312 loaded and encoded
13824 loaded and encoded
14336 loaded and encoded
14848 loaded and encoded
15360 loaded and encoded
15872 loaded and encoded
16384 loaded and encoded
16896 loaded and encoded
17408 loaded and encoded
17920 loaded and encoded
18432 loaded and encoded
18944 loaded and encoded
19456 loaded and encoded
19968 loaded and encoded
20480 loaded and encoded
20992 loaded and encoded
21504 loaded and encoded
22016 loaded and encoded
22528 loaded and

In [21]:
# check data
for key in image_embeddings:
  print(key, sum(image_embeddings[key]), len(image_embeddings[key]))

cat 25 12502
dog 24 12499
