# Choose settings

##### Choose your settings here

In [1]:
# choose dataset
DATASET_NAME = 'traffic-signs'  # needs to match folder name in FM/datasets
LOAD_AND_EMBED_DATASET_IN_BATCHES = True  # True for large datasets, False for small ones
BATCH_SIZE = 256
USE_CACHED_EMBEDDINGS = 'CREATE__ALIGN_traffic-signs.pkl'  # '' for loading the dataset normally, 'CREATE__{x}.pkl' for creating the cache file {x}.pkl, '{x}.pkl' for loading the cache file {x}.pkl

# choose which model to use to generate the embeddings
MODEL = 'ALIGN'  # Can be 'CLIP', 'AlexNet', 'ViT-pooling' or 'ViT-CLS'

# settings for pickling instances
START_AT_INDEX = 12289
BUFFER_EACH = 2048

assert BUFFER_EACH % BATCH_SIZE == 0

##### This part is calculated automatically

In [2]:
datasets_path = '/content/drive/My Drive/FM/datasets/'
dataset_path = datasets_path + DATASET_NAME + '/'

if DATASET_NAME == 'cats-vs-dogs-large' or DATASET_NAME == 'train-small':
  LABELS = ['cat', 'dog']
elif DATASET_NAME == 'jellyfish-classification':
  LABELS = ['barrel jellyfish', 'compass jellyfish', 'lions mane jellyfish', 'moon jellyfish']
elif DATASET_NAME == 'traffic-signs':
  LABELS = ['30 kilometers per hour speed limit traffic sign', '80 kilometers per hour speed limit traffic sign', '100 kilometers per hour speed limit traffic sign', 'give way traffic sign', 'no entry traffic sign', 'no overtaking traffic sign', 'priority over oncoming traffic sign', 'stop sign']
else:
  raise ValueError('Invalid dataset selected or labels not set!')

assert MODEL in ['CLIP', 'AlexNet', 'ViT-pooling', 'ViT-CLS', 'ALIGN'], f'Selected model {MODEL} not implemented!'

# Load libraries

In [3]:
! pip install ftfy regex tqdm
! pip install git+https://github.com/openai/CLIP.git

