In [None]:
!pip install -q \
  datasets \
  sentencepiece \
  matplotlib \
  rouge_score \
  evaluate \
  nltk \
  transformers \
  bitsandbytes \
  accelerate \
  trl \
  peft \
  torch torchvision \
  pillow \
  tqdm \
  scikit-learn \
  bert-score \
  flash-attn

In [None]:
mv /kaggle/input/vqa-data-for-model-training/checkpoint-3500 /kaggle/working/qlora_gemma2b_vqa

In [None]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "0"  # Force single-GPU usage
os.environ["WANDB_DISABLED"] = "true"
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "backend:cudaMallocAsync,expandable_segments:True"
import json
import time
import random
import torch
import numpy as np
from PIL import Image
from pathlib import Path
from datasets import Dataset
from tqdm.auto import tqdm
from dataclasses import dataclass
from typing import Dict, List, Optional, Union

# HuggingFace imports
from transformers import (
    AutoProcessor, 
    BitsAndBytesConfig,
    TrainingArguments,
    LlavaForConditionalGeneration,
    LlavaProcessor,
    AutoTokenizer
)
from peft import LoraConfig, get_peft_model, prepare_model_for_kbit_training, PeftModel
from trl import SFTTrainer

# For visualization
import matplotlib.pyplot as plt

# Set seeds for reproducibility
random.seed(42)
np.random.seed(42)
torch.manual_seed(42)
torch.cuda.manual_seed_all(42)

torch.backends.cuda.matmul.allow_tf32 = False  # T4 doesn't support TF32
torch.backends.cudnn.allow_tf32 = False

In [None]:
# Configure QLoRA (4-bit Quantized Low-Rank Adaptation)
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",      # NormalFloat4 quantization
    bnb_4bit_use_double_quant=True, # Nested quantization for 0.4 more bits saved
    bnb_4bit_compute_dtype=torch.float16,
    quantize_kv_cache=False          # Quantize key/value cache for memory efficiency
)

# Load LLaVA-Gemma-2B model and processor
checkpoint = "Intel/llava-gemma-2b"

# 1. Load tokenizer FIRST
tokenizer = AutoTokenizer.from_pretrained(
    checkpoint,
    trust_remote_code=True,
    use_fast=False
)

# 2. Create processor with explicit tokenizer
processor = LlavaProcessor.from_pretrained(
    checkpoint,
    tokenizer=tokenizer,  # ← Critical!
    trust_remote_code=True,
    use_fast=False
)

# 3. Set required attributes from model config
model = LlavaForConditionalGeneration.from_pretrained(
    checkpoint,
    device_map="auto",
    trust_remote_code=True,
    quantization_config=bnb_config,
    torch_dtype=torch.float16,
    use_flash_attention_2=True
)

processor.patch_size = model.config.vision_config.patch_size  # 14 for CLIP
processor.num_additional_image_tokens = 1
processor.vision_feature_select_strategy = "default"

# 4. Verify tokenizer exists
print(f"Tokenizer class: {type(processor.tokenizer).__name__}")  # Now shows LlamaTokenizerFast

# Enable KV-cache for faster inference
model.config.use_cache = False

# Print model size info
print(f"Model loaded with {model.num_parameters() / 1e9:.2f}B parameters")

In [None]:
# Configure LoRA 
lora_config = LoraConfig(
    r=16,                          # Rank of update matrices (higher = more capacity, more memory)
    lora_alpha=32,                 # Scaling factor (typically 2*r)
    lora_dropout=0.1,              # Dropout probability for robustness
    target_modules=[
        "q_proj", "k_proj", "v_proj", "o_proj",  # Attention modules
        "gate_proj", "up_proj", "down_proj"      # MLP modules
    ],
    bias="none",                   # Don't add trainable bias
    task_type="CAUSAL_LM"          # Task type for LLaVA-Gemma model
)

# Prepare model for QLoRA training
print("Applying LoRA adapters...")
model = prepare_model_for_kbit_training(model)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()  # Should show <1% of parameters being trained

In [None]:
from torch.utils.data import Dataset as TorchDataset

