# Enhancing Semantic Image Understanding with Fine-Tuned CLIP

## Libraries

In [14]:
import json
from PIL import Image

import torch
import torch.nn as nn
from torch.utils.data import DataLoader

import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader, Dataset, Subset
from torchvision.datasets import CIFAR10
from PIL import Image
import random

from transformers import CLIPProcessor, CLIPModel
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Model

In [None]:
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

## Dataset creation 

This dataset class facilitates the creation of pairs of images for training a model, particularly for tasks like contrastive learning or siamese networks, where pairs of similar and dissimilar samples are required. It's particularly useful for training models like CLIP, which benefit from learning to associate images based on their semantic content.

In [3]:
# Function to resize image to the required size
def resize_image(image, size):
    return image.resize(size, Image.LANCZOS)

class CIFAR10ClipDataset(Dataset):
    def __init__(self, cifar_dataset, num_samples=None, transform=None):
        self.cifar_dataset = cifar_dataset
        self.transform = transform
        self.num_samples = num_samples

    def __len__(self):
        return self.num_samples if self.num_samples else len(self.cifar_dataset)

    def __getitem__(self, idx):
        if self.num_samples:
            idx1 = random.randint(0, len(self.cifar_dataset) - 1)
            idx2 = random.randint(0, len(self.cifar_dataset) - 1)
        else:
            idx1 = idx
            idx2 = random.randint(0, len(self.cifar_dataset) - 1)

        img1, label1 = self.cifar_dataset[idx1]
        img2, label2 = self.cifar_dataset[idx2]
        
        # Apply transforms if provided
        if self.transform:
            img1 = self.transform(img1)
            img2 = self.transform(img2)

        # Determine if the images are from the same category
        if label1 == label2:
            return img1, img2, torch.tensor(1)  # Pair from the same category
        else:
            return img1, img2, torch.tensor(0)  # Pair from different categories

# Resize and normalize images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Resize images to 224x224
    transforms.ToTensor(),           # Convert images to PyTorch tensors
])

# Load CIFAR-10 dataset
cifar_dataset = CIFAR10(root='./data', train=True, download=True, transform=None)




Files already downloaded and verified


## Finetuning

This loss function effectively encourages the model to learn embeddings such that similar pairs have high cosine similarity values (close to 1) and dissimilar pairs have low cosine similarity values (close to -1). It's suitable for tasks like contrastive learning, where the model learns to distinguish between pairs of samples based on their semantic similarity.

The loss is computed as follows:

- For similar pairs (where labels is 1), the loss encourages the similarity value to be close to 1. Hence, (1 - similarities) penalizes low similarity scores.
- For dissimilar pairs (where labels is 0), the loss encourages the similarity value to be close to -1. Hence, (1 - similarities) penalizes high similarity scores.
- The loss is computed as the mean of these penalties across all pairs.

In [None]:
import torch.optim as optim
# Create DataLoader
cifar_clip_dataset = CIFAR10ClipDataset(cifar_dataset,1000, transform=transform)
batch_size = 128
dataloader = DataLoader(cifar_clip_dataset, batch_size=batch_size, shuffle=True)

# Custom loss function
class CustomLoss(nn.Module):
    def __init__(self):
        super(CustomLoss, self).__init__()

    def forward(self, emb1, emb2, labels):
        # Calculate cosine similarity between embeddings
        similarities = torch.nn.functional.cosine_similarity(emb1, emb2)
        
        # Calculate loss
        loss = torch.mean((1 - labels) * similarities + labels * (1 - similarities))
        return loss

# Define the loss function
criterion = CustomLoss()

# Define the optimizer
optimizer = optim.Adam(model.parameters(), lr=1e-6)

# Move model to appropriate device
model.to(device)

epochs = 150

