# CLIP: Connecting text and images


"CLIP" or "Contrastive Language-Image Pretraining" is a powerful model developed by OpenAI that bridges the gap between natural language and image understanding. It can take both images and text as inputs and relate them in meaningful ways, allowing it to perform a variety of tasks such as zero-shot image classification, image search, and more.

The model is trained by learning a joint embedding space where images and their corresponding text descriptions are closely aligned. In this notebook, we will explore how CLIP works, how to use it for various tasks, and how to implement it.

## why CLIP
Traditional image classification models are typically limited to the categories they were trained on. CLIP, on the other hand, can recognize a wide variety of objects and concepts in images without being explicitly trained on specific tasks. This capability is achieved by learning from a massive dataset of image-text pairs gathered from the internet. As a result, CLIP can generalize to many tasks without needing further fine-tuning.

Some key use cases for CLIP include:

- Zero-shot classification: Classify images based on new categories without additional training.
- Image search: Find images related to specific text descriptions.
- Text-to-image mapping: Generate embeddings for both images and text, enabling cross-modal understanding.

## Basics

The raw product of CLIP is a shared representation(embedding) between two modalities (text and images) by training on a large dataset of image-text pairs.

1. Input Data: CLIP is trained on a large set of image-text pairs. Each image is accompanied by a textual description (e.g., a picture of a dog and the text "a dog sitting in a park").

2. Dual Encoder Architecture:
  - Image Encoder: CLIP uses a Vision Transformer (ViT) or a ResNet to process the images and generate an embedding vector for each image.

  - Text Encoder: A Transformer model is used to process the text descriptions and generate an embedding vector for each description.

 - Image of the output of dual encoders:
