# CLIP-based Image Similarity Training Pipeline
## Complete Training Notebook for Google Colab

This notebook implements a two-phase triplet learning approach using CLIP:
- **Phase 1**: Train projection head (CLIP frozen)
- **Phase 2**: Fine-tune CLIP layers + projection head

### Features:
- Automatic dataset discovery from folder structure
- CLIP-based feature extraction
- Triplet loss with margin
- Learning rate scheduling
- Early stopping
- Model checkpointing
- Training visualization

In [1]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


## Setup: Install Dependencies

In [2]:
# Install required packages
!pip install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install ftfy regex tqdm
!pip install git+https://github.com/openai/CLIP.git
!pip install Pillow numpy matplotlib seaborn pandas scikit-learn

Looking in indexes: https://download.pytorch.org/whl/cu118
Collecting ftfy
  Downloading ftfy-6.3.1-py3-none-any.whl.metadata (7.3 kB)
Downloading ftfy-6.3.1-py3-none-any.whl (44 kB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m44.8/44.8 kB[0m [31m3.2 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ftfy
Successfully installed ftfy-6.3.1
Collecting git+https://github.com/openai/CLIP.git
  Cloning https://github.com/openai/CLIP.git to /tmp/pip-req-build-lid0qkgt
  Running command git clone --filter=blob:none --quiet https://github.com/openai/CLIP.git /tmp/pip-req-build-lid0qkgt
  Resolved https://github.com/openai/CLIP.git to commit dcba3cb2e2827b402d2701e7e1c7d9fed8a20ef1
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: clip
  Building wheel for clip (setup.py) ... [?25l[?25hdone
  Created wheel for clip: filename=clip-1.0-py3-none-any.whl size=1369490 sha256=acea49e3d1240f70bd6d95006ce73f60ad2c27152ed13

In [4]:
!unzip "/content/drive/MyDrive/raw-img.zip" -d "/content/drive/MyDrive/raw-img"

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-eBblY9bQyUrRADbLuh7pMQHaE8.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-Ebel8CWMjbPSxpiQs-hAKwHaFi.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-EbErR8KNlYyBM9jlNrIGIAHaGK.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-eBN9F0r9OL4COmgyEYGvaAEVEs.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-EbpsFIp42kQg7B467MR5bQHaE7.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-ecbjPxEXH_UreyOy1xjZlQHaFH.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-EcBwehwooCBORsmW1hDq6gHaGm.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-ECg-But0WowvNZcIb8QIeAHaEo.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP-EcIZuLV7tnHDYHDPO0wVdAHaE7.jpeg  
  inflating: /content/drive/MyDrive/raw-img/raw-img/spider/OIP

## Upload Dataset to Google Drive (Optional)

If using Google Drive, mount it and specify the path:

In [5]:
# Uncomment to mount Google Drive
# from google.colab import drive
# drive.mount('/content/drive')

# Set your data directory path
# DATA_DIR = '/content/drive/MyDrive/your-dataset-folder'
DATA_DIR = '/content/drive/MyDrive/raw-img/raw-img'  # Default: local to Colab session

## Configuration

All hyperparameters and training settings

In [7]:
from pathlib import Path

# ============= PATHS =============
BASE_DIR = Path('/content')
OUTPUT_DIR = BASE_DIR / 'outputs'
CHECKPOINT_DIR = OUTPUT_DIR / 'checkpoints'

# Create directories
OUTPUT_DIR.mkdir(exist_ok=True)
CHECKPOINT_DIR.mkdir(exist_ok=True)

# ============= MODEL CONFIGURATION =============
CLIP_MODEL_NAME = "ViT-B/32"  # Options: "ViT-B/32", "ViT-B/16", "RN50"
CLIP_EMBEDDING_DIM = 512
PROJECTION_DIM = 128
USE_PROJECTION_HEAD = True
PROJECTION_DROPOUT = 0.1

# ============= TRAINING CONFIGURATION =============
PHASE1_EPOCHS = 10  # Train only projection head
PHASE2_EPOCHS = 5   # Fine-tune CLIP + projection
TOTAL_EPOCHS = PHASE1_EPOCHS + PHASE2_EPOCHS
UNFREEZE_LAST_N_LAYERS = 2  # Number of CLIP layers to unfreeze in phase 2

# Batch size and workers
BATCH_SIZE = 32
NUM_WORKERS = 2  # For Colab

# Optimizer settings
LEARNING_RATE = 1e-4
WEIGHT_DECAY = 1e-4

# ============= TRIPLET LOSS CONFIGURATION =============
TRIPLET_MARGIN = 0.3

# ============= DATASET CONFIGURATION =============
TRAIN_VAL_SPLIT = 0.8  # 80% train, 20% validation
TRIPLETS_PER_IMAGE = 3  # Triplets per image per epoch

# ============= LOGGING CONFIGURATION =============
LOG_INTERVAL = 10
EARLY_STOPPING_PATIENCE = 5

# ============= FILE NAMES =============
BEST_MODEL_PATH = CHECKPOINT_DIR / "best_triplet_model.pth"
LAST_MODEL_PATH = CHECKPOINT_DIR / "last_triplet_model.pth"
TRAINING_LOG = OUTPUT_DIR / "training_log.csv"

# ============= REPRODUCIBILITY =============
RANDOM_SEED = 42

print("✓ Configuration loaded")
print(f"Output directory: {OUTPUT_DIR}")
print(f"Checkpoint directory: {CHECKPOINT_DIR}")

✓ Configuration loaded
Output directory: /content/outputs
Checkpoint directory: /content/outputs/checkpoints


## Import Libraries

In [8]:
import os
import random
import time
from collections import defaultdict
from typing import List, Dict, Tuple, Optional

import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from PIL import Image
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import Dataset, DataLoader
import clip

# Set style
sns.set_style('darkgrid')
plt.rcParams['figure.figsize'] = (12, 6)

print(f"PyTorch version: {torch.__version__}")
print(f"CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"CUDA device: {torch.cuda.get_device_name(0)}")

PyTorch version: 2.9.0+cu126
CUDA available: True
CUDA device: Tesla T4


## Set Random Seed for Reproducibility

In [9]:
def set_seed(seed: int = 42):
    """Set random seed for reproducibility"""
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    if torch.cuda.is_available():
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False
    print(f"✓ Random seed set to {seed}")

set_seed(RANDOM_SEED)

✓ Random seed set to 42


## Dataset Classes

Custom dataset for triplet learning with automatic class discovery from folder structure

In [10]:
class TripletImageDataset(Dataset):
    """
    Dataset that generates triplets (anchor, positive, negative) for metric learning.

    Folder structure:
    data_dir/
        class1/
            img1.jpg
            img2.jpg
        class2/
            img1.jpg
            img2.jpg
    """

    def __init__(self, root_dir: str, transform=None, triplets_per_image: int = 3,
                 subset_size: Optional[int] = None):
        self.root_dir = Path(root_dir)
        self.transform = transform
        self.triplets_per_image = triplets_per_image

        # Discover classes and build index
        self.class_to_images = self._build_class_index()
        self.classes = list(self.class_to_images.keys())

        # Flatten all image paths
        self.all_image_paths = []
        self.image_to_class = {}

        for class_name, image_paths in self.class_to_images.items():
            for img_path in image_paths:
                self.all_image_paths.append(img_path)
                self.image_to_class[img_path] = class_name

        # Apply subset if specified
        if subset_size is not None and subset_size < len(self.all_image_paths):
            self.all_image_paths = random.sample(self.all_image_paths, subset_size)

        print(f"Dataset: {len(self.classes)} classes, {len(self.all_image_paths)} images")
        print(f"Classes: {', '.join(self.classes)}")

    def _build_class_index(self) -> Dict[str, List[str]]:
        """Scan directory and build class-to-images mapping"""
        class_to_images = defaultdict(list)
        valid_extensions = {'.jpg', '.jpeg', '.png', '.bmp', '.gif', '.tiff'}

        for class_dir in self.root_dir.iterdir():
            if not class_dir.is_dir():
                continue

            class_name = class_dir.name
            for img_path in class_dir.iterdir():
                if img_path.suffix.lower() in valid_extensions:
                    class_to_images[class_name].append(str(img_path))

        # Filter out classes with too few images
        class_to_images = {k: v for k, v in class_to_images.items() if len(v) >= 2}
        return dict(class_to_images)

    def _load_and_transform_image(self, image_path: str) -> torch.Tensor:
        """Load and preprocess image"""
        image = Image.open(image_path).convert('RGB')
        if self.transform:
            image = self.transform(image)
        return image

    def _get_positive_sample(self, anchor_path: str, anchor_class: str) -> str:
        """Get positive sample (same class, different image)"""
        positive_candidates = [p for p in self.class_to_images[anchor_class] if p != anchor_path]
        if not positive_candidates:
            return anchor_path
        return random.choice(positive_candidates)

    def _get_negative_sample(self, anchor_class: str) -> str:
        """Get negative sample (different class)"""
        negative_classes = [c for c in self.classes if c != anchor_class]
        negative_class = random.choice(negative_classes)
        return random.choice(self.class_to_images[negative_class])

    def __len__(self) -> int:
        return len(self.all_image_paths) * self.triplets_per_image

    def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        # Map idx to anchor image
        anchor_idx = idx // self.triplets_per_image
        anchor_path = self.all_image_paths[anchor_idx]
        anchor_class = self.image_to_class[anchor_path]

        # Get positive and negative samples
        positive_path = self._get_positive_sample(anchor_path, anchor_class)
        negative_path = self._get_negative_sample(anchor_class)

        # Load and transform images
        anchor_img = self._load_and_transform_image(anchor_path)
        positive_img = self._load_and_transform_image(positive_path)
        negative_img = self._load_and_transform_image(negative_path)

        return anchor_img, positive_img, negative_img

print("✓ Dataset class defined")

✓ Dataset class defined


## Model Architecture

CLIP-based triplet network with optional projection head

In [11]:
class ProjectionHead(nn.Module):
    """Projection head to reduce dimensionality and adapt features"""

    def __init__(self, input_dim: int, output_dim: int, dropout: float = 0.1):
        super().__init__()

        self.projection = nn.Sequential(
            nn.Linear(input_dim, input_dim),
            nn.BatchNorm1d(input_dim),
            nn.ReLU(inplace=True),
            nn.Dropout(dropout),
            nn.Linear(input_dim, output_dim)
        )

        self._init_weights()

    def _init_weights(self):
        """Initialize weights using Xavier initialization"""
        for m in self.modules():
            if isinstance(m, nn.Linear):
                nn.init.xavier_uniform_(m.weight)
                if m.bias is not None:
                    nn.init.constant_(m.bias, 0)
            elif isinstance(m, nn.BatchNorm1d):
                nn.init.constant_(m.weight, 1)
                nn.init.constant_(m.bias, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = self.projection(x)
        # L2 normalize to unit hypersphere
        x = F.normalize(x, p=2, dim=1)
        return x


class CLIPTripletModel(nn.Module):
    """Triplet network using CLIP image encoder with projection head"""

    def __init__(self, clip_model_name: str = "ViT-B/32", projection_dim: Optional[int] = None,
                 dropout: float = 0.1, freeze_clip: bool = True):
        super().__init__()

        # Load pretrained CLIP
        self.clip_model, _ = clip.load(clip_model_name, device="cpu")

        # Get CLIP embedding dimension
        if "RN" in clip_model_name:
            self.clip_embedding_dim = 1024 if "RN50x" in clip_model_name else 512
        else:  # ViT-based
            self.clip_embedding_dim = 768 if "ViT-L" in clip_model_name else 512

        # Freeze CLIP if specified
        if freeze_clip:
            self.freeze_clip_backbone()

        # Projection head
        self.use_projection = projection_dim is not None
        if self.use_projection:
            self.projection_head = ProjectionHead(self.clip_embedding_dim, projection_dim, dropout)
            self.output_dim = projection_dim
        else:
            self.projection_head = None
            self.output_dim = self.clip_embedding_dim

        print(f"Model: {clip_model_name}, CLIP dim: {self.clip_embedding_dim}, Output dim: {self.output_dim}")

    def freeze_clip_backbone(self):
        """Freeze all CLIP parameters"""
        for param in self.clip_model.parameters():
            param.requires_grad = False

    def unfreeze_last_n_layers(self, n: int):
        """Unfreeze last N transformer blocks of CLIP"""
        self.freeze_clip_backbone()

        if hasattr(self.clip_model.visual, 'transformer'):
            total_blocks = len(self.clip_model.visual.transformer.resblocks)
            print(f"Unfreezing last {n} of {total_blocks} transformer blocks")

            for i in range(total_blocks - n, total_blocks):
                for param in self.clip_model.visual.transformer.resblocks[i].parameters():
                    param.requires_grad = True

            # Unfreeze layer norm and projection
            if hasattr(self.clip_model.visual, 'ln_post'):
                for param in self.clip_model.visual.ln_post.parameters():
                    param.requires_grad = True

            if hasattr(self.clip_model.visual, 'proj'):
                if self.clip_model.visual.proj is not None:
                    self.clip_model.visual.proj.requires_grad = True

    def encode_image(self, image: torch.Tensor) -> torch.Tensor:
        """Encode image using CLIP and optional projection head"""
        with torch.set_grad_enabled(self.training):
            clip_features = self.clip_model.encode_image(image).float()

        if self.use_projection:
            embeddings = self.projection_head(clip_features)
        else:
            embeddings = F.normalize(clip_features, p=2, dim=1)

        return embeddings

    def forward(self, anchor: torch.Tensor, positive: torch.Tensor,
                negative: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
        """Forward pass for triplet"""
        anchor_emb = self.encode_image(anchor)
        positive_emb = self.encode_image(positive)
        negative_emb = self.encode_image(negative)
        return anchor_emb, positive_emb, negative_emb

    def count_parameters(self) -> dict:
        """Count trainable and total parameters"""
        total_params = sum(p.numel() for p in self.parameters())
        trainable_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
        return {
            'total': total_params,
            'trainable': trainable_params,
            'frozen': total_params - trainable_params
        }

print("✓ Model architecture defined")

✓ Model architecture defined


## Loss Function and Metrics

In [12]:
class TripletLoss(nn.Module):
    """Triplet Margin Loss for metric learning"""

    def __init__(self, margin: float = 0.3):
        super().__init__()
        self.margin = margin

    def forward(self, anchor: torch.Tensor, positive: torch.Tensor,
                negative: torch.Tensor) -> torch.Tensor:
        # Euclidean distance
        pos_dist = F.pairwise_distance(anchor, positive, p=2)
        neg_dist = F.pairwise_distance(anchor, negative, p=2)

        # Triplet loss: max(0, pos_dist - neg_dist + margin)
        losses = F.relu(pos_dist - neg_dist + self.margin)
        return losses.mean()


def compute_triplet_accuracy(anchor: torch.Tensor, positive: torch.Tensor,
                            negative: torch.Tensor, margin: float = 0.0) -> float:
    """Compute triplet accuracy: fraction where d(a,p) < d(a,n)"""
    pos_dist = F.pairwise_distance(anchor, positive, p=2)
    neg_dist = F.pairwise_distance(anchor, negative, p=2)
    correct = (pos_dist + margin < neg_dist).float().sum()
    accuracy = correct / anchor.size(0)
    return accuracy.item()

print("✓ Loss function and metrics defined")

✓ Loss function and metrics defined


## Trainer Class

Handles training loop, validation, checkpointing, and early stopping

In [13]:
class TripletTrainer:
    """Trainer for triplet network with staged learning"""

    def __init__(self, model: CLIPTripletModel, train_loader: DataLoader,
                 val_loader: DataLoader, device: str = "cuda", learning_rate: float = 1e-4,
                 weight_decay: float = 1e-4, margin: float = 0.3):
        self.model = model.to(device)
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.device = device

        # Loss function
        self.criterion = TripletLoss(margin=margin)

        # Optimizer
        self.optimizer = torch.optim.AdamW(
            self.model.parameters(), lr=learning_rate, weight_decay=weight_decay
        )

        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            self.optimizer, T_max=TOTAL_EPOCHS, eta_min=1e-6
        )

        # Training state
        self.current_epoch = 0
        self.best_val_loss = float('inf')
        self.best_val_accuracy = 0.0
        self.epochs_without_improvement = 0

        # Metrics
        self.train_losses = []
        self.val_losses = []
        self.train_accuracies = []
        self.val_accuracies = []

        print(f"Trainer initialized on {device}")

    def train_epoch(self) -> Tuple[float, float]:
        """Train for one epoch"""
        self.model.train()
        epoch_loss = 0.0
        epoch_accuracy = 0.0
        num_batches = 0

        pbar = tqdm(self.train_loader, desc=f"Epoch {self.current_epoch + 1} [Train]")

        for batch_idx, (anchor, positive, negative) in enumerate(pbar):
            anchor = anchor.to(self.device)
            positive = positive.to(self.device)
            negative = negative.to(self.device)

            # Forward pass
            anchor_emb, pos_emb, neg_emb = self.model(anchor, positive, negative)
            loss = self.criterion(anchor_emb, pos_emb, neg_emb)
            accuracy = compute_triplet_accuracy(anchor_emb, pos_emb, neg_emb)

            # Backward pass
            self.optimizer.zero_grad()
            loss.backward()
            torch.nn.utils.clip_grad_norm_(self.model.parameters(), max_norm=1.0)
            self.optimizer.step()

            # Accumulate metrics
            epoch_loss += loss.item()
            epoch_accuracy += accuracy
            num_batches += 1

            # Update progress bar
            if batch_idx % LOG_INTERVAL == 0:
                pbar.set_postfix({
                    'loss': f'{loss.item():.4f}',
                    'acc': f'{accuracy:.4f}',
                    'lr': f'{self.optimizer.param_groups[0]["lr"]:.2e}'
                })

        return epoch_loss / num_batches, epoch_accuracy / num_batches

    @torch.no_grad()
    def validate(self) -> Tuple[float, float]:
        """Validate on validation set"""
        self.model.eval()
        epoch_loss = 0.0
        epoch_accuracy = 0.0
        num_batches = 0

        pbar = tqdm(self.val_loader, desc=f"Epoch {self.current_epoch + 1} [Val]")

        for anchor, positive, negative in pbar:
            anchor = anchor.to(self.device)
            positive = positive.to(self.device)
            negative = negative.to(self.device)

            # Forward pass
            anchor_emb, pos_emb, neg_emb = self.model(anchor, positive, negative)
            loss = self.criterion(anchor_emb, pos_emb, neg_emb)
            accuracy = compute_triplet_accuracy(anchor_emb, pos_emb, neg_emb)

            # Accumulate metrics
            epoch_loss += loss.item()
            epoch_accuracy += accuracy
            num_batches += 1

            pbar.set_postfix({'loss': f'{loss.item():.4f}', 'acc': f'{accuracy:.4f}'})

        return epoch_loss / num_batches, epoch_accuracy / num_batches

    def save_checkpoint(self, is_best: bool = False):
        """Save model checkpoint"""
        checkpoint = {
            'epoch': self.current_epoch,
            'model_state_dict': self.model.state_dict(),
            'optimizer_state_dict': self.optimizer.state_dict(),
            'scheduler_state_dict': self.scheduler.state_dict(),
            'best_val_loss': self.best_val_loss,
            'best_val_accuracy': self.best_val_accuracy,
            'train_losses': self.train_losses,
            'val_losses': self.val_losses,
            'train_accuracies': self.train_accuracies,
            'val_accuracies': self.val_accuracies,
        }

        torch.save(checkpoint, LAST_MODEL_PATH)

        if is_best:
            torch.save(checkpoint, BEST_MODEL_PATH)
            print(f"✓ Best model saved (val_acc: {self.best_val_accuracy:.4f})")

    def train(self, num_epochs: int, phase: int = 1):
        """Train for multiple epochs"""
        print(f"\n{'='*60}\nPhase {phase} Training\n{'='*60}")

        start_epoch = self.current_epoch

        for epoch in range(start_epoch, start_epoch + num_epochs):
            self.current_epoch = epoch
            epoch_start_time = time.time()

            # Train and validate
            train_loss, train_acc = self.train_epoch()
            val_loss, val_acc = self.validate()

            # Update learning rate
            self.scheduler.step()

            # Save metrics
            self.train_losses.append(train_loss)
            self.val_losses.append(val_loss)
            self.train_accuracies.append(train_acc)
            self.val_accuracies.append(val_acc)

            # Print summary
            epoch_time = time.time() - epoch_start_time
            print(f"\nEpoch {epoch + 1}/{start_epoch + num_epochs}")
            print(f"  Train Loss: {train_loss:.4f} | Train Acc: {train_acc:.4f}")
            print(f"  Val Loss:   {val_loss:.4f} | Val Acc:   {val_acc:.4f}")
            print(f"  Time: {epoch_time:.1f}s | LR: {self.optimizer.param_groups[0]['lr']:.2e}")

            # Check for improvement
            is_best = val_acc > self.best_val_accuracy

            if is_best:
                self.best_val_accuracy = val_acc
                self.best_val_loss = val_loss
                self.epochs_without_improvement = 0
            else:
                self.epochs_without_improvement += 1

            # Save checkpoint
            self.save_checkpoint(is_best=is_best)

            # Early stopping
            if EARLY_STOPPING_PATIENCE > 0:
                if self.epochs_without_improvement >= EARLY_STOPPING_PATIENCE:
                    print(f"\n⚠ Early stopping triggered after {self.epochs_without_improvement} epochs")
                    break

        print(f"\nPhase {phase} completed!")
        print(f"  Best Val Accuracy: {self.best_val_accuracy:.4f}")
        print(f"  Best Val Loss: {self.best_val_loss:.4f}")

print("✓ Trainer class defined")

✓ Trainer class defined


## Create DataLoaders

In [14]:
def get_clip_transforms(model_name: str = "ViT-B/32"):
    """Get CLIP preprocessing transforms"""
    _, preprocess = clip.load(model_name, device="cpu")
    return preprocess


def create_dataloaders(data_dir: str, batch_size: int, num_workers: int,
                      train_val_split: float = 0.8, clip_model_name: str = "ViT-B/32",
                      triplets_per_image: int = 3, subset_size: Optional[int] = None):
    """Create train and validation dataloaders"""

    # Get CLIP transforms
    transform = get_clip_transforms(clip_model_name)

    # Create full dataset
    full_dataset = TripletImageDataset(
        root_dir=data_dir, transform=transform,
        triplets_per_image=triplets_per_image, subset_size=subset_size
    )

    # Split into train and validation
    dataset_size = len(full_dataset.all_image_paths)
    train_size = int(train_val_split * dataset_size)

    all_paths = full_dataset.all_image_paths.copy()
    random.shuffle(all_paths)

    train_paths = all_paths[:train_size]
    val_paths = all_paths[train_size:]

    # Create separate datasets
    train_dataset = TripletImageDataset(
        root_dir=data_dir, transform=transform,
        triplets_per_image=triplets_per_image, subset_size=None
    )
    train_dataset.all_image_paths = train_paths

    val_dataset = TripletImageDataset(
        root_dir=data_dir, transform=transform,
        triplets_per_image=triplets_per_image, subset_size=None
    )
    val_dataset.all_image_paths = val_paths

    # Create dataloaders
    train_loader = DataLoader(
        train_dataset, batch_size=batch_size, shuffle=True,
        num_workers=num_workers, pin_memory=True
    )

    val_loader = DataLoader(
        val_dataset, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    print(f"\nDataLoaders created:")
    print(f"  Train: {len(train_dataset)} triplets ({len(train_paths)} images)")
    print(f"  Val: {len(val_dataset)} triplets ({len(val_paths)} images)")

    return train_loader, val_loader

print("✓ DataLoader functions defined")

✓ DataLoader functions defined


## Main Training Function

In [15]:
def train_complete_pipeline(data_dir: str, debug: bool = False):
    """Complete training pipeline with two-phase learning"""

    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

    print("="*60)
    print("CLIP Triplet Network Training")
    print("="*60)
    print(f"Device: {device}")
    print(f"Data directory: {data_dir}")
    print(f"Total epochs: {TOTAL_EPOCHS}")
    print(f"  Phase 1 (frozen CLIP): {PHASE1_EPOCHS} epochs")
    print(f"  Phase 2 (fine-tune): {PHASE2_EPOCHS} epochs")
    print("="*60)

    # Create dataloaders
    subset_size = 100 if debug else None

    train_loader, val_loader = create_dataloaders(
        data_dir=data_dir,
        batch_size=BATCH_SIZE,
        num_workers=NUM_WORKERS,
        train_val_split=TRAIN_VAL_SPLIT,
        clip_model_name=CLIP_MODEL_NAME,
        triplets_per_image=TRIPLETS_PER_IMAGE,
        subset_size=subset_size
    )

    # Create model
    model = CLIPTripletModel(
        clip_model_name=CLIP_MODEL_NAME,
        projection_dim=PROJECTION_DIM if USE_PROJECTION_HEAD else None,
        dropout=PROJECTION_DROPOUT,
        freeze_clip=True
    )

    # Print model info
    param_counts = model.count_parameters()
    print(f"\nModel parameters:")
    print(f"  Total: {param_counts['total']:,}")
    print(f"  Trainable: {param_counts['trainable']:,}")
    print(f"  Frozen: {param_counts['frozen']:,}")

    # Create trainer
    trainer = TripletTrainer(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        device=str(device),
        learning_rate=LEARNING_RATE,
        weight_decay=WEIGHT_DECAY,
        margin=TRIPLET_MARGIN
    )

    # ========== PHASE 1: Train projection head only ==========
    if PHASE1_EPOCHS > 0:
        print("\n" + "="*60)
        print("PHASE 1: Training projection head (CLIP frozen)")
        print("="*60)
        trainer.train(num_epochs=PHASE1_EPOCHS, phase=1)

    # ========== PHASE 2: Fine-tune CLIP + projection head ==========
    if PHASE2_EPOCHS > 0:
        print("\n" + "="*60)
        print("PHASE 2: Fine-tuning CLIP + projection head")
        print("="*60)

        # Unfreeze last N layers
        model.unfreeze_last_n_layers(UNFREEZE_LAST_N_LAYERS)

        # Recreate optimizer with lower LR
        trainer.optimizer = torch.optim.AdamW(
            model.parameters(),
            lr=LEARNING_RATE * 0.1,
            weight_decay=WEIGHT_DECAY
        )

        trainer.scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
            trainer.optimizer, T_max=PHASE2_EPOCHS, eta_min=1e-7
        )

        param_counts = model.count_parameters()
        print(f"Trainable parameters now: {param_counts['trainable']:,}")

        trainer.train(num_epochs=PHASE2_EPOCHS, phase=2)

    # Save final results
    print("\n" + "="*60)
    print("Training completed!")
    print("="*60)
    print(f"Best model saved at: {BEST_MODEL_PATH}")
    print(f"Last model saved at: {LAST_MODEL_PATH}")
    print(f"Best validation accuracy: {trainer.best_val_accuracy:.4f}")
    print(f"Best validation loss: {trainer.best_val_loss:.4f}")
    print("="*60)

    return trainer

print("✓ Main training function defined")

✓ Main training function defined


## Visualization Functions

In [16]:
def plot_training_history(trainer):
    """Plot training and validation metrics"""

    fig, axes = plt.subplots(1, 2, figsize=(15, 5))

    epochs = range(1, len(trainer.train_losses) + 1)

    # Plot losses
    axes[0].plot(epochs, trainer.train_losses, 'b-', label='Train Loss', linewidth=2)
    axes[0].plot(epochs, trainer.val_losses, 'r-', label='Val Loss', linewidth=2)
    axes[0].set_xlabel('Epoch', fontsize=12)
    axes[0].set_ylabel('Loss', fontsize=12)
    axes[0].set_title('Training and Validation Loss', fontsize=14, fontweight='bold')
    axes[0].legend()
    axes[0].grid(True, alpha=0.3)

    # Plot accuracies
    axes[1].plot(epochs, trainer.train_accuracies, 'b-', label='Train Accuracy', linewidth=2)
    axes[1].plot(epochs, trainer.val_accuracies, 'r-', label='Val Accuracy', linewidth=2)
    axes[1].set_xlabel('Epoch', fontsize=12)
    axes[1].set_ylabel('Accuracy', fontsize=12)
    axes[1].set_title('Training and Validation Accuracy', fontsize=14, fontweight='bold')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.savefig(OUTPUT_DIR / 'training_history.png', dpi=150, bbox_inches='tight')
    plt.show()

    print(f"✓ Plot saved to {OUTPUT_DIR / 'training_history.png'}")

print("✓ Visualization functions defined")

✓ Visualization functions defined


## Run Training

Execute the complete training pipeline

In [None]:
# Run training
# Set debug=True for a quick test run with small subset
trainer = train_complete_pipeline(data_dir=DATA_DIR, debug=False)

CLIP Triplet Network Training
Device: cuda
Data directory: /content/drive/MyDrive/raw-img/raw-img
Total epochs: 15
  Phase 1 (frozen CLIP): 10 epochs
  Phase 2 (fine-tune): 5 epochs


100%|███████████████████████████████████████| 338M/338M [00:11<00:00, 31.1MiB/s]


Dataset: 10 classes, 26179 images
Classes: butterfly, cat, chicken, cow, dog, elephant, horse, sheep, spider, squirrel
Dataset: 10 classes, 26179 images
Classes: butterfly, cat, chicken, cow, dog, elephant, horse, sheep, spider, squirrel
Dataset: 10 classes, 26179 images
Classes: butterfly, cat, chicken, cow, dog, elephant, horse, sheep, spider, squirrel

DataLoaders created:
  Train: 62829 triplets (20943 images)
  Val: 15708 triplets (5236 images)
Model: ViT-B/32, CLIP dim: 512, Output dim: 128

Model parameters:
  Total: 151,606,657
  Trainable: 329,344
  Frozen: 151,277,313
Trainer initialized on cuda

PHASE 1: Training projection head (CLIP frozen)

Phase 1 Training


Epoch 1 [Train]:   0%|          | 0/1964 [00:00<?, ?it/s]

## Visualize Results

In [None]:
# Plot training history
plot_training_history(trainer)

## Save Training Log

In [None]:
import pandas as pd

# Create training log DataFrame
training_log = pd.DataFrame({
    'Epoch': range(1, len(trainer.train_losses) + 1),
    'Train_Loss': trainer.train_losses,
    'Val_Loss': trainer.val_losses,
    'Train_Accuracy': trainer.train_accuracies,
    'Val_Accuracy': trainer.val_accuracies
})

# Save to CSV
training_log.to_csv(TRAINING_LOG, index=False)
print(f"✓ Training log saved to {TRAINING_LOG}")

# Display summary
print("\nTraining Summary:")
print(training_log.tail(10))

## Download Model Checkpoints

Download the trained model to your local machine

In [None]:
from google.colab import files

print("Downloading best model...")
files.download(str(BEST_MODEL_PATH))

print("Downloading last model...")
files.download(str(LAST_MODEL_PATH))

print("Downloading training log...")
files.download(str(TRAINING_LOG))

print(f"Model files location:")
print(f"  Best model: {BEST_MODEL_PATH}")
print(f"  Last model: {LAST_MODEL_PATH}")
print(f"  Training log: {TRAINING_LOG}")

## Model Inference Example

How to use the trained model for generating embeddings

In [None]:
# Load best model for inference
checkpoint = torch.load(BEST_MODEL_PATH, map_location='cpu')

# Create model and load weights
inference_model = CLIPTripletModel(
    clip_model_name=CLIP_MODEL_NAME,
    projection_dim=PROJECTION_DIM if USE_PROJECTION_HEAD else None,
    dropout=PROJECTION_DROPOUT,
    freeze_clip=False
)
inference_model.load_state_dict(checkpoint['model_state_dict'])
inference_model.eval()

print("✓ Model loaded for inference")
print(f"Best validation accuracy: {checkpoint['best_val_accuracy']:.4f}")

# Example: Generate embedding for a single image
def get_image_embedding(image_path: str, model, transform, device='cpu'):
    """Get embedding for a single image"""
    model = model.to(device)
    model.eval()

    # Load and transform image
    image = Image.open(image_path).convert('RGB')
    image_tensor = transform(image).unsqueeze(0).to(device)

    # Get embedding
    with torch.no_grad():
        embedding = model.encode_image(image_tensor)

    return embedding.cpu().numpy()

# Example usage (uncomment to test with your image)
# transform = get_clip_transforms(CLIP_MODEL_NAME)
# embedding = get_image_embedding('path/to/your/image.jpg', inference_model, transform)
# print(f"Embedding shape: {embedding.shape}")
# print(f"Embedding L2 norm: {np.linalg.norm(embedding):.4f}  (should be ~1.0)")

## Final Summary

### What was trained:
- CLIP-based triplet network for image similarity learning
- Two-phase training: projection head → fine-tuning
- Triplet loss with margin for metric learning

### Model outputs:
- L2-normalized embeddings (dimension: 128)
- Can be used for similarity search with cosine/euclidean distance

### Next steps:
1. Build FAISS index from all image embeddings
2. Implement similarity search
3. Visualize search results

### Files generated:
- `best_triplet_model.pth` - Best model checkpoint
- `last_triplet_model.pth` - Last epoch checkpoint
- `training_log.csv` - Training metrics
- `training_history.png` - Loss and accuracy plots