In [None]:

!pip install transformers>=4.46.0
!pip install datasets
!pip install unsloth[colab-new]@git+https://github.com/unslothai/unsloth.git
!pip install xformers --no-cache-dir   --index-url https://download.pytorch.org/whl/cu124
!pip install flash-attn --no-build-isolation --no-cache-dir

In [None]:
from huggingface_hub import login
HF_TOKEN = "your_key"
login(HF_TOKEN)

In [None]:
%cd /home/
from huggingface_hub import snapshot_download, hf_hub_download
from datasets import Dataset

REXVQA_REPO = "rajpurkarlab/ReXVQA"
REXGRAD_REPO = "rajpurkarlab/ReXGradient-160K"


meta_path = snapshot_download(repo_id=REXGRAD_REPO, repo_type="dataset")

!cat {meta_path}/deid_png.part* > /home/deid_png.tar
!tar -xf /home/deid_png.tar
meta_path = snapshot_download(repo_id=REXVQA_REPO, repo_type="dataset")
!mkdir /home/QA_json/
!cp  {meta_path}/metadata/test_vqa_data.json  /home/QA_json/
!cp  {meta_path}/metadata/train_vqa_data.json  /home/QA_json/
!cp  {meta_path}/metadata/valid_vqa_data.json  /home/QA_json/


In [5]:
import os, re, json, random
from typing import Dict, List, Iterable, Union, Optional
import numpy as np
from PIL import Image
import torch
from datasets import IterableDataset

def load_items_once(json_path: str) -> List[Dict]:
    """Load and normalize one time; return a list of records (dicts)."""
    with open(json_path, "r") as f:
        data = json.load(f)
    items = list(data.values()) if isinstance(data, dict) else list(data)
    return items
# =======================
# Config / Constants
# =======================
_OPTION_LETTERS = ["A", "B", "C", "D", "E", "F"]

REASONING_START = "<start_working_out>"
REASONING_END   = "<end_working_out>"
SOLUTION_START  = "<SOLUTION>"
SOLUTION_END    = "</SOLUTION>"

# Probability of adding demographics/context and body/view into the prompt per split.
# For TEST we always add them (1.0 as requested).
PROMPT_META_PROB = {
    "train": 0.05,   # "very few"
    "val":   0.05,
    "test":  1.0,    # ALWAYS on test split
}
PROMPT_BODYVIEW_PROB = {
    "train": 0.30,   # sometimes
    "val":   0.30,
    "test":  1.0,    # ALWAYS on test split
}

MODEL_ID = "unsloth/medgemma-4b-it"




TRAIN_JSON = "/home/QA_json/train_vqa_data.json"
VAL_JSON   = "/home/QA_json/valid_vqa_data.json"
TEST_JSON  = "/home/QA_json/test_vqa_data.json"

# Load once per split
train_items = load_items_once(TRAIN_JSON)
val_items   = load_items_once(VAL_JSON)
test_items  = load_items_once(TEST_JSON)


In [13]:
import os, re, json, random
from typing import Dict, List, Iterable, Union, Optional
import numpy as np
from PIL import Image
from datasets import IterableDataset


# =======================
# Small helpers
# =======================
def _norm(s: Optional[str]) -> str:
    return re.sub(r"\s+", " ", (s or "")).strip()

def _format_options(options: List[str]) -> str:
    lines = []
    for i, opt in enumerate(options):
        letter = _OPTION_LETTERS[i] if i < len(_OPTION_LETTERS) else chr(ord("A") + i)
        opt = _norm(opt)
        if re.match(r"^[A-F][\)\.\-:]\s", opt, flags=re.IGNORECASE):
            lines.append(opt)  # already labeled
        else:
            lines.append(f"{letter}) {opt}")
    return "\n".join(lines)

