In [None]:
!pip install datasets transformers torch

# Load dataset and models

In [None]:
from datasets import load_dataset
from transformers import CLIPProcessor, CLIPModel
import torch

ds = load_dataset("IGNF/FLAIR_1_osm_clip")

model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

Resolving data files:   0%|          | 0/54 [00:00<?, ?it/s]

Resolving data files:   0%|          | 0/54 [00:00<?, ?it/s]

Loading dataset shards:   0%|          | 0/51 [00:00<?, ?it/s]

2024-09-10 22:18:43.104857: E external/local_xla/xla/stream_executor/cuda/cuda_fft.cc:485] Unable to register cuFFT factory: Attempting to register factory for plugin cuFFT when one has already been registered
2024-09-10 22:18:43.237175: E external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:8454] Unable to register cuDNN factory: Attempting to register factory for plugin cuDNN when one has already been registered
2024-09-10 22:18:43.274197: E external/local_xla/xla/stream_executor/cuda/cuda_blas.cc:1452] Unable to register cuBLAS factory: Attempting to register factory for plugin cuBLAS when one has already been registered
2024-09-10 22:18:43.482276: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


In [3]:
model = model.to("cuda")

# Compute the embeddings

In [45]:
def clip_embed(images=None, text=None):
    if not images and not text:
         raise ValueError("Specify eiter text or image")
    
    inputs = processor(text=text, images=images, return_tensors="pt", padding=True).to("cuda")

    with torch.no_grad():
        if not images:
            return model.get_text_features(**inputs)
        elif not text:
            return model.get_image_features(**inputs)
        else:
            return model(**inputs)
    
    return outputs

def clip_embed_batch(images):
    inputs = processor(images=images, return_tensors="pt", padding=True)
    with torch.no_grad():
        outputs = model.get_image_features(**inputs)
    return outputs[0].numpy()


def add_clip_embeddings_batch(batch):
    images = batch['image']
    embeddings = clip_embed_batch(images)
    batch['clip_embedding'] = embeddings
    return batch

In [51]:
from tqdm.notebook import tqdm

embeddings = {}

# this should take around 20mn for 60k images
batch_size = 10
for i in tqdm(range(0, len(ds['train']), batch_size)):
    batch = ds['train'][i:i + batch_size]['image']
    embeds = clip_embed(batch)
    for j, embed in enumerate(embeds):
        embeddings[i+j] = embed.cpu().numpy()

  0%|          | 0/6172 [00:00<?, ?it/s]

In [None]:
len(embeddings)

61712

In [None]:
ds["train"].add_column("clip_embeddings",  [e[0] for e in embeddings])

# Now the dataset has a new column 'new_column'
print(ds)['train']

KeyboardInterrupt: 

# Save embeddings locally

# Build the vector database 

In [None]:
!pip install chromadb 

In [None]:
import chromadb
chroma_client = chromadb.Client()

# Create ChromaDB

In [None]:
collection = chroma_client.get_or_create_collection(name="FLAIR_CLIP")

from math import ceil

batch_size = 1000
num_batches = ceil(len(embeddings) / batch_size)

# Loop through each batch
for batch_idx in range(num_batches):
    start_idx = batch_idx * batch_size
    end_idx = min((batch_idx + 1) * batch_size, len(embeddings))

    # Create a batch of embeddings, documents, and ids
    batch_embeddings = [e[0].tolist() for e in embeddings[start_idx:end_idx]]
    batch_documents = list(range(start_idx, end_idx))
    batch_ids = [f"image_{i}" for i in range(start_idx, end_idx)]

    # Add the batch to the collection
    collection.add(
        embeddings=batch_embeddings,
        documents=batch_documents,
        metadatas=None,  # Assuming you still don't need metadata
        ids=batch_ids
    )



In [None]:
search_string = "treehouse"

query_embedding = clip_embed(images=None, text=search_string)[0].tolist()

results = collection.query(
    query_embeddings=[query_embedding],
    n_results=5
)
print(results)

{'ids': [['image_4442', 'image_4673', 'image_4525', 'image_4131', 'image_4134']], 'distances': [[147.169677734375, 147.169677734375, 147.169677734375, 147.169677734375, 147.169677734375]], 'metadatas': [[None, None, None, None, None]], 'embeddings': None, 'documents': [['4442', '4673', '4525', '4131', '4134']], 'uris': None, 'data': None, 'included': ['metadatas', 'documents', 'distances']}
