In [None]:
"""
SFT finetuning of the Qwen3-VL 4B-Instruct model on a custom multi-image vision–language dataset, using QLoRA 4-bit quantization and TRL’s SFTTrainer


Designed to run on a Google Colab free-tier T4 GPU, the script showcases a resource-efficient strategy for finetuning large vision–language models for video understanding. It covers the full workflow, including dataset loading, model and PEFT adapter configuration, training, and an inference pipeline that extracts frames from a video and uses the finetuned model to answer video-related questions.

Instead of processing full video sequences—which would normally require FlashAttention 2, supports only A100/Hopper GPUs (Not Turing GPUs) —the approach extracts a fixed number of frames per video, creating multi-image inputs that are compatible with T4 GPUs while still enabling meaningful temporal reasoning.

The project uses the custom dataset "SarveshBTelang/SFT_VLA_Dataset_1.0" hosted on the Hugging Face Hub. Each record contains multiple images along with instruction–completion text pairs for supervised finetuning. Images are sourced from BDD100K driving videos, and paired with instructions from https://github.com/sungyeonparkk/vision-assistant-for-driving
.
The codebase is modular and can be adapted to other VL models or image datasets.

"""



# === Imports ===
import os
import torch
from huggingface_hub import notebook_login
from datasets import load_dataset
from transformers import AutoProcessor
from transformers import logging as hf_logging
hf_logging.set_verbosity_info()

# For Qwen3-VL model class, depending on the repo you might need trust_remote_code=True when loading.
from transformers import AutoModelForCausalLM

# Optional: cv2 and PIL for frame extraction
import cv2
from PIL import Image
from tqdm import tqdm

# === 1) Login to Hugging Face (interactive in Colab) ===
# In Colab you'll be prompted for a token. Locally ensure `huggingface-cli login` is done.


# === 2) Load dataset ===
# Replace with your dataset path on the Hub or local dataset.
DATASET_ID = "SarveshBTelang/SFT_VLA_Dataset_1.0"
print("Loading dataset:", DATASET_ID)
train_dataset = load_dataset(DATASET_ID, split="train")

# Utility to ensure prompt/completion fields are lists (compatible with trainer code below)
def fix_dataset_format(hf_dataset):
    """Ensure prompt and completion are lists. This works with Dataset.map (batched=False)."""
    if isinstance(hf_dataset.get("prompt"), dict):
        hf_dataset["prompt"] = [hf_dataset["prompt"]]
    if isinstance(hf_dataset.get("completion"), dict):
        hf_dataset["completion"] = [hf_dataset["completion"]]
    return hf_dataset

train_dataset = train_dataset.map(fix_dataset_format)
print("Sample record:")
print(train_dataset[0])

# === 3) Model loading and optional QLoRA quantization ===
# NOTE: Qwen3-VL classes may be provided under custom repo names (e.g. Qwen/Qwen3-VL-...)
MODEL_NAME = "Qwen/Qwen3-VL-4B-Instruct"

# If you have bitsandbytes installed and want QLoRA, use BitsAndBytesConfig as in your snippet.
use_qloRA = True

if use_qloRA:
    try:
        from transformers import BitsAndBytesConfig
        print("Attempting to load model with 4-bit quantization (QLoRA)...")
        bnb_config = BitsAndBytesConfig(
            load_in_4bit=True,
            bnb_4bit_compute_dtype=torch.float16,
            bnb_4bit_use_double_quant=True,
            bnb_4bit_quant_type="nf4",
        )
        # Many model implementations let you call from_pretrained with trust_remote_code=True
        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, device_map="auto", quantization_config=bnb_config, trust_remote_code=True)
    except Exception as e:
        print("QLoRA/4-bit load failed - falling back to normal fp16 load. Error:\n", e)
        model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True)
else:
    model = AutoModelForCausalLM.from_pretrained(MODEL_NAME, torch_dtype=torch.float16, device_map="auto", trust_remote_code=True)

# LoRA config
from peft import LoraConfig
peft_config = LoraConfig(
    r=32,
    lora_alpha=32,
    target_modules=["down_proj","o_proj","k_proj","q_proj","gate_proj","up_proj","v_proj"],
    bias="none",
    task_type="CAUSAL_LM",
)

