<a href="https://colab.research.google.com/github/MeenakshiRajpurohit/CMPE-252-AI-and-Data-Engineering/blob/main/VQA.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
# ============================================================
#  Finance Domain-Specific Fine-Tuning of VQA Models
#  Dataset: sujet-ai/Sujet-Finance-QA-Vision-100k
#  Model:   Qwen/Qwen2-VL-2B-Instruct
#  Runs on: Google Colab (T4 GPU)
# ============================================================
# HOW TO USE:
#   1. New Colab notebook -> Runtime -> T4 GPU
#   2. Run Section 2 first, then Runtime -> Restart session
#   3. Run all remaining sections in order
# ============================================================


# ────────────────────────────────────────────────────────────
# SECTION 1 — Check GPU
# ────────────────────────────────────────────────────────────
import torch
print(f"CUDA available : {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"GPU            : {torch.cuda.get_device_name(0)}")
    print(f"GPU Memory     : {torch.cuda.get_device_properties(0).total_memory / 1e9:.2f} GB")
else:
    raise SystemError("No GPU! Go to Runtime > Change runtime type > T4 GPU")


# ────────────────────────────────────────────────────────────
# SECTION 2 — Install Dependencies
# *** RESTART SESSION after this cell ***
# ────────────────────────────────────────────────────────────
import subprocess, sys

cmds = [
    "pip install -q transformers==4.49.0",
    "pip install -q qwen-vl-utils",
    "pip install -q peft accelerate bitsandbytes",
    "pip install -q datasets pillow torchvision",
]
for cmd in cmds:
    subprocess.run(cmd.split(), check=True)

print("\nDone! >>> Runtime -> Restart session, then run from Section 3 <<<")


# ────────────────────────────────────────────────────────────
# SECTION 3 — HF Login (optional)
# ────────────────────────────────────────────────────────────
try:
    from google.colab import userdata
    from huggingface_hub import login
    token = userdata.get("HF_TOKEN")
    if token:
        login(token=token)
        print("Logged in")
    else:
        print("Anonymous access")
except Exception:
    print("Anonymous access")


# ────────────────────────────────────────────────────────────
# SECTION 4 — Load and Flatten Dataset
# ────────────────────────────────────────────────────────────
from datasets import load_dataset
import matplotlib.pyplot as plt
import json, ast, torch

print("Loading dataset ...")
dataset = load_dataset(
    "sujet-ai/Sujet-Finance-QA-Vision-100k",
    split="train",
    trust_remote_code=True,
)
print(f"Loaded {len(dataset):,} documents")
print("Raw qa_pairs sample:", str(dataset[0]["qa_pairs"])[:300])


def parse_qa_pairs(qa_str):
    if not qa_str:
        return []
    try:
        pairs = json.loads(qa_str)
    except Exception:
        try:
            pairs = ast.literal_eval(qa_str)
        except Exception:
            return []
    if isinstance(pairs, dict):
        pairs = [pairs]
    if not isinstance(pairs, list):
        return []
    results = []
    for p in pairs:
        if not isinstance(p, dict):
            continue
        q = (p.get("question") or p.get("Q") or p.get("q") or "").strip()
        a = (p.get("answer")   or p.get("A") or p.get("a") or "").strip()
        if q and a:
            results.append((q, a))
    return results


flat_images, flat_questions, flat_answers = [], [], []
for row in dataset:
    for q, a in parse_qa_pairs(row["qa_pairs"]):
        flat_images.append(row["image"])
        flat_questions.append(q)
        flat_answers.append(a)

print(f"Total QA pairs: {len(flat_questions):,}")
print(f"Sample Q: {flat_questions[0][:120]}")
print(f"Sample A: {flat_answers[0][:120]}")


# ────────────────────────────────────────────────────────────
# SECTION 5 — Load Model & Processor
# ────────────────────────────────────────────────────────────
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor

MODEL_NAME = "Qwen/Qwen2-VL-2B-Instruct"
DEVICE     = "cuda"
DTYPE      = torch.bfloat16

processor = AutoProcessor.from_pretrained(
    MODEL_NAME,
    min_pixels=128 * 28 * 28,
    max_pixels=256 * 28 * 28,  # smaller = less GPU memory per image
)

model = Qwen2VLForConditionalGeneration.from_pretrained(
    MODEL_NAME,
    torch_dtype=DTYPE,
    device_map="auto",
)
print(f"Params: {sum(p.numel() for p in model.parameters())/1e9:.2f}B")
print(f"GPU   : {torch.cuda.memory_allocated()/1e9:.2f} GB used")


# ────────────────────────────────────────────────────────────
# SECTION 6 — Apply LoRA
# ────────────────────────────────────────────────────────────
from peft import LoraConfig, get_peft_model, TaskType

model = get_peft_model(model, LoraConfig(
    r=8,
    lora_alpha=16,
    target_modules="all-linear",
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM,
))
model.print_trainable_parameters()


# ────────────────────────────────────────────────────────────
# SECTION 7 — Dataset & DataLoader
# ────────────────────────────────────────────────────────────
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import random

BATCH_SIZE    = 1        # 1 is safest on T4; increase to 2 if no OOM
TRAIN_SAMPLES = 3000
VAL_SAMPLES   = 200
SEED          = 42
MAX_SEQ_LEN   = 768


class FinanceVQADataset(Dataset):
    def __init__(self, images, questions, answers):
        self.images    = images
        self.questions = questions
        self.answers   = answers

    def __len__(self):
        return len(self.images)

    def __getitem__(self, idx):
        img = self.images[idx]
        if not isinstance(img, Image.Image):
            img = Image.fromarray(img)
        return {
            "image":    img.convert("RGB"),
            "question": self.questions[idx],
            "answer":   self.answers[idx],
        }


def make_messages(question, answer=None):
    """
    Build Qwen2-VL message list.
    If answer is provided, append it as the assistant turn (for training).
    """
    user_content = [
        {"type": "image", "image": "placeholder"},   # placeholder replaced below
        {"type": "text",  "text": (
            "You are a financial document expert. "
            "Answer the question based on the document image.\n\n"
            f"Question: {question}"
        )},
    ]
    msgs = [{"role": "user", "content": user_content}]
    if answer is not None:
        msgs.append({"role": "assistant", "content": answer})
    return msgs


def collate_fn(batch):
    """
    Key insight: we must let the processor build input_ids and pixel_values
    together from the SAME messages so the image token count matches.
    We then mask the prompt part of labels with -100.
    """
    all_input_ids      = []
    all_attention_mask = []
    all_labels         = []
    all_pixel_values   = []
    all_image_grid_thw = []

    for item in batch:
        image    = item["image"]
        question = item["question"]
        answer   = item["answer"]

        # ── Build full conversation (user + assistant) ──────────
        # We use qwen_vl_utils to properly insert the image
        from qwen_vl_utils import process_vision_info

        full_messages = [
            {
                "role": "user",
                "content": [
                    {"type": "image", "image": image},
                    {"type": "text",  "text": (
                        "You are a financial document expert. "
                        "Answer the question based on the document image.\n\n"
                        f"Question: {question}"
                    )},
                ],
            },
            {"role": "assistant", "content": answer},
        ]

        prompt_messages = full_messages[:-1]  # user turn only

        # Apply chat template to full text and prompt only
        full_text   = processor.apply_chat_template(full_messages,   tokenize=False, add_generation_prompt=False)
        prompt_text = processor.apply_chat_template(prompt_messages, tokenize=False, add_generation_prompt=True)

        # Extract image inputs using qwen_vl_utils
        image_inputs, _ = process_vision_info(full_messages)

        # Process full sequence with image
        enc = processor(
            text=[full_text],
            images=image_inputs,
            return_tensors="pt",
            padding=False,
            truncation=True,
            max_length=MAX_SEQ_LEN,
        )

        # Process prompt only (no image needed for length calculation)
        # We tokenize just the text to find where the answer starts
        prompt_ids = processor.tokenizer(
            prompt_text,
            return_tensors="pt",
            add_special_tokens=False,
        ).input_ids

        input_ids = enc["input_ids"][0]       # shape: (seq_len,)
        labels    = input_ids.clone()

        # Mask everything up to (but not including) the answer tokens
        prompt_len = min(len(prompt_ids[0]), len(labels))
        labels[:prompt_len] = -100            # don't compute loss on prompt

        all_input_ids.append(input_ids)
        all_attention_mask.append(enc["attention_mask"][0])
        all_labels.append(labels)
        all_pixel_values.append(enc["pixel_values"])
        all_image_grid_thw.append(enc["image_grid_thw"])

    # ── Pad to longest sequence in batch ───────────────────────
    max_len = max(x.shape[0] for x in all_input_ids)
    pad_id  = processor.tokenizer.pad_token_id or 0

    batch_input_ids = torch.full((len(batch), max_len), pad_id,  dtype=torch.long)
    batch_attn_mask = torch.zeros((len(batch), max_len),          dtype=torch.long)
    batch_labels    = torch.full((len(batch), max_len), -100,     dtype=torch.long)

    for i, (ids, mask, lbls) in enumerate(zip(all_input_ids, all_attention_mask, all_labels)):
        L = ids.shape[0]
        batch_input_ids[i, :L] = ids
        batch_attn_mask[i, :L] = mask
        batch_labels[i, :L]    = lbls

    pixel_values   = torch.cat(all_pixel_values,   dim=0)
    image_grid_thw = torch.cat(all_image_grid_thw, dim=0)

    return {
        "input_ids":      batch_input_ids.to(DEVICE),
        "attention_mask": batch_attn_mask.to(DEVICE),
        "labels":         batch_labels.to(DEVICE),
        "pixel_values":   pixel_values.to(DEVICE, dtype=DTYPE),
        "image_grid_thw": image_grid_thw.to(DEVICE),
    }


# Split
indices = list(range(len(flat_images)))
random.seed(SEED)
random.shuffle(indices)

n_val   = min(VAL_SAMPLES, int(len(indices) * 0.1))
n_train = min(TRAIN_SAMPLES, len(indices) - n_val)
tr_idx  = indices[:n_train]
va_idx  = indices[n_train:n_train + n_val]

train_ds = FinanceVQADataset([flat_images[i] for i in tr_idx],
                              [flat_questions[i] for i in tr_idx],
                              [flat_answers[i]   for i in tr_idx])
val_ds   = FinanceVQADataset([flat_images[i] for i in va_idx],
                              [flat_questions[i] for i in va_idx],
                              [flat_answers[i]   for i in va_idx])

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          collate_fn=collate_fn, num_workers=0)
val_loader   = DataLoader(val_ds,   batch_size=BATCH_SIZE, shuffle=False,
                          collate_fn=collate_fn, num_workers=0)

