In [None]:
import torch

torch.__version__

In [None]:
from huggingface_hub import snapshot_download, hf_hub_download
from datasets import Dataset

REXVQA_REPO = "rajpurkarlab/ReXVQA"
REXGRAD_REPO = "rajpurkarlab/ReXGradient-160K"


meta_path = snapshot_download(repo_id=REXGRAD_REPO, repo_type="dataset")

!cat {meta_path}/deid_png.part* > deid_png.tar
!tar -xf /home/deid_png.tar
meta_path = snapshot_download(repo_id=REXVQA_REPO, repo_type="dataset")
!cp  {meta_path}/metadata/test_vqa_data.json  /home/QA_json/
!cp  {meta_path}/metadata/train_vqa_data.json  /home/QA_json/
!cp  {meta_path}/metadata/valid_vqa_data.json  /home/QA_json/


In [1]:
"""
Configuration file for MedGemma SFT training
"""

# Dataset paths - Windows paths
TRAIN_JSON = "/home/QA_json/train_vqa_data.json"
VAL_JSON = "/home/QA_json/valid_vqa_data.json"
TEST_JSON = "/home/QA_json/test_vqa_data.json"

# Model configuration
MODEL_ID = "unsloth/medgemma-4b-it"  # Medical vision-language model
USE_ONLY_FIRST_IMAGE = False  # Set to True to use only the first image per sample

# Training configuration - SFT with Unsloth
TRAINING_CONFIG = {
    "output_dir": "medgemma4b_it_sft_reasoning",
    "per_device_train_batch_size": 2,  # Batch size
    "gradient_accumulation_steps": 4,  # Effective batch size = 2 * 4 = 8
    "learning_rate": 2e-4,  # Higher learning rate for SFT
    "num_train_epochs": 3,  # More epochs for SFT
    "max_steps": -1,  # Use epochs instead of max_steps
    "max_seq_length": 2048,  # Maximum sequence length
    "bf16": True,
    "remove_unused_columns": False,
    "logging_steps": 10,
    "save_steps": 500,
    "save_total_limit": 3,
    "report_to": "none",
    "dataloader_num_workers": 0,
    "gradient_checkpointing": True,
    "warmup_steps": 50,
    "weight_decay": 0.01,
    "lr_scheduler_type": "linear",
}
# LoRA configuration
LORA_CONFIG = {
    "r": 16,
    "lora_alpha": 32,
    "lora_dropout": 0.05,
    "target_modules": ["q_proj", "v_proj", "k_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    "task_type": "CAUSAL_LM"
}

# Quantization configuration
QUANTIZATION_CONFIG = {
    "load_in_4bit": True,
    "bnb_4bit_quant_type": "nf4",
    "bnb_4bit_use_double_quant": True,
    "bnb_4bit_compute_dtype": "bfloat16"
}

# Response format tokens
REASONING_START = "<start_working_out>"
REASONING_END = "<end_working_out>"
SOLUTION_START = "<SOLUTION>"
SOLUTION_END = "</SOLUTION>"

In [2]:
import torch
from unsloth import FastVisionModel
from transformers import TextStreamer
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
import json
from datasets import Dataset, IterableDataset
from PIL import Image
import os

print("Torch:", torch.__version__)
print("CUDA available:", torch.cuda.is_available())

# Load model and tokenizer using Unsloth
print(f"Loading model: {MODEL_ID}")


model, processor = FastVisionModel.from_pretrained(
    MODEL_ID,
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)
# Add LoRA adapters
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 16,                           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 16,                  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,               # We support rank stabilized LoRA
    loftq_config = None,               # And LoftQ
    target_modules = "all-linear",    # Optional now! Can specify a list if needed
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)
print("Model loaded successfully with LoRA adapters")

import os, json

