# Choose settings

##### Choose your settings here

In [10]:
# choose dataset
DATASET_NAME = 'cats-vs-dogs-large'  # needs to match folder name in FM/datasets
USE_CACHED_EMBEDDINGS = 'ALIGN_cats-vs-dogs-large.pkl'  # '' for loading the dataset normally, 'CREATE__{x}.pkl' for creating the cache file {x}.pkl, '{x}.pkl' for loading the cache file {x}.pkl
MISLABELED_INSTANCES_LIST = 'mislabeled_instances_cats-vs-dogs.pkl'
MISLABELED_INSTANCES = [
    'cat/92.jpg',
    'cat/835.jpg',
    'cat/3216.jpg',
    'cat/3672.jpg',
    'cat/4085.jpg',
    'cat/4338.jpg',
    'cat/4688.jpg',
    'cat/5351.jpg',
    'cat/5418.jpg',
    'cat/5673.jpg',
    'cat/6987.jpg',
    'cat/7377.jpg',
    'cat/7564.jpg',
    'cat/8100.jpg',
    'cat/10029.jpg',
    'cat/10539.jpg',
    'cat/10712.jpg',
    'cat/10827.jpg',
    'cat/11184.jpg',
    'cat/11565.jpg',
    'cat/12272.jpg',
    'cat/8100.jpg',
    'cat/8456.jpg',
    'cat/9171.jpg',
    'cat/9626.jpg',
    'cat/9770.jpg',
    'dog/1043.jpg',
    'dog/1194.jpg',
    'dog/1773.jpg',
    'dog/2614.jpg',
    'dog/2877.jpg',
    'dog/4334.jpg',
    'dog/4367.jpg',
    'dog/5490.jpg',
    'dog/5604.jpg',
    'dog/6475.jpg',
    'dog/7413.jpg',
    'dog/10161.jpg',
    'dog/10237.jpg',
    'dog/10401.jpg',
    'dog/10747.jpg',
    'dog/10797.jpg',
    'dog/10801.jpg',
    'dog/11299.jpg',
    'dog/11702.jpg',
    'dog/11731.jpg',
    'dog/12223.jpg',
    'dog/12376.jpg',
    'dog/8736.jpg',
    'dog/9090.jpg',
    'dog/9517.jpg'
]

# choose which model to use to generate the embeddings
MODEL = 'ALIGN'  # Can be 'CLIP', 'AlexNet', 'ViT-pooling' or 'ViT-CLS'

##### This part is calculated automatically

In [12]:
datasets_path = '/content/drive/My Drive/FM/datasets/'
dataset_path = datasets_path + DATASET_NAME + '/'

if DATASET_NAME == 'cats-vs-dogs-large' or DATASET_NAME == 'train-small':
  LABELS = ['cat', 'dog']
elif DATASET_NAME == 'jellyfish-classification':
  LABELS = ['barrel jellyfish', 'compass jellyfish', 'lions mane jellyfish', 'moon jellyfish']
elif DATASET_NAME == 'traffic-signs':
  LABELS = ['30 kilometers per hour speed limit traffic sign', '80 kilometers per hour speed limit traffic sign', '100 kilometers per hour speed limit traffic sign', 'give way traffic sign', 'no entry traffic sign', 'no overtaking traffic sign', 'priority over oncoming traffic sign', 'stop sign']
elif DATASET_NAME == 'imagenet_one-of-each-class-except-cats-and-dogs':
  LABELS = ['val']
else:
  raise ValueError('Invalid dataset selected or labels not set!')

assert MODEL in ['CLIP', 'AlexNet', 'ViT-pooling', 'ViT-CLS',"ALIGN"], f'Selected model {MODEL} not implemented!'

# Load libraries

In [13]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-b12cib0h
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-b12cib0h
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone


In [16]:
import torch
from torchvision import transforms
import clip
from transformers import AutoImageProcessor, ViTModel, AlignProcessor, AlignModel, AutoTokenizer
from transformers.tokenization_utils_base import BatchEncoding
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from pkg_resources import packaging
import os
from google.colab import drive
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix
import glob
import pickle
from scipy.spatial.distance import cosine

print("Torch version:", torch.__version__)

Torch version: 2.1.0+cu121


# Load model

In [17]:
device = "cuda" if torch.cuda.is_available() else "cpu"