print(f"Train: {len(train_ds):,} | Val: {len(val_ds):,}")

# Quick sanity check — confirm one batch works before training
print("\nRunning sanity check on one batch ...")
test_batch = next(iter(train_loader))
with torch.no_grad():
    out = model(**test_batch)
print(f"Sanity check passed! Loss = {out.loss.item():.4f}")
del test_batch
torch.cuda.empty_cache()


# ────────────────────────────────────────────────────────────
# SECTION 8 — Optimizer & Scheduler
# ────────────────────────────────────────────────────────────
from torch.optim import AdamW
from torch.optim.lr_scheduler import OneCycleLR
from torch.cuda.amp import GradScaler, autocast
import os, time

NUM_EPOCHS         = 3
LEARNING_RATE      = 2e-4
WEIGHT_DECAY       = 0.01
GRAD_CLIP          = 1.0
ACCUMULATION_STEPS = 8
LOG_EVERY          = 20
SAVE_DIR           = "/content/finance_vqa_ckpt"

optimizer = AdamW(
    [p for p in model.parameters() if p.requires_grad],
    lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY,
)
total_steps = max(1, (len(train_loader) * NUM_EPOCHS) // ACCUMULATION_STEPS)
scheduler   = OneCycleLR(optimizer, max_lr=LEARNING_RATE,
                         total_steps=total_steps, pct_start=0.1)
scaler      = GradScaler(enabled=True)

print(f"Steps: {total_steps} | Effective batch: {BATCH_SIZE * ACCUMULATION_STEPS}")


# ────────────────────────────────────────────────────────────
# SECTION 9 — Training Loop
# ────────────────────────────────────────────────────────────

def evaluate(model, loader, max_batches=15):
    model.eval()
    total, n = 0.0, 0
    with torch.no_grad():
        for i, batch in enumerate(loader):
            if i >= max_batches:
                break
            try:
                with autocast(dtype=torch.bfloat16):
                    out = model(**batch)
                if out.loss is not None:
                    total += out.loss.item(); n += 1
            except Exception:
                pass
    return total / n if n else float("inf")


def save_ckpt(model, processor, epoch, step, loss):
    p = os.path.join(SAVE_DIR, f"ep{epoch}_s{step}_l{loss:.3f}")
    os.makedirs(p, exist_ok=True)
    model.save_pretrained(p); processor.save_pretrained(p)
    print(f"Saved -> {p}")
    return p


train_losses, val_losses, step_log = [], [], []
global_step = 0
best_val    = float("inf")
best_ckpt   = None

print("=" * 55)
print("Finance VQA Fine-Tuning — Qwen2-VL-2B + LoRA")
print("=" * 55)

for epoch in range(NUM_EPOCHS):
    model.train()
    ep_loss, t0 = 0.0, time.time()
    optimizer.zero_grad()

    for bi, batch in enumerate(train_loader):
        try:
            with autocast(dtype=torch.bfloat16):
                loss = model(**batch).loss / ACCUMULATION_STEPS
            scaler.scale(loss).backward()

            if (bi + 1) % ACCUMULATION_STEPS == 0:
                scaler.unscale_(optimizer)
                torch.nn.utils.clip_grad_norm_(
                    [p for p in model.parameters() if p.requires_grad], GRAD_CLIP)
                scaler.step(optimizer); scaler.update()
                scheduler.step(); optimizer.zero_grad()
                global_step += 1

            ep_loss += loss.item() * ACCUMULATION_STEPS

            if (bi + 1) % LOG_EVERY == 0:
                avg = ep_loss / (bi + 1)
                print(f"[Ep{epoch+1}/{NUM_EPOCHS}] "
                      f"B{bi+1}/{len(train_loader)} | "
                      f"Loss {avg:.4f} | "
                      f"GPU {torch.cuda.memory_allocated()/1e9:.1f}GB | "
                      f"{time.time()-t0:.0f}s")
                train_losses.append(avg); step_log.append(global_step)

        except RuntimeError as e:
            if "out of memory" in str(e).lower():
                print(f"OOM at batch {bi} — skipping")
                torch.cuda.empty_cache(); optimizer.zero_grad()
            else:
                raise

    vl = evaluate(model, val_loader)
    at = ep_loss / max(1, len(train_loader))
    print(f"\nEpoch {epoch+1}: train={at:.4f} val={vl:.4f} "
          f"({(time.time()-t0)/60:.1f}min)\n")
    val_losses.append(vl)
    if vl < best_val:
        best_val  = vl
        best_ckpt = save_ckpt(model, processor, epoch+1, global_step, vl)
    model.train()

print(f"Done! Best val loss: {best_val:.4f}  ckpt: {best_ckpt}")


# ────────────────────────────────────────────────────────────
# SECTION 10 — Plot Curves
# ────────────────────────────────────────────────────────────
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(13, 4))
if train_losses:
    ax1.plot(step_log, train_losses, "b-"); ax1.set_title("Train Loss"); ax1.grid(True, alpha=0.3)