def _safe_open_as_rgb(img_path: str) -> Optional[Image.Image]:
    """Open image robustly; fix bit-depth; return RGB PIL.Image or None on failure."""
    try:
        im = Image.open(img_path)
        im.load()  # force read
        arr = np.array(im)

        # 16-bit -> keep high 8 bits
        if arr.dtype == np.uint16:
            arr8 = (arr >> 8).astype(np.uint8)
            im = Image.fromarray(arr8)
        elif im.mode == "I":  # 32-bit signed int -> min-max normalize to 0..255
            arr = arr.astype(np.int32)
            mn, mx = int(arr.min()), int(arr.max())
            if mx > mn:
                scale = 255.0 / (mx - mn)
                arr8 = ((arr - mn) * scale).astype(np.uint8)
            else:
                arr8 = np.zeros_like(arr, dtype=np.uint8)
            im = Image.fromarray(arr8, mode="L")
        return im.convert("RGB")
    except Exception:
        return None

# =======================
# Reasoning & Prompt builders
# =======================
def _build_reasoning(record: Dict) -> str:
    """
    Your latest choice: reasoning contains ONLY (in order)
    - Findings:
    - Impression:
    (No correct_answer_explanation and never the answer here.)
    """
    findings = _norm(record.get("Findings"))
    impression = _norm(record.get("Impression"))

    parts = []
    if findings:
        parts.append(f"Findings: {findings}")
    if impression:
        parts.append(f"Impression: {impression}")
    return "\n".join(parts).strip()

def _maybe_meta_block(record: Dict, prob: float) -> str:
    """Demographics/context block with probability `prob` (always for test)."""
    if random.random() >= prob:
        return ""
    patient_sex  = record.get("PatientSex")
    patient_age  = record.get("PatientAge")
    ethnic_group = record.get("EthnicGroup")
    weight       = record.get("PatientWeight")
    size         = record.get("PatientSize")
    indication   = record.get("Indication")
    comparison   = record.get("Comparison")

    lines = []
    if patient_sex: lines.append(f"PatientSex: {patient_sex}")
    if patient_age: lines.append(f"PatientAge: {patient_age}")
    if ethnic_group: lines.append(f"EthnicGroup: {ethnic_group}")
    try:
        if weight is not None and weight == weight:
            lines.append(f"PatientWeight: {weight}")
    except Exception:
        pass
    try:
        if size is not None and size == size:
            lines.append(f"PatientSize: {size}")
    except Exception:
        pass
    if indication: lines.append(f"Indication: {indication}")
    if comparison: lines.append(f"Comparison: {comparison}")

    return ("CONTEXT:\n" + "\n".join(lines)) if lines else ""

def _maybe_body_view_block(body_part_sel, view_sel, prob: float) -> str:
    """Randomly include ImageBodyPart & ImageViewPosition (always on test)."""
    if random.random() >= prob:
        return ""
    lines = []
    if body_part_sel:
        if isinstance(body_part_sel, list):
            lines.append(f"ImageBodyPart: {', '.join(body_part_sel)}")
        else:
            lines.append(f"ImageBodyPart: {body_part_sel}")
    if view_sel:
        if isinstance(view_sel, list):
            lines.append(f"ImageViewPosition: {', '.join(view_sel)}")
        else:
            lines.append(f"ImageViewPosition: {view_sel}")
    return "\n".join(lines) if lines else ""

def _test_only_block(record: Dict, split: str) -> str:
    if split != "test":
        return ""
    cat_ = record.get("category")
    cls_ = record.get("class")
    lines = []
    if cat_:
        lines.append(f"Category: {cat_}")
    if cls_:
        lines.append(f"Class: {cls_}")
    return "\n".join(lines)

