In [2]:
# Check if necessary packages are installed, if not, install them
try:
    import torch
    import torchvision
    from torchvision import transforms
    from transformers import ViTModel
    import torch.nn as nn
    import torch.nn.functional as F
    import requests
    from PIL import Image
    from io import BytesIO
except ImportError:
    !pip install torch torchvision transformers

    # After installing, import again
    import torch
    import torchvision
    from torchvision import transforms
    from transformers import ViTModel
    import torch.nn as nn
    import torch.nn.functional as F
    import requests
    from PIL import Image
    from io import BytesIO

# Define transformations for the dataset
transform = transforms.Compose([
    transforms.ToTensor(),  # Convert PIL Image to tensor
    transforms.Resize((224, 224)),  # Resize to the input size expected by CLIP
])

# Load CIFAR-10 dataset
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                             download=True, transform=transform)

# Custom DataLoader (optional)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=32,
                                           shuffle=True, num_workers=2)

# CLIP Model class
class CLIPModel(nn.Module):
    def __init__(self, embed_dim=512):
        super(CLIPModel, self).__init__()

        # Image backbone (ViT)
        self.image_backbone = ViTModel.from_pretrained('google/vit-base-patch16-224-in21k')

        # Text backbone (ViT, shared with image backbone)
        self.text_backbone = self.image_backbone

        # Projection MLPs
        self.image_projection = nn.Sequential(
            nn.Linear(768, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

        self.text_projection = nn.Sequential(
            nn.Linear(768, embed_dim),
            nn.ReLU(),
            nn.Linear(embed_dim, embed_dim)
        )

    def forward(self, images, texts):
        # Encode images and texts
        image_features = self.image_backbone(images).last_hidden_state[:, 0]
        text_features = self.text_backbone(texts).last_hidden_state[:, 0]

        # Project image and text features
        image_features = self.image_projection(image_features)
        text_features = self.text_projection(text_features)

        return image_features, text_features

# Function to load an image from URL
def load_image_from_url(url):
    response = requests.get(url)
    image = Image.open(BytesIO(response.content))
    return image

# Function to calculate contrastive loss
def contrastive_loss(image_features, text_features, temperature=0.07):
    # Normalize features
    image_features = F.normalize(image_features, dim=-1)
    text_features = F.normalize(text_features, dim=-1)

    # Compute cosine similarity
    logits = torch.matmul(image_features, text_features.T) / temperature

    # Construct labels (0 = negative pairs, 1 = positive pairs)
    labels = torch.arange(logits.shape[0]).to(logits.device)

    # Compute cross-entropy loss
    loss = F.cross_entropy(logits, labels)

    return loss

# Initialize CLIP model
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = CLIPModel().to(device)

# Optimizer and learning rate scheduler
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.5)

# Training loop
def train_model(model, train_loader, optimizer, scheduler, num_epochs=10):
    model.train()

    for epoch in range(num_epochs):
        running_loss = 0.0
        for i, (images, _) in enumerate(train_loader):
            images = images.to(device)
            texts = torch.randn(images.size(0), 77).to(device)  # Random text input for demonstration

            optimizer.zero_grad()

            # Forward pass
            image_features, text_features = model(images, texts)

            # Calculate contrastive loss
            loss = contrastive_loss(image_features, text_features)

            # Backward pass and optimize
            loss.backward()
            optimizer.step()

            running_loss += loss.item()

            if (i+1) % 100 == 0:
                print(f"Epoch [{epoch+1}/{num_epochs}], Step [{i+1}/{len(train_loader)}], Loss: {running_loss / 100:.4f}")
                running_loss = 0.0

        # Step the scheduler at the end of each epoch
        scheduler.step()
        print(f"Finished Epoch [{epoch+1}/{num_epochs}], LR: {scheduler.get_last_lr()}")

# Example usage of the training loop
if __name__ == "__main__":
    train_model(model, train_loader, optimizer, scheduler, num_epochs=5)


Files already downloaded and verified