class AlexNetEmbedder(torch.nn.Module):

  def __init__(self):
    super().__init__()
    self.alexnet = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.alexnet.features(x)
    x = self.alexnet.avgpool(x)
    return self.alexnet.classifier[:5](torch.flatten(x, 1))

  def encode_image(self, img: torch.Tensor) -> torch.Tensor:
    return self(img)

class VisionTransformer(torch.nn.Module):

  def __init__(self, use_pooler_output_instead_of_last_hidden_state=False):
    super().__init__()
    self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
    self.use_pooler_output = use_pooler_output_instead_of_last_hidden_state

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    out = self.vit(x, return_dict=True)
    if self.use_pooler_output:
      return out.pooler_output
    else:
      return out.last_hidden_state[:, 0, :]

  def encode_image(self, img: torch.Tensor) -> torch.Tensor:
    return self(img)


def vit_preprocessor_with_memory_fix(img) -> torch.Tensor:
  processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
  with torch.no_grad():
    processed = processor(img, return_tensors="pt").to(device).pixel_values.squeeze(0)
  del processor
  return processed


class Align(torch.nn.Module):

  def __init__(self):
    super().__init__()
    self.align = AlignModel.from_pretrained("kakaobrain/align-base")

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.align(x)

  def encode_image(self, img: torch.Tensor) -> torch.Tensor:
    return self.align.get_image_features(img)

  def encode_text(self, text: BatchEncoding) -> torch.Tensor:
    return self.align.get_text_features(**text)


def align_preprocessor_with_memory_fix(img) -> torch.Tensor:
  processor = AlignProcessor.from_pretrained("kakaobrain/align-base")
  with torch.no_grad():
    processed = processor(images=img, return_tensors="pt").to(device).pixel_values.squeeze(0)
  del processor
  return processed


if MODEL == 'CLIP':
  model, preprocess = clip.load("ViT-B/32", device=device)
  tokenize = clip.tokenize
elif MODEL == 'AlexNet':
  model = AlexNetEmbedder().to(device)
  preprocess = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
elif MODEL[:3] == 'ViT':
  model = VisionTransformer(MODEL[4:] == 'pooling').to(device)
  preprocess = vit_preprocessor_with_memory_fix
elif MODEL == 'ALIGN':
  model = Align().to(device)
  preprocess = align_preprocessor_with_memory_fix
  tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")
  tokenize = lambda s: tokenizer([s], padding=True, return_tensors="pt")
else:
  raise ValueError(f'Invalid model {MODEL} selected!')

model.eval()
print(model)

The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


