In [1]:
from sklearn.metrics.pairwise import cosine_similarity
from transformers import CLIPProcessor, CLIPModel
from torch.utils.data import DataLoader, Dataset
import torch.nn.functional as F
import torch.optim as optim
import torch.nn as nn
from tqdm import tqdm
from PIL import Image
import numpy as np
import torch
import json
import time
import os

  from .autonotebook import tqdm as notebook_tqdm


# Utils

Class to load the TRAIN data, considering the directory structure;

In [2]:
class CLIPImageDatasetTrain(Dataset):
    def __init__(self, root_dir, processor):
        """
        root_dir: cartella principale con sottocartelle per ogni classe
        processor: istanza di CLIPProcessor da Hugging Face
        """
        self.processor = processor
        self.image_paths = []
        self.labels = []

        # Crea mappatura: nome_cartella → etichetta numerica
        class_names = sorted(os.listdir(root_dir))
        self.class_to_idx = {name: idx for idx, name in enumerate(class_names)}

        for class_name in class_names:
            class_path = os.path.join(root_dir, class_name)
            if not os.path.isdir(class_path):
                continue
            for fname in os.listdir(class_path):
                if fname.lower().endswith(('.png', '.jpg', '.jpeg')):
                    self.image_paths.append(os.path.join(class_path, fname))
                    self.labels.append(self.class_to_idx[class_name])

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        label = self.labels[idx]
        image = Image.open(img_path).convert("RGB")
        
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)

        return pixel_values, label

Class to load TEST data, considering directories structure;

In [3]:
class CLIPImageDatasetTest(Dataset):
    def __init__(self, image_dir, processor):
        """
        image_dir: directory con immagini
        processor: istanza di CLIPProcessor da Hugging Face
        """
        self.image_dir = image_dir
        self.image_paths = [
            os.path.join(image_dir, fname)
            for fname in os.listdir(image_dir)
            if fname.lower().endswith(('.png', '.jpg', '.jpeg'))
        ]
        self.processor = processor

    def __len__(self):
        return len(self.image_paths)

    def __getitem__(self, idx):
        img_path = self.image_paths[idx]
        image = Image.open(img_path).convert("RGB")
        
        # Usa il processor CLIP per ottenere pixel_values
        inputs = self.processor(images=image, return_tensors="pt")
        pixel_values = inputs["pixel_values"].squeeze(0)  # Remove batch dim

        return pixel_values, img_path  # Ritorna tensor e percorso per tracciamento

Batch sampler to N separate classes per batch (e.g. 8), K images for each class (e.g. 4) -> So the batch will have N × K images (e.g. 8 × 4 = 32 images per batch).

In [4]:
from torch.utils.data import Sampler
import random
from collections import defaultdict

class BalancedBatchSampler(Sampler):
    def __init__(self, labels, n_classes, n_samples_per_class):
        """
        labels: lista di etichette (stessa lunghezza del dataset)
        n_classes: quante classi diverse per batch
        n_samples_per_class: quante immagini per classe per batch
        """
        self.labels = labels
        self.n_classes = n_classes
        self.n_samples_per_class = n_samples_per_class

        # Costruisci mappa: label → lista di indici
        self.label_to_indices = defaultdict(list)
        for idx, label in enumerate(labels):
            self.label_to_indices[label].append(idx)

        # Filtra classi che NON hanno abbastanza immagini
        self.label_to_indices = {
            label: idxs for label, idxs in self.label_to_indices.items()
            if len(idxs) >= self.n_samples_per_class
        }

        self.labels_set = list(self.label_to_indices.keys())

        if len(self.labels_set) < self.n_classes:
            raise ValueError(f"Numero di classi valide ({len(self.labels_set)}) "
                             f"inferiore a n_classes richieste per batch ({self.n_classes})")

        # Numero massimo di batch stimato
        self.num_batches = len(labels) // (n_classes * n_samples_per_class)

    def __iter__(self):
        for _ in range(self.num_batches):
            batch = []
            selected_classes = random.sample(self.labels_set, self.n_classes)

            for cls in selected_classes:
                indices = self.label_to_indices[cls]
                sampled_indices = random.sample(indices, self.n_samples_per_class)
                batch.extend(sampled_indices)

            yield batch

    def __len__(self):
        return self.num_batches

