## Setup

In [None]:
!pip install torch==2.6.0 torchvision torchaudio


In [None]:
!pip install https://github.com/mjun0812/flash-attention-prebuild-wheels/releases/download/v0.3.12/flash_attn-2.8.0+cu124torch2.6-cp310-cp310-linux_x86_64.whl


In [None]:
!pip install -U trl>=0.9.6 transformers>=4.42 peft>=0.12.0 accelerate>=0.33.0 bitsandbytes>=0.43.3 datasets>=2.18 qwen-vl-utils pillow


In [None]:
!pip install -U scipy

In [None]:
!env PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True,max_split_size_mb:64

In [6]:
import os, torch
from datasets import load_dataset
import torch.nn.functional as F
from transformers import (AutoProcessor, AutoTokenizer, BitsAndBytesConfig, EarlyStoppingCallback, LlavaNextForConditionalGeneration, LlavaNextProcessor)
from qwen_vl_utils import process_vision_info
from peft import LoraConfig, get_peft_model, PeftModel
from trl import SFTConfig, SFTTrainer
from PIL import Image, ImageOps
import re

  from .autonotebook import tqdm as notebook_tqdm


## Preprocessing

In [2]:
# --------- Config ---------
MODEL_ID = "llava-hf/llama3-llava-next-8b-hf" 
DATASET_REPO = "AI-4-Everyone/Visual-TableQA"
OUTPUT_DIR   = "llava-hf-sft-lora-tableqa"

# --------- Load data ---------
ds = load_dataset(DATASET_REPO)

train = ds.get("train")

evald = ds.get("validation")


In [3]:
system_message = """You are a Vision Language Model specialized in interpreting visual data from charts and diagrams images.
Answer the questions strictly from the image, with clear, rigorous step-by-step justification. Stay concise, but include all reasoning that’s relevant."""

In [4]:
def to_pil(img):
    return img if isinstance(img, Image.Image) else Image.fromarray(img)

In [1]:
def to_pil(img):
    return img if isinstance(img, Image.Image) else Image.fromarray(img)

def build_messages(sample):
    gt=sample["answer"]
    system= {"role": "system", "content": [{"type": "text", "text": system_message}]}
    user = {"role": "user", "content": [{"type":"image", "image": to_pil(sample['image'])}, {"type":"text","text": sample['question']}]}
    asst = {"role": "assistant", "content": [{"type":"text","text": gt}]}
    return [system, user], [system, user, asst]

def collate_fn(examples):
    msgs = [build_messages(example)[1] for example in examples]
    texts = [processor.apply_chat_template(m, tokenize=False) for m in msgs]

    image_inputs, _ = process_vision_info(msgs)
    batch = processor(
        text=texts,
        images=image_inputs,
        return_tensors="pt",
        padding=True,
        truncation=False,
        max_length=None,
    )

    labels = batch["input_ids"].clone()
    labels[labels == processor.tokenizer.pad_token_id] = -100

    # mask image tokens
    image_token_ids = []
    for tok in ["<image>", "<im_patch>", "<im_start>", "<im_end>"]:
        tid = processor.tokenizer.convert_tokens_to_ids(tok)
        if tid != processor.tokenizer.unk_token_id and tid is not None:
            image_token_ids.append(tid)
    if image_token_ids:
        mask = torch.isin(labels, torch.tensor(image_token_ids, device=labels.device))
        labels[mask] = -100

    batch["labels"] = labels
    return batch


In [7]:
use_bf16 = torch.cuda.is_bf16_supported()

## Training

In [None]:
model = LlavaNextForConditionalGeneration.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
    low_cpu_mem_usage=True,
)

processor = AutoProcessor.from_pretrained(MODEL_ID)
model.config.use_cache = False
model.gradient_checkpointing_enable()
model.config.pretraining_tp = 1
processor.tokenizer.padding_side = "right"

In [15]:
# --------- LoRA config (r=16, alpha=8) ---------
TARGETS = "all-linear"

r, lora_alpha = 16, 8

peft_cfg = LoraConfig(r=r, lora_alpha=lora_alpha, lora_dropout=0.05, bias="none",
                      target_modules=TARGETS, task_type="CAUSAL_LM")

In [None]:
# Apply PEFT model adaptation
peft_model = get_peft_model(model, peft_cfg)

# Print trainable parameters
peft_model.print_trainable_parameters()