class VQADataset(TorchDataset):
    """Custom VQA dataset that handles multiple JSON formats"""
    
    def __init__(self, json_path, img_dir, processor, max_samples=None):
        self.processor = processor
        self.img_dir = img_dir             # ← store the image folder here
        
        # Load data
        print(f"Loading data from {json_path}...")
        with open(json_path, 'r') as f:
            data = json.load(f)
        
        self.samples = []
        
        if isinstance(data, list):
            print(f"Detected list format with {len(data)} examples")
            for item in data:
                if self._validate_item(item):
                    self.samples.append(self._process_item(item))
        
        elif isinstance(data, dict):
            print(f"Detected dictionary format with {len(data)} images")
            for img_id, entry in data.items():
                if not isinstance(entry, dict) or "questions" not in entry:
                    continue
                
                # Build full path from the folder you passed in
                image_path = os.path.join(self.img_dir, img_id)
                if not os.path.exists(image_path):
                    # Try common extensions
                    for ext in ['.jpg', '.jpeg', '.png']:
                        alt_path = os.path.join(self.img_dir, img_id + ext)
                        if os.path.exists(alt_path):
                            image_path = alt_path
                            break
                    else:  # No valid path found
                        continue

                
                for q in entry["questions"]:
                    if not all(k in q for k in ("question","options","answer")):
                        continue
                    self.samples.append({
                        "image_path": image_path,
                        "question":    q["question"],
                        "options":     [str(o) for o in q["options"]],
                        "answer":      str(q["answer"]).lower()
                    })
        
        # Limit dataset size if requested
        if max_samples and max_samples < len(self.samples):
            self.samples = random.sample(self.samples, max_samples)
            
        print(f"Loaded {len(self.samples)} VQA samples")

        # After creating samples
        print(f"First few image paths:")
        for i in range(min(3, len(self.samples))):
            print(f"- {self.samples[i]['image_path']}")

        
        # Show a few examples
        if len(self.samples) > 0:
            print("\nExample samples:")
            for i in range(min(2, len(self.samples))):
                print(f"Question: {self.samples[i]['question']}")
                print(f"Options: {', '.join(self.samples[i]['options'])}")
                print(f"Answer: {self.samples[i]['answer']}")
                print()
    
    def _validate_item(self, item):
        """Check if an item has all required fields"""
        required = ["image_path", "question", "options", "answer"]
        if all(k in item for k in required):
            if os.path.exists(item["image_path"]):
                return True
        return False
    
    def _validate_question(self, question):
        """Check if a question entry has all required fields"""
        required = ["question", "options", "answer"]
        return all(k in question for k in required)
    
    def _process_item(self, item):
        """Process a single item into the required format"""
        return {
            "image_path": item["image_path"],
            "question": item["question"],
            "options": [str(opt) for opt in item["options"]],
            "answer": str(item["answer"]).lower()
        }
    
    def __len__(self):
        return len(self.samples)
    
    def __getitem__(self, idx):
        return self.samples[idx]

In [None]:
def collate_fn(batch, processor, device=None):
    """Process a batch of examples for training"""
    if not hasattr(processor, 'patch_size') or processor.patch_size is None:
        processor.patch_size = 14
    
    if not hasattr(processor, 'num_additional_image_tokens') or processor.num_additional_image_tokens is None:
        processor.num_additional_image_tokens = 1

    images = []
    texts = []
    answers = []
    
    for example in batch:
        try:
            # Load and convert image
            image = Image.open(example["image_path"]).convert("RGB")
            
            # Format prompt with options
            options_str = ", ".join(example["options"])
            prompt = f"<image>\n{example['question']}\nOptions: {options_str}\nAnswer with just one option."
            
            # Create chat message
            messages = [{"role": "user", "content": prompt}]
            formatted_prompt = processor.tokenizer.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=True
            )
            
            images.append(image)
            texts.append(formatted_prompt)
            answers.append(example["answer"])
            
        except Exception as e:
            print(f"Error processing example: {e}")
            continue
    
    if len(images) == 0:
        raise ValueError("No valid examples in batch")
    
    # Process inputs
    inputs = processor(
    text=texts,
    images=images,
    return_tensors="pt",
    padding=True,
    truncation=True
)
    
    # Create labels
    labels = inputs["input_ids"].clone()
    
    # Mask padding tokens
    labels[labels == processor.tokenizer.pad_token_id] = -100
    
    # For each example, mask the input portion (we only want loss on the answer tokens)
    for i, answer in enumerate(answers):
        # Get answer length in tokens
        answer_tokens = processor.tokenizer.encode(answer, add_special_tokens=False)
        answer_len = len(answer_tokens)
        
        # Get sequence length for this example (excluding padding)
        seq_len = inputs["attention_mask"][i].sum().item()
        
        # Mask everything except the last few tokens which should include the answer
        # This is an approximation - in production you might need more precise masking
        labels[i, :seq_len-answer_len] = -100
    
    inputs["labels"] = labels
    
    return inputs

In [None]:
import math
from transformers import TrainerCallback

# 1. Compute total update steps
def get_total_updates(dataset_len, batch_size, accum_steps, epochs):
    updates_per_epoch = math.ceil(dataset_len / (batch_size * accum_steps))
    return updates_per_epoch * epochs

