In [1]:
import os
import json
import argparse
from datasets import load_dataset
import torch
from transformers import Qwen2_5_VLForConditionalGeneration, Qwen2_5_VLProcessor
from peft import get_peft_model, prepare_model_for_kbit_training
from trl import SFTConfig, SFTTrainer
from qwen_vl_utils import process_vision_info
from peft import LoraConfig
from transformers.utils.import_utils import is_flash_attn_2_available
from PIL import Image
# System prompt for MPDocVQA
SYSTEM_MESSAGE = """
You are a vision-language assistant specialized in answering questions based on document page images.
Given a question about the document, use the provided page images to only generate accurate, short and concise answers.
"""

def load_candidates(cands_path: str) -> dict:
    with open(cands_path, 'r') as f:
        return json.load(f)




  from .autonotebook import tqdm as notebook_tqdm


In [2]:
def format_data(example: dict, image_dir: str, candidates: dict, top_k: int) -> dict:
    # Extract QID, question, and answer
    qid = example['questionId']
    question = example['question']
    answers = example.get('answers', [])
    answer = answers[0] if answers else ""

    # Construct chat messages

    # Select top-k candidate page IDs
    cand_pages = candidates.get(str(qid), [])[:top_k]
    # print(cand_pages, qid)
    image_paths = [os.path.join(image_dir, 'images', f"{pid}.jpg") for pid, _ in cand_pages]
    # print(image_paths)
    user_messages_content = [
        {
                "type": "text",
                "text": question
        }
    ]
    for image_path in image_paths:
        user_messages_content.append(
            {
                "type": "image",
                "image": image_path
            }
        )

    messages = [
        {"role": "system",    "content": {"type": "text", "text": SYSTEM_MESSAGE}},
        {
            "role": "user",
            "content": user_messages_content
        },
        {"role": "assistant", "content": {"type": "text", "text": answer}},
    ]
    return {"messages": messages}

def collate_fn(batch, processor):
    # Separate raw multimodal content and chat msgs
    # print(batch)
    messages = [msg["messages"] for msg in batch]
    texts = [
        processor.apply_chat_template(msg, tokenize=False, add_generation_prompt=True)
        for msg in messages
    ]
    

    # Process vision inputs
    image_inputs, video_inputs = process_vision_info(messages)

    # Tokenize text + align with vision features
    batch_enc = processor(
        text=texts,
        images=image_inputs,
        videos=video_inputs,
        return_tensors="pt",
        padding=True
    )
    labels = batch_enc['input_ids'].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100 # Mask padding tokens in labels
    image_tokens = [processor.tokenizer.convert_tokens_to_ids(processor.image_token)]
    for image_token_id in image_tokens:
        labels[labels == image_token_id] = -100  # Mask image token IDs in labels
    
    batch_enc["labels"] = labels


    return batch_enc



In [3]:
from dataclasses import dataclass

In [4]:
@dataclass
class SFTArgs:
    train_json: str = "mpdocvqa/question_answers/val.json"
    candidates_json: str = "close_vanilla_colqwen_val_eval_4.json"
    root_dir: str= "mpdocvqa"
    eval_json  = None
    model_id: str = "Qwen/Qwen2.5-VL-7B-Instruct"
    output_dir: str = "./qwen2vl-sft-mpdocvqa"
    num_epochs: int = 1
    batch_size: int = 1
    lr: float = 2e-4
    top_k: int = 1
args = SFTArgs()

In [5]:
raw_train = load_dataset('json', data_files={'train': args.train_json}, field='data')['train']
raw_eval  = load_dataset('json', data_files={'eval': args.eval_json}, field='data')['eval'] if args.eval_json else None
print(raw_train)
# Load candidate-page mapping
# Load candidate mappings and prepare samples
candidates = load_candidates(args.candidates_json)
# print(candidates['49153'])
train_samples = [format_data(ex, args.root_dir, candidates, args.top_k) for ex in raw_train]
eval_samples  = [format_data(ex, args.root_dir, candidates, args.top_k) for ex in raw_eval] if raw_eval else None
# # Load model and processor with 4-bit quantization
# # bnb_config = BitsAndBytesConfig(
# #     load_in_4bit=True,
# #     bnb_4bit_use_double_quant=True,
# #     bnb_4bit_quant_type='nf4',
# #     bnb_4bit_compute_dtype=torch.bfloat16
# # )