# === 4) Setup TRL SFT trainer ===
from trl import SFTConfig, SFTTrainer

OUTPUT_DIR = "SFT_VLA_Qwen3-VL-4B-Instruct-multimage-trl"
training_args = SFTConfig(
    max_steps=10,
    per_device_train_batch_size=2,
    gradient_accumulation_steps=8,
    warmup_steps=5,
    learning_rate=2e-4,
    optim="adamw_8bit",
    max_length=None,
    output_dir=OUTPUT_DIR,
    logging_steps=1,
    report_to="none",  # set to trackio or wandb if available
    push_to_hub=False,   # set True if you want to push adapters
)

trainer = SFTTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    peft_config=peft_config,
)

# View GPU info (if GPU present)
if torch.cuda.is_available():
    try:
        print("GPU available:", torch.cuda.get_device_name(0))
    except Exception:
        pass
else:
    print("No CUDA GPU detected. Training will be very slow or will fail for 4-bit loads.")

# === 5) Train (UNCOMMENT to run) ===
# trainer_stats = trainer.train()
# trainer.save_model(OUTPUT_DIR)
# trainer.push_to_hub(dataset_name=OUTPUT_DIR)

# === 6) Save adapter locally (after training) and later load for inference ===
# For demonstration we'll assume the adapter is saved to OUTPUT_DIR locally.

# === 7) Inference pipeline ===
# This section demonstrates how to run the multi-image V+L assistant on frames extracted from a video.

# Load model + adapter for inference (example):
BASE_MODEL = MODEL_NAME
ADAPTER_DIR = OUTPUT_DIR  # local path to adapter after training

# If the model object above already exists with adapter loaded, you can skip re-loading.
# Here we show a clean reload path.

def load_inference_model(base_model: str, adapter_path: str = None, device: str = None):
    """Load base model and optionally merge/apply a PEFT adapter for inference.
    Returns (model, processor, device).
    """
    if device is None:
        device = "cuda" if torch.cuda.is_available() else "cpu"

    # Processor: contains image transforms + tokenizer
    processor = AutoProcessor.from_pretrained(base_model, trust_remote_code=True)

    # Load base model (float16 if possible)
    model = AutoModelForCausalLM.from_pretrained(base_model, torch_dtype=torch.float16 if device=="cuda" else torch.float32, device_map="auto", trust_remote_code=True)

    if adapter_path is not None and os.path.exists(adapter_path):
        try:
            from peft import PeftModel
            model = PeftModel.from_pretrained(model, adapter_path, device_map={"": device})
            print("Loaded PEFT adapter from", adapter_path)
        except Exception as e:
            print("Warning: failed to load adapter. Continuing with base model. Error:\n", e)

    model.eval()
    return model, processor, device


# --- Frame extraction helper ---

def extract_frames(video_path, output_image_folder, max_frames=8):
    """Extract `max_frames` evenly spaced frames from a video and save them to `output_image_folder`.
    Returns a list of saved image paths.
    """
    os.makedirs(output_image_folder, exist_ok=True)
    cap = cv2.VideoCapture(video_path)
    if not cap.isOpened():
        raise RuntimeError(f"Error opening video: {video_path}")

    total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    if total_frames <= 0:
        # fallback: try reading sequential frames until exhaustion
        print("Warning: video reports 0 total frames. Attempting sequential extraction up to max_frames.")
        total_frames = max_frames

    # choose frame indices evenly
    frame_indices = [int(total_frames * i / max_frames) for i in range(max_frames)]
    saved = []
    for idx, frame_num in enumerate(tqdm(frame_indices, desc="Extracting frames")):
        cap.set(cv2.CAP_PROP_POS_FRAMES, frame_num)
        ret, frame = cap.read()
        if not ret:
            print(f"Warning: could not read frame {frame_num} (idx {idx}). Skipping.")
            continue
        # frame is BGR (cv2). Convert to RGB and save
        rgb = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
        pil = Image.fromarray(rgb)
        save_path = os.path.join(output_image_folder, f"frame_{idx+1:03d}.jpg")
        pil.save(save_path)
        saved.append(save_path)

    cap.release()
    return saved