![dual encoders output](https://images.ctfassets.net/kftzwdyauwt9/fbc4f633-9ad4-4dc2-3809c22df5e0/0bd2d5abf90d052731538613e4a42668/overview-a.svg)

3. Contrastive Loss: The optimization objective in CLIP is contrastive learning. After both the image and the text are passed through their respective encoders to produce embeddings, CLIP uses a contrastive loss that encourages the image and its matching text to have similar embeddings, while mismatched pairs (e.g., a dog image and "a cat sitting on a tree") are distinguished in the embedding space.
  
  This similarity is usually formalised as a distance in the embedding space that is closer to zero when embedded elements are more similar.

4. Joint Embedding Space: After training, CLIP learns a joint embedding space where related images and text are close together, and unrelated ones are far apart. This allows CLIP to perform tasks as:

  - Zero-Shot Classification: Given a new category (e.g., "a cat"), CLIP can classify images by computing the similarity between the image embeddings and the text embedding of the label.
  - Text-Image Similarity: CLIP can rank images by their similarity to a textual description or rank text by its similarity to an image.


## Pretrained Model
Lets First take a look at the pretrained implementation of CLIP from OpenAI.

In [None]:
# Install CLIP library
!pip install git+https://github.com/openai/CLIP.git
!pip install torch torchvision

In [None]:
# Import necessary libraries
import torch
import clip
from PIL import Image
import requests
from io import BytesIO

In [None]:
# Load the model and the preprocessing function
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

Let's create a simple function that will accept an image URL and a list of text descriptions. The function will then calculate the similarity between the image and each text description.

In [None]:
# Function to process image and text, and compute similarity
def match_image_text(image_url, text_descriptions):
    # Load and preprocess the image
    response = requests.get(image_url)
    img = Image.open(BytesIO(response.content))
    image = preprocess(img).unsqueeze(0).to(device)

    # Tokenize and encode the text
    text = clip.tokenize(text_descriptions).to(device)

    # Run the image and the text through the model
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

    # Compute similarity
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarities = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    return similarities

In [None]:
# @title Interactive Section: Enter an image URL and text descriptions
# image_url = input("Enter an image URL: ")
# text_descriptions = input("Enter text descriptions (comma-separated): ").split(',')

# # Compute similarities
# similarities = match_image_text(image_url, text_descriptions)

# # Show results
# print(f"\nImage URL: {image_url}")
# print("Text Descriptions and Similarity Scores:")
# for i, desc in enumerate(text_descriptions):
#     print(f"Description: {desc.strip()} | Similarity: {similarities[0, i].item():.4f}")

This picture shows what an idea of happened in the above code:
![similarity scoring text based on image](https://images.ctfassets.net/kftzwdyauwt9/d9d46e4b-6d6a-4f9e-59a242ea1441/c7b386880f1af005fd02f159de7f4d00/overview-b.svg)

As you (hopefully) saw the model didnt need to be trained on a dataset of your provided text and image. Hence the term zero-shot prediction.

But how do `model.encode_image()` and `model.encode_text()` output the same embedding space for image and text. we will see that soon.


## Contrastive Loss

The contrastive loss encourages the model to bring the embeddings of matching image-text pairs "closer" together and push the embeddings of non-matching pairs "further apart".

For notions of "closer" and "furthur" in the embeddings showing a quanititative similarity between original data pairs we can use **cosine similarity** between the image and text embeddings.

The contrastive loss can be formalized using the softmax function applied over the cosine similarity between the image and text embeddings. Here’s the mathematical formula for the contrastive loss in CLIP:

$$
L = - \frac{1}{N} \sum_{i=1}^{N} \log \frac{\exp(\text{sim}(x_i,y_i))}{\sum_{j=1}^{N} \exp(\text{sim}(x_j,y_j))}
$$
Where:
- 𝑁 is the number of image-text pairs in the batch.
- $x_i$ is the image embedding for the 𝑖-th image.
- $y_i$ is the text embedding for the corresponding 𝑖-th text.
- $\text{sim}(𝑥_𝑖,𝑦_𝑗)$ is the similarity (usually cosine similarity) between the image embedding $x_i$ and the text embedding $y_j$.

The loss function penalizes when the similarity between matching pairs is low or when mismatching pairs have a high similarity.

Having seen the contrastive loss function, now we can see the training process on $N$ sample image-text pairs with below psudoe-code from the paper:

```
# image_encoder - ResNet or Vision Transformer
# text_encoder - CBOW or Text Transformer
# I[n, h, w, c] - minibatch of aligned images
# T[n, l] - minibatch of aligned texts
# W_i[d_i, d_e] - learned proj of image to embed
# W_t[d_t, d_e] - learned proj of text to embed
# t - learned temperature parameter

# extract feature representations of each modality
I_f = image_encoder(I) #[n, d_i]
T_f = text_encoder(T) #[n, d_t]

# joint multimodal embedding [n, d_e]
I_e = l2_normalize(np.dot(I_f, W_i), axis=1)
T_e = l2_normalize(np.dot(T_f, W_t), axis=1)

# scaled pairwise cosine similarities [n, n]
logits = np.dot(I_e, T_e.T) * np.exp(t)

# symmetric loss function
labels = np.arange(n)
loss_i = cross_entropy_loss(logits, labels, axis=0)
loss_t = cross_entropy_loss(logits, labels, axis=1)
loss = (loss_i + loss_t)/2
```

The logits created is the matrix or table u saw in the first picture, where on the diagonal is the cosine similarity of matching pairs and the off diagonal elements are missmatched pairs similarities.


## Visualise Embeddings

We’ll now use t-SNE to reduce the 512-dimensional embeddings from CLIP to 2D and 3D and visualize the relationship between images and their corresponding text descriptions. t-SNE helps in visualizing how similar or dissimilar image and text embeddings are in the shared embedding space.

visualisation code with t-SNE and plotly.
try adjusting the perplexity parameter.

In [None]:
# !pip install transformers torchvision plotly scikit-learn

In [None]:
# import torch
# from transformers import CLIPProcessor, CLIPModel

# # Load the pretrained CLIP model and processor
# device = "cuda" if torch.cuda.is_available() else "cpu"
# clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
# clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
import plotly.express as px
import torch
from sklearn.manifold import TSNE
import requests
from PIL import Image
from io import BytesIO

# Function to fetch and preprocess images
def preprocess_images(image_urls):
    images = []
    for url in image_urls:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content))
        image = preprocess(img).unsqueeze(0).to(device)
        images.append(image)
    return torch.cat(images)

