# CLIP Implementation and Fine-tuning

This notebook demonstrates how to use and fine-tune CLIP (Contrastive Language-Image Pre-training) for multimodal tasks.

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import clip
from PIL import Image
import numpy as np
import matplotlib.pyplot as plt
from transformers import CLIPProcessor, CLIPModel
import requests
from io import BytesIO

## 1. Loading Pre-trained CLIP Model

In [None]:
# Load CLIP model and processor
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")

model.to(device)
print(f"Model loaded on {device}")

## 2. Basic Image-Text Similarity

In [None]:
def compute_similarity(image, texts):
    """Compute similarity between an image and multiple text descriptions"""
    
    # Process inputs
    inputs = processor(text=texts, images=image, return_tensors="pt", padding=True)
    inputs = {k: v.to(device) for k, v in inputs.items()}
    
    # Get embeddings
    with torch.no_grad():
        outputs = model(**inputs)
        logits_per_image = outputs.logits_per_image
        probs = logits_per_image.softmax(dim=-1)
    
    return probs.cpu().numpy()[0]

# Example usage
# Load a sample image (you can replace with your own)
url = "https://upload.wikimedia.org/wikipedia/commons/thumb/7/7d/A_golden_retriever_sitting_in_the_snow.jpg/320px-A_golden_retriever_sitting_in_the_snow.jpg"
response = requests.get(url)
image = Image.open(BytesIO(response.content))

# Define text descriptions
texts = [
    "a photo of a dog",
    "a photo of a cat",
    "a photo of a car",
    "a golden retriever in snow"
]

# Compute similarities
similarities = compute_similarity(image, texts)

# Display results
plt.figure(figsize=(12, 4))
plt.subplot(1, 2, 1)
plt.imshow(image)
plt.axis('off')
plt.title('Input Image')

plt.subplot(1, 2, 2)
plt.bar(range(len(texts)), similarities)
plt.xticks(range(len(texts)), texts, rotation=45, ha='right')
plt.ylabel('Similarity Score')
plt.title('Text-Image Similarities')
plt.tight_layout()
plt.show()

for text, sim in zip(texts, similarities):
    print(f"{text}: {sim:.3f}")

## 3. Zero-Shot Image Classification

In [None]:
def zero_shot_classification(image, class_names):
    """Perform zero-shot image classification using CLIP"""
    
    # Create text prompts
    text_prompts = [f"a photo of a {class_name}" for class_name in class_names]
    
    # Compute similarities
    similarities = compute_similarity(image, text_prompts)
    
    # Get prediction
    predicted_idx = np.argmax(similarities)
    predicted_class = class_names[predicted_idx]
    confidence = similarities[predicted_idx]
    
    return predicted_class, confidence, similarities

# Example: Animal classification
animal_classes = ['dog', 'cat', 'bird', 'fish', 'horse', 'cow', 'elephant']

predicted_class, confidence, all_similarities = zero_shot_classification(image, animal_classes)

print(f"Predicted class: {predicted_class}")
print(f"Confidence: {confidence:.3f}")
print("\nAll class probabilities:")
for class_name, sim in zip(animal_classes, all_similarities):
    print(f"{class_name}: {sim:.3f}")

## 4. Custom Dataset for Fine-tuning

In [None]:
class CustomMultimodalDataset(Dataset):
    def __init__(self, image_paths, captions, processor, max_length=77):
        self.image_paths = image_paths
        self.captions = captions
        self.processor = processor
        self.max_length = max_length
    
    def __len__(self):
        return len(self.image_paths)
    
    def __getitem__(self, idx):
        # Load image
        image = Image.open(self.image_paths[idx]).convert('RGB')
        caption = self.captions[idx]
        
        # Process image and text
        inputs = self.processor(
            text=caption,
            images=image,
            return_tensors="pt",
            padding="max_length",
            max_length=self.max_length,
            truncation=True
        )
        
        # Remove batch dimension
        inputs = {k: v.squeeze(0) for k, v in inputs.items()}
        
        return inputs