SupConLoss extends the classical contrastive loss using class information. For each anchor (embedding), it considers all other examples of the same class as positive, and the rest as negative. The objective is:

Approach embeddings of the same class.

To distance embeddings of different classes.

It works on a balanced batch like the one we built with the BalancedBatchSampler.

Breakdown of the code:
| Passaggio               | Funzione                                              |
| ----------------------- | ----------------------------------------------------- |
| Maschera `mask`         | Identifica quali coppie sono positive                 |
| `features @ features.T` | Similarità tra embeddings                             |
| `log_prob`              | Probabilità (softmax) che due embeddings siano simili |
| `mean_log_prob_pos`     | Log-probabilità media dei veri positivi               |
| `loss`                  | Media negativa → da minimizzare                       |

In [5]:
class SupConLoss(nn.Module):
    def __init__(self, temperature=0.07):
        super(SupConLoss, self).__init__()
        self.temperature = temperature

    def forward(self, features, labels):
        """
        features: tensor (B, D), già normalizzato (embedding L2-normalizzati)
        labels: tensor (B,) con etichette
        """
        device = features.device
        labels = labels.contiguous().view(-1, 1)
        mask = torch.eq(labels, labels.T).float().to(device)  # (B, B)

        # Calcola similarità: prodotto scalare normalizzato
        anchor_dot_contrast = torch.div(torch.matmul(features, features.T), self.temperature)

        # Per stabilità numerica
        logits_max, _ = torch.max(anchor_dot_contrast, dim=1, keepdim=True)
        logits = anchor_dot_contrast - logits_max.detach()

        # Maschera: esclude confronto con se stessi
        logits_mask = torch.scatter(
            torch.ones_like(mask), 1, torch.arange(features.shape[0]).view(-1, 1).to(device), 0
        )
        mask = mask * logits_mask

        # Calcola loss
        exp_logits = torch.exp(logits) * logits_mask
        log_prob = logits - torch.log(exp_logits.sum(1, keepdim=True) + 1e-12)

        # Media sulle ancore (solo se ci sono positivi)
        mean_log_prob_pos = (mask * log_prob).sum(1) / mask.sum(1).clamp(min=1.0)

        loss = -mean_log_prob_pos.mean()
        return loss

CLIP Encoder che:

Usa la visione di CLIP,

Restituisce embedding L2-normalizzati (necessari per SupConLoss).

In [6]:
class CLIPEncoder(nn.Module):
    def __init__(self, model_name="openai/clip-vit-large-patch14-336"):
        super(CLIPEncoder, self).__init__()
        # Carica solo la parte visuale di CLIP
        self.vision_encoder = CLIPModel.from_pretrained(model_name).vision_model

    def forward(self, pixel_values):
        outputs = self.vision_encoder(pixel_values=pixel_values)
        embeddings = outputs.pooler_output  # shape (B, D)
        embeddings = F.normalize(embeddings, p=2, dim=1)  # L2-normalization
        return embeddings

# Main

In a nutshell: 

Try a supervised contrastive loss framework to adjust CLIP. Let CLIP do the embeddings, calculate the matrix similarity score between some samples of the classes in the train. How -> one class is chosen, then images (3??) for the same class become positive pair and some 4x7??? other classes to create negative pairs. 

Main idea: the loss helps the backprop to tune the CLIP (Visual Trasnformers) encoder to move the embeddings of the positive pairs closer in the shared multimodal embedding space, while the embeddings of the negative pairs should be moved away from each other.

### CLIP loading

More powerful version

In [7]:
processor = CLIPProcessor.from_pretrained("openai/clip-vit-large-patch14-336")
device = "cuda" if torch.cuda.is_available() else "cpu"
model = CLIPModel.from_pretrained("openai/clip-vit-large-patch14-336").vision_model.to(device)  # only visual encoder

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Base Version

In [8]:
# processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32")
# device = "cuda" if torch.cuda.is_available() else "cpu"
# model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32").vision_model.to(device)

### Train loading

In [9]:
train_dir =  "train"

In [10]:
train_dataset = CLIPImageDatasetTrain(train_dir, processor=processor)