Dataset({
    features: ['questionId', 'question', 'doc_id', 'page_ids', 'answers', 'answer_page_idx', 'data_split'],
    num_rows: 5187
})


In [6]:
model = Qwen2_5_VLForConditionalGeneration.from_pretrained(
    args.model_id,
    # quantization_config=bnb_config,
    device_map='auto',
    torch_dtype=torch.bfloat16,
    attn_implementation="flash_attention_2" if is_flash_attn_2_available() else None
)
processor = Qwen2_5_VLProcessor.from_pretrained(args.model_id, use_fast=True)
# Prepare model for k-bit training, then apply LoRA adapters
# model = prepare_model_for_kbit_training(model)
peft_config = LoraConfig(
    task_type='CAUSAL_LM',
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=['q_proj','v_proj']
)
model = get_peft_model(model, peft_config)
model.print_trainable_parameters()

Loading checkpoint shards: 100%|██████████| 5/5 [00:03<00:00,  1.65it/s]


trainable params: 5,046,272 || all params: 8,297,212,928 || trainable%: 0.0608


In [7]:
import wandb
wandb.init(project="my-ms-thesis")
wandb.config.update(model.config.to_dict(), allow_val_change=True)

[34m[1mwandb[0m: [32m[41mERROR[0m Failed to detect the name of this notebook. You can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.
[34m[1mwandb[0m: Currently logged in as: [33mak11089[0m to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


In [8]:
# # collate_fn.processor = processor
# candidates = load_candidates(args.candidates_json)
# train_ds = MPDocVQADataset(args.train_json, args.root_dir, candidates, args.top_k, processor)
# eval_ds  = MPDocVQADataset(args.eval_json, args.root_dir, candidates, args.top_k, processor) if args.eval_json else None

In [9]:
# train_ds.__dict__

In [10]:
# Configure SFT training
sft_config = SFTConfig(
    output_dir=args.output_dir,
    num_train_epochs=args.num_epochs,
    per_device_train_batch_size=args.batch_size,
    per_device_eval_batch_size=args.batch_size,
    gradient_accumulation_steps=8,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    gradient_checkpointing=True,
    # gradient_checkpoint=True,
    optim='adamw_torch_fused',
    learning_rate=args.lr,
    lr_scheduler_type='constant',
    logging_steps=1,
    # eval_steps=50,
    # eval_strategy='steps' if eval_samples else 'no',
    report_to=['wandb'],
    label_names = 'labels',
    remove_unused_columns=False,
    dataset_text_field='messages',
    dataset_kwargs={'skip_prepare_dataset': True},
    bf16=True
)
# Initialize SFT trainer
trainer = SFTTrainer(
    model=model,
    args=sft_config,
    train_dataset=train_samples,
    eval_dataset=eval_samples,
    data_collator=lambda exs: collate_fn(exs, processor),
    peft_config=peft_config,
    # tokenizer=processor.tokenizer
)

In [None]:
# Start training
trainer.train()
# Save adapters and tokenizer
# os.makedirs(args.output_dir, exist_ok=True)
# model.save_pretrained(args.output_dir)
# processor.tokenizer.save_pretrained(args.output_dir)

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...


Step,Training Loss
1,11.3778
2,11.3832
3,10.0037
4,9.1482
5,8.4899
6,7.1286
7,7.387
8,6.7641


In [None]:
# odel.save_pretrained(args.output_dir)
processor.save_pretrained(args.output_dir)

In [None]:
processor.image_processor.save_pretrained(args.output_dir)

In [None]:
!rm -rf args.output_dir 