# Appendix B: Fine-tuning ColPali

Edited version of [Merve Noyan's notebook](https://github.com/merveenoyan/smol-vision/blob/main/Finetune_ColPali.ipynb)

This notebook is a very minimal example to fine-tune ColPali on [UFO documents and queries](https://huggingface.co/datasets/davanstrien/ufo-Colpali) a dataset synthetically generated.

Then we will show a very minimal example on how to retrieve infographics.

In [None]:
import os

os.environ["CUDA_DEVICE_ORDER"] = "PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"] = "1"
os.environ["NCCL_P2P_DISABLE"] = "1"
os.environ["NCCL_IB_DISABLE"] = "1"

import pandas as pd
import torch
from colpali_engine.loss import ColbertPairwiseCELoss
from datasets import DatasetDict, load_dataset
from huggingface_hub import login
from peft import LoraConfig, get_peft_model
from transformers import BitsAndBytesConfig, ColPaliForRetrieval, ColPaliProcessor, Trainer, TrainingArguments

In [None]:
secret = pd.read_csv("secret.config", header=None)
HF_TOKEN = secret[1][1]
del secret
login(HF_TOKEN)

## Loading the Model

Fine-tuning ColPali takes around 48 GB of VRAM, which is way too much for an RTX 4090. To overcome memory limits, we can apply QLoRA to only train an adapter and load the model in a lower precision (4-bit).
Furthermore, we reduce the batch size and compensate for it with gradient accumulation.

In [None]:
model_name = "vidore/colpali-v1.2-hf"

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)
model = ColPaliForRetrieval.from_pretrained(
    model_name,
    torch_dtype=torch.bfloat16,
    quantization_config=bnb_config,
    device_map="cuda:0",
).eval()

lora_config = LoraConfig(
    r=8,
    lora_alpha=8,
    lora_dropout=0.1,
    target_modules=["down_proj", "o_proj", "k_proj", "q_proj", "gate_proj", "up_proj", "v_proj"],
    init_lora_weights="gaussian",
)
lora_config.inference_mode = False
model = get_peft_model(model, lora_config)
processor = ColPaliProcessor.from_pretrained(model_name)

## Load the dataset

We will use [this dataset](https://huggingface.co/datasets/davanstrien/ufo-ColPali) created by Daniel van Strien. In [this blog post](https://danielvanstrien.xyz/posts/post-with-code/colpali/2024-09-23-generate_colpali_dataset.html) he explains very thoroughly how to create a dataset for retrieval tasks.

In [None]:
ds = load_dataset("davanstrien/ufo-ColPali")
ds = ds["train"].train_test_split(test_size=0.1, seed=42)
train_ds = ds["train"]
test_ds = ds["test"]

ds

We need to get rid of dataset items where our text query column is None.

In [None]:
train_ds = train_ds.filter(lambda example: example["specific_detail_query"] is not None)
train_ds  # should be less than 2018

The dataset contains documents about UFOs and queries that might be related to document. Take alook at examples shortly.

In [None]:
print(train_ds[0]["specific_detail_query"])
display(train_ds[0]["image"])

From this dataset we will have the following columns to create the documents and the queries:

- `image` contains our documents.
- `specific_detail_query` contains the textual queries.

In [None]:
def collate_fn(examples):
    texts = []
    images = []

    for example in examples:
        texts.append(example["specific_detail_query"])
        images.append(example["image"].convert("RGB"))

    batch_images = processor(images=images, return_tensors="pt").to(model.device)
    batch_queries = processor(text=texts, return_tensors="pt").to(model.device)
    return (batch_queries, batch_images)

## Trainer

The trainer uses a ColBERT contrastive hard-margin loss. This loss is implemented in ColPali engine, it expects batch document and query embeddings, so essentially we need to process the documents and queries separately, then pass them to the model separately, then send to loss calculation.

Note that, since we are defining a custom loss, we have to subclass transformers Trainer to be able to pass it to the model.

In [None]:
class ContrastiveTrainer(Trainer):
    def __init__(self, loss_func, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.loss_func = loss_func

    def compute_loss(self, model, inputs, num_items_in_batch=4, return_outputs=False):
        query_inputs, doc_inputs = inputs
        query_outputs = model(**query_inputs)
        doc_outputs = model(**doc_inputs)
        loss = self.loss_func(query_outputs.embeddings, doc_outputs.embeddings)
        return (loss, (query_outputs, doc_outputs)) if return_outputs else loss

    def prediction_step(self, model, inputs):
        query_inputs, doc_inputs = inputs  # unpack from data collator
        with torch.no_grad():
            query_outputs = model(**query_inputs)
            doc_outputs = model(**doc_inputs)

            loss = self.loss_func(query_outputs.embeddings, doc_outputs.embeddings)
            return loss, None, None

In [None]:
training_args = TrainingArguments(
    output_dir="./colpali_ufo",
    num_train_epochs=1,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    gradient_checkpointing=False,
    logging_steps=20,
    warmup_steps=100,
    learning_rate=5e-5,
    save_total_limit=1,
    report_to="wandb",
    dataloader_pin_memory=False,
)

In [None]:
trainer = ContrastiveTrainer(
    model=model, train_dataset=train_ds, args=training_args, loss_func=ColbertPairwiseCELoss(), data_collator=collate_fn
)

trainer.args.remove_unused_columns = False

We are training on a small dataset (little less than 2k examples) for one epoch so the training is fairly short (around 8 mins).

In [None]:
trainer.train()

## Load and test fine-tuned model

Let's try the fine-tuned model. You can simply test by passing in text-image pairs and check the scores for the ones that are actually pairs of each other, and also the scores of those that are irrelevant (i.e. all scores except for the scores of the matching ones).

In [None]:
print(test_ds[0]["specific_detail_query"])
display(test_ds[0]["image"])
print(test_ds[1]["specific_detail_query"])
display(test_ds[1]["image"])
print(test_ds[2]["specific_detail_query"])
display(test_ds[2]["image"])

In [None]:
images = [test_ds[0]["image"], test_ds[1]["image"], test_ds[2]["image"]]
texts = [test_ds[0]["specific_detail_query"], test_ds[1]["specific_detail_query"], test_ds[2]["specific_detail_query"]]

# process
batch_images = processor(images=images).to(model.device)
batch_queries = processor(text=texts).to(model.device)

# infer
with torch.no_grad():
    image_embeddings = model(**batch_images).embeddings
    query_embeddings = model(**batch_queries).embeddings

# Score the queries against the images
scores = processor.score_retrieval(query_embeddings, image_embeddings)

The matching text-image scores are on the diagon of the scores below, as you can see, they're matched correctly!

In [None]:
scores