### Sampling train for SupConLoss

In [11]:
# Parametri del batch
n_classes_per_batch = 4
n_samples_per_class = 2

sampler = BalancedBatchSampler(
    labels=train_dataset.labels,
    n_classes=n_classes_per_batch,
    n_samples_per_class=n_samples_per_class
)

### Sampled train loading

In [12]:
train_loader = DataLoader(
    train_dataset,
    batch_sampler=sampler,   # use batch_sampler instead of batch_size
    num_workers=4,
    pin_memory=True
)

### Encoder: Clip encoder (ViT - Visual Transformers)

In [13]:
device = "cuda" if torch.cuda.is_available() else "cpu"

encoder = CLIPEncoder().to(device)

### Optimizer: Adam

In [14]:
optimizer = optim.Adam(encoder.parameters(), lr=3e-5, weight_decay=1e-4)

### Loss declaration: Supervised Contrastive Loss

In [15]:
criterion = SupConLoss(temperature=0.07)

### Train Loop

Clean and simple training loop using your:

- train_loader (with BalancedBatchSampler)

- CLIPEncoder (with normalized embeddings)

- SupConLoss

- Adam optimizer

We’ll use a small number of epochs (e.g., 5) just to validate that everything works smoothly.

In [None]:
num_epochs = 1  # Start small
encoder.train()  # Set model to training mode