if val_losses:
    ax2.plot(range(1, len(val_losses)+1), val_losses, "r-o"); ax2.set_title("Val Loss"); ax2.grid(True, alpha=0.3)
plt.suptitle("Finance VQA Fine-Tuning"); plt.tight_layout()
plt.savefig("/content/curves.png", dpi=130); plt.show()


# ────────────────────────────────────────────────────────────
# SECTION 11 — Inference
# ────────────────────────────────────────────────────────────
from qwen_vl_utils import process_vision_info

def run_inference(model, processor, image, question, max_new_tokens=128):
    model.eval()
    if not isinstance(image, Image.Image):
        image = Image.fromarray(image)
    image = image.convert("RGB")

    messages = [{
        "role": "user",
        "content": [
            {"type": "image", "image": image},
            {"type": "text",  "text": f"You are a financial document expert.\nQuestion: {question}"},
        ],
    }]
    prompt       = processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
    image_inputs, _ = process_vision_info(messages)

    inputs = processor(text=[prompt], images=image_inputs,
                       return_tensors="pt", padding=True)
    inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
    if "pixel_values" in inputs:
        inputs["pixel_values"] = inputs["pixel_values"].to(dtype=DTYPE)

    with torch.no_grad():
        out_ids = model.generate(**inputs, max_new_tokens=max_new_tokens,
                                 do_sample=False, repetition_penalty=1.1)

    answer = processor.tokenizer.decode(
        out_ids[0][inputs["input_ids"].shape[1]:], skip_special_tokens=True
    ).strip()
    return answer