In [17]:
# --------- SFT training args  ---------
args = SFTConfig(
    output_dir=OUTPUT_DIR,

    # dataset-sized schedule
    num_train_epochs=1,         
    per_device_train_batch_size=3,
    per_device_eval_batch_size=4,
    gradient_accumulation_steps=8,

    # stability & speed
    gradient_checkpointing=True,
    gradient_checkpointing_kwargs={"use_reentrant": False},
    bf16=use_bf16,
    fp16=not use_bf16,
    tf32=True,

    # optimization
    max_grad_norm=0.5,
    learning_rate=2e-5,           # LoRA-friendly; fits higher capacity
    lr_scheduler_type="cosine",
    warmup_ratio=0.03,
    optim="adamw_torch_fused",      
    adam_beta1=0.9, adam_beta2=0.999,
    adam_epsilon=1e-8,
    weight_decay=0.1, #0.0,    
    
    # logging/eval/save: keep it simple on small data
    logging_steps=40,
    eval_strategy="steps", 
    eval_steps=100,
    save_strategy="steps",
    save_steps=100, 
    load_best_model_at_end=True,
    metric_for_best_model="eval_loss",
    greater_is_better=False,
    save_total_limit=3,
    
    remove_unused_columns=False,
    dataset_kwargs={"skip_prepare_dataset": True},
    dataset_text_field="",  #   # <- stops TRL from looking for "text"
)
callbacks = [EarlyStoppingCallback(early_stopping_patience=2, early_stopping_threshold=0.0)]

In [18]:
trainer = SFTTrainer(
    model=peft_model,
    args=args,
    train_dataset=train,
    eval_dataset=evald,
    data_collator=collate_fn,
    #peft_config=peft_cfg,
    callbacks=callbacks,
)

In [None]:
trainer.train()

In [None]:
trainer.save_model(OUTPUT_DIR)      # saves the adapters
processor.save_pretrained(OUTPUT_DIR)

#### Resume training

In [None]:
trainer.train(resume_from_checkpoint=True)   

In [None]:
trainer.save_model(OUTPUT_DIR)      # saves the adapters
processor.save_pretrained(OUTPUT_DIR)

## Final Model Loading and Inferance

In [13]:
processor = AutoProcessor.from_pretrained(MODEL_ID)

The image processor of type `Qwen2VLImageProcessor` is now loaded as a fast processor by default, even if the model checkpoint was saved with a slow processor. This is a breaking change and may produce slightly different outputs. To continue using the slow processor, instantiate this class with `use_fast=False`. Note that this behavior will be extended to all models in a future release.


In [14]:
model = LlavaNextForConditionalGeneration.from_pretrained(
    MODEL_ID,
    device_map="auto",
    torch_dtype=torch.bfloat16 if use_bf16 else torch.float16,
    low_cpu_mem_usage=True,
)
processor = AutoProcessor.from_pretrained(MODEL_ID)

adapter_path = OUTPUT_DIR
model.load_adapter(adapter_path)

`torch_dtype` is deprecated! Use `dtype` instead!
Fetching 5 files: 100%|██████████| 5/5 [00:59<00:00, 11.94s/it]
Loading checkpoint shards: 100%|██████████| 5/5 [00:03<00:00,  1.35it/s]


In [17]:
def generate_text_from_sample(model, sample, max_new_tokens=5000, device="cuda"):
    text_input = processor.apply_chat_template(build_messages(sample)[0], tokenize=False, add_generation_prompt=True)
    # Process the visual input from the sample
    image_inputs, _ = process_vision_info(build_messages(sample)[0])
    # Prepare the inputs for the model
    model_inputs = processor(text=[text_input], images=image_inputs, return_tensors="pt").to(device)  # Move inputs to the specified device

    gen_kwargs = {"max_new_tokens":max_new_tokens, "do_sample":False, "pad_token_id": processor.tokenizer.pad_token_id, 
                  "eos_token_id": processor.tokenizer.eos_token_id}
    
    with torch.inference_mode():
        output_ids = model.generate(**model_inputs, **gen_kwargs)

    # Decode only the generated part (exclude input tokens)
    input_token_len = model_inputs["input_ids"].shape[1]
    generated_ids = output_ids[:, input_token_len:]
    
    # Decode the generated text
    generated_text = processor.decode(
        generated_ids[0], 
        skip_special_tokens=True, 
        clean_up_tokenization_spaces=True
    )
    
    return generated_text.strip()

In [None]:
output = generate_text_from_sample(model, processor, train[0])
output