Skip to content

XingweiLiang/ClipModel

Folders and files

NameName
Last commit message
Last commit date

Latest commit

Β 

History

1 Commit
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 
Β 

Repository files navigation

OpenAI CLIP: Contrastive Language-Image Pre-training

A PyTorch implementation of CLIP (Contrastive Language-Image Pre-training) for learning joint embeddings of images and text. This project trains a model to align visual and textual representations, enabling powerful zero-shot image retrieval from text queries.

🎯 Project Overview

This implementation uses the Flickr8k dataset to train a CLIP-style model that learns to match images with their corresponding text descriptions. The model learns a shared embedding space where semantically similar images and texts are close together.

Key Features

  • πŸ–ΌοΈ Image Encoder: ResNet-50 for visual feature extraction
  • πŸ“ Text Encoder: DistilBERT for language understanding
  • πŸ”— Cross-modal Alignment: Contrastive learning to align image and text embeddings
  • πŸ” Text-to-Image Search: Find images using natural language queries
  • πŸ“Š Experiment Tracking: SwanLab integration for monitoring training metrics

πŸ—οΈ Architecture

Image Encoder: ResNet-50

  • Purpose: Extracts visual features from images
  • Architecture: Deep convolutional neural network with 50 layers
  • Output: 2048-dimensional feature vectors
  • Pretrained: Can use ImageNet weights for transfer learning
  • Advantages:
    • Proven performance on visual recognition tasks
    • Skip connections prevent vanishing gradients
    • Efficient feature extraction at 224Γ—224 resolution

Text Encoder: DistilBERT

  • Purpose: Encodes text captions into semantic representations
  • Architecture: Distilled version of BERT (6 layers vs 12)
  • Output: 768-dimensional feature vectors
  • Pretrained: Always uses pretrained weights from Hugging Face
  • Advantages:
    • 40% smaller and 60% faster than BERT
    • Retains 97% of BERT's language understanding
    • Handles variable-length text up to 200 tokens

Projection Heads

Both image and text features are projected into a shared 256-dimensional embedding space:

  • Input: 2048-dim (image) or 768-dim (text)
  • Hidden Layers: Configurable (default: 1 layer)
  • Output: 256-dimensional normalized embeddings
  • Activation: GELU with dropout (0.1)
  • Normalization: L2 normalization for cosine similarity

πŸ”„ Text-Image Alignment

Contrastive Learning Objective

The model learns by maximizing similarity between matching image-text pairs while minimizing similarity between non-matching pairs.

How it works:

  1. Compute Embeddings:

    • Process batch of N image-text pairs
    • Get normalized embeddings: image_emb (NΓ—256), text_emb (NΓ—256)
  2. Similarity Matrix:

    logits = (text_emb @ image_emb.T) / temperature
    
    • Creates NΓ—N matrix of all pairwise similarities
    • Temperature (Ο„=1.0) controls sharpness of distribution
  3. Contrastive Loss:

    • Positive pairs: Same image ID (one image has 5 captions)
    • Negative pairs: Different image IDs
    • Loss = Cross-entropy over similarity distribution
    • Symmetric loss: average of imageβ†’text and textβ†’image
  4. Training Signal:

    • Pull matching pairs closer in embedding space
    • Push non-matching pairs further apart
    • Model learns semantic alignment across modalities

Example:

Image: [dog playing in grass]  β†’  Embedding: [0.2, 0.8, ..., 0.5]
Text:  "a dog on the lawn"     β†’  Embedding: [0.3, 0.7, ..., 0.4]
                                    ↓
                            High Similarity (positive pair)

πŸ“Š Dataset

Flickr8k

  • Images: 8,091 images
  • Captions: 40,455 captions (5 per image)
  • Training Split: 6,473 images (~32,365 pairs)
  • Validation Split: 1,618 images (~8,090 pairs)
  • Content: Diverse everyday scenes with people, animals, activities

Data Preprocessing:

python data_preprocessing.py

This creates captions.csv with image IDs for train/val splitting.


πŸš€ Training Process

Configuration (config.py)

# Model Architecture
model_name = 'resnet50'              # Image encoder
text_encoder_model = 'distilbert-base-uncased'
image_embedding = 2048               # ResNet-50 output dim
text_embedding = 768                 # DistilBERT output dim
projection_dim = 256                 # Shared embedding space

# Training Hyperparameters
batch_size = 8                       # 8 image-text pairs per batch
lr = 1e-3                           # Learning rate (Adam)
weight_decay = 1e-3                 # L2 regularization
epochs = 150                        # Training epochs
temperature = 1.0                   # Contrastive learning temperature

# Transfer Learning
pretrained = True                   # Use ImageNet weights (recommended)
trainable = True                    # Fine-tune encoders (recommended)

# Image Processing
size = 224                          # Input image size
max_length = 200                    # Max text tokens