for epoch in range(num_epochs):
    total_loss = 0
    num_batches = 0

    for pixel_values, labels in tqdm(train_loader, desc=f"Epoch {epoch+1}/{num_epochs}"):
        pixel_values = pixel_values.to(device)
        labels = labels.to(device)

        # Forward pass
        embeddings = encoder(pixel_values)  # Already normalized

        # Compute loss
        loss = criterion(embeddings, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()

        # Logging
        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches
    print(f"Epoch [{epoch+1}/{num_epochs}] - Average Loss: {avg_loss:.4f}")

Epoch 1/1:  25%|██▍       | 153/624 [02:28<07:33,  1.04it/s]

Saving model's weights

In [None]:
save_path = "fine_tuned_clip_encoder.pth"
torch.save(encoder.state_dict(), save_path)
print(f"Model saved to {save_path}")

# Forward pass

Test set directories

In [None]:
gallery_dir = 'test-2/gallery'
query_dir = 'test-2/query'

In [None]:
# Crea istanze del dataset aggiornato
gallery_dataset = CLIPImageDatasetTest(gallery_dir, processor=processor)
query_dataset = CLIPImageDatasetTest(query_dir, processor=processor)

gallery_loader = DataLoader(gallery_dataset, batch_size=8, shuffle=False)  # era 32
query_loader = DataLoader(query_dataset, batch_size=8, shuffle=False)

Forward evalutation for Pre-trained models without fine-tuning

In [None]:
with torch.no_grad():
    gallery_embeddings = []
    query_embeddings = []
    gallery_paths = []
    query_paths = []

    print("Extracting gallery embeddings...")
    for i, (pixel_values, paths) in enumerate(tqdm(gallery_loader)):
        start_time = time.time()

        pixel_values = pixel_values.to(device)

        # Forward pass con il visual encoder CLIP (restituisce dict con 'pooler_output')
        outputs = model(pixel_values)
        emb = outputs.pooler_output  # shape: (batch_size, hidden_dim)

        gallery_embeddings.append(emb.cpu().numpy())
        gallery_paths.extend(paths)

        if i % 10 == 0:
            print(f"Batch {i}: {time.time() - start_time:.2f}s")

        del pixel_values, emb, outputs
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

    print("Extracting query embeddings...")
    for i, (pixel_values, paths) in enumerate(tqdm(query_loader)):
        start_time = time.time()

        pixel_values = pixel_values.to(device)

        outputs = model(pixel_values)
        emb = outputs.pooler_output

        query_embeddings.append(emb.cpu().numpy())
        query_paths.extend(paths)

        if i % 10 == 0:
            print(f"Batch {i}: {time.time() - start_time:.2f}s")

        del pixel_values, emb, outputs
        torch.cuda.empty_cache() if torch.cuda.is_available() else None


Forward evalutation for fine-tuned encoder

In [None]:
with torch.no_grad():
    gallery_embeddings = []
    query_embeddings = []
    gallery_paths = []
    query_paths = []

    # Extract gallery embeddings con progress bar
    print("Extracting gallery embeddings...")
    for i, (pixel_values, paths) in enumerate(tqdm(gallery_loader)):
        start_time = time.time()
        
        pixel_values = pixel_values.to(device)
        # Usa l'encoder fine-tuned invece del model completo
        emb = encoder(pixel_values)  # L'encoder restituisce già embeddings normalizzati
        gallery_embeddings.append(emb.cpu().numpy())
        gallery_paths.extend(paths)
        
        # Stampa timing ogni 10 batch
        if i % 10 == 0:
            print(f"Batch {i}: {time.time() - start_time:.2f}s")
        
        # Libera memoria GPU
        del pixel_values, emb
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

    # Extract query embeddings
    print("Extracting query embeddings...")
    for i, (pixel_values, paths) in enumerate(tqdm(query_loader)):
        start_time = time.time()
        
        pixel_values = pixel_values.to(device)
        # Usa l'encoder fine-tuned invece del model completo
        emb = encoder(pixel_values)  # L'encoder restituisce già embeddings normalizzati
        query_embeddings.append(emb.cpu().numpy())
        query_paths.extend(paths)
        
        if i % 10 == 0:
            print(f"Batch {i}: {time.time() - start_time:.2f}s")
            
        # Libera memoria GPU
        del pixel_values, emb
        torch.cuda.empty_cache() if torch.cuda.is_available() else None

Evalutation && results.json

In [None]:
# Stack all embedding batches into single numpy arrays
gallery_embeddings = np.vstack(gallery_embeddings)  # shape: (N_gallery, D)
query_embeddings = np.vstack(query_embeddings)      # shape: (N_query, D)

# Compute cosine similarity between each query and all gallery embeddings
similarity_matrix = cosine_similarity(query_embeddings, gallery_embeddings)

# For each query, find the index of the most similar gallery image
retrieved_indices = np.argmax(similarity_matrix, axis=1)

top_k = 10
top_k_indices = np.argsort(similarity_matrix, axis=1)[:, -top_k:][:, ::-1]

# Build results dictionary in the required format
results = {}

for i, indices in enumerate(top_k_indices):
    # Extract just the filename from the full path
    query_filename = os.path.basename(query_paths[i])
    
    # Get the top-k gallery filenames
    retrieved_filenames = [os.path.basename(gallery_paths[idx]) for idx in indices]
    
    results[query_filename] = retrieved_filenames

# Save results to JSON file
output_file = "retrieval_results.json"
with open(output_file, 'w') as f:
    json.dump(results, f, indent=2)

print(f"Results saved to {output_file}")

# Optional: Print a few examples to verify format
print("\nFirst 3 results:")
for i, (query, retrieved) in enumerate(results.items()):
    if i >= 3:
        break
    print(f"Query: {query}")
    print(f"Top-3 Retrieved: {retrieved[:3]}")
    print("-" * 50)

# Multiple Fine-Tuning && Forward pass + results.json

In [None]:
# Create results and weights directories if they don't exist
results_dir = "results"
weights_dir = "weights"
os.makedirs(results_dir, exist_ok=True)
os.makedirs(weights_dir, exist_ok=True)

num_epochs = 10
encoder.train()  # Set model to training mode

for epoch in range(num_epochs):
    print(f"\n{'='*50}")
    print(f"EPOCH {epoch+1}/{num_epochs}")
    print(f"{'='*50}")
    
    # Training phase
    print("Training...")
    total_loss = 0
    num_batches = 0

    for pixel_values, labels in tqdm(train_loader, desc=f"Training Epoch {epoch+1}/{num_epochs}"):
        pixel_values = pixel_values.to(device)
        labels = labels.to(device)

        # Forward pass
        embeddings = encoder(pixel_values)  # Already normalized

        # Compute loss
        loss = criterion(embeddings, labels)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        torch.cuda.empty_cache()

        # Logging
        total_loss += loss.item()
        num_batches += 1

    avg_loss = total_loss / num_batches
    print(f"Epoch [{epoch+1}/{num_epochs}] - Average Training Loss: {avg_loss:.4f}")
    
    # Evaluation phase
    print(f"\nEvaluating after epoch {epoch+1}...")
    encoder.eval()  # Set to evaluation mode
    
    with torch.no_grad():
        gallery_embeddings = []
        query_embeddings = []
        gallery_paths = []
        query_paths = []

        # Extract gallery embeddings con progress bar
        print("Extracting gallery embeddings...")
        for i, (pixel_values, paths) in enumerate(tqdm(gallery_loader, desc="Gallery")):
            start_time = time.time()
            
            pixel_values = pixel_values.to(device)
            # Usa l'encoder fine-tuned invece del model completo
            emb = encoder(pixel_values)  # L'encoder restituisce già embeddings normalizzati
            gallery_embeddings.append(emb.cpu().numpy())
            gallery_paths.extend(paths)
            
            # Stampa timing ogni 10 batch
            if i % 10 == 0:
                print(f"Batch {i}: {time.time() - start_time:.2f}s")
            
            # Libera memoria GPU
            del pixel_values, emb
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

        # Extract query embeddings
        print("Extracting query embeddings...")
        for i, (pixel_values, paths) in enumerate(tqdm(query_loader, desc="Query")):
            start_time = time.time()
            
            pixel_values = pixel_values.to(device)
            # Usa l'encoder fine-tuned invece del model completo
            emb = encoder(pixel_values)  # L'encoder restituisce già embeddings normalizzati
            query_embeddings.append(emb.cpu().numpy())
            query_paths.extend(paths)
            
            if i % 10 == 0:
                print(f"Batch {i}: {time.time() - start_time:.2f}s")
                
            # Libera memoria GPU
            del pixel_values, emb
            torch.cuda.empty_cache() if torch.cuda.is_available() else None

        # Stack all embedding batches into single numpy arrays
        gallery_embeddings = np.vstack(gallery_embeddings)  # shape: (N_gallery, D)
        query_embeddings = np.vstack(query_embeddings)      # shape: (N_query, D)

        # Compute cosine similarity between each query and all gallery embeddings
        print("Computing similarities...")
        similarity_matrix = cosine_similarity(query_embeddings, gallery_embeddings)

        # For each query, find the top-k most similar gallery images
        top_k = 10
        top_k_indices = np.argsort(similarity_matrix, axis=1)[:, -top_k:][:, ::-1]

        # Build results dictionary in the required format
        results = {}

        for i, indices in enumerate(top_k_indices):
            # Extract just the filename from the full path
            query_filename = os.path.basename(query_paths[i])
            
            # Get the top-k gallery filenames
            retrieved_filenames = [os.path.basename(gallery_paths[idx]) for idx in indices]
            
            results[query_filename] = retrieved_filenames

        # Save results to JSON file with epoch number
        output_file = os.path.join(results_dir, f"retrieval_results_epoch_{epoch+1:02d}.json")
        with open(output_file, 'w') as f:
            json.dump(results, f, indent=2)

        print(f"Results saved to {output_file}")
        
        # Save model weights after evaluation
        weights_file = os.path.join(weights_dir, f"encoder_epoch_{epoch+1:02d}.pth")
        torch.save({
            'epoch': epoch + 1,
            'model_state_dict': encoder.state_dict(),
            'optimizer_state_dict': optimizer.state_dict(),
            'loss': avg_loss,
        }, weights_file)
        print(f"Model weights saved to {weights_file}")

        # Optional: Print a few examples to verify format
        print(f"\nFirst 3 results for epoch {epoch+1}:")
        for i, (query, retrieved) in enumerate(results.items()):
            if i >= 3:
                break
            print(f"Query: {query}")
            print(f"Top-3 Retrieved: {retrieved[:3]}")
            print("-" * 30)
    
    # Set back to training mode for next epoch
    encoder.train()
    
    print(f"\nEpoch {epoch+1} completed!")

print(f"\n{'='*50}")
print("TRAINING COMPLETED!")
print(f"All results saved in '{results_dir}' directory")
print(f"All model weights saved in '{weights_dir}' directory")
print(f"{'='*50}")