# Training loop
for epoch in range(epochs):  # You can adjust the number of epochs
    running_loss = 0.0
    
    for img1, img2, labels in dataloader:
        # Move images to device
        img1, img2, labels = img1.to(device), img2.to(device), labels.to(device)
        
        # Generate image embeddings
        emb1 = model.get_image_features(img1)
        emb2 = model.get_image_features(img2)
        
        # Calculate loss
        loss = criterion(emb1, emb2, labels)
        
        # Zero gradients, backward pass, and optimization step
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
    
    # Print average loss for the epoch
    print(f"Epoch {epoch+1}/{epochs}, Loss: {running_loss/len(dataloader)}")

## Evaluation

In [None]:
import torch
from PIL import Image
from transformers import CLIPProcessor, CLIPModel
import matplotlib.pyplot as plt
import numpy as np

# Function to load and preprocess image
def resize_image(image_path, target_size=(224, 224)):
    # Open the image using PIL
    img = Image.open(image_path).convert("RGB")
    # Resize the image to the target size
    resized_img = img.resize(target_size, Image.Resampling.LANCZOS)
    return resized_img

# Define a function to load and preprocess images from file paths
def load_and_preprocess_images(image_paths):
    images = [resize_image(path) for path in image_paths]
    return images

# Load images
image_paths = ["/workspace/data/Untitled.jpg","/workspace/data/avion-800_1.png","/workspace/data/2021-best-cars-ford-mustang-hero-desktop.jpg","/workspace/data/kitty-cat-kitten-pet-45201.jpeg","/workspace/data/Puerto_rican-Paso-Fino-Horse-chestnut.jpg","/workspace/data/P-51_Mustang_edit1.jpg","/workspace/data/Comment-le-secteur-automobile-prepare-la-voiture-durable.png","/workspace/data/Asana3808_Dashboard_Standard.jpg"] # Replace with paths to your images

# Load CLIP model from Hugging Face Transformers
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Generate image embeddings using the original CLIP model
image_embeddings = []
for image_path in image_paths:
    images = load_and_preprocess_images([image_path])  # Pass image path as a list
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt", padding=True).to("cuda")
        outputs = clip_model.get_image_features(**inputs)
    image_embeddings.append(outputs)

# Generate image embeddings using the fine-tuned model (assuming `model` is your fine-tuned model)
fine_tuned_image_embeddings = []
for image_path in image_paths:
    images = load_and_preprocess_images([image_path])  # Pass image path as a list
    with torch.no_grad():
        inputs = processor(images=images, return_tensors="pt", padding=True).to("cuda")
        outputs = model.get_image_features(**inputs)
    fine_tuned_image_embeddings.append(outputs)

# Compute similarity scores
def compute_similarity_matrix(embeddings):
    num_images = len(embeddings)
    similarity_matrix = torch.zeros((num_images, num_images))
    
    for i, embedding1 in enumerate(embeddings):
        for j, embedding2 in enumerate(embeddings):
            similarity = torch.cosine_similarity(embedding1, embedding2).item()
            similarity_matrix[i, j] = similarity
    
    return similarity_matrix

original_similarities = compute_similarity_matrix(image_embeddings)
fine_tuned_similarities = compute_similarity_matrix(fine_tuned_image_embeddings)

# Function to plot the similarity matrix with images
def plot_similarity_matrix(similarity_matrix, image_paths, title):
    num_images = len(image_paths)
    fig, ax = plt.subplots(figsize=(15, 15))
    similarity = similarity_matrix.numpy()
    ax.imshow(similarity, vmin=0, vmax=1, cmap='viridis')
    
    ax.set_xticks(np.arange(num_images))
    ax.set_yticks(np.arange(num_images))
    
    # Plot images on top and left side
    for i in range(num_images):
        img_top = Image.open(image_paths[i])
        img_left = Image.open(image_paths[i])
        
        ax.imshow(img_top, extent=(i - 0.5, i + 0.5, -1.5, -0.5), origin="lower")
        ax.imshow(img_left, extent=(-1.5, -0.5, i - 0.5, i + 0.5), origin="lower")
    
    for x in range(similarity.shape[1]):
        for y in range(similarity.shape[0]):
            ax.text(x, y, f"{similarity[y, x]:.4f}", ha="center", va="center", size=12, color='white')
    
    for side in ["left", "top", "right", "bottom"]:
        ax.spines[side].set_visible(False)
    
    ax.set_xlim([-1.5, num_images - 0.5])
    ax.set_ylim([num_images - 0.5, -1.5])
    ax.set_title(title, size=20)
    
    plt.show()