# Example dataset creation (replace with your data)
# image_paths = ['path/to/image1.jpg', 'path/to/image2.jpg', ...]
# captions = ['caption for image 1', 'caption for image 2', ...]
# dataset = CustomMultimodalDataset(image_paths, captions, processor)

## 5. Fine-tuning CLIP

In [None]:
def contrastive_loss(logits, temperature=0.07):
    """Compute contrastive loss for CLIP training"""
    # Normalize logits
    logits = logits / temperature
    
    # Create labels (diagonal matrix)
    batch_size = logits.shape[0]
    labels = torch.arange(batch_size, device=logits.device)
    
    # Compute cross-entropy loss for both directions
    loss_i2t = nn.CrossEntropyLoss()(logits, labels)
    loss_t2i = nn.CrossEntropyLoss()(logits.T, labels)
    
    return (loss_i2t + loss_t2i) / 2

def train_clip(model, dataloader, num_epochs=5, learning_rate=1e-5):
    """Fine-tune CLIP model"""
    
    optimizer = optim.AdamW(model.parameters(), lr=learning_rate)
    model.train()
    
    for epoch in range(num_epochs):
        total_loss = 0
        num_batches = 0
        
        for batch in dataloader:
            # Move batch to device
            batch = {k: v.to(device) for k, v in batch.items()}
            
            # Forward pass
            outputs = model(**batch)
            
            # Compute loss
            loss = contrastive_loss(outputs.logits_per_image)
            
            # Backward pass
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()
            
            total_loss += loss.item()
            num_batches += 1
        
        avg_loss = total_loss / num_batches
        print(f"Epoch {epoch+1}/{num_epochs}, Average Loss: {avg_loss:.4f}")
    
    return model

# Example fine-tuning (uncomment when you have data)
# dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
# fine_tuned_model = train_clip(model, dataloader)

## 6. Image Retrieval with Text Queries

In [None]:
def image_retrieval(query_text, image_database, top_k=5):
    """Retrieve most relevant images for a text query"""
    
    similarities = []
    
    for image_path in image_database:
        # Load and process image
        image = Image.open(image_path).convert('RGB')
        
        # Compute similarity
        sim = compute_similarity(image, [query_text])[0]
        similarities.append((image_path, sim))
    
    # Sort by similarity
    similarities.sort(key=lambda x: x[1], reverse=True)
    
    return similarities[:top_k]

# Example usage (replace with your image database)
# image_database = ['path/to/image1.jpg', 'path/to/image2.jpg', ...]
# query = "a dog playing in the park"
# results = image_retrieval(query, image_database)
# 
# print(f"Top results for query: '{query}'")
# for i, (image_path, similarity) in enumerate(results):
#     print(f"{i+1}. {image_path}: {similarity:.3f}")

## 7. Multimodal Embeddings Visualization

In [None]:
from sklearn.decomposition import PCA
from sklearn.manifold import TSNE

def extract_embeddings(images, texts):
    """Extract CLIP embeddings for images and texts"""
    
    image_embeddings = []
    text_embeddings = []
    
    model.eval()
    with torch.no_grad():
        # Process images
        for image in images:
            inputs = processor(images=image, return_tensors="pt")
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            image_features = model.get_image_features(**inputs)
            image_embeddings.append(image_features.cpu().numpy())
        
        # Process texts
        for text in texts:
            inputs = processor(text=text, return_tensors="pt", padding=True)
            inputs = {k: v.to(device) for k, v in inputs.items()}
            
            text_features = model.get_text_features(**inputs)
            text_embeddings.append(text_features.cpu().numpy())
    
    return np.vstack(image_embeddings), np.vstack(text_embeddings)