print("Inference on 5 validation samples:\n")
for i in range(min(5, len(val_ds))):
    item = val_ds[i]
    pred = run_inference(model, processor, item["image"], item["question"])
    print(f"--- {i+1} ---")
    print(f"Q   : {item['question'][:180]}")
    print(f"GT  : {item['answer'][:180]}")
    print(f"PRED: {pred[:180]}\n")


# ────────────────────────────────────────────────────────────
# SECTION 12 — Metrics
# ────────────────────────────────────────────────────────────
import re, numpy as np
from collections import Counter

def norm(s):
    s = re.sub(r"[^\w\s]", "", str(s).lower().strip())
    return re.sub(r"\s+", " ", s).strip()

def em(p, g):   return int(norm(p) == norm(g))
def f1(p, g):
    pt = norm(p).split(); gt = norm(g).split()
    if not pt or not gt: return 0.0
    c = sum((Counter(pt) & Counter(gt)).values())
    if c == 0: return 0.0
    return 2*c/(len(pt)+len(gt))

ems, f1s = [], []
N = min(50, len(val_ds))
for i in range(N):
    item = val_ds[i]
    pred = run_inference(model, processor, item["image"], item["question"])
    ems.append(em(pred, item["answer"]))
    f1s.append(f1(pred, item["answer"]))
    if (i+1) % 10 == 0: print(f"{i+1}/{N} evaluated")