def _build_instruction(
    record: Dict,
    split: str,
    body_part_sel,
    view_sel,
) -> str:
    question     = _norm(record.get("question"))
    options      = record.get("options") or []
    options_block = _format_options(options) if options else ""

    meta_prob     = PROMPT_META_PROB.get(split, 0.0)
    bodyview_prob = PROMPT_BODYVIEW_PROB.get(split, 0.0)

    meta_block      = _maybe_meta_block(record, meta_prob)
    body_view_block = _maybe_body_view_block(body_part_sel, view_sel, bodyview_prob)
    test_block      = _test_only_block(record, split)

    chunks = []
    if test_block:
        chunks.append(test_block)
    if question:
        chunks.append(question)
    if options_block:
        chunks.append("OPTIONS:\n" + options_block)
    if meta_block:
        chunks.append(meta_block)
    if body_view_block:
        chunks.append(body_view_block)

    return "\n\n".join([c for c in chunks if c]).strip()

# =======================
# Image + metadata selection (synced)
# =======================
def _select_image_and_meta(item: Dict) -> (Optional[str], Optional[Union[str, List[str]]], Optional[Union[str, List[str]]]):
    """
    Choose one image path. If ImagePath, ImageBodyPart, ImageViewPosition are lists,
    pick a single random index and return matched body/view for the same index.
    Maps '../' -> '/home/'.
    """
    raw_image_paths = item.get("ImagePath")
    body_parts_list = item.get("ImageBodyPart")
    views_list      = item.get("ImageViewPosition")

    body_part_sel = None
    view_sel      = None

    if isinstance(raw_image_paths, list) and raw_image_paths:
        idx = random.randrange(len(raw_image_paths))
        raw_image_path = raw_image_paths[idx]
        if isinstance(body_parts_list, list) and len(body_parts_list) == len(raw_image_paths):
            body_part_sel = body_parts_list[idx]
        else:
            body_part_sel = body_parts_list
        if isinstance(views_list, list) and len(views_list) == len(raw_image_paths):
            view_sel = views_list[idx]
        else:
            view_sel = views_list
    else:
        raw_image_path = raw_image_paths
        body_part_sel  = body_parts_list
        view_sel       = views_list

    if not isinstance(raw_image_path, str) or not raw_image_path:
        return None, body_part_sel, view_sel

    img_path = raw_image_path.replace("../", "/home/")
    return img_path, body_part_sel, view_sel

# =======================
# Dataset factory
# =======================
def create_iterable_dataset_from_items(
    items: List[Dict],
    split: str,
):
    """
    Build an IterableDataset from preloaded items.
    - No JSON re-read per iteration.
    - Reasoning: Findings -> Impression (no answer here).
    - Solution: ONLY 'correct_answer' (letter or text).
    - Prompt:
        * Train/Val: demographics & body/view appear with small probability.
        * Test: ALWAYS include demographics & body/view + Category/Class lines.
    - Yields a dict with {messages, split, category, class, correct_answer}.
    """
    def generator() -> Iterable[Dict]:
        processed = 0
        for item in items:
            # Pick image and synchronized body/view
            img_path, body_part_sel, view_sel = _select_image_and_meta(item)
            if not img_path or not os.path.exists(img_path):
                continue

            img = _safe_open_as_rgb(img_path)
            if img is None:
                continue

            instruction = _build_instruction(item, split, body_part_sel, view_sel)
            reasoning   = _build_reasoning(item)

            # Final answer: ONLY the letter/text in 'correct_answer'
            correct_letter = _norm(item.get("answer") or item.get("correct_answer"))
            explanation = _norm(
                item.get("explanation") or item.get("correct_answer_explanation")
            )
            
            if explanation:
                # "Correct answer: B) Minimal scarring at left base. Explanation: ..."
                correct_answer = f"{correct_letter} - Explanation: {explanation}"
            else:
                # If no explanation, just show the letter
                correct_answer = correct_letter
                
            assistant_response = (
                f"{REASONING_START}\n{reasoning}\n{REASONING_END}\n\n"
                f"{SOLUTION_START}\n{correct_answer}\n{SOLUTION_END}"
            )

            processed += 1
            if processed % 1000 == 0:
                print(f"[{split}] Processed {processed} samples...")

            output =  {
                "messages": [
                    {
                        "role": "user",
                        "content": [
                            {"type": "text", "text": instruction},
                            {"type": "image", "image": img},
                        ],
                    },
                    {
                        "role": "assistant",
                        "content": [{"type": "text", "text": assistant_response}],
                    },
                ],
                "split": split,
                # Expose these always for a stable schema
                "category": _norm(item.get("category")),
                "class": _norm(item.get("class")),
                "correct_answer": correct_answer,
                "image_path": img_path,  # handy for debugging/audits
            }
            if split == "test":
                yield output
            else:
                
                yield output["messages"]
    return IterableDataset.from_generator(generator)

