# Qwen2-VL + (Q)LoRA on ROCOv2

Colab-ready notebook you can run from VS Code. It downloads ROCOv2 via Hugging Face `datasets`, sets up Qwen2-VL, and fine-tunes with LoRA/QLoRA for radiology captioning.

Storage recommendations:
- Quick experiments: keep everything in `/content/roco` (fastest, lost on runtime reset).
- Persistent: mount Drive and point `BASE_DIR` to `/content/drive/MyDrive/roco_v2`.

In [1]:
# Install deps (pin transformers to a recent version that supports Qwen2-VL)
# Pin numpy/pandas/datasets to avoid ABI mismatches on Colab base image.
!pip install -q --upgrade torch torchvision torchaudio transformers==4.42.3 accelerate peft bitsandbytes \
    numpy==1.26.4 pandas==2.2.2 datasets==2.19.1 pillow tqdm

[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m43.6/43.6 kB[0m [31m4.3 MB/s[0m eta [36m0:00:00[0m
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m61.0/61.0 kB[0m [31m6.4 MB/s[0m eta [36m0:00:00[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m9.3/9.3 MB[0m [31m86.6 MB/s[0m eta [36m0:00:00[0m:00:01[0m0:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m899.7/899.7 MB[0m [31m2.2 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m594.3/594.3 MB[0m [31m2.3 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m10.2/10.2 MB[0m [31m153.6 MB/s[0m eta [36m0:00:00[0m [36m0:00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m88.0/88.0 MB[0m [31m31.1 MB/s[0m eta [36m0:00:00[0m:00:01[0m00:01[0m
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m954.8/954.

In [None]:
# Optional: mount Google Drive for persistent storage
USE_DRIVE = True  # flip to True if you want to store data/checkpoints on Drive
if USE_DRIVE:
    from google.colab import drive
    drive.mount('/content/drive')

In [None]:
from pathlib import Path
import torch

BASE_DIR = Path('/content/drive/MyDrive/roco_v2') if USE_DRIVE else Path('/content/roco')
DATA_DIR = BASE_DIR / 'data'
CKPT_DIR = BASE_DIR / 'checkpoints'
DATA_DIR.mkdir(parents=True, exist_ok=True)
CKPT_DIR.mkdir(parents=True, exist_ok=True)
print('Using base dir:', BASE_DIR)

In [None]:
# Download ROCOv2 via HF datasets.
# Official 'roco' on HF requires manual download (due to license) and will raise EmptyDatasetError if data isn't present.
# Use a hosted mirror (default below) or download manually to DATA_DIR and set DATASET_NAME='roco'.
# If you have an HF token, set HF_TOKEN env var to avoid rate limits.
from datasets import load_dataset, DownloadConfig
from datasets.exceptions import EmptyDatasetError

DATASET_NAME = 'flaviagiammarino/roco-dataset'  # change to 'roco' if you place official files manually
DATASET_CONFIG = None  # for the mirror; use 'en' for the official script
TRAIN_SPLIT = 'train[:30]'  # trim for first experiments; bump when ready
VAL_SPLIT = 'validation[:2]'

try:
    raw_ds = load_dataset(
        DATASET_NAME,
        DATASET_CONFIG,
        cache_dir=str(DATA_DIR),
        download_config=DownloadConfig(use_auth_token=True),
    )
except EmptyDatasetError as e:
    raise RuntimeError('ROCO official script needs local files. Download the dataset to DATA_DIR and set DATASET_NAME="roco", or point DATASET_NAME to a hosted mirror that includes images.') from e

if TRAIN_SPLIT:
    train_ds = raw_ds[TRAIN_SPLIT]
else:
    train_ds = raw_ds['train']
if VAL_SPLIT:
    val_ds = raw_ds[VAL_SPLIT]
else:
    val_ds = raw_ds['validation']

print(train_ds)
print(val_ds)

# Inspect fields to confirm caption key
print('Sample keys:', train_ds.column_names)
print('Example:', train_ds[0])


If the caption column is not `caption` or `text`, update `CAPTION_KEY` below accordingly.

In [None]:
from PIL import Image

CAPTION_KEY = 'caption' if 'caption' in train_ds.column_names else 'text'

def select_caption(ex):
    if CAPTION_KEY in ex and ex[CAPTION_KEY]:
        return ex[CAPTION_KEY]
    for key in ['caption_en', 'description', 'report']:
        if key in ex and ex[key]:
            return ex[key]
    raise ValueError('Caption field not found; please set CAPTION_KEY manually.')

print('Using caption field:', CAPTION_KEY)

In [None]:
from transformers import AutoProcessor

MODEL_ID = 'Qwen/Qwen2-VL-7B-Instruct'
processor = AutoProcessor.from_pretrained(MODEL_ID)

def make_prompt(example):
    caption = select_caption(example)
    messages = [
        {
            'role': 'user',
            'content': [
                {'type': 'image', 'image': example['image']},
                {'type': 'text', 'text': 'Provide a concise radiology caption for this image.'},
            ],
        }
    ]
    prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    example['prompt'] = prompt
    example['target_text'] = caption
    return example

train_ds = train_ds.map(make_prompt)
val_ds = val_ds.map(make_prompt)

In [None]:
import torch
from torch.utils.data import DataLoader
from transformers import AutoModelForVision2Seq, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type='nf4',
    bnb_4bit_compute_dtype=torch.bfloat16,
)

model = AutoModelForVision2Seq.from_pretrained(
    MODEL_ID,
    quantization_config=bnb_config,
    device_map='auto',
)
model.gradient_checkpointing_enable()

target_modules = ['q_proj', 'k_proj', 'v_proj', 'o_proj', 'gate_proj', 'up_proj', 'down_proj']
lora_config = LoraConfig(
    r=8,
    lora_alpha=32,
    target_modules=target_modules,
    lora_dropout=0.05,
    bias='none',
    task_type='SEQ_2_SEQ_LM',
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

In [None]:
def collate_fn(batch):
    texts = [ex['prompt'] + ex['target_text'] for ex in batch]
    images = [ex['image'] for ex in batch]
    model_inputs = processor(text=texts, images=images, padding=True, return_tensors='pt')
    labels = model_inputs['input_ids'].clone()
    labels[model_inputs['attention_mask'] == 0] = -100
    return {**model_inputs, 'labels': labels}

train_loader = DataLoader(train_ds, batch_size=2, shuffle=True, collate_fn=collate_fn)
val_loader = DataLoader(val_ds, batch_size=2, shuffle=False, collate_fn=collate_fn)

batch = next(iter(train_loader))
for k, v in batch.items():
    if torch.is_tensor(v):
        print(k, v.shape, v.dtype)
    else:
        print(k, type(v))

In [None]:
from transformers import AdamW, get_cosine_schedule_with_warmup
from tqdm.auto import tqdm

epochs = 1
lr = 2e-4
optimizer = AdamW(model.parameters(), lr=lr)
num_steps = epochs * len(train_loader)
warmup_steps = int(0.03 * num_steps)
scheduler = get_cosine_schedule_with_warmup(optimizer, warmup_steps, num_steps)

model.train()
for epoch in range(epochs):
    pbar = tqdm(train_loader, desc=f'Epoch {epoch+1}/{epochs}')
    for batch in pbar:
        batch = {k: v.to(model.device) if torch.is_tensor(v) else v for k, v in batch.items()}
        outputs = model(**batch)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        scheduler.step()
        optimizer.zero_grad()
        pbar.set_postfix({'loss': loss.item()})

    ckpt_path = CKPT_DIR / f'epoch_{epoch+1}'
    model.save_pretrained(ckpt_path)
    processor.save_pretrained(ckpt_path)
    print('Saved checkpoint to', ckpt_path)

In [None]:
# Quick qualitative check
model.eval()
sample = val_ds[0]
messages = [
    {
        'role': 'user',
        'content': [
            {'type': 'image', 'image': sample['image']},
            {'type': 'text', 'text': 'Provide a concise radiology caption for this image.'},
        ],
    }
]
prompt = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = processor(text=prompt, images=sample['image'], return_tensors='pt').to(model.device)
with torch.no_grad():
    generated_ids = model.generate(**inputs, max_new_tokens=64)
    output_text = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print('Prediction:', output_text)
print('Reference:', sample['target_text'])