Training Pipeline

  1. Data Loading:

    • Load image-caption pairs from CSV
    • Random 80/20 train/validation split
    • Apply image augmentations (resize, normalize)
    • Tokenize captions with DistilBERT tokenizer
  2. Forward Pass:

    # Encode images and text
    image_features = image_encoder(images)      # (B, 2048)
    text_features = text_encoder(captions)       # (B, 768)
    
    # Project to shared space
    image_emb = image_projection(image_features) # (B, 256)
    text_emb = text_projection(text_features)    # (B, 256)
    
    # Normalize embeddings
    image_emb = F.normalize(image_emb, p=2, dim=-1)
    text_emb = F.normalize(text_emb, p=2, dim=-1)
  3. Loss Computation:

    • Compute similarity matrix: logits = (text_emb @ image_emb.T) / Ο„
    • Create positive pair mask based on image IDs
    • Calculate bidirectional cross-entropy loss
    • Backpropagate and update weights
  4. Optimization:

    • Optimizer: AdamW with weight decay
    • LR Scheduler: ReduceLROnPlateau (reduces LR when validation loss plateaus)
    • Gradient Clipping: Prevents exploding gradients
  5. Monitoring:

    • SwanLab tracks: train loss, validation loss, learning rate
    • Best model saved based on lowest validation loss
    • Metrics logged every 10 batches

Training Command (Docker)

# Build Docker image
docker build -t clip-training:latest .

# Train with GPU
docker run --rm --gpus '"device=0"' \
  --ipc=host \
  -e HF_HOME=/workspace/.cache \
  -e SWANLAB_API_KEY=your_api_key \
  -v $(pwd):/workspace \
  -w /workspace \
  clip-training:latest

# Monitor at: https://swanlab.cn/@your_username/CLIP-Training

Training Command (Local)

# Make sure to set pretrained=True, trainable=True in config.py
python main.py

Training Time: ~30-45 minutes per epoch on GB10 GPU (with batch_size=8)


πŸ” Inference

Text-to-Image Search

Once trained, you can search for images using natural language queries:

# Load trained model
model.load_state_dict(torch.load("best.pt"))

# Search with text query
query = "dogs playing on the grass"
matches = find_matches(model, image_embeddings, query, n=9)

Running Inference (Docker)

# Run inference with GPU
docker run --rm --gpus '"device=0"' \
  --ipc=host \
  -e HF_HOME=/workspace/.cache \
  -v $(pwd):/workspace \
  -w /workspace \
  clip-training:latest python inference.py

# Results saved to: search_results.png

Example Queries

  • "a dog running in the park"
  • "children playing with a ball"
  • "a person riding a bicycle"
  • "sunset over mountains"
  • "food on a plate"

πŸ“ Project Structure

OpenAI-CLIP/
β”œβ”€β”€ CLIP.py                 # Main CLIP model definition
β”œβ”€β”€ modules.py              # Image/text encoders and projection heads
β”œβ”€β”€ dataset.py              # Dataset and data loading utilities
β”œβ”€β”€ main.py                 # Training script with SwanLab integration
β”œβ”€β”€ inference.py            # Text-to-image search and visualization
β”œβ”€β”€ data_preprocessing.py   # Prepares captions.csv from raw data
β”œβ”€β”€ config.py              # Hyperparameters and configuration
β”œβ”€β”€ utils.py               # Helper functions (AvgMeter, get_lr)
β”œβ”€β”€ Dockerfile             # Docker environment for GB10 GPU
β”œβ”€β”€ setup.txt              # Docker commands reference
β”œβ”€β”€ requirements.txt       # Python dependencies
└── Datasets/
    β”œβ”€β”€ Images/            # Flickr8k image files
    └── captions.csv       # Preprocessed image-caption pairs

Setup

  1. Clone Repository:

    git clone https://github.com/moein-shariatnia/OpenAI-CLIP.git
    cd OpenAI-CLIP
  2. Download Flickr8k Dataset:

    # Download from Kaggle
    curl -L -o ~/Downloads/flickr8k.zip \
      https://www.kaggle.com/api/v1/datasets/download/adityajn105/flickr8k
    
    # Extract to project
    unzip ~/Downloads/flickr8k.zip -d Datasets/
  3. Preprocess Data:

    python data_preprocessing.py
  4. Install Dependencies:

    pip install torch torchvision transformers pandas pillow tqdm \
      albumentations timm swanlab opencv-python matplotlib
  5. Build Docker Image (for GB10 GPU):

    docker build -t clip-training:latest .

πŸ“ˆ Results & Metrics

Training Metrics

  • Train Loss: Contrastive loss on training set
  • Validation Loss: Contrastive loss on held-out validation set
  • Learning Rate: Adaptively reduced when validation plateaus

Expected Performance

  • Convergence: Loss should decrease steadily over 50-100 epochs
  • Validation Loss: Should track training loss (gap indicates overfitting)
  • Best Model: Saved when validation loss reaches new minimum

Zero-Shot Retrieval

  • Model can retrieve relevant images for arbitrary text queries
  • Performance improves with:
    • More training epochs
    • Larger batch sizes
    • Pretrained encoders
    • Data augmentation

About

A Simple Implementation of Clip Model

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Contributors