# Function to extract embeddings for images and texts
def extract_embeddings(images, text_descriptions):
    # encode images
    with torch.no_grad():
        image_features = model.encode_image(images)

    # Tokenize and encode texts
    text = clip.tokenize(text_descriptions).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text)

    return image_features, text_features

# Function to reduce embeddings to 2D or 3D using t-SNE
def reduce_with_tsne(embeddings, dim = 2):
    tsne = TSNE(n_components=dim, perplexity=30, learning_rate=200, n_iter=1000)
    return tsne.fit_transform(embeddings.cpu())

# Function to visualize image and text embeddings using Plotly (2D)
def visualize_embeddings_plotly(images, text_descriptions, image_labels =None, dim = 2):
    image_features, text_features = extract_embeddings(images, text_descriptions)

    # Normalize embeddings
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Combine and reduce to 2D or 3D using t-SNE
    all_embeddings = torch.cat([image_features, text_features], dim=0)
    reduced_embeddings = reduce_with_tsne(all_embeddings, dim = dim)

    # Create labels for plotting
    if image_labels == None:
      labels = ["Image " + str(i+1) for i in range(len(images))] + text_descriptions
    else:
      labels = ["Image " + text_descriptions[image_labels[i]] for i in range(len(image_labels))] + text_descriptions

    # Create a DataFrame for Plotly
    import pandas as pd
    if dim == 2:
      df = pd.DataFrame(reduced_embeddings, columns=["x", "y"])
    elif dim == 3:
      df = pd.DataFrame(reduced_embeddings, columns=["x", "y", "z"])
    else:
      raise ValueError("Invalid dimension. Must be 2 or 3.")

    df["label"] = labels
    df["type"] = ["Image"] * len(images) + ["Text"] * len(text_descriptions)

    # # Create an interactive 2D scatter plot using Plotly
    if dim == 2:
      fig = px.scatter(df, x="x", y="y", color="type", text="label", title="Interactive 2D t-SNE Visualization")
      fig.update_traces(textposition='top center')
    else:
      fig = px.scatter_3d(df, x="x", y="y", z="z", color="type", text="label", title="Interactive 3D t-SNE Visualization")
      fig.update_traces(marker=dict(size=5), textposition='top center')

    fig.show()

visulise your own input images and text. make sure to set perplexity lower than the number of samples for this. dont forget to increase it again for next part 😀.

In [None]:
image_urls = [
    "https://t4.ftcdn.net/jpg/00/97/58/97/360_F_97589769_t45CqXyzjz0KXwoBZT9PRaWGHRk5hQqQ.jpg",   # Cat image
    "https://cdn.pixabay.com/photo/2023/08/18/15/02/dog-8198719_640.jpg", # Dog image
    "https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcRI2RLOBO8DYvk8aAUNEs6DJzCJzlgHT7HfAg&s" # Car image
]
text_descriptions = ["a cat", "a dog", "a car"]

images = preprocess_images(image_urls)

# Visualize embeddings interactively with Plotly
# visualize_embeddings_plotly(images, text_descriptions)

In [None]:
# @title CIFAR10 dataset
import torch
from torch.utils.data import DataLoader

from torchvision import transforms
from torchvision.datasets.cifar import CIFAR10

torch.manual_seed(0)

# Load the CIFAR-10 training dataset
dataset = CIFAR10(root='./data', train=True, download=True)


visulise on CIFAR10

In [None]:
import random

classes = ['a photo of a plane', 'a photo of a car', 'a photo of abird', 'a photo of a cat',
           'a photo of a deer', 'a photo of a dog', 'a photo of a frog', 'a photo of a horse', 'a photo of a ship', 'a photo of a truck']

random_indices = random.sample(range(len(dataset)), 200)
preprocessed_images = torch.stack([preprocess(dataset[i][0]).to(device) for i in random_indices])

image_labels = [dataset[i][1] for i in random_indices]  # Extract corresponding labels

# Map the numeric labels to their corresponding class names
text_descriptions = [classes[label] for label in image_labels]

# random_lable_indices = random.choices(range(len(classes)), k=64)


# image_labels = [classes[i] for i in random_lable_indices]  # Extract corresponding labels

visualize_embeddings_plotly(preprocessed_images,text_descriptions, image_labels)

## Zero-shot Classification

