<a href="https://colab.research.google.com/github/TensorCruncher/animal-image-search/blob/main/embeddings.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Mount Drive, get image paths

In [None]:
import json
from pathlib import Path
from google.colab import drive
from google.colab import files

drive.mount('/content/drive')

In [None]:
root_dir = Path("/content/drive/MyDrive/animals")

image_paths = sorted(list(root_dir.glob("*/*.jpg")))

image_paths = [str(p) for p in image_paths]

In [None]:
image_paths[:10]

In [None]:
len(image_paths)

In [None]:
with open("image_paths.json", "w") as f:
  json.dump(image_paths, f)

In [None]:
files.download("image_paths.json")

# Create Embeddings

In [None]:
!pip install open-clip-torch torchvision -q

In [None]:
import torch
import open_clip
import numpy as np

from PIL import Image
from tqdm import tqdm

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
device

In [None]:
model, _, preprocess = open_clip.create_model_and_transforms('ViT-B-32', pretrained='laion2b_s34b_b79k')

In [None]:
from tqdm import tqdm

batch_size = 64
image_embeddings_list = []

model.to(device)
model.eval()

with torch.no_grad():
    for i in tqdm(range(0, len(image_paths), batch_size)):
        batch_paths = image_paths[i:i+batch_size]

        batch_tensors = []
        for p in batch_paths:
            img = Image.open(p).convert("RGB")
            tensor = preprocess(img).unsqueeze(0)
            batch_tensors.append(tensor)

        image_input = torch.cat(batch_tensors, dim=0).to(device)
        batch_embeddings = model.encode_image(image_input)
        batch_embeddings = batch_embeddings / batch_embeddings.norm(dim=-1, keepdim=True)

        image_embeddings_list.append(batch_embeddings.cpu())

image_embeddings = torch.cat(image_embeddings_list, dim=0)
image_embeddings_np = image_embeddings.numpy()


In [None]:
np.save("image_embeddings.npy", image_embeddings_np)
files.download("image_embeddings.npy")