def visualize_embeddings(image_embeddings, text_embeddings, image_labels, text_labels):
    """Visualize embeddings using t-SNE"""
    
    # Combine embeddings
    all_embeddings = np.vstack([image_embeddings, text_embeddings])
    
    # Apply t-SNE
    tsne = TSNE(n_components=2, random_state=42)
    embeddings_2d = tsne.fit_transform(all_embeddings)
    
    # Split back
    n_images = len(image_embeddings)
    image_2d = embeddings_2d[:n_images]
    text_2d = embeddings_2d[n_images:]
    
    # Plot
    plt.figure(figsize=(12, 8))
    
    # Plot images
    plt.scatter(image_2d[:, 0], image_2d[:, 1], c='red', marker='o', s=100, alpha=0.7, label='Images')
    for i, label in enumerate(image_labels):
        plt.annotate(label, (image_2d[i, 0], image_2d[i, 1]), xytext=(5, 5), textcoords='offset points')
    
    # Plot texts
    plt.scatter(text_2d[:, 0], text_2d[:, 1], c='blue', marker='s', s=100, alpha=0.7, label='Texts')
    for i, label in enumerate(text_labels):
        plt.annotate(label, (text_2d[i, 0], text_2d[i, 1]), xytext=(5, 5), textcoords='offset points')
    
    plt.xlabel('t-SNE Dimension 1')
    plt.ylabel('t-SNE Dimension 2')
    plt.title('CLIP Embeddings Visualization')
    plt.legend()
    plt.grid(True, alpha=0.3)
    plt.show()

# Example usage (replace with your data)
# images = [image1, image2, image3, ...]  # PIL Images
# texts = ["text1", "text2", "text3", ...]
# image_labels = ["img1", "img2", "img3", ...]
# text_labels = ["txt1", "txt2", "txt3", ...]
# 
# img_emb, txt_emb = extract_embeddings(images, texts)
# visualize_embeddings(img_emb, txt_emb, image_labels, text_labels)

## 8. Evaluation Metrics

In [None]:
def recall_at_k(similarities, k=5):
    """Compute Recall@K for image-text retrieval"""
    
    n = similarities.shape[0]
    recall_scores = []
    
    for i in range(n):
        # Get top-k most similar items
        top_k_indices = np.argsort(similarities[i])[-k:]
        
        # Check if correct match is in top-k
        recall = 1 if i in top_k_indices else 0
        recall_scores.append(recall)
    
    return np.mean(recall_scores)

def evaluate_retrieval(model, test_images, test_texts):
    """Evaluate image-text retrieval performance"""
    
    # Extract embeddings
    image_embeddings, text_embeddings = extract_embeddings(test_images, test_texts)
    
    # Compute similarity matrix
    similarities = np.dot(image_embeddings, text_embeddings.T)
    
    # Compute metrics
    recall_1 = recall_at_k(similarities, k=1)
    recall_5 = recall_at_k(similarities, k=5)
    recall_10 = recall_at_k(similarities, k=10)
    
    print(f"Image-to-Text Retrieval:")
    print(f"Recall@1: {recall_1:.3f}")
    print(f"Recall@5: {recall_5:.3f}")
    print(f"Recall@10: {recall_10:.3f}")
    
    # Text-to-Image retrieval (transpose similarities)
    similarities_t = similarities.T
    recall_1_t = recall_at_k(similarities_t, k=1)
    recall_5_t = recall_at_k(similarities_t, k=5)
    recall_10_t = recall_at_k(similarities_t, k=10)
    
    print(f"\nText-to-Image Retrieval:")
    print(f"Recall@1: {recall_1_t:.3f}")
    print(f"Recall@5: {recall_5_t:.3f}")
    print(f"Recall@10: {recall_10_t:.3f}")
    
    return {
        'i2t_recall_1': recall_1,
        'i2t_recall_5': recall_5,
        'i2t_recall_10': recall_10,
        't2i_recall_1': recall_1_t,
        't2i_recall_5': recall_5_t,
        't2i_recall_10': recall_10_t
    }

# Example evaluation (replace with your test data)
# test_images = [img1, img2, img3, ...]  # PIL Images
# test_texts = ["caption1", "caption2", "caption3", ...]
# metrics = evaluate_retrieval(model, test_images, test_texts)