# =======================
# Example wiring
# =======================
# Assuming you already loaded JSON into lists:
# train_items, val_items, test_items = ...

train_dataset = create_iterable_dataset_from_items(train_items, split="train")
val_dataset   = create_iterable_dataset_from_items(val_items,   split="val")
test_dataset  = create_iterable_dataset_from_items(test_items,  split="test")


In [14]:
sample = next(iter(val_dataset))

print(sample)

[{'role': 'user', 'content': [{'type': 'text', 'text': 'Which of the following findings is absent in the lung fields on this chest X-ray?\n\nOPTIONS:\nA. Focal airspace disease\nB. Mild peribronchial thickening\nC. Cardiomegaly\nD. Pleural effusion'}, {'type': 'image', 'image': <PIL.Image.Image image mode=RGB size=748x747 at 0x7F723A8908E0>}]}, {'role': 'assistant', 'content': [{'type': 'text', 'text': '<start_working_out>\nFindings: The cardiomediastinal silhouette is unremarkable. Mild peribronchial thickening is noted. There is no evidence of focal airspace disease, pulmonary edema, suspicious pulmonary nodule/mass, pleural effusion, or pneumothorax. No acute bony abnormalities are identified.\nImpression: Mild peribronchial thickening-likely chronic. No other significant abnormalities.\n<end_working_out>\n\n<SOLUTION>\nA - Explanation: The chest X-ray does not show any focal airspace disease, making A the correct answer. Mild peribronchial thickening is present, and there is no evi

In [16]:
from unsloth.trainer import UnslothVisionDataCollator
from trl import SFTTrainer, SFTConfig
import torch 
from unsloth import FastVisionModel 
from transformers import TextStreamer 
from trl import SFTTrainer 
from transformers import TrainingArguments 
from unsloth import is_bfloat16_supported 

model, processor = FastVisionModel.from_pretrained(
    MODEL_ID,
    load_in_4bit = True, # Use 4bit to reduce memory use. False for 16bit LoRA.
    use_gradient_checkpointing = "unsloth", # True or "unsloth" for long context
)
# Add LoRA adapters
model = FastVisionModel.get_peft_model(
    model,
    finetune_vision_layers     = True, # False if not finetuning vision layers
    finetune_language_layers   = True, # False if not finetuning language layers
    finetune_attention_modules = True, # False if not finetuning attention layers
    finetune_mlp_modules       = True, # False if not finetuning MLP layers

    r = 64,                           # The larger, the higher the accuracy, but might overfit
    lora_alpha = 64,                  # Recommended alpha == r at least
    lora_dropout = 0,
    bias = "none",
    random_state = 3407,
    use_rslora = False,               # We support rank stabilized LoRA
    loftq_config = None,               # And LoftQ
    target_modules = "all-linear",    # Optional now! Can specify a list if needed
    modules_to_save=[
        "lm_head",
        "embed_tokens",
    ],
)
print("Model loaded successfully with LoRA adapters")

FastVisionModel.for_training(model) # Enable for training!

trainer = SFTTrainer(
    model=model,
    train_dataset=train_dataset,
    processing_class=processor.tokenizer,
    data_collator=UnslothVisionDataCollator(model, processor),
    args = SFTConfig(
        per_device_train_batch_size = 4,
        gradient_accumulation_steps = 4,
        gradient_checkpointing = True,

        # use reentrant checkpointing
        gradient_checkpointing_kwargs = {"use_reentrant": False},
        max_grad_norm = 0.3,              # max gradient norm based on QLoRA paper
        warmup_ratio = 0.03,
        max_steps = 100,
        #num_train_epochs = 2,          # Set this instead of max_steps for full training runs
        learning_rate = 2e-4,
        logging_steps = 1,
        save_strategy="steps",
        optim = "adamw_torch_fused",
        weight_decay = 0.01,
        lr_scheduler_type = "cosine",
        seed = 3407,
        output_dir = "outputs",
        report_to = "none",             # For Weights and Biases

        # You MUST put the below items for vision finetuning:
        remove_unused_columns = False,
        dataset_text_field = "",
        dataset_kwargs = {"skip_prepare_dataset": True},
        max_length = 2048,
    )
)




==((====))==  Unsloth 2025.8.5: Fast Gemma3 patching. Transformers: 4.55.2.
   \\   /|    NVIDIA A40. Num GPUs = 1. Max memory: 44.352 GB. Platform: Linux.
O^O/ \_/ \    Torch: 2.6.0+cu124. CUDA: 8.6. CUDA Toolkit: 12.4. Triton: 3.2.0
\        /    Bfloat16 = TRUE. FA [Xformers = 0.0.29.post3. FA2 = False]
 "-____-"     Free license: http://github.com/unslothai/unsloth
Unsloth: Fast downloading is enabled - ignore downloading bars which are red colored!
Unsloth: Making `base_model.model.model.vision_tower.vision_model` require gradients
Model loaded successfully with LoRA adapters


In [17]:

print("Starting SFT training...")
trainer_stats = trainer.train()


Starting SFT training...


==((====))==  Unsloth - 2x faster free finetuning | Num GPUs used = 1
   \\   /|    Num examples = 1,600 | Num Epochs = 9,223,372,036,854,775,807 | Total steps = 100
O^O/ \_/ \    Batch size per device = 4 | Gradient accumulation steps = 4
\        /    Data Parallel GPUs = 1 | Total batch size (4 x 4 x 1) = 16
 "-____-"     Trainable parameters = 153,991,168 of 4,454,070,640 (3.46% trained)
`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`.


Step,Training Loss
1,2.9329
2,3.0065
3,2.6456
4,2.2543
5,1.757
6,1.5723
7,1.4291
8,1.1988
9,1.3756
10,1.1384


[train] Processed 1000 samples...


In [28]:
FastVisionModel.for_inference(model)  # Enable for inference!

data = next(iter(test_dataset))


image = data["messages"][0]["content"][1].pop("image")
messages = [data["messages"][0]]
class_data= data["class"]
correct_answer= data["correct_answer"]


input_text = processor.apply_chat_template(messages, add_generation_prompt=True)
inputs = processor(
    image,
    input_text,
    add_special_tokens=False,
    return_tensors="pt",
).to("cuda")

from transformers import TextStreamer

text_streamer = TextStreamer(processor.tokenizer, skip_prompt=True)
_ = model.generate(**inputs, streamer = text_streamer, max_new_tokens = 512,
                   use_cache=True, temperature = 1.0, top_p = 0.95, top_k = 64)

<start_working_out>
Findings: Cardiomegaly is noted. There is stable bibasilar scarring noted. The aorta remains calcified. Mediastinal contours appear stable. The patient status post CABG. Postoperative surgical clips are noted.
Impression: No acute cardiopulmonary findings. Stable bibasilar scarring.
<end_working_out>

<SOLUTION>
C - Explanation: The chest X-ray shows stable bibasilar scarring, indicating that the scarring is not worsening or resolving, but remains stable.
</SOLUTION><end_of_turn>


In [None]:
# Save locally to 16bit
#model.save_pretrained_merged("/home/unsloth_finetune", processor)
# Save locally
#model.save_pretrained("/home/weights")
#processor.save_pretrained("/home/weights")


# To export and save to your Hugging Face account
model.push_to_hub_merged("SerdarHelli/medgemma-4b-it_rexvqa", processor, token = HF_TOKEN)