print(f"\nExact Match : {np.mean(ems)*100:.1f}%")
print(f"Token F1    : {np.mean(f1s)*100:.1f}%")


# ────────────────────────────────────────────────────────────
# SECTION 13 — Save & Download
# ────────────────────────────────────────────────────────────
import subprocess
from google.colab import files

OUT = "/content/finance_vqa_final"
os.makedirs(OUT, exist_ok=True)
model.save_pretrained(OUT); processor.save_pretrained(OUT)
subprocess.run(["zip", "-r", "-q", "/content/finance_vqa.zip", OUT])
files.download("/content/finance_vqa.zip")
print("Download started!")


# ────────────────────────────────────────────────────────────
# SECTION 14 — Reload Snippet
# ────────────────────────────────────────────────────────────
print("""
from transformers import Qwen2VLForConditionalGeneration, AutoProcessor
from peft import PeftModel
import torch

ADAPTER = "/content/finance_vqa_final"
BASE    = "Qwen/Qwen2-VL-2B-Instruct"

processor = AutoProcessor.from_pretrained(ADAPTER)
base  = Qwen2VLForConditionalGeneration.from_pretrained(
    BASE, torch_dtype=torch.bfloat16, device_map="auto")
model = PeftModel.from_pretrained(base, ADAPTER)
model = model.merge_and_unload()
model.eval()
""")

CUDA available : True
GPU            : NVIDIA A100-SXM4-80GB
GPU Memory     : 85.09 GB


KeyboardInterrupt: 