Zero-shot classification is the ability to classify images into categories without having explicitly trained the model on those specific categories. CLIP enables this by understanding images and text in a shared embedding space. This means that once the model has learned general concepts through contrastive learning, it can generalize to entirely new categories just by providing textual labels.

This is one of CLIP's most remarkable capabilities—performing tasks without being explicitly trained for them.

In a traditional model, you would need to fine-tune the model for specific categories. In contrast, CLIP does this out-of-the-box. All you need to do is provide some candidate class names as text and let CLIP predict which one is the most similar to a given image.

- Image Representation: The input image is passed through the image encoder to get its embedding.

- Text Representation: Each of the class labels is passed through the text encoder to get their embeddings.

- Similarity Calculation: CLIP computes the cosine similarity between the image embedding and each text embedding.

- Prediction: The class label with the highest similarity score is chosen as the predicted label.

In [None]:
# Zero-shot classification function with CLIP
def zero_shot_classification(image_url, class_labels):
    # Load and preprocess the image
    response = requests.get(image_url)
    img = Image.open(BytesIO(response.content))
    image = preprocess(img).unsqueeze(0).to(device)

    # Tokenize and encode the class labels (text descriptions)
    text = clip.tokenize(class_labels).to(device)

    # Compute image and text embeddings
    with torch.no_grad():
        image_features = model.encode_image(image)
        text_features = model.encode_text(text)

    # Compute cosine similarity between image and text embeddings
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)
    similarity_scores = (100.0 * image_features @ text_features.T).softmax(dim=-1)

    # Output the most likely class
    class_probabilities = similarity_scores[0]
    best_idx = class_probabilities.argmax().item()

    return class_labels[best_idx], class_probabilities[best_idx].item()

In [None]:
# # Try it out with a sample image and custom categories
# image_url = input("Enter the URL of an image: ")

# # Define class labels (text descriptions)
# class_labels = ["a dog", "a cat", "a car", "a person", "a bird"]

# # Perform zero-shot classification
# predicted_class, probability = zero_shot_classification(image_url, class_labels)

# # Display results
# print(f"\nPredicted Class: {predicted_class}")
# print(f"Confidence: {probability:.4f}")

In [None]:
# Example: Classifying an image of a sports event
# sports_categories = ["soccer", "basketball", "tennis", "swimming", "cycling"]

# # Use an image of a sports event
# image_url = "https://static.vecteezy.com/system/resources/thumbnails/027/829/023/small_2x/close-up-of-many-soccer-players-kicking-a-football-on-a-field-competition-scene-created-with-generative-ai-technology-free-photo.jpg"

# # Perform zero-shot classification
# predicted_class, probability = zero_shot_classification(image_url, sports_categories)

# # Display results
# print(f"\nPredicted Class: {predicted_class}")
# print(f"Confidence: {probability:.4f}")

# Some Analysis reported by the paper:

### problems addressed by CLIP:

CLIP addresses several key challenges in the traditional deep learning approach to computer vision:

Costly datasets: Traditional vision models require large, manually labeled datasets like ImageNet, which is expensive to create. CLIP avoids this by learning from publicly available text-image pairs, reducing the need for costly, labeled data.

Limited adaptability: Standard models like those trained on ImageNet are restricted to predefined tasks (e.g., 1000 categories). CLIP, however, can adapt to various visual tasks without additional training, simply by providing relevant text prompts for the task's concepts.

Poor real-world performance: Vision models often perform well on benchmarks but struggle in real-world applications due to overfitting to benchmark data. CLIP performs more robustly in real-world settings, as it doesn't require training on specific benchmark datasets, making its performance more generalizable. Testing has shown CLIP's performance remains consistent across multiple datasets, unlike models that "study" for benchmarks.

By not directly optimizing for the benchmark, CLIP becomes much more representative: CLIPs system closes this “robustness gap” by up to 75% while matching the performance of the original ResNet-507 on ImageNet(opens in a new window) zero-shot without using any of the original 1.28M labeled examples.