# Plot and display the similarity matrices
plot_similarity_matrix(original_similarities, image_paths, "Original CLIP Model Similarities")
plot_similarity_matrix(fine_tuned_similarities, image_paths, "Fine-tuned CLIP Model Similarities")


See the results in the readme.

In [None]:
# Define CIFAR-10 classes
cifar10_classes = ['airplane', 'automobile', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck']

# Load CIFAR-10 dataset
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
])

test_dataset = CIFAR10(root='./data', train=False, download=True, transform=transform)

# Create a subset of the test dataset with only 1000 images
subset_indices = list(range(1000))
test_subset = Subset(test_dataset, subset_indices)
test_loader = DataLoader(test_subset, batch_size=64, shuffle=False)

# Load CLIP model from Hugging Face Transformers
clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").to("cuda")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

# Placeholder for the fine-tuned model (replace 'fine_tuned_model' with your actual fine-tuned model)
fine_tuned_model = model  # Replace with your actual fine-tuned model

def get_classwise_embeddings(model, data_loader, device):
    classwise_embeddings = {cls: [] for cls in range(10)}
    model.to(device)
    model.eval()
    
    with torch.no_grad():
        for images, labels in data_loader:
            images = images.to(device)
            # Convert images from [0, 1] range to [0, 255] range
            images = images * 255.0
            inputs = processor(images=images, return_tensors="pt", padding=True).to(device)
            outputs = model.get_image_features(**inputs)
            
            for output, label in zip(outputs, labels):
                classwise_embeddings[label.item()].append(output.cpu())
    
    # Compute mean embeddings for each class
    for cls in classwise_embeddings:
        classwise_embeddings[cls] = torch.stack(classwise_embeddings[cls]).mean(dim=0)
    
    return classwise_embeddings

# Get embeddings for both models
original_classwise_embeddings = get_classwise_embeddings(clip_model, test_loader, "cuda")
fine_tuned_classwise_embeddings = get_classwise_embeddings(fine_tuned_model, test_loader, "cuda")

def compute_mean_similarity_matrix(classwise_embeddings):
    num_classes = len(classwise_embeddings)
    similarity_matrix = torch.zeros((num_classes, num_classes))
    
    for i in range(num_classes):
        for j in range(num_classes):
            similarity = torch.cosine_similarity(classwise_embeddings[i], classwise_embeddings[j], dim=0).item()
            similarity_matrix[i, j] = similarity
    
    return similarity_matrix

original_similarity_matrix = compute_mean_similarity_matrix(original_classwise_embeddings)
fine_tuned_similarity_matrix = compute_mean_similarity_matrix(fine_tuned_classwise_embeddings)

# Function to plot the similarity matrix
def plot_similarity_matrix(similarity_matrix, class_names, title):
    fig, ax = plt.subplots(figsize=(10, 10))
    similarity = similarity_matrix.numpy()
    im = ax.imshow(similarity, vmin=0, vmax=1, cmap='viridis')
    
    ax.set_xticks(np.arange(len(class_names)))
    ax.set_yticks(np.arange(len(class_names)))
    ax.set_xticklabels(class_names, rotation=45, ha='right')
    ax.set_yticklabels(class_names)
    
    # Display similarity scores
    for i in range(len(class_names)):
        for j in range(len(class_names)):
            text = ax.text(j, i, f"{similarity[i, j]:.2f}", ha="center", va="center", color="white")
    
    fig.colorbar(im, ax=ax)
    ax.set_title(title, size=20)
    plt.show()

# Plot and display the similarity matrices
plot_similarity_matrix(original_similarity_matrix, cifar10_classes, "Original CLIP Model Mean Similarities")
plot_similarity_matrix(fine_tuned_similarity_matrix, cifar10_classes, "Fine-tuned CLIP Model Mean Similarities")
