In [4]:
import os
import gzip
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms
from PIL import Image
from transformers import CLIPProcessor, CLIPModel

Downloading data:   8%|▊         | 41.9M/494M [24:57<4:28:56, 28.0kB/s]


In [16]:
import struct
import numpy as np
from PIL import Image

In [17]:
# Function to load the MNIST images from IDX format
def load_mnist_images(filename):
    with open(filename, 'rb') as f:
        _, num_images, rows, cols = struct.unpack('>IIII', f.read(16))
        images = np.fromfile(f, dtype=np.uint8).reshape(num_images, rows, cols)
    return images

In [18]:
# Function to load the MNIST labels from IDX format
def load_mnist_labels(filename):
    with open(filename, 'rb') as f:
        _, num_labels = struct.unpack('>II', f.read(8))
        labels = np.fromfile(f, dtype=np.uint8)
    return labels

In [26]:
# Load the dataset from the raw MNIST .gz files
train_images = load_mnist_images('./dataset/train-images.idx3-ubyte')
train_labels = load_mnist_labels('./dataset/train-labels.idx1-ubyte')
test_images = load_mnist_images('./dataset/t10k-images.idx3-ubyte')
test_labels = load_mnist_labels('./dataset/t10k-labels.idx1-ubyte')


In [20]:
# Convert images to PIL format (grayscale images need 3 channels for CLIP)
def preprocess_image(image):
    image = Image.fromarray(image)
    image = image.convert('RGB')  # Convert to 3-channel RGB format
    image = image.resize((224, 224))  # Resize for CLIP model
    return image

In [21]:
# Example usage: preprocess first image from the training set
sample_image = preprocess_image(train_images[0])
sample_image.show()

In [22]:
from torch.utils.data import Dataset, DataLoader
from transformers import CLIPProcessor, CLIPModel

# Custom Dataset Class for MNIST with Text Descriptions
class MNISTTextDataset(Dataset):
    def __init__(self, images, labels, processor):
        self.images = images
        self.labels = labels
        self.processor = processor

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        # Get image and label
        image = preprocess_image(self.images[idx])  # Preprocess image
        label = self.labels[idx]
        # Create a text description (e.g., "This is a 3")
        text = f"This is a {label}"
        # Process the image and text using the CLIP processor
        inputs = self.processor(text=[text], images=image, return_tensors="pt", padding=True)
        # Remove the batch dimension for processing
        input_ids = {k: v.squeeze(0) for k, v in inputs.items()}
        return input_ids

# Initialize the processor and the dataset
processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
train_dataset = MNISTTextDataset(train_images, train_labels, processor)
test_dataset = MNISTTextDataset(test_images, test_labels, processor)

# Create DataLoaders
train_dataloader = DataLoader(train_dataset, batch_size=8, shuffle=True)
test_dataloader = DataLoader(test_dataset, batch_size=8, shuffle=False)


In [24]:
# Initialize the CLIP model
import torch.nn.functional as F
model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

# Define optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-6)

# Modified Training Function with Contrastive Loss
def train_clip_model(model, dataloader, optimizer, device):
    model.train()
    total_loss = 0
    for batch in dataloader:
        # Move the inputs to the GPU if available
        inputs = {k: v.to(device) for k, v in batch.items()}
        # Get the image and text embeddings from the model
        outputs = model(**inputs)
        
        # Extract the logits
        logits_per_image = outputs.logits_per_image
        logits_per_text = outputs.logits_per_text
        
        # Create labels for contrastive loss
        labels = torch.arange(logits_per_image.size(0), device=device)
        
        # Compute the loss (cross-entropy loss between image-text pairs)
        loss_image = F.cross_entropy(logits_per_image, labels)
        loss_text = F.cross_entropy(logits_per_text, labels)
        loss = (loss_image + loss_text) / 2  # Average the losses

        # Backpropagation and optimization
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        total_loss += loss.item()
    return total_loss / len(dataloader)

# Validation Function with the same loss
def validate_clip_model(model, dataloader, device):
    model.eval()
    total_loss = 0
    with torch.no_grad():
        for batch in dataloader:
            # Move the inputs to the GPU if available
            inputs = {k: v.to(device) for k, v in batch.items()}
            # Get the image and text embeddings from the model
            outputs = model(**inputs)
            
            # Extract the logits
            logits_per_image = outputs.logits_per_image
            logits_per_text = outputs.logits_per_text
            
            # Create labels for contrastive loss
            labels = torch.arange(logits_per_image.size(0), device=device)
            
            # Compute the loss
            loss_image = F.cross_entropy(logits_per_image, labels)
            loss_text = F.cross_entropy(logits_per_text, labels)
            loss = (loss_image + loss_text) / 2  # Average the losses

            total_loss += loss.item()
    return total_loss / len(dataloader)