# 2. Define a callback to print progress
class StepProgressCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, **kwargs):
        total = state.max_steps
        print(
            f"⏳ Starting training: "
            f"{total} total update steps "
            f"({args.num_train_epochs} epochs, "
            f"batch_size={args.per_device_train_batch_size}, "
            f"grad_accum={args.gradient_accumulation_steps})"
        )
        return control

    def on_step_end(self, args, state, control, **kwargs):
        # Log every logging_steps update
        if state.global_step % args.logging_steps == 0:
            print(f"Step {state.global_step}/{state.max_steps}")
        return control

In [None]:
from transformers import TrainerCallback
import os

class AdapterCheckpointCallback(TrainerCallback):
    def on_save(self, args, state, control, **kwargs):
        # Folder where the Trainer just saved its checkpoint
        ckpt_folder = os.path.join(args.output_dir, f"checkpoint-{state.global_step}")
        # Path to save only the LoRA adapter weights
        adapter_path = os.path.join(ckpt_folder, "adapter_model")
        # Save only the low-rank adapter tensors & config
        kwargs["model"].save_pretrained(adapter_path, safe_serialization=True)
        # (Optional) remove full-model bin to free space
        full_model = os.path.join(ckpt_folder, "pytorch_model.bin")
        if os.path.exists(full_model):
            os.remove(full_model)
        return control

In [None]:
from trl import SFTConfig
img_dir="/kaggle/input/vqa-data-for-model-training/tinygptv_images/kaggle/working/tinygptv_images"

def train_model(
    model,
    processor,
    train_json,
    val_json=None,
    output_dir="./vqa_model_output",
    epochs=3,
    batch_size=2,
    grad_accum_steps=4,
    learning_rate=3e-4,
    max_samples=None,
):
    """Complete training function with appropriate configurations"""
    
    # Create training dataset
    print("Creating training dataset...")
    train_dataset = VQADataset(train_json, img_dir, processor, max_samples)
    
    # Create validation dataset if provided
    val_dataset = None
    if val_json and os.path.exists(val_json):
        print("Creating validation dataset...")
        val_dataset = VQADataset(val_json, img_dir, processor, 800)
    
    # Configure training arguments with appropriate parameters for QLoRA
    training_args = SFTConfig(
         output_dir=output_dir,
         num_train_epochs=epochs,
         per_device_train_batch_size=batch_size, 
         per_device_eval_batch_size=batch_size,  # ← Critical! Default is same as train batch size
         gradient_checkpointing=True,    # ← Enable for both train/eval
         gradient_checkpointing_kwargs={"use_reentrant": False},
         gradient_accumulation_steps=grad_accum_steps,
         learning_rate=learning_rate,
         report_to=None,
         lr_scheduler_type="cosine",
         warmup_ratio=0.03,
         weight_decay=0.01,
         tf32=False,
         fp16=True,
         logging_steps=1,
         disable_tqdm=False,        
         save_strategy="steps",         # or "epoch"
         save_steps=500,                # save every 500 optimization steps
         save_total_limit=1,            # keep only the 3 most recent checkpoints
         eval_strategy="steps" if val_dataset else "no",
         eval_steps=500,
         run_name=f"vqa-training-{int(time.time())}",
         remove_unused_columns=False,
         push_to_hub=False,
         dataloader_pin_memory=True,
         dataloader_num_workers=8,
         max_grad_norm=0.3,
         optim="paged_adamw_8bit",
         dataset_kwargs={"skip_prepare_dataset": True}
     )
    
    # Define custom collation function
    def collator(examples):
        return collate_fn(examples, processor, device=model.device)
    
    # Initialize trainer
    trainer = SFTTrainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=val_dataset,
        data_collator=collator,
        processing_class=processor.tokenizer,
        callbacks=[AdapterCheckpointCallback, StepProgressCallback]
        #compute_metrics=compute_metrics,  # Add metrics computation after first successful run
    )
    
    # Enable training optimizations
    model.train()
    model.gradient_checkpointing_enable()
    
    # Run training
    print("Starting training...")
    trainer.train(resume_from_checkpoint=True)
    
    # Save the model (adapter weights)
    print("Saving LoRA adapter to", output_dir)
    model.save_pretrained(output_dir, safe_serialization=True)
    print(f"LoRA adapter saved to {output_dir}")
   
    return model, processor

In [None]:
# Run training with small batch to validate pipeline
model, processor = train_model(
    model=model,
    processor=processor,
    train_json="/kaggle/input/vqa-data-for-model-training/Segregated Data/Training Data/vqa_outputs_fixed.json",
#    val_json="/kaggle/input/vqa-data-for-model-training/Segregated Data/Validation Data/vqa_outputs_fixed.json",
    output_dir="/kaggle/working/qlora_gemma2b_vqa",
    epochs=1,
    batch_size=3,
    grad_accum_steps=4
)

# Once pipeline works, remove the max_samples limit for full training