In [1]:
import sys
import clip
import torch
import os
import tqdm
from argparse import Namespace
import random
import numpy as np

sys.path.append('..')

from viscoin.datasets.cub import Labeled_CUB_200_2011

sys.path.append(os.path.join(os.path.abspath(""), "./../clip/LoRA/"))

from loralib.utils import (
    apply_lora,
    load_lora,
)

In [2]:
# Load clip model
device = "cuda" if torch.cuda.is_available() else "cpu"
model, preprocess = clip.load("ViT-B/32", device=device)

In [3]:
dataset_path = "./../datasets/CUB_200_2011/"
dataset = Labeled_CUB_200_2011(dataset_path, mode='test', transform=preprocess)

### Apply LoRA to Model

In [4]:
args = Namespace(**{
    # Model backbone type (e.g., CLIP backbone)
    'backbone': 'ViT-B/16',  # Default: 'ViT-B/16'

    # LoRA-specific arguments
    # Defines where to insert LoRA modules within the model's layers
    'position': 'all',  # Options: 'bottom', 'mid', 'up', 'half-up', 'half-bottom', 'all', 'top3'; Default: 'all'
    # Determines whether to apply LoRA to the text encoder, vision encoder, or both
    'encoder': 'both',  # Options: 'text', 'vision', 'both'; Default: 'both'
    # Specifies which attention matrices in the model will use LoRA
    'params': ('q', 'k', 'v'),  # Default: Query (q), Key (k), and Value (v) matrices
    # The rank of the low-rank matrices used in LoRA
    'r': 2,  # Default: 2
    # Scaling factor applied to LoRA matrices (see LoRA paper for details)
    'alpha': 1,  # Default: 1
    # Dropout rate applied before the LoRA module
    'dropout_rate': 0.25,  # Default: 0.25

    # Filename for saving LoRA weights (without extension)
    'filename': 'lora_weights',  # Default: 'lora_weights'    
    # Path to save the LoRA weights after training; will not save if set to None
    'save_path': "./../clip/LoRA/weights/"  # Default: './../clip/LoRA/weights/'
})

In [15]:
list_lora_layers = apply_lora(args, model)
load_lora(args, list_lora_layers)

LoRA weights loaded from ./../clip/LoRA/weights//lora_weights.pt


  loaded_data = torch.load(load_path)


In [6]:
def convert_model_to_float32(model):
    """Converts all weights in a PyTorch model to float32 precision."""
    for param in model.parameters():
        if param.dtype == torch.float16:  # Check if the parameter is in half-precision
            param.data = param.data.float()  # Convert the parameter data to float32
        if param.grad is not None and param.grad.dtype == torch.float16:  # Convert gradient, if it exists
            param.grad.data = param.grad.data.float()
    return model

model = convert_model_to_float32(model)

### Classifier

There are 312 possible attributes in the dataset : "blue wings", "red belly" ...etc.
We want clip to rank the probabilities of each of those attributes for each image and we compare the top attributes with the attributes actually present in the image.

In [17]:
class ClipAttributeClassifier:
    def __init__(self, model, preprocess, dataset, device):
        self.model = model
        self.preprocess = preprocess
        self.dataset = dataset
        
        self.device = device
        
        self.model.to(device)
        self.model.eval()
        
        # Get attribute labels
        self.attribute_labels = list(dataset.attributes_labels.values())
        self.attribute_labels = [f"A photo of a bird with {dataset.get_attribute_caption(attr)}" for attr in self.attribute_labels]
        
        # Get text features
        self.text_inputs = clip.tokenize(self.attribute_labels).to(device)
            
    def classify_batch(self, batch_indices):
        """
        Classify a batch of images, predicting the top-k attributes for each image where k is the number of attributes for a given image.
        
        Args:
            batch_indices (list): Indices of the images in the dataset.
            
        Returns:
            list: Accuracies for each image in the batch.
        """
        images = torch.stack([self.dataset[i][0] for i in batch_indices]).to(self.device)
        targets = [self.dataset.attributes[i] for i in batch_indices]
        
        # Get k values for each image in the batch
        ks = [target.shape[0] for target in targets]
        
        with torch.no_grad():
            image_features = self.model.encode_image(images)
            text_features = self.model.encode_text(self.text_inputs)
            
            image_features = image_features / image_features.norm(dim=-1, keepdim=True)
            text_features = text_features / text_features.norm(dim=-1, keepdim=True)
            
            logits_per_image = image_features @ text_features.T
            probs = logits_per_image.softmax(dim=-1)   
        
        accs = []
        for i, k in enumerate(ks):
            top_indices = torch.topk(probs[i], k).indices

            common_indices = set(top_indices.cpu().numpy()).intersection(set(targets[i]))
            
            accs.append(len(common_indices) / k)
        
        return accs
    
    def average_accuracy(self, batch_size=32, n=1000):
        """
        Compute average accuracy over the dataset in batches.
        
        Args:
            batch_size (int): Number of images per batch.
            n (int): Maximum number of images to evaluate.
            
        Returns:
            float: Average accuracy.
        """
        num_samples = min(n, len(self.dataset))
        accs = []
        
        for i in tqdm.tqdm(range(0, num_samples, batch_size)):
            batch_indices = list(range(i, min(i + batch_size, num_samples)))
            accs.extend(self.classify_batch(batch_indices))
        
        return np.mean(accs)


In [18]:
classifier = ClipAttributeClassifier(model, preprocess, dataset, device)

classifier.average_accuracy()

100%|██████████| 32/32 [00:06<00:00,  4.62it/s]


np.float64(0.0853576100637068)