def create_iterable_dataset_generator(json_path, images_base_path="."):
    """Create a generator for iterable dataset with role/content format."""
    def generator():
        print(f"Loading dataset from {json_path}")

        with open(json_path, 'r') as f:
            data = json.load(f)

        processed = 0

        # Handle both dict and list formats
        items = data.values() if isinstance(data, dict) else data

        for item in items:
            # Resolve image path
            image_path = None
            if 'image_path' in item:
                raw_path = item['image_path']
                if isinstance(raw_path, list):
                    raw_path = raw_path[0]
                if isinstance(raw_path, str) and raw_path.startswith('../'):
                    # Map ../... to /home/... (adjust if needed)
                    clean_path = raw_path.replace('../', '/home/')
                    image_path = os.path.normpath(clean_path)
                else:
                    image_path = os.path.join(images_base_path, raw_path) if isinstance(raw_path, str) else None

            elif 'ImagePath' in item and item['ImagePath']:
                image_paths = item['ImagePath']
                raw_path = image_paths[0] if isinstance(image_paths, list) and image_paths else image_paths
                if isinstance(raw_path, str) and raw_path.startswith('../'):
                    clean_path = raw_path.replace('../', '/home/')
                    image_path = os.path.normpath(clean_path)
                else:
                    image_path = raw_path if isinstance(raw_path, str) else None

            # Skip if no image path or image doesn't exist
            if not image_path or not os.path.exists(image_path):
                continue

            # Build instruction/question and assistant response
            instruction = item.get('question', '')

            reasoning = item.get('heur_reason', item.get('reason', ''))
            answer_letter = item.get('answer', item.get('correct_answer', ''))
            explanation = item.get('explanation', item.get('correct_answer_explanation', ''))

            assistant_response = (
                f"{REASONING_START}\n{reasoning}\n{REASONING_END}\n\n"
                f"{SOLUTION_START}\n{answer_letter}: {explanation}\n{SOLUTION_END}"
            )
            processed += 1
            if processed % 1000 == 0:
                print(f"Processed {processed} samples...")

            # Yield in the requested role/content format
            yield [
                {
                    "role": "user",
                    "content": [
                        {"type": "text", "text": instruction},
                        {"type": "image", "image": Image.open(image_path).convert('RGB')},
                    ],
                },
                {
                    "role": "assistant",
                    "content": [
                        {"type": "text", "text": assistant_response},
                    ],
                },
            ]

    
    return generator


print("Creating iterable training dataset...")
train_generator = create_iterable_dataset_generator(TRAIN_JSON)
train_dataset = IterableDataset.from_generator(train_generator)




Unsloth: Patching Xformers to fix some performance issues.
🦥 Unsloth: Will patch your computer to enable 2x faster free finetuning.
🦥 Unsloth Zoo will now patch everything to make training faster!
Torch: 2.5.0+cu124
CUDA available: True
Loading model: unsloth/medgemma-4b-it
==((====))==  Unsloth 2025.8.5: Fast Gemma3 patching. Transformers: 4.55.2.
   \\   /|    NVIDIA H100 80GB HBM3. Num GPUs = 1. Max memory: 79.209 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.5.0+cu124. CUDA: 9.0. CUDA Toolkit: 12.4. Triton: 3.1.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.28.post2. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!


Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


Unsloth: Making `base_model.model.model.vision_tower.vision_model` require gradients
Model loaded successfully with LoRA adapters
Creating iterable training dataset...


In [None]:
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig

FastVisionModel.for_training(model) # Enable for training!

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    processing_class=processor.tokenizer,
    data_collator=UnslothVisionDataCollator(model, processor),
    args = SFTConfig(
        per_device_train_batch_size = 1,
        gradient_accumulation_steps = 4,
        gradient_checkpointing = True,

        # use reentrant checkpointing
        gradient_checkpointing_kwargs = {"use_reentrant": False},
        max_grad_norm = 0.3,              # max gradient norm based on QLoRA paper
        warmup_ratio = 0.03,
        max_steps = 30,
        #num_train_epochs = 2,          # Set this instead of max_steps for full training runs
        learning_rate = 2e-4,
        logging_steps = 1,
        save_strategy="steps",
        optim = "adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",             # For Weights and Biases

        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        max_length = 2048,
    )
)


print("Starting SFT training...")
trainer_stats = trainer.train()

model.save_pretrained(TRAINING_CONFIG["output_dir"])
tokenizer.save_pretrained(TRAINING_CONFIG["output_dir"])

# Save to GGUF format for inference (optional)
print("Saving model in GGUF format...")
model.save_pretrained_gguf(
    f"{TRAINING_CONFIG['output_dir']}_gguf", 
    tokenizer, 
    quantization_method="q4_k_m"
)

print("Training completed!")

Loading dataset from /home/QA_json/train_vqa_data.json


In [None]:
!pip install "unsloth[cu124-torch250] @ git+https://github.com/unslothai/unsloth.git"
