## Setup

In [None]:
!pip install torch==2.6.0 torchvision torchaudio


In [None]:
!pip install https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.3.12/flash_attn-2.8.0+cu124torch2.6-cp310-cp310-linux_x86_64.whl


In [None]:
!pip install -U trl>=0.9.6 transformers>=4.42 peft>=0.12.0 accelerate>=0.33.0 bitsandbytes>=0.43.3 datasets>=2.18 trl>=0.9.6 qwen-vl-utils pillow


In [None]:
!env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:64

In [None]:
import os, torch
from datasets import load_dataset
import torch.nn.functional as F
from transformers import (AutoProcessor, AutoTokenizer, EarlyStoppingCallback, LlavaNextForConditionalGeneration, LlavaNextProcessor)
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, get_peft_model
from trl import SFTConfig, SFTTrainer
from PIL import Image, ImageOps

## Preprocessing

In [None]:
# --------- Config ---------
MODEL_ID = "llava-hf/llama3-llava-next-8b-hf"  # LLaVA-NeXT-Llama3-8B (HF port)
DATASET_REPO = "AI-4-Everyone/Visual-TableQA"
OUTPUT_DIR   = "llava-hf-sft-lora-tableqa"

# --------- Load data ---------
ds = load_dataset(DATASET_REPO)

train = ds.get("train")
evald = ds.get("validation")

Generating train split: 100%|██████████| 3617/3617 [00:13<00:00, 258.38 examples/s]
Generating validation split: 100%|██████████| 779/779 [00:00<00:00, 1027.62 examples/s]
Generating test split: 100%|██████████| 769/769 [00:00<00:00, 1245.79 examples/s]


In [None]:
use_bf16 = torch.cuda.is_bf16_supported()

In [None]:
model = LlavaNextForConditionalGeneration.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
    low_cpu_mem_usage=True,
)

processor = AutoProcessor.from_pretrained(MODEL_ID)
model.config.use_cache = False
model.gradient_checkpointing_enable()
model.config.pretraining_tp = 1
processor.tokenizer.padding_side = "right"

Fetching 5 files: 100%|██████████| 5/5 [00:59<00:00, 11.95s/it]
Loading checkpoint shards: 100%|██████████| 5/5 [00:05<00:00,  1.20s/it]
The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.
You have video processor config saved in `preprocessor.json` file which is deprecated. Video processor configs should be saved in their own `video_preprocessor.json` file. You can rename the file or load and save the processor back which renames it automatically. Loading from `preprocessor.json` will be removed in v5.0.


In [None]:
system_message = """You are a Vision Language Model specialized in interpreting visual data from charts and diagrams images.
Answer the questions strictly from the image, with clear, rigorous step-by-step justification. Stay concise, but include all reasoning that’s relevant."""


from typing import List, Dict

def build_messages(q:str, a:str):
    user = {"role": "user", "content": [{"type":"image"}, {"type":"text","text": q}]}
    asst = {"role": "assistant", "content": [{"type":"text","text": a}]}
    return [user], [user, asst]

LONG_EDGE = 1024

def to_pil(img):
    return img if isinstance(img, Image.Image) else Image.fromarray(img)

def clamp_long_edge(img, longest=LONG_EDGE):
    img = to_pil(img)
    # preserve aspect ratio; bicubic keeps text/lines readable
    return ImageOps.contain(img, (longest, longest), Image.Resampling.BICUBIC)

In [None]:
def collate_fn(examples: List[Dict]):
    images, prompts_user, prompts_full = [], [], []
    for ex in examples:
        img = clamp_long_edge(ex["image"])  # PIL.Image from datasets
        q, a = ex["question"], ex["answer"]
        user_msg, full_msg = build_messages(q, a)
        images.append(img)
        prompts_user.append(processor.apply_chat_template(user_msg, add_generation_prompt=True, tokenize=False))
        prompts_full.append(processor.apply_chat_template(full_msg, add_generation_prompt=False, tokenize=False))

    # Batched tokenize + image processing (handles image token expansion & padding)
    batch_user = processor(images, prompts_user, return_tensors="pt", padding=True)
    batch_full = processor(images, prompts_full, return_tensors="pt", padding=True)

    input_ids = batch_full["input_ids"]
    attention_mask = batch_full["attention_mask"]
    pixel_values = batch_full["pixel_values"]

    # Compute per-sample offset = length of user-only sequence (including expanded image tokens)
    user_lengths = batch_user["attention_mask"].sum(dim=1)

    # Build labels: supervise only assistant tokens
    labels = input_ids.clone()
    for i, L in enumerate(user_lengths.tolist()):
        labels[i, :int(L)] = -100  # ignore user part

    # Create batch dict with all required fields
    batch_dict = {
        "input_ids": input_ids,
        "attention_mask": attention_mask,
        "pixel_values": pixel_values,
        "labels": labels,
    }

    # Add image_sizes if present in the batch (required for LlavaNext)
    if "image_sizes" in batch_full:
        batch_dict["image_sizes"] = batch_full["image_sizes"]
    else:
        # Fallback: create image_sizes from actual image dimensions
        image_sizes = []
        for img in images:
            if hasattr(img, 'size'):
                image_sizes.append(list(img.size))  # PIL Image size is (width, height)
            else:
                image_sizes.append([LONG_EDGE, LONG_EDGE])  # fallback size
        batch_dict["image_sizes"] = torch.tensor(image_sizes)

    return batch_dict

