In [1]:
!pip -q install --upgrade colpali-engine[train] peft bitsandbytes transformers accelerate datasets hf_xet torchvision pillow tqdm --progress-bar off

import os, random, math, torch, gc
os.environ["TORCH_CHECKPOINT_USE_REENTRANT"] = "0"
from dataclasses import dataclass
from pathlib import Path
from accelerate import Accelerator
from tqdm.auto import tqdm
from torch.nn import functional as F
from datasets import load_dataset, concatenate_datasets
from torchvision.transforms import (Compose, RandomRotation, RandomAffine, ColorJitter, Resize, InterpolationMode)
from torch.utils.data import DataLoader
from peft import (LoraConfig, prepare_model_for_kbit_training, get_peft_model)
from transformers import BitsAndBytesConfig, get_cosine_schedule_with_warmup
from colpali_engine.models import ColQwen2_5, ColQwen2_5_Processor

[31mERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
gcsfs 2025.3.2 requires fsspec==2025.3.2, but you have fsspec 2024.12.0 which is incompatible.[0m[31m
[0m

In [3]:
# Hyper-parameters
@dataclass
class CFG:
    model_name: str = "nomic-ai/colnomic-embed-multimodal-7b"
    image_size: int = 224
    max_txt_len: int = 96
    batch_size: int = 32
    epochs: int = 3
    lr: float = 1e-4
    warmup_steps: int = 500
    lora_rank: int = 32
    out_dir: str = "colnomic_lora_handwriting"
    augment: bool = True

# 1. Dataloader

In [4]:
base_tf = [Resize((CFG.image_size, CFG.image_size), InterpolationMode.BICUBIC)]
aug_tf  = [RandomRotation(3, fill=255),
           RandomAffine(3, translate=(0.03,0.03), fill=255),
           ColorJitter(brightness=0.2, contrast=0.2)] if CFG.augment else []
transform = Compose(aug_tf + base_tf)

print("Downloading IAM & RIMES …")
iam   = load_dataset("Teklia/IAM-line", split="train")
rimes = load_dataset("Teklia/RIMES-2011-line", split="train")
full_ds = concatenate_datasets([iam, rimes])

val_size = int(0.10 * len(full_ds))
val_ds = full_ds.shuffle(seed=CFG.seed).select(range(val_size))
train_ds = full_ds.shuffle(seed=CFG.seed).select(range(val_size, len(full_ds)))


def preprocess(ex):
    img = ex["image"].convert("RGB")
    ex["image"] = transform(img)          # still PIL
    ex["text"]  = ex["text"][:CFG.max_txt_len]
    return {"image": ex["image"], "text": ex["text"]}

train_ds = train_ds.map(preprocess, remove_columns=train_ds.column_names, num_proc=4)
val_ds = val_ds.map(preprocess, remove_columns=val_ds.column_names, num_proc=4)

train_ds.set_format(type="python")
val_ds.set_format(type="python")

print(f"Train {len(train_ds)}  Val {len(val_ds)}")


def collate_fn(batch):
    images = [b["image"] for b in batch]   # PIL list
    texts  = [b["text"]  for b in batch]
    return images, texts

train_loader = DataLoader(train_ds, batch_size=CFG.batch_size, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate_fn)
val_loader   = DataLoader(val_ds, batch_size=CFG.batch_size, shuffle=False, num_workers=2, pin_memory=True, collate_fn=collate_fn)

Downloading IAM & RIMES …


The secret `HF_TOKEN` does not exist in your Colab secrets.
To authenticate with the Hugging Face Hub, create a token in your settings tab (https://huggingface.co/settings/tokens), set it as secret in your Google Colab and restart your session.
You will be able to reuse this secret in all of your notebooks.
Please note that authentication is recommended but still optional to access public models or datasets.


README.md:   0%|          | 0.00/2.14k [00:00<?, ?B/s]

train.parquet:   0%|          | 0.00/167M [00:00<?, ?B/s]

validation.parquet:   0%|          | 0.00/24.7M [00:00<?, ?B/s]

test.parquet:   0%|          | 0.00/73.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/6482 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/976 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/2915 [00:00<?, ? examples/s]

README.md:   0%|          | 0.00/3.07k [00:00<?, ?B/s]

train.parquet:   0%|          | 0.00/212M [00:00<?, ?B/s]

validation.parquet:   0%|          | 0.00/23.6M [00:00<?, ?B/s]

test.parquet:   0%|          | 0.00/16.6M [00:00<?, ?B/s]

Generating train split:   0%|          | 0/10188 [00:00<?, ? examples/s]

Generating validation split:   0%|          | 0/1138 [00:00<?, ? examples/s]

Generating test split:   0%|          | 0/778 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/15003 [00:00<?, ? examples/s]

Map (num_proc=4):   0%|          | 0/1667 [00:00<?, ? examples/s]

Train 15003  Val 1667


# 2. Load model and inject LoRA

In [5]:
bnb_cfg = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=torch.float16, bnb_4bit_quant_type="nf4")

model = ColQwen2_5.from_pretrained(CFG.model_name, device_map="auto", torch_dtype=torch.float16, quantization_config=bnb_cfg).train()
processor = ColQwen2_5_Processor.from_pretrained(CFG.model_name)

model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": False})
model.enable_input_require_grads()
model = prepare_model_for_kbit_training(model)

lora_cfg = LoraConfig(r=CFG.lora_rank,
                      lora_alpha=CFG.lora_rank*2,
                      lora_dropout=0.05,
                      bias="none",
                      target_modules=["q_proj","k_proj","v_proj","o_proj"],
                      task_type="FEATURE_EXTRACTION")
model = get_peft_model(model, lora_cfg)
print(model.print_trainable_parameters())

adapter_config.json:   0%|          | 0.00/805 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/1.19k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/57.8k [00:00<?, ?B/s]

Fetching 7 files:   0%|          | 0/7 [00:00<?, ?it/s]

model-00004-of-00007.safetensors:   0%|          | 0.00/5.00G [00:00<?, ?B/s]

model-00002-of-00007.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

model-00007-of-00007.safetensors:   0%|          | 0.00/3.39G [00:00<?, ?B/s]

model-00001-of-00007.safetensors:   0%|          | 0.00/4.95G [00:00<?, ?B/s]

model-00003-of-00007.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00006-of-00007.safetensors:   0%|          | 0.00/4.93G [00:00<?, ?B/s]

model-00005-of-00007.safetensors:   0%|          | 0.00/4.98G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/7 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/214 [00:00<?, ?B/s]

adapter_model.safetensors:   0%|          | 0.00/323M [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/574 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/7.33k [00:00<?, ?B/s]

vocab.json:   0%|          | 0.00/2.78M [00:00<?, ?B/s]

merges.txt:   0%|          | 0.00/1.67M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/11.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/605 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/613 [00:00<?, ?B/s]

processor_config.json:   0%|          | 0.00/81.0 [00:00<?, ?B/s]

chat_template.json:   0%|          | 0.00/1.05k [00:00<?, ?B/s]



trainable params: 80,859,136 || all params: 8,373,484,672 || trainable%: 0.9657
None


# 3. Train

In [None]:
acc = Accelerator(log_with="tensorboard", project_dir=CFG.out_dir)
optim = torch.optim.AdamW(model.parameters(), lr=CFG.lr, weight_decay=1e-2)
steps = len(train_loader)*CFG.epochs
sched = get_cosine_schedule_with_warmup(optim, CFG.warmup_steps, steps)

model, optim, train_loader, val_loader, sched = acc.prepare(model, optim, train_loader, val_loader, sched)

for ep in range(CFG.epochs):
    model.train(); running=total=0
    pbar = tqdm(train_loader, disable=not acc.is_local_main_process)
    for images, texts in pbar:
        batch_img = processor.process_images(images).to(acc.device)
        batch_txt = processor.process_queries(texts).to(acc.device)

        img_emb = model(**batch_img)
        txt_emb = model(**batch_txt)
        # mean‑pool the multi‑vector outputs (keeps grad path)
        img_emb = F.normalize(img_emb.mean(dim=1), dim=-1)
        txt_emb = F.normalize(txt_emb.mean(dim=1), dim=-1)
        logits   = img_emb @ txt_emb.T * 100.0
        labels   = torch.arange(logits.size(0), device=logits.device)
        loss     = (F.cross_entropy(logits, labels) + F.cross_entropy(logits.T, labels)) / 2

        acc.backward(loss)
        optim.step(); sched.step(); optim.zero_grad()

        running += loss.item() * len(images)
        total += len(images)
    acc.save_state(CFG.out_dir)

# 4. Snippet for reloading the model (for reference)

In [None]:
import os, random, torch
from dataclasses import dataclass
from safetensors.torch import load_file
from peft import LoraConfig, get_peft_model
from transformers import BitsAndBytesConfig
from colqwen2_5 import ColQwen2_5, ColQwen2_5_Processor

@dataclass
class CFG:
    model_name:  str = "nomic-ai/colnomic-embed-multimodal-7b"
    image_size:  int = 224
    max_txt_len: int = 96
    lora_rank:   int = 32
    out_dir:     str = "drive/MyDrive/colnomic_lora_handwriting"   # folder contains model.safetensors


# rebuild the *exact* empty model
bnb_cfg = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_use_double_quant=True,
        bnb_4bit_compute_dtype=torch.float16,
        bnb_4bit_quant_type="nf4",
)

base_model = ColQwen2_5.from_pretrained(
        CFG.model_name,
        device_map="auto",
        torch_dtype=torch.float16,
        quantization_config=bnb_cfg,
).eval()                                   # inference mode

# recreate the LoRA scaffolding (same targets / rank as before)
lora_cfg = LoraConfig(
        r               = CFG.lora_rank,
        lora_alpha      = CFG.lora_rank * 2,
        lora_dropout    = 0.05,
        bias            = "none",
        target_modules  = ["q_proj","k_proj","v_proj","o_proj"],
        task_type       = "FEATURE_EXTRACTION",
)
model = get_peft_model(base_model, lora_cfg)
model.eval()                               # no dropout, no grad
processor = ColQwen2_5_Processor.from_pretrained(CFG.model_name)
