<a href="https://colab.research.google.com/github/Papa-Panda/Paper_reading/blob/main/CLIP.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [5]:
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision import datasets, transforms, models
from transformers import BertTokenizer, BertModel

# Configuration
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
batch_size = 64
num_epochs = 10
learning_rate = 0.001
feature_dim = 512  # Dimensionality of the shared embedding space
cifar10_classes = [
    "airplane", "automobile", "bird", "cat", "deer", "dog", "frog", "horse", "ship", "truck"
]


In [2]:
# Load CIFAR-10 Dataset
transform = transforms.Compose([transforms.Resize(224),  # Resize to match input size of pre-trained models
                                transforms.ToTensor()])

train_dataset = datasets.CIFAR10(root='./data', train=True, download=True, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=batch_size, shuffle=True)

In [3]:
# Image Encoder (e.g., ResNet-18)
class ImageEncoder(nn.Module):
    def __init__(self):
        super(ImageEncoder, self).__init__()
        self.resnet = models.resnet18(pretrained=True)
        self.resnet.fc = nn.Identity()  # Remove final classification layer
        self.projection = nn.Linear(512, feature_dim)  # Project ResNet features to shared space

    def forward(self, images):
        features = self.resnet(images)
        return self.projection(features)


# Text Encoder (e.g., BERT for class names)
class TextEncoder(nn.Module):
    def __init__(self):
        super(TextEncoder, self).__init__()
        self.tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
        self.bert = BertModel.from_pretrained('bert-base-uncased')
        self.projection = nn.Linear(768, feature_dim)  # Project BERT features to shared space

    def forward(self, text):
        inputs = self.tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=10)
        inputs = {key: val.to(device) for key, val in inputs.items()}
        outputs = self.bert(**inputs)
        return self.projection(outputs.pooler_output)

# CLIP Model (Image + Text Encoder)
class CLIPModel(nn.Module):
    def __init__(self):
        super(CLIPModel, self).__init__()
        self.image_encoder = ImageEncoder().to(device)
        self.text_encoder = TextEncoder().to(device)

    def forward(self, images, text):
        image_features = self.image_encoder(images)
        text_features = self.text_encoder(text)
        return image_features, text_features

# Loss Function: Contrastive Loss (Cross-entropy on cosine similarity)
def contrastive_loss(image_features, text_features, temperature=0.07):
    logits = torch.matmul(image_features, text_features.T) / temperature
    labels = torch.arange(len(logits)).to(device)
    loss_i = nn.CrossEntropyLoss()(logits, labels)
    loss_t = nn.CrossEntropyLoss()(logits.T, labels)
    return (loss_i + loss_t) / 2

In [None]:

# Instantiate Model
model = CLIPModel()
optimizer = optim.Adam(model.parameters(), lr=learning_rate)

# Training Loop
for epoch in range(num_epochs):
    model.train()
    total_loss = 0
    for images, labels in train_loader:
        images = images.to(device)
        text_labels = [cifar10_classes[label] for label in labels]

        optimizer.zero_grad()
        image_features, text_features = model(images, text_labels)

        # Normalize features
        image_features = nn.functional.normalize(image_features, dim=1)
        text_features = nn.functional.normalize(text_features, dim=1)

        loss = contrastive_loss(image_features, text_features)
        loss.backward()
        optimizer.step()

        total_loss += loss.item()

    print(f'Epoch [{epoch+1}/{num_epochs}], Loss: {total_loss/len(train_loader):.4f}')

print("Training finished.")