![comarison with resnet on imagenet](https://blog.lancedb.com/content/images/2024/07/Untitled.png)

Although both models have the same accuracy on the ImageNet test set, CLIP’s performance is much more representative of how it will fare on datasets that measure accuracy in different, non-ImageNet settings. For instance, ObjectNet checks a model’s ability to recognize objects in many different poses and with many different backgrounds inside homes while ImageNet Rendition and ImageNet Sketch check a model’s ability to recognize more abstract depictions of objects.

### Key Takeaways:

1. **CLIP is highly efficient**:
   - CLIP trains on highly varied and noisy data in a zero-shot manner, similar to GPT-2 and GPT-3, but required significant compute to achieve strong performance. To reduce compute costs, two key algorithmic choices were made:
     - **Contrastive objective**: Connecting text and images through a contrastive learning approach proved 4x to 10x more efficient than image-to-text methods.
     - **Vision Transformer (ViT)**: Adopting ViT resulted in a further 3x gain in efficiency over traditional ResNet models.
   - With these optimizations, the best CLIP model was trained on 256 GPUs for 2 weeks, comparable to other large-scale models.

2. **CLIP is flexible and general**:
   - CLIP learns a wide range of visual concepts from natural language, making it more flexible than models trained on datasets like ImageNet.
   - CLIP demonstrated strong **zero-shot performance** across 30+ datasets, covering tasks like fine-grained object classification, geo-localization, action recognition, and even OCR (optical character recognition), which traditional models struggle with.
   - In a linear probe evaluation, the best CLIP model outperformed the top ImageNet model (Noisy Student EfficientNet-L2) on 20 out of 26 transfer datasets. This underscores CLIP’s generalization capabilities across different tasks.

# Linear Probe
A linear probe is a method used in representation learning to evaluate the quality of learned representations. In this context, the linear probe involves freezing the pretrained model (such as CLIP) and training a simple linear classifier on top of its output features to solve a specific task (e.g., classification).

- Evaluation of Pretrained Representations: Linear probing helps assess how much information is retained in the representations learned by the CLIP model.
- Efficient Transfer Learning: Instead of fine-tuning the whole CLIP model on a new dataset, a linear classifier is trained on top of the frozen CLIP features, saving compute time and resources.
- Generalization: If a linear probe performs well, it indicates that the representations are rich enough to solve tasks beyond what the model was trained on (i.e., zero-shot generalization).

Lets try this

In [None]:
!pip install transformers torchvision

In [None]:
import torch
import torchvision
import torchvision.transforms as T
from transformers import CLIPProcessor, CLIPModel

# Load the pretrained CLIP model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to(device)
clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

In [None]:
# @title prepare dataset

# Load the CIFAR-10 dataset
transform = T.Compose([
    T.Resize((224, 224)),  # Resize to CLIP's input size
    T.ToTensor()
])

train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
subset_indices = torch.randperm(len(train_dataset))[:500]  # Take only 500 samples for faster processing
subset_dataset = torch.utils.data.Subset(train_dataset, subset_indices)
train_loader = torch.utils.data.DataLoader(subset_dataset, batch_size=64, shuffle=True)

test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform, download=True)
subset_indices_test = torch.randperm(len(test_dataset))[:200]  # Take only 200 samples for faster processing
subset_dataset_test = torch.utils.data.Subset(test_dataset, subset_indices_test)
test_loader = torch.utils.data.DataLoader(subset_dataset_test, batch_size=64, shuffle=False)

# train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=64, shuffle=True)
# test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=64, shuffle=False)

Now we extract image features using CLIP's visual encoder. These features will serve as inputs for the linear classifier.

In [None]:
def extract_clip_features(loader, model):
    all_features = []
    all_labels = []

    with torch.no_grad():
        for images, labels in loader:
            images = images.to(device)
            # Process images to get CLIP features
            features = model.get_image_features(images)
            all_features.append(features.cpu())
            all_labels.append(labels)

            print(f"Extracted features for batch")

    all_features = torch.cat(all_features)
    all_labels = torch.cat(all_labels)

    return all_features, all_labels

# Extract features for training and test sets
train_features, train_labels = extract_clip_features(train_loader, clip_model)
test_features, test_labels = extract_clip_features(test_loader, clip_model)