config.json:   0%|          | 0.00/5.25k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/690M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/399 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Align(
  (align): AlignModel(
    (text_model): AlignTextModel(
      (embeddings): AlignTextEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): AlignTextEncoder(
        (layer): ModuleList(
          (0-11): 12 x AlignTextLayer(
            (attention): AlignTextAttention(
              (self): AlignTextSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): AlignTextSelfOutput(
                (dense): Linear(in_features=768, out_

# Mounting storage

In [18]:
drive.mount('/content/drive')
!ls "{datasets_path}"

Mounted at /content/drive
AlexNet_cats-vs-dogs-large.pkl	     imagenet_one-of-each-class-except-cats-and-dogs
AlexNet_imagenet-subset.pkl	     indizes_clean_Anni.txt
AlexNet_traffic-signs.pkl	     jellyfish-classification
ALIGN_cats-vs-dogs-large.pkl	     mislabeled_instances_cats-vs-dogs_CATS_ONLY.pkl
ALIGN_imagenet-subset.pkl	     mislabeled_instances_cats-vs-dogs.pkl
align_text_dog_embeddings.pkl	     note.txt
align_text_random_embeddings.pkl     text_dog_embeddings_more.pkl
ALIGN_traffic-signs.pkl		     text_dog_embeddings.pkl
cats-dogs-big_ids.pkl		     text_random_embeddings_more.pkl
cats-dogs-big.pkl		     text_random_embeddings.pkl
cats-vs-dogs-large		     traffic-signs
CLIP_cats-vs-dogs-large.pkl	     train-small
CLIP_imagenet-subset.pkl	     ViT-CLS_cats-vs-dogs-large.pkl
CLIP_traffic-signs.pkl		     ViT-CLS_imagenet-subset.pkl
dog_wrong_2_12.txt		     ViT-CLS_traffic-signs.pkl
dog_wrong.txt			     ViT-pooling_cats-vs-dogs-large.pkl
image_embeddings__cats-vs-dogs.pkl   ViT-po

# Define dataset loader

In [19]:
def load_dataset(folder_path, labels):

    # Checking if the provided folder path exists
    if not os.path.exists(folder_path):
        raise ValueError("Folder path does not exist.")

    images = {}
    for label in labels:
      images[label] = []

    # Looping through all files in the folder
    for i, filename in enumerate(glob.glob(folder_path + '**/*', recursive=True)):

      if '/'.join(filename.split('/')[-2:]) not in MISLABELED_INSTANCES:
        continue

      if i % 1000 == 0:
        print(i, 'files loaded')

      try:
        img = Image.open(filename).convert('RGB')
      except:
        continue

      label_found = False
      for label in labels:
        if label in '/'.join(filename.split('/')[-2:]):
          if label_found:
            raise ValueError(f"Label of {filename} is ambiguous.")
          label_found = True
          images[label].append(img)

      if not label_found:
        raise ValueError(f"No label for {filename} found.")

    print(i+1, 'files loaded')

    return images

# Load and embed dataset

In [20]:
# import cached embedded images
with open(datasets_path + USE_CACHED_EMBEDDINGS, 'rb') as f:
  image_embeddings = pickle.load(f)
for label in image_embeddings:
  print(f'{len(image_embeddings[label])} images loaded for class {label}')

# filter with imported cached pollution indices
with open(datasets_path + MISLABELED_INSTANCES_LIST, 'rb') as f:
  mislabeled_indices = pickle.load(f)
for label in mislabeled_indices:
  if len(mislabeled_indices[label]) == 0 or sum(mislabeled_indices[label]) == 0:
    continue
  image_embeddings[label + '_pollution'] = image_embeddings[label][mislabeled_indices[label]]
  print(f"{len(image_embeddings[label + '_pollution'])} images removed for class {label}")

# re-embedd polluted images based on file names
images = load_dataset(dataset_path, LABELS)
image_embeddings_pollution = {}
for label in LABELS:
  if len(images[label]) == 0:
    continue
  with torch.no_grad():
    processed_images = torch.cat(([preprocess(img).unsqueeze(0) for img in images[label]])).to(device)
    image_embeddings_pollution[label] = model.encode_image(processed_images)
for label in image_embeddings_pollution:
  print(f'{len(image_embeddings_pollution[label])} images loaded for class {label}')

12502 images loaded for class cat
12499 images loaded for class dog
25 images removed for class cat
24 images removed for class dog
25004 files loaded


preprocessor_config.json:   0%|          | 0.00/508 [00:00<?, ?B/s]

25 images loaded for class cat
24 images loaded for class dog


In [21]:
# compare both tensors
for label in mislabeled_indices:
  if len(mislabeled_indices[label]) == 0 or sum(mislabeled_indices[label]) == 0:
    continue
  print(f"class {label}: {image_embeddings[label + '_pollution'].shape} for imported pollution, {image_embeddings_pollution[label].shape} for newly embedded pollution")
  tensors_equal = torch.equal(image_embeddings[label + '_pollution'], image_embeddings_pollution[label])
  print('tensors equal:', tensors_equal)

  if not tensors_equal:

    missing = 0
    for i, t in enumerate(image_embeddings[label + '_pollution']):
      if t not in image_embeddings_pollution[label]:
        missing += 1
        print(f"Embedding at index {i} of imported tensor not in embedded tensor!")
    print('Every embedding of imported tensor is also in embedded tensor')

    missing = 0
    for i, t in enumerate(image_embeddings_pollution[label]):
      if t not in image_embeddings[label + '_pollution']:
        missing += 1
        print(f"Embedding at index {i} of embedded tensor not in imported tensor!")
    print('Every embedding of embedded tensor is also in imported tensor')

class cat: torch.Size([25, 640]) for imported pollution, torch.Size([25, 640]) for newly embedded pollution
tensors equal: False
Every embedding of imported tensor is also in embedded tensor
Every embedding of embedded tensor is also in imported tensor
class dog: torch.Size([24, 640]) for imported pollution, torch.Size([24, 640]) for newly embedded pollution
tensors equal: False
Every embedding of imported tensor is also in embedded tensor
Every embedding of embedded tensor is also in imported tensor