Collecting ftfy
  Downloading ftfy-6.1.3-py3-none-any.whl (53 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m53.4/53.4 kB[0m [31m701.5 kB/s[0m eta [36m0:00:00[0m
Installing collected packages: ftfy
Successfully installed ftfy-6.1.3
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-m2jxew1l
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-m2jxew1l
  Resolved https://github.com/openai/CLIP.git to commit a1d071733d7111c9c014f024669f959182114e33
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369497 sha256=1efd307eeda1bff7a33cf22c13097297d91204697e76006c1191d110c97fdfc4
  Stored in directory: /tmp/pip-ephem-wheel-cache-v0gayifz/wheels/da/2b/4c/d6691fa9597aac8bb85d2ac1

In [4]:
import torch
from torchvision import transforms
import clip
from transformers import AutoImageProcessor, ViTModel, AlignProcessor, AlignModel, AutoTokenizer
from transformers.tokenization_utils_base import BatchEncoding
import numpy as np
from matplotlib import pyplot as plt
from PIL import Image
from pkg_resources import packaging
import os
from google.colab import drive
from sklearn.metrics.pairwise import cosine_similarity
from sklearn.metrics import f1_score, precision_score, recall_score, accuracy_score, confusion_matrix
import glob
import pickle
from scipy.spatial.distance import cosine

print("Torch version:", torch.__version__)

Torch version: 2.1.0+cu121


# Load model

In [5]:
device = "cuda" if torch.cuda.is_available() else "cpu"

class AlexNetEmbedder(torch.nn.Module):

  def __init__(self):
    super().__init__()
    self.alexnet = torch.hub.load('pytorch/vision:v0.10.0', 'alexnet', pretrained=True)

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    x = self.alexnet.features(x)
    x = self.alexnet.avgpool(x)
    return self.alexnet.classifier[:5](torch.flatten(x, 1))

  def encode_image(self, img: torch.Tensor) -> torch.Tensor:
    return self(img)


class VisionTransformer(torch.nn.Module):

  def __init__(self, use_pooler_output_instead_of_last_hidden_state=False):
    super().__init__()
    self.vit = ViTModel.from_pretrained("google/vit-base-patch16-224-in21k")
    self.use_pooler_output = use_pooler_output_instead_of_last_hidden_state

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    out = self.vit(x, return_dict=True)
    if self.use_pooler_output:
      return out.pooler_output
    else:
      return out.last_hidden_state[:, 0, :]

  def encode_image(self, img: torch.Tensor) -> torch.Tensor:
    return self(img)


def vit_preprocessor_with_memory_fix(img) -> torch.Tensor:
  processor = AutoImageProcessor.from_pretrained("google/vit-base-patch16-224-in21k")
  with torch.no_grad():
    processed = processor(img, return_tensors="pt").to(device).pixel_values.squeeze(0)
  del processor
  return processed


class Align(torch.nn.Module):

  def __init__(self):
    super().__init__()
    self.align = AlignModel.from_pretrained("kakaobrain/align-base")

  def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.align(x)

  def encode_image(self, img: torch.Tensor) -> torch.Tensor:
    return self.align.get_image_features(img)

  def encode_text(self, text: BatchEncoding) -> torch.Tensor:
    return self.align.get_text_features(**text)


def align_preprocessor_with_memory_fix(img) -> torch.Tensor:
  processor = AlignProcessor.from_pretrained("kakaobrain/align-base")
  with torch.no_grad():
    processed = processor(images=img, return_tensors="pt").to(device).pixel_values.squeeze(0)
  del processor
  return processed


if MODEL == 'CLIP':
  model, preprocess = clip.load("ViT-B/32", device=device)
  tokenize = clip.tokenize
elif MODEL == 'AlexNet':
  model = AlexNetEmbedder().to(device)
  preprocess = transforms.Compose([
      transforms.Resize(256),
      transforms.CenterCrop(224),
      transforms.ToTensor(),
      transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
  ])
elif MODEL[:3] == 'ViT':
  model = VisionTransformer(MODEL[4:] == 'pooling').to(device)
  preprocess = vit_preprocessor_with_memory_fix
elif MODEL == 'ALIGN':
  model = Align().to(device)
  preprocess = align_preprocessor_with_memory_fix
  tokenizer = AutoTokenizer.from_pretrained("kakaobrain/align-base")
  tokenize = lambda s: tokenizer([s], padding=True, return_tensors="pt")
else:
  raise ValueError(f'Invalid model {MODEL} selected!')

model.eval()
print(model)

config.json:   0%|          | 0.00/5.25k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/690M [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/399 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Align(
  (align): AlignModel(
    (text_model): AlignTextModel(
      (embeddings): AlignTextEmbeddings(
        (word_embeddings): Embedding(30522, 768, padding_idx=0)
        (position_embeddings): Embedding(512, 768)
        (token_type_embeddings): Embedding(2, 768)
        (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
        (dropout): Dropout(p=0.1, inplace=False)
      )
      (encoder): AlignTextEncoder(
        (layer): ModuleList(
          (0-11): 12 x AlignTextLayer(
            (attention): AlignTextAttention(
              (self): AlignTextSelfAttention(
                (query): Linear(in_features=768, out_features=768, bias=True)
                (key): Linear(in_features=768, out_features=768, bias=True)
                (value): Linear(in_features=768, out_features=768, bias=True)
                (dropout): Dropout(p=0.1, inplace=False)
              )
              (output): AlignTextSelfOutput(
                (dense): Linear(in_features=768, out_

# Mounting storage

In [6]:
drive.mount('/content/drive')
!ls "{datasets_path}"

Mounted at /content/drive
AlexNet_cats-vs-dogs-large.pkl		   imagenet_one-of-each-class-except-cats-and-dogs
AlexNet_imagenet-subset.pkl		   indizes_clean_Anni.txt
AlexNet_traffic-signs.pkl		   jellyfish-classification
ALIGN_cats-vs-dogs-large.pkl		   mislabeled_instances_cats-vs-dogs.pkl
ALIGN_imagenet-subset.pkl		   note.txt
buffer_ALIGN_traffic-signs_upto-12288.pkl  text_dog_embeddings.pkl
buffer_ALIGN_traffic-signs_upto-8192.pkl   text_random_embeddings.pkl
cats-dogs-big_ids.pkl			   traffic-signs
cats-dogs-big.pkl			   train-small
cats-vs-dogs-large			   ViT-CLS_cats-vs-dogs-large.pkl
CLIP_cats-vs-dogs-large.pkl		   ViT-CLS_cats-vs-dogs-large-SUBSET.pkl
CLIP_imagenet-subset.pkl		   ViT-CLS_imagenet-subset.pkl
CLIP_traffic-signs.pkl			   ViT-CLS_traffic-signs.pkl
dog_wrong_2_12.txt			   ViT-pooling_cats-vs-dogs-large.pkl
dog_wrong.txt				   ViT-pooling_imagenet-subset.pkl
image_embeddings__cats-vs-dogs.pkl	   ViT-pooling_traffic-signs.pkl
image_embeddings__traffic-signs.pkl


# Define dataset loader

In [9]:
def load_dataset(folder_path, labels):

    # Checking if the provided folder path exists
    if not os.path.exists(folder_path):
        raise ValueError("Folder path does not exist.")

    images = {}
    for label in labels:
      images[label] = []

    # Looping through all files in the folder
    for i, filename in enumerate(glob.glob(folder_path + '**/*', recursive=True)):

      if i % 1000 == 0:
        print(i, 'files loaded')

      try:
        img = Image.open(filename).convert('RGB')
      except:
        continue

      label_found = False
      for label in labels:
        if label in '/'.join(filename.split('/')[-2:]):
          if label_found:
            raise ValueError(f"Label of {filename} is ambiguous.")
          label_found = True
          images[label].append(img)

      if not label_found:
        raise ValueError(f"No label for {filename} found.")

    print(i+1, 'files loaded')

    return images

def get_embeddings_dict_batchwise(folder_path, labels, model, preprocess, batch_size=64):

  # Checking if the provided folder path exists
  if not os.path.exists(folder_path):
    raise ValueError("Folder path does not exist.")

  image_embeddings = {}
  for label in labels:
    image_embeddings[label] = []

  images = {}
  for label in labels:
    images[label] = []

  # Looping through all files in the folder
  all_files = glob.glob(folder_path + '**/*', recursive=True)
  for n_instances_processed, filename in enumerate(all_files):

    if n_instances_processed <= START_AT_INDEX:
      continue

    try:
      img = Image.open(filename).convert('RGB')
    except:
      continue

    # Find label of the image
    label_found = False
    for label in labels:
      if label in '/'.join(filename.split('/')[-2:]):
        if label_found:
          raise ValueError(f"Label of {filename} is ambiguous.")
        label_found = True
        images[label].append(img)
    if not label_found:
      raise ValueError(f"No label for {filename} found.")

    # Get embeddings if already a batch is full
    if n_instances_processed % batch_size == 0 and n_instances_processed > 0 or n_instances_processed == len(all_files) - 1:
      for label in labels:
        if len(images[label]) == 0:
          continue
        with torch.no_grad():
          processed_images = torch.cat(([preprocess(img).unsqueeze(0) for img in images[label]]))
          image_embeddings[label].append(model.encode_image(processed_images.to(device)).cpu().type(torch.float))
          del processed_images
      images = {}
      for label in labels:
        images[label] = []
      print(n_instances_processed, 'loaded and encoded')

      if n_instances_processed > 0 and n_instances_processed % BUFFER_EACH == 0 or n_instances_processed == len(all_files) - 1:
        pickle_file = datasets_path + f'buffer_{MODEL}_{DATASET_NAME}_upto-{n_instances_processed}.pkl'
        with open(pickle_file, 'wb') as f:
          pickle.dump(image_embeddings, f)
        print(f'First {n_instances_processed} Embeddings buffered in', pickle_file)

  # Convert list of embeddings to tensor
  for label in labels:
    image_embeddings[label] = torch.cat((image_embeddings[label]))

  return image_embeddings

# Load and embed dataset

In [10]:
if USE_CACHED_EMBEDDINGS != '' and USE_CACHED_EMBEDDINGS.split('__')[0] != 'CREATE':

  # load embeddings of previous execution from pickle file
  pickle_file = datasets_path + USE_CACHED_EMBEDDINGS
  with open(pickle_file, 'rb') as f:
    image_embeddings = pickle.load(f)

  print('Embeddings loaded from', pickle_file)

else:

  if LOAD_AND_EMBED_DATASET_IN_BATCHES:

    # load and embed images in batches (to save GPU memory and especially RAM)
    image_embeddings = get_embeddings_dict_batchwise(dataset_path, LABELS, model, preprocess, batch_size=BATCH_SIZE)

  else:

    # load images
    images = load_dataset(dataset_path, LABELS)

    # embed images and text
    image_embeddings = {}
    for label in LABELS:
      with torch.no_grad():
        processed_images = torch.cat(([preprocess(img).unsqueeze(0) for img in images[label]])).to(device)
        image_embeddings[label] = model.encode_image(processed_images)
        del processed_images

  # move embeddings to cpu and convert to suitable datatype for further analysis
  for key in image_embeddings:
    image_embeddings[key] = image_embeddings[key].cpu().type(torch.float)

  # save embeddings in pickle file if desired (enabled to reload them later on)
  if USE_CACHED_EMBEDDINGS != '' and USE_CACHED_EMBEDDINGS.split('__')[0] == 'CREATE':
    pickle_filename = '__'.join(USE_CACHED_EMBEDDINGS.split('__')[1:])  # remove prefix 'CREATE__'
    pickle_file = datasets_path + pickle_filename
    with open(pickle_file, 'wb') as f:
      pickle.dump(image_embeddings, f)
    print('Embeddings stored in', pickle_file)

12544 loaded and encoded
12800 loaded and encoded
13056 loaded and encoded
13312 loaded and encoded
13417 loaded and encoded
First 13417 Embeddings buffered in /content/drive/My Drive/FM/datasets/buffer_ALIGN_traffic-signs_upto-13417.pkl


RuntimeError: ignored