## Model Training & Fine Tuning CLIP For CIFAR-10 Dataset

In [1]:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms, datasets
from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel
import numpy as np

In [2]:
# Map labels to text prompts
labels_to_text = {
    0: "An image of an airplane.",
    1: "An image of an automobile.",
    2: "An image of a bird.",
    3: "An image of a cat.",
    4: "An image of a deer.",
    5: "An image of a dog.",
    6: "An image of a frog.",
    7: "An image of a horse.",
    8: "An image of a ship.",
    9: "An image of a truck."
}

# Custom Dataset to integrate text prompts
class CIFAR10WithText(Dataset):
    def __init__(self, dataset, labels_to_text, transform=None):
        self.dataset = dataset
        self.labels_to_text = labels_to_text
        self.transform = transform

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image, label = self.dataset[idx]
        if self.transform:
            image = self.transform(image)
        text_prompt = self.labels_to_text[label]
        return image, text_prompt

# Transforms for images
transform = transforms.Compose([
    transforms.Resize((224, 224)),  # Match CLIP's input size
    transforms.ToTensor(),
    transforms.Normalize((0.48145466, 0.4578275, 0.40821073), (0.26862954, 0.26130258, 0.27577711))
])

# Load CIFAR-10 Dataset
train_dataset = datasets.CIFAR10(root="./data", train=True, download=True)
test_dataset = datasets.CIFAR10(root="./data", train=False, download=True)

# Wrap CIFAR-10 with text prompts
train_dataset = CIFAR10WithText(train_dataset, labels_to_text, transform=transform)
test_dataset = CIFAR10WithText(test_dataset, labels_to_text, transform=transform)


Files already downloaded and verified
Files already downloaded and verified


In [3]:
class CLIPFineTuner(nn.Module):
    def __init__(self, text_encoder, vision_encoder, projection_dim=512):
        super().__init__()
        self.text_encoder = text_encoder
        self.vision_encoder = vision_encoder

        # Projection layers to align embeddings
        self.text_projection = nn.Linear(768, projection_dim)  # Project text features
        self.image_projection = nn.Linear(1024, projection_dim)  # Project image features
        # Learnable scale parameter
        self.logit_scale = nn.Parameter(torch.ones([]) * np.log(1 / 0.07))

    def forward(self, images, input_ids, attention_mask):
        # Extract features
        image_features = self.vision_encoder(images).last_hidden_state[:, 0, :]
        text_features = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask).last_hidden_state[:, 0, :]

        # Normalize features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Project features into the same dimension
        image_features = self.image_projection(image_features)
        text_features = self.text_projection(text_features)

        # Normalize projected features
        image_features = image_features / image_features.norm(dim=-1, keepdim=True)
        text_features = text_features / text_features.norm(dim=-1, keepdim=True)

        # Compute similarity logits
        logit_scale = self.logit_scale.exp()
        logits_per_image = logit_scale * image_features @ text_features.t()
        logits_per_text = logits_per_image.t()

        return logits_per_image, logits_per_text


In [4]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Load CLIP tokenizer, text model, and vision model
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device)
vision_encoder = CLIPVisionModel.from_pretrained("openai/clip-vit-large-patch14").to(device)

In [None]:
import os
from torch.utils.data import DataLoader
from tqdm import tqdm

# Create DataLoader
train_loader = DataLoader(train_dataset, batch_size=16, shuffle=True)

# Initialize model
model = CLIPFineTuner(text_encoder, vision_encoder).to(device)

# Optimizer and loss
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)
loss_fn = nn.CrossEntropyLoss()

# Directory to save the models
save_dir = "saved_models"
os.makedirs(save_dir, exist_ok=True)  # Create directory if it doesn't exist

# Training Loop with model saving
num_epochs = 5
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    
    for images, captions in tqdm(train_loader):
        images = images.to(device)
        inputs = tokenizer(captions, padding="max_length", truncation=True, return_tensors="pt")
        input_ids = inputs.input_ids.to(device)
        attention_mask = inputs.attention_mask.to(device)

        # Forward pass
        logits_per_image, logits_per_text = model(images, input_ids, attention_mask)

        # Ground truth
        ground_truth = torch.arange(len(images), dtype=torch.long, device=device)

        # Compute loss
        loss = (loss_fn(logits_per_image, ground_truth) + loss_fn(logits_per_text, ground_truth)) / 2

        # Backpropagation
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    # Print epoch loss
    avg_loss = total_loss / len(train_loader)
    print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}")

    # Save the model after each epoch
    model_save_path = os.path.join(save_dir, f"clip_fine_tuned_epoch_{epoch + 1}.pth")
    torch.save(model.state_dict(), model_save_path)
    print(f"Model saved to {model_save_path}")

 68%|██████▊   | 2133/3125 [24:53<17:41,  1.07s/it] 

In [None]:
torch.save(model.state_dict(), "clip_fine_tuned_cifar10.pth")
print("Model saved to clip_fine_tuned_cifar10.pth")

In [None]:
torch.cuda.empty_cache()