In [None]:
# Define a simple linear classifier (logistic regression)
class LinearClassifier(torch.nn.Module):
    def __init__(self, input_dim, num_classes):
        super(LinearClassifier, self).__init__()
        self.fc = torch.nn.Linear(input_dim, num_classes)

    def forward(self, x):
        return self.fc(x)

# Initialize the classifier
input_dim = train_features.shape[1]  # CLIP embedding dimension
num_classes = 10  # CIFAR-10 has 10 classes
classifier = LinearClassifier(input_dim, num_classes).to(device)

# Define loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(classifier.parameters(), lr=0.001)

# Training loop
def train_classifier(train_features, train_labels, model, criterion, optimizer, epochs=10):
    for epoch in range(epochs):
        model.train()
        optimizer.zero_grad()

        outputs = model(train_features.to(device))
        loss = criterion(outputs, train_labels.to(device))
        loss.backward()
        optimizer.step()

        print(f'Epoch {epoch+1}/{epochs}, Loss: {loss.item():.4f}')

# Train the classifier
train_classifier(train_features, train_labels, classifier, criterion, optimizer)

In [None]:
# Evaluation
def evaluate_classifier(test_features, test_labels, model):
    model.eval()
    with torch.no_grad():
        outputs = model(test_features.to(device))
        _, predicted = torch.max(outputs, 1)
        accuracy = (predicted == test_labels.to(device)).float().mean().item()
    return accuracy

accuracy = evaluate_classifier(test_features, test_labels, classifier)
print(f'Test Accuracy: {accuracy * 100:.2f}%')

## visualising after linear probe

## Text-Image Similarity Search with CLIP

Text-image similarity search allows us to search for images that are most closely aligned with a given textual description, or vice versa, by comparing the embeddings in a shared space.

Let's create an example where, given a text description, CLIP retrieves the best matching image from a set of images. This will demonstrate CLIP's text-to-image retrieval capability.

In this section, we will:

1. Provide a list of image URLs.
2. Provide a text description as a query.
3. CLIP will find the image that best matches the given text description based on cosine similarity.


In [None]:
import requests
from PIL import Image
from io import BytesIO
import torch

# Function to fetch and preprocess images
def preprocess_images(image_urls):
    images = []
    for url in image_urls:
        response = requests.get(url)
        img = Image.open(BytesIO(response.content))
        image = preprocess(img).unsqueeze(0).to(device)
        images.append(image)
    return torch.cat(images)

# Function for text-to-image search
def text_to_image_search(text_query, image_urls):
    # Preprocess the images
    images = preprocess_images(image_urls)

    # Encode the images
    with torch.no_grad():
        image_features = model.encode_image(images)

    # Tokenize and encode the text query
    text = clip.tokenize([text_query]).to(device)
    with torch.no_grad():
        text_features = model.encode_text(text)

    # Normalize the embeddings
    image_features /= image_features.norm(dim=-1, keepdim=True)
    text_features /= text_features.norm(dim=-1, keepdim=True)

    # Compute similarity between text and each image
    similarities = (100.0 * image_features @ text_features.T).squeeze(1)

    # Find the best matching image
    best_match_idx = similarities.argmax().item()
    best_similarity_score = similarities[best_match_idx].item()

    return image_urls[best_match_idx], best_similarity_score


In [None]:
# Example: Set of image URLs to search from
image_urls = [
    "https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg",    # Cat image
    "https://upload.wikimedia.org/wikipedia/commons/6/62/Dog_face.png",  # Dog image
    "https://upload.wikimedia.org/wikipedia/commons/9/9a/Car-Toyota.jpg",# Car image
]

# Example text query
text_query = "a picture of a dog"

# Perform text-to-image search
best_image_url, similarity_score = text_to_image_search(text_query, image_urls)

# Show the result
print(f"Best Matching Image URL: {best_image_url}")
print(f"Similarity Score: {similarity_score:.4f}")

## Fine Tuning

While CLIP is very effective out-of-the-box, you may want to fine-tune it on a specific dataset for specialized tasks. Fine-tuning CLIP on a domain-specific dataset can help it learn the nuances of that domain, thereby improving performance on domain-specific tasks.

Fine-tuning involves continuing to train the model on a smaller, more specialized dataset (related to your domain) after its initial pretraining on a large general dataset. Fine-tuning can help CLIP:

Focus on specific visual-text relationships relevant to a given domain.
Adjust its embedding space for tasks with specific characteristics (e.g., medical images, satellite images).

Here’s a high-level outline of the steps you would take to fine-tune CLIP on a custom dataset:

- Dataset Preparation: You need a dataset of image-text pairs related to the task at hand. The dataset should have labeled images and corresponding textual descriptions.

- Modify CLIP for Fine-Tuning: During fine-tuning, it is often beneficial to freeze the early layers of the model that capture general features. We’ll only update the higher layers that are more specific to the task.

- Training Objective: The objective remains similar to CLIP’s original training—contrastive loss. We'll minimize the loss between matched image-text pairs while maximizing the loss between mismatched pairs.

- Training Loop: The training loop involves calculating contrastive loss between image and text embeddings (like in the original CLIP training) and updating the model’s weights using a smaller learning rate.

In [None]:
import torch
import clip
from torch.utils.data import DataLoader, Dataset
from PIL import Image
import requests
from io import BytesIO
import torch.optim as optim
import torch.nn.functional as F

prepare dataset and dataloaders

In [None]:
# Custom dataset class
class CustomImageTextDataset(Dataset):
    def __init__(self, image_urls, text_descriptions, preprocess):
        self.image_urls = image_urls
        self.text_descriptions = text_descriptions
        self.preprocess = preprocess

    def __len__(self):
        return len(self.image_urls)

    def __getitem__(self, idx):
        # Fetch image
        response = requests.get(self.image_urls[idx])
        img = Image.open(BytesIO(response.content))
        img = self.preprocess(img)

        # Fetch text
        text = clip.tokenize([self.text_descriptions[idx]])[0]

        return img, text

# Simulated image URLs and text descriptions
image_urls = [
    "https://upload.wikimedia.org/wikipedia/commons/3/3a/Cat03.jpg",
    "https://upload.wikimedia.org/wikipedia/commons/6/62/Dog_face.png",
    "https://upload.wikimedia.org/wikipedia/commons/9/9a/Car-Toyota.jpg"
]
text_descriptions = ["a cat", "a dog", "a car"]

# Initialize the dataset and dataloader
dataset = CustomImageTextDataset(image_urls, text_descriptions, preprocess)
dataloader = DataLoader(dataset, batch_size=2, shuffle=True)


In [None]:
# Load CLIP model and optimizer
model, preprocess = clip.load("ViT-B/32", device="cuda" if torch.cuda.is_available() else "cpu")

# Freeze the earlier layers if desired (optional)
for param in model.parameters():
    param.requires_grad = False

# Unfreeze the last few layers
for param in model.visual.transformer.resblocks[-1].parameters():
    param.requires_grad = True

# Define optimizer (fine-tuning specific parameters)
optimizer = optim.AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=1e-5)

# Fine-tuning loop
num_epochs = 5

def fine_tune_clip(dataloader, model, optimizer, num_epochs):
    model.train()  # Set model to training mode

    for epoch in range(num_epochs):
        running_loss = 0.0
        for images, texts in dataloader:
            images = images.to(device)
            texts = texts.to(device)

            # Zero the gradients
            optimizer.zero_grad()

            # Forward pass: get image and text features
            image_features = model.encode_image(images)
            text_features = model.encode_text(texts)

            # Normalize features
            image_features = F.normalize(image_features, p=2, dim=-1)
            text_features = F.normalize(text_features, p=2, dim=-1)

            # Compute contrastive loss
            logits_per_image = image_features @ text_features.T
            logits_per_text = text_features @ image_features.T

            # Targets are diagonal (correct image-text pairs)
            targets = torch.arange(len(images), device=images.device)

            # Compute contrastive loss in both directions
            loss_image = F.cross_entropy(logits_per_image, targets)
            loss_text = F.cross_entropy(logits_per_text, targets)
            loss = (loss_image + loss_text) / 2

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {running_loss/len(dataloader):.4f}")

In [None]:
# Perform fine-tuning
fine_tune_clip(dataloader, model, optimizer, num_epochs)

# References

- https://openai.com/index/clip/

- https://arxiv.org/abs/2103.00020

- ChatGPT helped as well ⚡