A PyTorch implementation of CLIP (Contrastive Language-Image Pre-training) for learning joint embeddings of images and text. This project trains a model to align visual and textual representations, enabling powerful zero-shot image retrieval from text queries.
This implementation uses the Flickr8k dataset to train a CLIP-style model that learns to match images with their corresponding text descriptions. The model learns a shared embedding space where semantically similar images and texts are close together.
- πΌοΈ Image Encoder: ResNet-50 for visual feature extraction
- π Text Encoder: DistilBERT for language understanding
- π Cross-modal Alignment: Contrastive learning to align image and text embeddings
- π Text-to-Image Search: Find images using natural language queries
- π Experiment Tracking: SwanLab integration for monitoring training metrics
- Purpose: Extracts visual features from images
- Architecture: Deep convolutional neural network with 50 layers
- Output: 2048-dimensional feature vectors
- Pretrained: Can use ImageNet weights for transfer learning
- Advantages:
- Proven performance on visual recognition tasks
- Skip connections prevent vanishing gradients
- Efficient feature extraction at 224Γ224 resolution
- Purpose: Encodes text captions into semantic representations
- Architecture: Distilled version of BERT (6 layers vs 12)
- Output: 768-dimensional feature vectors
- Pretrained: Always uses pretrained weights from Hugging Face
- Advantages:
- 40% smaller and 60% faster than BERT
- Retains 97% of BERT's language understanding
- Handles variable-length text up to 200 tokens
Both image and text features are projected into a shared 256-dimensional embedding space:
- Input: 2048-dim (image) or 768-dim (text)
- Hidden Layers: Configurable (default: 1 layer)
- Output: 256-dimensional normalized embeddings
- Activation: GELU with dropout (0.1)
- Normalization: L2 normalization for cosine similarity
The model learns by maximizing similarity between matching image-text pairs while minimizing similarity between non-matching pairs.
How it works:
-
Compute Embeddings:
- Process batch of N image-text pairs
- Get normalized embeddings:
image_emb(NΓ256),text_emb(NΓ256)
-
Similarity Matrix:
logits = (text_emb @ image_emb.T) / temperature- Creates NΓN matrix of all pairwise similarities
- Temperature (Ο=1.0) controls sharpness of distribution
-
Contrastive Loss:
- Positive pairs: Same image ID (one image has 5 captions)
- Negative pairs: Different image IDs
- Loss = Cross-entropy over similarity distribution
- Symmetric loss: average of imageβtext and textβimage
-
Training Signal:
- Pull matching pairs closer in embedding space
- Push non-matching pairs further apart
- Model learns semantic alignment across modalities
Example:
Image: [dog playing in grass] β Embedding: [0.2, 0.8, ..., 0.5]
Text: "a dog on the lawn" β Embedding: [0.3, 0.7, ..., 0.4]
β
High Similarity (positive pair)
- Images: 8,091 images
- Captions: 40,455 captions (5 per image)
- Training Split: 6,473 images (~32,365 pairs)
- Validation Split: 1,618 images (~8,090 pairs)
- Content: Diverse everyday scenes with people, animals, activities
Data Preprocessing:
python data_preprocessing.pyThis creates captions.csv with image IDs for train/val splitting.
# Model Architecture
model_name = 'resnet50' # Image encoder
text_encoder_model = 'distilbert-base-uncased'
image_embedding = 2048 # ResNet-50 output dim
text_embedding = 768 # DistilBERT output dim
projection_dim = 256 # Shared embedding space
# Training Hyperparameters
batch_size = 8 # 8 image-text pairs per batch
lr = 1e-3 # Learning rate (Adam)
weight_decay = 1e-3 # L2 regularization
epochs = 150 # Training epochs
temperature = 1.0 # Contrastive learning temperature
# Transfer Learning
pretrained = True # Use ImageNet weights (recommended)
trainable = True # Fine-tune encoders (recommended)
# Image Processing
size = 224 # Input image size
max_length = 200 # Max text tokens-
Data Loading:
- Load image-caption pairs from CSV
- Random 80/20 train/validation split
- Apply image augmentations (resize, normalize)
- Tokenize captions with DistilBERT tokenizer
-
Forward Pass:
# Encode images and text image_features = image_encoder(images) # (B, 2048) text_features = text_encoder(captions) # (B, 768) # Project to shared space image_emb = image_projection(image_features) # (B, 256) text_emb = text_projection(text_features) # (B, 256) # Normalize embeddings image_emb = F.normalize(image_emb, p=2, dim=-1) text_emb = F.normalize(text_emb, p=2, dim=-1)
-
Loss Computation:
- Compute similarity matrix:
logits = (text_emb @ image_emb.T) / Ο - Create positive pair mask based on image IDs
- Calculate bidirectional cross-entropy loss
- Backpropagate and update weights
- Compute similarity matrix:
-
Optimization:
- Optimizer: AdamW with weight decay
- LR Scheduler: ReduceLROnPlateau (reduces LR when validation loss plateaus)
- Gradient Clipping: Prevents exploding gradients
-
Monitoring:
- SwanLab tracks: train loss, validation loss, learning rate
- Best model saved based on lowest validation loss
- Metrics logged every 10 batches
# Build Docker image
docker build -t clip-training:latest .
# Train with GPU
docker run --rm --gpus '"device=0"' \
--ipc=host \
-e HF_HOME=/workspace/.cache \
-e SWANLAB_API_KEY=your_api_key \
-v $(pwd):/workspace \
-w /workspace \
clip-training:latest
# Monitor at: https://swanlab.cn/@your_username/CLIP-Training# Make sure to set pretrained=True, trainable=True in config.py
python main.pyTraining Time: ~30-45 minutes per epoch on GB10 GPU (with batch_size=8)
Once trained, you can search for images using natural language queries:
# Load trained model
model.load_state_dict(torch.load("best.pt"))
# Search with text query
query = "dogs playing on the grass"
matches = find_matches(model, image_embeddings, query, n=9)# Run inference with GPU
docker run --rm --gpus '"device=0"' \
--ipc=host \
-e HF_HOME=/workspace/.cache \
-v $(pwd):/workspace \
-w /workspace \
clip-training:latest python inference.py
# Results saved to: search_results.png- "a dog running in the park"
- "children playing with a ball"
- "a person riding a bicycle"
- "sunset over mountains"
- "food on a plate"
OpenAI-CLIP/
βββ CLIP.py # Main CLIP model definition
βββ modules.py # Image/text encoders and projection heads
βββ dataset.py # Dataset and data loading utilities
βββ main.py # Training script with SwanLab integration
βββ inference.py # Text-to-image search and visualization
βββ data_preprocessing.py # Prepares captions.csv from raw data
βββ config.py # Hyperparameters and configuration
βββ utils.py # Helper functions (AvgMeter, get_lr)
βββ Dockerfile # Docker environment for GB10 GPU
βββ setup.txt # Docker commands reference
βββ requirements.txt # Python dependencies
βββ Datasets/
βββ Images/ # Flickr8k image files
βββ captions.csv # Preprocessed image-caption pairs
-
Clone Repository:
git clone https://github.com/moein-shariatnia/OpenAI-CLIP.git cd OpenAI-CLIP -
Download Flickr8k Dataset:
# Download from Kaggle curl -L -o ~/Downloads/flickr8k.zip \ https://www.kaggle.com/api/v1/datasets/download/adityajn105/flickr8k # Extract to project unzip ~/Downloads/flickr8k.zip -d Datasets/
-
Preprocess Data:
python data_preprocessing.py
-
Install Dependencies:
pip install torch torchvision transformers pandas pillow tqdm \ albumentations timm swanlab opencv-python matplotlib
-
Build Docker Image (for GB10 GPU):
docker build -t clip-training:latest .
- Train Loss: Contrastive loss on training set
- Validation Loss: Contrastive loss on held-out validation set
- Learning Rate: Adaptively reduced when validation plateaus
- Convergence: Loss should decrease steadily over 50-100 epochs
- Validation Loss: Should track training loss (gap indicates overfitting)
- Best Model: Saved when validation loss reaches new minimum
- Model can retrieve relevant images for arbitrary text queries
- Performance improves with:
- More training epochs
- Larger batch sizes
- Pretrained encoders
- Data augmentation