In [None]:
# --------- LoRA config (r=16, alpha=8) ---------

TARGETS = "all-linear"
r, lora_alpha = 16, 8
peft_cfg = LoraConfig(r=r, lora_alpha=lora_alpha, lora_dropout=0.05, bias="none",
                      target_modules=TARGETS, task_type="CAUSAL_LM")

In [None]:
# Apply PEFT model adaptation

peft_model = get_peft_model(model, peft_cfg)

# Print trainable parameters

peft_model.print_trainable_parameters()

trainable params: 51,824,640 || all params: 8,343,991,296 || trainable%: 0.6211


In [None]:
# --------- SFT training args  ---------
args = SFTConfig(
    output_dir=OUTPUT_DIR,

    # dataset-sized schedule
    num_train_epochs=1,
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    gradient_accumulation_steps=8,

    # stability & speed
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=use_bf16,
    fp16=not use_bf16,
    tf32=True,

    # optimization
    max_grad_norm=0.5,
    learning_rate=2e-5,               # LoRA-friendly; fits higher capacity
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    optim="adamw_torch_fused",          # fallback: "adamw_torch" if not supported
    adam_beta1=0.9, adam_beta2=0.999,
    adam_epsilon=1e-8,
    weight_decay=0.1,

    # logging/eval/save: keep it simple on small data
    logging_steps=10,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=2,

    remove_unused_columns=False,
    dataset_kwargs={"skip_prepare_dataset": True},
    dataset_text_field="",  #   # <- stops TRL from looking for "text"
)

## Training

In [None]:
trainer = SFTTrainer(
    model=peft_model,
    args=args,
    train_dataset=train,
    eval_dataset=evald,
    data_collator=collate_fn,
    #peft_config=peft_cfg,
)
trainer.add_callback(EarlyStoppingCallback(
    early_stopping_patience=2,              # stop if no val-loss improvement for 2 evals
    early_stopping_threshold=0.0
))

In [None]:
trainer.train()

`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`...
`loss_type=None` was set in the config but it is unrecognised.Using the default loss: `ForCausalLMLoss`.


Step,Training Loss,Validation Loss
100,11.5614,1.34895
200,7.1859,0.862599
300,6.3366,0.793658
400,6.4442,0.775869
500,5.9035,0.768302
600,6.7524,0.763234
700,6.1018,0.760473
800,6.2966,0.75951
900,6.7847,0.759124


TrainOutput(global_step=906, training_loss=8.325606478760575, metrics={'train_runtime': 15187.6903, 'train_samples_per_second': 0.476, 'train_steps_per_second': 0.06, 'total_flos': 3.7215529310029824e+17, 'train_loss': 8.325606478760575})

In [None]:
trainer.save_model(OUTPUT_DIR)      # saves the adapters
processor.save_pretrained(OUTPUT_DIR)

[]

## Inferance

In [None]:
processor = AutoProcessor.from_pretrained(MODEL_ID)

In [None]:
model = LlavaNextForConditionalGeneration.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
    low_cpu_mem_usage=True,
)

adapter_path = OUTPUT_DIR
model.load_adapter(adapter_path)

In [None]:
def generate_text_from_sample(model, processor, sample, max_new_tokens=5000, device="cuda"):
    text_input = processor.apply_chat_template(build_messages(q=sample["question"], a="")[0], tokenize=False, add_generation_prompt=True)
    # Process the visual input from the sample
    image= clamp_long_edge(sample["image"])

    # Prepare the inputs for the model
    model_inputs = processor(text=text_input, images=image, return_tensors="pt").to(device)  # Move inputs to the specified device

    gen_kwargs = {"max_new_tokens":max_new_tokens, "do_sample":False, "pad_token_id": processor.tokenizer.pad_token_id,
                  "eos_token_id": processor.tokenizer.eos_token_id}

    with torch.inference_mode():
        output_ids = model.generate(**model_inputs, **gen_kwargs)

    # Decode only the generated part (exclude input tokens)
    input_token_len = model_inputs["input_ids"].shape[1]
    generated_ids = output_ids[:, input_token_len:]

    # Decode the generated text
    generated_text = processor.decode(
        generated_ids[0],
        skip_special_tokens=True,
        clean_up_tokenization_spaces=True
    )

    return generated_text.strip()

In [None]:
output = generate_text_from_sample(model, processor, train[0])
output

'The city with the highest parks density is Sydney, with 6.3 parks per square mile.'