# --- Preprocess frames for model ---

def load_images_as_pil(image_paths):
    """Load image paths into a list of PIL images (RGB)."""
    images = []
    for p in image_paths:
        img = Image.open(p).convert("RGB")
        images.append(img)
    return images


# --- Compose prompt for multi-image V+L model ---
PROMPT_SYSTEM = (
    "You are an autonomous driving vision-language assistant. Answer concisely and focus on the driving scene. "
)

# Example prompt template: include context and instructions
def build_prompt(question: str, extra_instructions: str = None):
    """Return the textual prompt to pass to the model along with images.
    For many VL models the prompt is appended to the text field and processor assembles inputs.
    """
    prompt = PROMPT_SYSTEM + "\n\n"
    if extra_instructions:
        prompt += extra_instructions + "\n\n"
    prompt += f"Question: {question}\nAnswer:"
    return prompt


# --- Run a single inference ---

def run_inference_on_images(model, processor, device, pil_images, prompt, max_length=512, temperature=0.0):
    """Run the multimodal model. Returns decoded string.

    Note: Exact input signature depends on the `processor` implementation. We use a generic
    approach: processor(images=..., text=..., return_tensors='pt') and feed to model.generate.
    """
    # Convert to tensors with processor
    inputs = processor(images=pil_images, text=prompt, return_tensors="pt")

    # Move tensors to device
    for k, v in inputs.items():
        if isinstance(v, torch.Tensor):
            inputs[k] = v.to(device)
        elif isinstance(v, dict):
            # nested (e.g. pixel_values)
            for subk, subv in v.items():
                if isinstance(subv, torch.Tensor):
                    inputs[k][subk] = subv.to(device)

    # Typical multimodal generate call (API may differ for Qwen3-VL implementations)
    with torch.no_grad():
        outputs = model.generate(**inputs, max_new_tokens=max_length, do_sample=False, temperature=temperature)

    # decode generated ids with processor.tokenizer if present
    tokenizer = getattr(processor, "tokenizer", None)
    if tokenizer is None:
        # fall back to model's tokenizer
        tokenizer = getattr(model, "get_tokenizer", None)

    if tokenizer is None:
        # If we don't have a tokenizer object, return raw output tensor shape info
        return str(outputs)

    decoded = tokenizer.batch_decode(outputs, skip_special_tokens=True)
    # usually batch size is 1 -> return first element
    return decoded[0]


# --- High level wrapper that accepts a video, question and returns answer ---

def answer_question_from_video(model, processor, device, video_path, question, frames_folder="/tmp/frames", max_frames=6):
    print("Extracting frames...")
    frames = extract_frames(video_path, frames_folder, max_frames=max_frames)
    pil_images = load_images_as_pil(frames)
    print("Running model inference...")
    answer = run_inference_on_images(model, processor, device, pil_images, question)
    return answer


# === Example usage ===
if __name__ == "__main__":
    # Example demonstration: load the model and run a quick inference on a short video file.
    # Replace `/path/to/video.mp4` with your video on Colab (or mount Drive).
    example_video = "/path/to/video.mp4"
    if os.path.exists(example_video):
        model_inf, proc, dev = load_inference_model(BASE_MODEL, ADAPTER_DIR)
        question = "Is there a pedestrian crossing the road in these frames? If yes, where?"
        answer = answer_question_from_video(model_inf, proc, dev, example_video, question, frames_folder="/tmp/frames", max_frames=6)
        print("Model answer:\n", answer)
    else:
        print("Example video not found. To try inference: set example_video to a real path and re-run.")


# === Helpful tips / next steps ===
# - For large-scale evaluation create a dataset of (video, question, reference) and run batch inference.
# - You can adapt build_prompt to include bounding-box requests or structured output tokens if you want
#   the model to return JSON-like outputs for downstream parsing.
# - When pushing adapters to the Hub, ensure you remove any large binaries from the repo, and
#   provide a README explaining how to re-load the adapter with the base model.

# End of notebook script
