Cell 1: Model Setup

In [None]:
# Load model directly
from transformers import AutoProcessor, AutoModelForImageTextToText

processor = AutoProcessor.from_pretrained("google/gemma-3-4b-it")
model = AutoModelForImageTextToText.from_pretrained("google/gemma-3-4b-it")
messages = [
    {
        "role": "user",
        "content": [
            {"type": "image", "url": "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/p-blog/candy.JPG"},
            {"type": "text", "text": "What animal is on the candy?"}
        ]
    },
]
inputs = processor.apply_chat_template(
	messages,
	add_generation_prompt=True,
	tokenize=True,
	return_dict=True,
	return_tensors="pt",
).to(model.device)

outputs = model.generate(**inputs, max_new_tokens=40)
print(processor.decode(outputs[0][inputs["input_ids"].shape[-1]:]))

Cell 2: Dataset Download

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("ghostbat101/lung-disease-clinical-texts-and-image-processed")

print("Path to dataset files:", path)

Cell 3: Dataset Exploration

In [None]:
import os

print("Dataset root:", path)
print("Subfolders:", os.listdir(path))


Cell 4: Dataloading and Inspection

In [None]:

import csv
import os

clinical_file = os.path.join(path, "lung_disease_clinical_texts_processed.csv")
image_file = os.path.join(path, "lung_disease_images_processed.csv")

# Read clinical text
clinical_data = []
with open(clinical_file, newline='', encoding="utf-8") as f:
    reader = csv.DictReader(f)
    for row in reader:
        clinical_data.append(row)

print("Clinical columns:", clinical_data[0].keys())
print("First row (clinical):", clinical_data[0])

# Read image data
image_data = []
with open(image_file, newline='', encoding="utf-8") as f:
    reader = csv.DictReader(f)
    for row in reader:
        image_data.append(row)

print("Image columns:", image_data[0].keys())
print("First row (image):", image_data[0])



Cell 5: Dataset Processing and Formatting

In [None]:
import pickle
from tqdm import tqdm

dataset = []
for clinical_row, image_row in tqdm(zip(clinical_data, image_data), total=len(clinical_data)):
    # Put <image> token at the start of every instruction
    instruction_text = "<image> Describe the findings in this lung X-ray."

    dataset.append({
        "instruction": instruction_text,
        "input": [float(image_row[f"Pixel_{i}"]) for i in range(65536)],
        "output": clinical_row.get("clinical_text", "").strip(),
        "disease": clinical_row.get("disease", "")
    })

# Save as a single pickled object (simple and consistent)
with open('processed_dataset.pkl', 'wb') as f:
    pickle.dump(dataset, f)

print("Saved processed dataset with <image> inserted in instruction. Total:", len(dataset))


Cell 6: Data Streaming Setup

In [None]:
import pickle

with open("processed_dataset.pkl", "rb") as f_in, open("processed_dataset_stream.pkl", "wb") as f_out:
    dataset = pickle.load(f_in)
    for item in dataset:
        pickle.dump(item, f_out)

def stream_dataset(filename):
    with open(filename, "rb") as f:
        while True:
            try:
                yield pickle.load(f)
            except EOFError:
                break

count = 0
for item in stream_dataset("processed_dataset_stream.pkl"):
    count += 1
    # process item here (or break early for testing)
    # print(item)  # Uncomment to see items

print("Loaded dataset length:", count)

In [None]:
print("Example Dataset:", dataset[0])

Cell 7: Enviroment and Import Setup (Start here if all dataset is loaded)

In [38]:
import os
os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python'
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True,max_split_size_mb:128"

# Then your imports
import torch
from torch.utils.data import Dataset
from PIL import Image
from datasets import load_dataset

from transformers import (
    AutoProcessor,
    AutoModelForImageTextToText,
    TrainingArguments,
    Trainer
)
from peft import LoraConfig, get_peft_model

Cell 8: Model and Processor Loading

In [39]:
model_name = "google/gemma-3-4b-it"  # "google/medgemma-4b-it" "google/gemma-3-4b-it"
processor = AutoProcessor.from_pretrained(model_name)
model = AutoModelForImageTextToText.from_pretrained(
    model_name,
    torch_dtype=torch.float16,
    # device_map="auto",
    # max_memory={0: "6GiB", "cpu": "30GiB"},  
    # trust_remote_code=True,
)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

Cell 9: NOT IN USE

In [40]:
# lora_config = LoraConfig(
#     r=16,
#     lora_alpha=32,
#     target_modules=["q_proj", "v_proj"],  # common in attention layers
#     lora_dropout=0.05,
#     bias="none",
#     task_type="SEQ_2_SEQ_LM"
# )
# model = get_peft_model(model, lora_config)
# model = model.to("cuda" if torch.cuda.is_available() else "cpu")

Cell 10: Special Token Addition

In [41]:
special_tokens_dict = {"additional_special_tokens": ["<image>"]}
num_added = processor.tokenizer.add_special_tokens(special_tokens_dict)

if num_added > 0:
    model.resize_token_embeddings(len(processor.tokenizer))
print("Special tokens now:", processor.tokenizer.additional_special_tokens)


Special tokens now: ['<image>']


Cell 11-12: Token Verification

In [42]:
print("Special tokens now:", processor.tokenizer.additional_special_tokens)
print(processor.tokenizer.tokenize("<image> Describe the findings in this lung X-ray."))

Special tokens now: ['<image>']
['<image>', '▁Describe', '▁the', '▁findings', '▁in', '▁this', '▁lung', '▁X', '-', 'ray', '.']


Cell 13: Advanced Preprocessing Function

In [43]:
# from PIL import Image
# import numpy as np
# import torch
# import tempfile, os
# from tqdm import tqdm
# import time

# def preprocess_dataset_chat_template(
#     raw_data,
#     processor,
#     max_length = 512,
#     dir = "./gemma_preproc",
#     max_items = 8000
# ):
#     os.makedirs(dir, exist_ok=True)
#     processed = []

#     n = len(raw_data) if max_items is None else min(max_items, len(raw_data))

#     for i, item in enumerate(tqdm(raw_data[:n], desc="Processing")):
#         t0 = time.time()
#         # rebuild PIL image
#         arr = np.array(item["input"], dtype=np.uint8)
#         t1 = time.time()
#         if arr.size == int(np.sqrt(arr.size))**2:
#             side = int(np.sqrt(arr.size))
#             pil_img = Image.fromarray(arr.reshape(side, side)).convert("RGB")
#             pil_img = pil_img.resize((224, 224))
#         else:
#             assert arr.size % 3 == 0
#             side = int(np.sqrt(arr.size // 3))
#             pil_img = Image.fromarray(arr.reshape(side, side, 3)).convert("RGB")
#         t2 = time.time()

#         # save the image locally
#         local_path = os.path.join(dir, f"img_{i}.png")
#         pil_img.save(local_path)
#         t3 = time.time()

#         # chat template with image
#         messages = [
#             {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
#             {"role": "user", "content": [
#                 {"type": "image", "url": local_path},
#                 {"type": "text", "text": item["instruction"].strip()}
#             ]}
#         ]
#         t4 = time.time()

#         inputs = processor.apply_chat_template(
#             messages,
#             add_generation_prompt = True,
#             tokenize = True,
#             return_dict = True,
#             return_tensors = "pt",
#             max_length = max_length,  
#             padding = "max_length",
#             truncation = True,
#         )
#         t5 = time.time()

#         # labels (target caption / answer)
#         target = item.get("output", item.get("disease", ""))
#         label_ids = processor.tokenizer(
#             target,
#             padding = "max_length",
#             truncation = True,
#             max_length = max_length,
#             return_tensors = "pt"
#         ).input_ids
#         t6 = time.time()

#         processed.append({
#             "input_ids": inputs["input_ids"].squeeze(0),
#             "attention_mask": inputs["attention_mask"].squeeze(0),
#             "pixel_values": inputs["pixel_values"].squeeze(0),
#             "labels": label_ids.squeeze(0),
#         })
#         t7 = time.time()

#         # print(
#         #     f"[{i}] np->arr: {t1-t0:.3f}s | arr->img: {t2-t1:.3f}s | save: {t3-t2:.3f}s | msg: {t4-t3:.3f}s | chat_template: {t5-t4:.3f}s | label: {t6-t5:.3f}s | append: {t7-t6:.3f}s | total: {t7-t0:.3f}s"
#         # )

#     return processed

In [44]:
print(processor.tokenizer.additional_special_tokens)

['<image>']


In [45]:
# import pickle
# import json

# with open("processed_dataset.pkl", "rb") as f:
#     data = pickle.load(f)

# # Optionally, only save the first N items for a quick look
# data = data[:1]

# with open("processed_dataset.json", "w", encoding="utf-8") as f:
#     json.dump(data, f, ensure_ascii=False, indent=2)

# print("Saved as processed_dataset.json")

Cell 14: Dataset Class Definition

In [46]:
# Replace your XrayDataset class with this streaming version
from PIL import Image
import numpy as np
import torch
import tempfile, os
from tqdm import tqdm
import time
import pickle

class StreamingXrayDataset(Dataset):
    def __init__(self, raw_data, processor, max_length=512, cache_dir="./xray_cache"):
        self.raw_data = raw_data
        self.processor = processor
        self.max_length = max_length
        self.cache_dir = cache_dir
        os.makedirs(cache_dir, exist_ok=True)
        
    def __len__(self):
        return len(self.raw_data)
    
    def __getitem__(self, idx):
        t0 = time.time()
        item = self.raw_data[idx]
        arr = np.array(item["input"], dtype=np.uint8)
        if arr.size == int(np.sqrt(arr.size))**2:
            side = int(np.sqrt(arr.size))
            pil_img = Image.fromarray(arr.reshape(side, side)).convert("RGB").resize((224, 224))
        else:
            side = int(np.sqrt(arr.size // 3))
            pil_img = Image.fromarray(arr.reshape(side, side, 3)).convert("RGB")
        # Pass pil_img directly to processor
        messages = [
            {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
            {"role": "user", "content": [
                {"type": "image", "image": pil_img},  # pass PIL image directly
                {"type": "text", "text": item["instruction"].strip()}
            ]}
        ]
        inputs = self.processor.apply_chat_template(
            messages,
            add_generation_prompt=True,
            tokenize=True,
            return_dict=True,
            return_tensors="pt",
            max_length=self.max_length,
            padding="max_length",
            truncation=False,
        )
        t1 = time.time()
        print(f"[{idx}] chat_template: {t1-t0:.3f}s")

        # Process labels
        
        target = item.get("output", item.get("disease", ""))
        label_ids = self.processor.tokenizer(
            target,
            padding="max_length",
            truncation=True,
            max_length=self.max_length,
            return_tensors="pt"
        ).input_ids
        t2 = time.time()
        print(f"[{idx}] label_tokenize: {t2-t1:.3f}s")
        # Clean up temp file

        return {
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "labels": label_ids.squeeze(0),
        }



In [47]:
    # def __getitem__(self, idx):
    #     # Process data on-demand instead of pre-processing everything
    #     item = self.raw_data[idx]
        
    #     # Create image
    #     arr = np.array(item["input"], dtype=np.uint8)
    #     if arr.size == int(np.sqrt(arr.size))**2:
    #         side = int(np.sqrt(arr.size))
    #         pil_img = Image.fromarray(arr.reshape(side, side)).convert("RGB")
    #         pil_img = pil_img.resize((224, 224))
    #     else:
    #         side = int(np.sqrt(arr.size // 3))
    #         pil_img = Image.fromarray(arr.reshape(side, side, 3)).convert("RGB")
        
    #     # Save temporarily (you can optimize this further by caching)
    #     temp_path = os.path.join(self.cache_dir, f"temp_{idx}.png")
    #     pil_img.save(temp_path)
        
    #     # Process with chat template
    #     messages = [
    #         {"role": "system", "content": [{"type": "text", "text": "You are a helpful assistant."}]},
    #         {"role": "user", "content": [
    #             {"type": "image", "url": temp_path},
    #             {"type": "text", "text": item["instruction"].strip()}
    #         ]}
    #     ]
        
    #     inputs = self.processor.apply_chat_template(
    #         messages,
    #         add_generation_prompt=True,
    #         tokenize=True,
    #         return_dict=True,
    #         return_tensors="pt",
    #         max_length=self.max_length,
    #         padding="max_length",
    #         truncation=False,
    #     )

Cell 15: Data Processing Pipeline

In [48]:
# import pickle

# # Load raw dataset
# t0 = time.time()
# with open("processed_dataset.pkl", "rb") as f:
#     raw_dataset = pickle.load(f)
# t1 = time.time()
# print(f"Loaded raw dataset with {len(raw_dataset)} items in {t1-t0:.3f}s")

# #_______________Debug___________________________________________________
# print("Raw dataset length:", len(raw_dataset))
# print("Special tokens:", processor.tokenizer.additional_special_tokens)

# # Check that <image> token is present in each instruction
# # i = 0
# # for item in raw_dataset:
# #     if "<image>" in item["instruction"]:
# #         print(f"✅ {i}. Image token is present in the prompt!")
# #     else:
# #         print(f"❌ {i}. Image token is NOT present in the prompt!")
# #         print("❌ Instruction:", repr(item["instruction"]))
# #         break  # Stop after first failure
# #     i += 1
# #______________________________________________________________________

# # Preprocess



In [49]:
import pickle

# Replace your data processing section with this:
from sklearn.model_selection import train_test_split

# Load raw dataset
t0 = time.time()
with open("processed_dataset.pkl", "rb") as f:
    raw_dataset = pickle.load(f)
t1 = time.time()
print(f"Loaded raw dataset with {len(raw_dataset)} items in {t1-t0:.3f}s")

# Split raw data (not processed data)
train_raw, val_raw = train_test_split(raw_dataset, test_size=0.2, random_state=42)

# Create streaming datasets
train_dataset = StreamingXrayDataset(train_raw, processor)
val_dataset = StreamingXrayDataset(val_raw, processor)

print("Train dataset length:", len(train_dataset))
print("Validation dataset length:", len(val_dataset))
print("Train example:", {k: v.shape for k, v in train_dataset[0].items()})
print("Validation example:", {k: v.shape for k, v in val_dataset[0].items()})

t2 = time.time()
print(f"Created streaming datasets in {t2-t1:.3f}s")

Loaded raw dataset with 8000 items in 55.432s
Train dataset length: 6400
Validation dataset length: 1600
[0] chat_template: 0.611s
[0] label_tokenize: 0.043s
Train example: {'input_ids': torch.Size([512]), 'attention_mask': torch.Size([512]), 'pixel_values': torch.Size([3, 896, 896]), 'labels': torch.Size([512])}
[0] chat_template: 0.072s
[0] label_tokenize: 0.031s
Validation example: {'input_ids': torch.Size([512]), 'attention_mask': torch.Size([512]), 'pixel_values': torch.Size([3, 896, 896]), 'labels': torch.Size([512])}
Created streaming datasets in 11.150s


In [50]:
print("Example Dataset:", train_dataset[0])

[0] chat_template: 0.057s
[0] label_tokenize: 0.004s
Example Dataset: {'input_ids': tensor([     0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,      0,      0,      0,      0,      0,      0,
             0,      0,      0,     

In [51]:
# import json

# print(raw_dataset[0])
# print(processed_data[0])

# def tensor_to_list(obj):
#     if hasattr(obj, "tolist"):
#         return obj.tolist()
#     elif isinstance(obj, dict):
#         return {k: tensor_to_list(v) for k, v in obj.items()}
#     elif isinstance(obj, list):
#         return [tensor_to_list(v) for v in obj]
#     else:
#         return obj

# print(raw_dataset[0])
# print(processed_data[0])

# # Convert all tensors to lists for JSON serialization
# processed_data_serializable = [tensor_to_list(item) for item in processed_data]

# with open("processed_dataset.json", "w", encoding="utf-8") as f:
#     json.dump(processed_data_serializable, f, ensure_ascii=False, indent=2)

# print("Saved as processed_dataset.json")

In [52]:
import torch
import torch.nn as nn
from torch.utils.data import Dataset
from transformers import Trainer, TrainingArguments
from peft import get_peft_model, LoraConfig, TaskType
from bitsandbytes.optim import Adam8bit

In [53]:
lora_config = LoraConfig(
    r=4,
    lora_alpha=8,
    target_modules=["q_proj", "v_proj"],  # Add more targets
    lora_dropout=0.05,
    bias="none",
    task_type=TaskType.CAUSAL_LM 
)

model = get_peft_model(model, lora_config) 

In [54]:
print(model)

PeftModelForCausalLM(
  (base_model): LoraModel(
    (model): Gemma3ForConditionalGeneration(
      (model): Gemma3Model(
        (vision_tower): SiglipVisionModel(
          (vision_model): SiglipVisionTransformer(
            (embeddings): SiglipVisionEmbeddings(
              (patch_embedding): Conv2d(3, 1152, kernel_size=(14, 14), stride=(14, 14), padding=valid)
              (position_embedding): Embedding(4096, 1152)
            )
            (encoder): SiglipEncoder(
              (layers): ModuleList(
                (0-26): 27 x SiglipEncoderLayer(
                  (layer_norm1): LayerNorm((1152,), eps=1e-06, elementwise_affine=True)
                  (self_attn): SiglipAttention(
                    (k_proj): Linear(in_features=1152, out_features=1152, bias=True)
                    (v_proj): lora.Linear(
                      (base_layer): Linear(in_features=1152, out_features=1152, bias=True)
                      (lora_dropout): ModuleDict(
                        (defaul

In [55]:
loss_fn = nn.CrossEntropyLoss(
    ignore_index=processor.tokenizer.pad_token_id
)

In [56]:
class CustomTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
        labels = inputs.pop("labels")
        outputs = model(**inputs, labels=labels)
        logits = outputs.logits
        loss = outputs.loss
        
        # # Shift labels for causal LM (important!)
        # shift_logits = logits[..., :-1, :].contiguous()
        # shift_labels = labels[..., 1:].contiguous()
        
        # # Flatten for CrossEntropyLoss
        # loss_fct = nn.CrossEntropyLoss(ignore_index=processor.tokenizer.pad_token_id)
        # loss = loss_fct(
        #     shift_logits.view(-1, shift_logits.size(-1)),
        #     shift_labels.view(-1)
        # )
        
        return (loss, outputs) if return_outputs else loss


In [57]:
from torch.optim import AdamW
optimizer = AdamW(model.parameters(), lr=1e-5, fused=True)

In [58]:
import os
import json
from transformers import TrainingArguments

training_args = TrainingArguments(
    output_dir="./results",
    per_device_train_batch_size=4,
    gradient_accumulation_steps=4,
    num_train_epochs=1,

    # # 🚀 Make training faster
    # eval_strategy="no",    # disable eval during training (run separately later)
    # save_strategy="epoch",       # save only once per epoch
    # save_total_limit=1,          # keep just final checkpoint
    # load_best_model_at_end=False,# skip extra eval+loading

    # Logging
    logging_strategy="steps",
    logging_steps=100,           # less frequent logs

    # Performance
    bf16=True,
    fp16=False,
    optim="adamw_torch_fused",
    dataloader_num_workers=4,
    remove_unused_columns=False,
    report_to=None,
)


# 1. TRAINING ARGUMENTS WITH CHECKPOINTING
# training_args = TrainingArguments(
#     output_dir="./results",                    # Main checkpoint directory
#     per_device_train_batch_size=4,
#     per_device_eval_batch_size=4,
#     gradient_accumulation_steps=4,
#     num_train_epochs=1,
#     eval_strategy="no",
#     eval_steps=200,
    
#     # ✅ CHECKPOINTING SETTINGS
#     save_strategy="epoch",                     # Save every N steps
#     save_steps=200,                           # Save checkpoint every 200 steps
#     save_total_limit=2,                       # Keep only 5 most recent checkpoints
#     save_on_each_node=True,                   # Save on each distributed node
    
#     # ✅ RESUMING SETTINGS  
#     resume_from_checkpoint=None,              # Set this when resuming (see below)
#     load_best_model_at_end=True,             # Load best checkpoint at end
#     metric_for_best_model="eval_loss",       # Which metric to track for "best"
#     greater_is_better=False,                 # Lower eval_loss is better
    
#     # ✅ LOGGING & MONITORING
#     logging_dir="./logs",
#     logging_steps=50,
#     logging_strategy="steps",
    
#     # Other settings
#     bf16=True,
#     fp16=False,
#     optim="adamw_torch_fused",
#     dataloader_num_workers=4,
#     remove_unused_columns=False,
#     report_to=None,
# )

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [59]:
# def find_latest_checkpoint(output_dir):
#     """Find the most recent checkpoint in output_dir"""
#     if not os.path.exists(output_dir):
#         return None
    
#     checkpoints = []
#     for item in os.listdir(output_dir):
#         if item.startswith("checkpoint-") and os.path.isdir(os.path.join(output_dir, item)):
#             try:
#                 step_num = int(item.split("-")[1])
#                 checkpoints.append((step_num, os.path.join(output_dir, item)))
#             except ValueError:
#                 continue
    
#     if not checkpoints:
#         return None
    
#     # Return path to checkpoint with highest step number
#     latest_step, latest_path = max(checkpoints, key=lambda x: x[0])
#     return latest_path

In [60]:
# def save_emergency_checkpoint(trainer, step_name="emergency"):
#     """Save checkpoint manually (e.g., if you need to stop training)"""
#     checkpoint_dir = f"./results/checkpoint-{step_name}"
#     trainer.save_model(checkpoint_dir)
#     trainer.save_state()
#     print(f"💾 Emergency checkpoint saved to: {checkpoint_dir}")
#     return checkpoint_dir

In [61]:
# def setup_training_with_resume(output_dir="./results"):
#     """Setup training arguments with automatic resume detection"""
    
#     # Check for existing checkpoints
#     latest_checkpoint = find_latest_checkpoint(output_dir)
    
#     if latest_checkpoint:
#         print(f"✅ Found checkpoint: {latest_checkpoint}")
#         print("Training will resume from this checkpoint")
        
#         # Load trainer state to see progress
#         trainer_state_file = os.path.join(latest_checkpoint, "trainer_state.json")
#         if os.path.exists(trainer_state_file):
#             with open(trainer_state_file, 'r') as f:
#                 trainer_state = json.load(f)
#                 print(f"📊 Resuming from epoch {trainer_state.get('epoch', 'unknown')}")
#                 print(f"📊 Resuming from global step {trainer_state.get('global_step', 'unknown')}")
#                 print(f"📊 Best eval loss so far: {trainer_state.get('best_metric', 'unknown')}")
        
#         resume_from = latest_checkpoint
#     else:
#         print("🆕 No existing checkpoints found. Starting fresh training.")
#         resume_from = None
    
#     # Update training arguments
#     training_args.resume_from_checkpoint = resume_from
#     return training_args, resume_from

In [62]:
# training_args, resume_checkpoint = setup_training_with_resume("./results")

In [63]:
# def compute_metrics(eval_pred):
#     predictions, labels = eval_pred
#     # Add your custom metrics here (BLEU, ROUGE, etc.)
#     return {"perplexity": torch.exp(torch.tensor(predictions.mean())).item()}

In [64]:
from transformers import get_linear_schedule_with_warmup

# Calculate total training steps
num_training_steps = len(train_dataset) * training_args.num_train_epochs // (
    training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps
)

scheduler = get_linear_schedule_with_warmup(
    optimizer,
    num_warmup_steps=int(0.1 * num_training_steps),  # 10% warmup
    num_training_steps=num_training_steps
)


In [65]:
# model = get_peft_model(model, lora_config) 
model.enable_input_require_grads()

In [66]:
trainer = CustomTrainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    # data_collator=MultimodalDataCollator(processor),
    tokenizer=processor.tokenizer,
    optimizers=(optimizer, scheduler),
    # compute_metrics=compute_metrics, 
)

  trainer = CustomTrainer(
Using auto half precision backend


In [67]:
model.config.use_cache = False  # Disable KV cache during training
# model.gradient_checkpointing_enable()  # Save memory

In [68]:
# import wandb  # Optional but recommended

# Initialize wandb (optional)
# wandb.init(project="lung-xray-gemma", name="lora-finetuning")

import os
os.environ["WANDB_DISABLED"] = "true"

In [69]:
# from transformers import EarlyStoppingCallback

# trainer.add_callback(EarlyStoppingCallback(early_stopping_patience=3))

In [70]:
# print("Running validation before training...")
# eval_results = trainer.evaluate()
# print(f"Pre-training eval loss: {eval_results['eval_loss']:.4f}")

In [71]:
torch.cuda.empty_cache()
print(f"GPU memory: {torch.cuda.memory_allocated()/1e9:.2f}GB allocated")
print(f"GPU memory: {torch.cuda.memory_reserved()/1e9:.2f}GB reserved")

# trainer.train()

# # Save LoRA adapter and processor
# model.save_pretrained("./lung_lora")
# processor.save_pretrained("./lung_lora")

GPU memory: 17.24GB allocated
GPU memory: 17.28GB reserved


In [72]:
# -------------------------
# Simple training monitor
# -------------------------
import logging
from transformers import TrainerCallback

# 1. Enable detailed logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger("transformers.trainer")
logger.setLevel(logging.DEBUG)

# 2. Define progress callback
class SimpleProgressCallback(TrainerCallback):
    def on_train_begin(self, args, state, control, **kwargs):
        print("🚀 Training starting...")
        return control

    def on_step_begin(self, args, state, control, **kwargs):
        print(f"⚡ Step {state.global_step}: Starting training step...")
        return control

    def on_step_end(self, args, state, control, **kwargs):
        print(f"✅ Step {state.global_step}: Completed")
        return control

    def on_epoch_begin(self, args, state, control, **kwargs):
        print(f"🚀 Epoch {state.epoch}: Starting...")
        return control

    def on_log(self, args, state, control, **kwargs):
        print(f"📊 Step {state.global_step}: Logging metrics")
        return control

# 3. Attach AFTER you create the trainer
trainer.add_callback(SimpleProgressCallback())

# 4. Test single training step
# print("🧪 Testing single training step...")
# trainer.args.max_steps = 1          # only run 1 step
# trainer.args.logging_steps = 1
# trainer.args.save_steps = 1

# try:
#     trainer.train()
#     print("✅ Single step completed! Training is working.")

#     # Reset for full training
# trainer.args.max_steps = -1
# trainer.args.num_train_epochs = 3
# trainer.args.logging_steps = 10
# trainer.args.save_steps = 200

#     print("🚀 Starting full training...")
#     trainer.train()

# except KeyboardInterrupt:
#     print("⚠️ Training interrupted by user")
# except Exception as e:
#     print(f"❌ Training failed at: {e}")
#     import traceback
#     traceback.print_exc()


In [73]:
import torch
print("CUDA available:", torch.cuda.is_available())
print("Device:", torch.cuda.get_device_name(0) if torch.cuda.is_available() else "CPU")

CUDA available: True
Device: NVIDIA GeForce RTX 3060 Laptop GPU


In [None]:
trainer.train()

model.save_pretrained("./lung_lora")
processor.save_pretrained("./lung_lora")

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.
Currently training with a batch size of: 4
***** Running training *****
  Num examples = 6,400
  Num Epochs = 1
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 4
  Total optimization steps = 400
  Number of trainable parameters = 1,611,776


🚀 Training starting...
🚀 Epoch 0: Starting...


In [None]:
# import time
# from transformers import Trainer

# class DebugTrainer(CustomTrainer):
#     def training_step(self, model, inputs):
#         start = time.time()
#         out = super().training_step(model, inputs)
#         print(f"[Step] One training step took {time.time() - start:.3f}s")
#         return out

# # Replace Trainer with DebugTrainer
# trainer = DebugTrainer(
#     model=model,
#     args=training_args,
#     train_dataset=train_dataset,
#     eval_dataset=val_dataset,
#     tokenizer=processor.tokenizer,
#     optimizers=(optimizer, scheduler),
# )

  trainer = DebugTrainer(
Using auto half precision backend


In [None]:
# try:
#     trainer.train(resume_from_checkpoint=None) # resume_checkpoint
    
#     # Save final model
#     final_save_path = "./lung_xray_gemma_final"
#     trainer.save_model(final_save_path)
#     processor.save_pretrained(final_save_path)
#     print(f"✅ Final model saved to: {final_save_path}")
    
# except KeyboardInterrupt:
#     print("\n⚠️  Training interrupted by user!")
#     emergency_path = save_emergency_checkpoint(trainer, "interrupted")
#     print(f"💾 Emergency save completed: {emergency_path}")
    
# except Exception as e:
#     print(f"\n❌ Training failed with error: {e}")
#     emergency_path = save_emergency_checkpoint(trainer, "error")
#     print(f"💾 Emergency save completed: {emergency_path}")
#     raise

The tokenizer has new PAD/BOS/EOS tokens that differ from the model config and generation config. The model config and generation config were aligned accordingly, being updated with the tokenizer's values. Updated tokens: {'eos_token_id': 1, 'bos_token_id': 2, 'pad_token_id': 0}.
Currently training with a batch size of: 4
***** Running training *****
  Num examples = 6,400
  Num Epochs = 1
  Instantaneous batch size per device = 4
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 4
  Total optimization steps = 400
  Number of trainable parameters = 1,611,776
wandb: Currently logged in as: jiraphat-sabutr (jiraphat-sabutr-king-mongkut-s-institute-of-technology-l) to https://api.wandb.ai. Use `wandb login --relogin` to force relogin
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  Expected `list[str]` but got `tuple` - serialized value may not be as expected
  return self.__pydantic_serializer__.to_py

---

## Helper cells inserted on 2025-09-23T07:06:56.255374Z
These helper functions were appended here: metric calculators, plotting helpers, evaluation wrapper, history extractor, and Grad-CAM utilities (one function per cell).
---

### Helper: imports and device

In [None]:
# imports and device
import torch
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import r2_score, mean_absolute_error, mean_squared_error, accuracy_score
from typing import Dict, Any, Tuple, List
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print("Device:", device)

Device: cuda


### compute_regression_metrics, compute_classification_metrics, compute_text_metrics

In [None]:
def compute_regression_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """Return R2, MAE, MSE."""
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred).ravel()
    return {
        "r2": float(r2_score(y_true, y_pred)),
        "mae": float(mean_absolute_error(y_true, y_pred)),
        "mse": float(mean_squared_error(y_true, y_pred)),
    }

def compute_classification_metrics(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
    """Return accuracy. y_pred may be probabilities/logits; this function expects integer labels or will argmax if needed."""
    y_true = np.asarray(y_true).ravel()
    y_pred = np.asarray(y_pred)
    if y_pred.ndim > 1:
        y_pred = np.argmax(y_pred, axis=1)
    y_pred = y_pred.ravel()
    return {"accuracy": float(accuracy_score(y_true, y_pred))}

def compute_text_metrics(references: List[str], predictions: List[str]) -> Dict[str, float]:
    """Simple text metrics: exact match rate and average token overlap (simple proxy)."""
    exact = sum(1 for r,p in zip(references, predictions) if r.strip() == p.strip())
    exact_rate = exact / max(1, len(references))
    overlaps = []
    for r,p in zip(references, predictions):
        rs = set(r.split())
        ps = set(p.split())
        if len(rs) + len(ps) == 0:
            overlaps.append(1.0)
        else:
            overlaps.append(len(rs & ps) / max(1, len(rs | ps)))
    return {"exact_match": float(exact_rate), "token_overlap": float(np.mean(overlaps))}

### plot_metrics(history)

In [None]:
def plot_metrics(history: Dict[str, List[float]], title_prefix: str = "Training"):
    """history is a dict like {'train_loss': [...], 'eval_loss': [...], 'accuracy': [...], 'r2': [...]}"""
    keys = list(history.keys())
    n = len(keys)
    plt.figure(figsize=(6, 3*n))
    for i, k in enumerate(keys):
        plt.subplot(n, 1, i+1)
        plt.plot(history[k], marker='o')
        plt.title(f"{title_prefix} - {k}")
        plt.grid(True)
    plt.tight_layout()
    plt.show()

### evaluate_trainer(trainer, eval_dataset, task, ...)

In [None]:
def evaluate_trainer(trainer, eval_dataset, task: str = 'regression', label_extractor=None, tokenizer=None, text_label_key='output'):
    """Run trainer.predict and compute metrics.
    - task: 'regression' | 'classification' | 'text'
    - label_extractor: optional callable to extract numeric labels from label_ids or dataset entries
    - tokenizer: required for text decoding (if task == 'text')
    - text_label_key: dataset key that contains reference text (for dataset entries) when task=='text'
    Returns: dict(metrics) and raw predictions
    """
    print('Running prediction (this may take some time) ...')
    pred_out = trainer.predict(eval_dataset)
    preds = pred_out.predictions
    label_ids = pred_out.label_ids

    # If text generation: predictions will often be token ids (seq2seq) -> decode
    if task == 'text':
        if tokenizer is None:
            raise ValueError('tokenizer required for text task')
        if isinstance(preds, tuple) or hasattr(preds, 'logits'):
            preds = preds[0] if isinstance(preds, tuple) else preds.logits
        if preds.ndim == 3:
            token_ids = np.argmax(preds, axis=-1)
        else:
            token_ids = preds.astype(int)
        decoded = tokenizer.batch_decode(token_ids, skip_special_tokens=True)
        refs = [ex[text_label_key] for ex in eval_dataset]
        metrics = compute_text_metrics(refs, decoded)
        return metrics, {'predictions': decoded, 'references': refs}

    if isinstance(preds, tuple):
        preds = preds[0]
    preds = np.asarray(preds)
    if preds.ndim > 1 and preds.shape[-1] > 1:
        if task == 'classification':
            y_pred = np.argmax(preds, axis=-1)
        else:
            y_pred = preds[:, 0]
    else:
        y_pred = preds.ravel()

    if label_ids is None:
        if label_extractor is not None:
            y_true = label_extractor(eval_dataset)
        else:
            try:
                y_true = np.array([ex['labels'] for ex in eval_dataset])
            except Exception as e:
                raise RuntimeError('Could not get labels from trainer output or eval_dataset. Provide label_extractor.') from e
    else:
        y_true = np.asarray(label_ids).ravel()

    if task == 'regression':
        metrics = compute_regression_metrics(y_true, y_pred)
    elif task == 'classification':
        metrics = compute_classification_metrics(y_true, y_pred)
    else:
        raise ValueError('Unknown task')

    return metrics, {'y_true': y_true, 'y_pred': y_pred}

### extract_history_from_trainer(trainer)

In [None]:
def extract_history_from_trainer(trainer):
    """Return dict with 'train_loss' and any logged eval metrics as lists.
    Uses trainer.state.log_history (list of dicts).
    """
    logs = trainer.state.log_history
    train_loss = [entry['loss'] for entry in logs if 'loss' in entry]
    eval_keys = set(k for e in logs for k in e.keys() if k.startswith('eval_'))
    history = {'train_loss': train_loss}
    for k in sorted(eval_keys):
        history[k] = [e[k] for e in logs if k in e]
    return history

### Grad-CAM for CNN models (GradCAM_CNN) and overlay helper

In [None]:
import torch.nn.functional as F
from PIL import Image

class GradCAM_CNN:
    def __init__(self, model: torch.nn.Module, target_layer: torch.nn.Module):
        self.model = model.eval()
        self.target_layer = target_layer
        self.activations = None
        self.gradients = None
        def forward_hook(module, inp, out):
            self.activations = out.detach()
        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)

    def __call__(self, input_tensor: torch.Tensor, target_index: int = None) -> np.ndarray:
        device = next(self.model.parameters()).device
        input_tensor = input_tensor.to(device)
        outputs = self.model(input_tensor)
        if isinstance(outputs, tuple):
            logits = outputs[0]
        else:
            logits = outputs
        if target_index is None:
            target_index = int(logits.argmax(dim=-1).cpu().numpy().ravel()[0])
        score = logits[:, target_index].squeeze()
        self.model.zero_grad()
        score.backward(retain_graph=True)
        grads = self.gradients.cpu().numpy()[0]
        acts = self.activations.cpu().numpy()[0]
        weights = np.mean(grads, axis=(1,2))
        cam = np.zeros(acts.shape[1:], dtype=np.float32)
        for i, w in enumerate(weights):
            cam += w * acts[i]
        cam = np.maximum(cam, 0)
        cam = cam - cam.min()
        if cam.max() > 0:
            cam = cam / cam.max()
        cam = np.uint8(255 * cam)
        cam = Image.fromarray(cam).resize((input_tensor.shape[-1], input_tensor.shape[-2]), resample=Image.BILINEAR)
        cam = np.array(cam) / 255.0
        return cam

def overlay_heatmap(image: np.ndarray, heatmap: np.ndarray, alpha=0.5):
    if image.ndim == 2:
        image = np.stack([image]*3, axis=-1)
    cmap = plt.cm.jet
    colored = cmap(heatmap)[:,:,:3]
    overlay = (1-alpha)*image + alpha*colored
    overlay = np.clip(overlay, 0, 1)
    return overlay

### Grad-CAM for ViT / patch-based models (best-effort)

In [None]:
def grad_cam_vit_patch(model, processor, pil_image, target_token_index=None, model_image_encoder_attr_candidates=None):
    """Compute a patch-level Grad-CAM heatmap for ViT-style encoders inside HF models.
    - model: the HF model (AutoModelForImageTextToText)
    - processor: processor to convert pil_image into pixel_values
    - pil_image: PIL.Image
    - target_token_index: for image->text models this is a proxy; we use last token argmax by default.
    - model_image_encoder_attr_candidates: list of attribute paths to try to find the image encoder or patch embedding module in the model.
    Returns heatmap resized to original image size (HxW, float [0,1]).
    """
    import torch.nn.functional as F
    model.eval()
    if model_image_encoder_attr_candidates is None:
        model_image_encoder_attr_candidates = [
            'vision_model.embeddings.patch_embeddings',
            'vision_model.embeddings',
            'encoder.vit.embeddings',
            'visual_encoder.embeddings',
        ]

    image_inputs = processor(images=pil_image, return_tensors='pt').to(next(model.parameters()).device)
    outputs = model(**image_inputs)
    logits = None
    if hasattr(outputs, 'logits'):
        logits = outputs.logits
    elif isinstance(outputs, tuple) and len(outputs) > 0:
        logits = outputs[0]
    else:
        raise RuntimeError('Cannot find logits in model outputs (model-specific).')

    if target_token_index is None:
        if logits.ndim == 3:
            target_token_index = int(logits.argmax(dim=-1)[0, -1].cpu().numpy())
        else:
            target_token_index = int(logits.argmax(dim=-1)[0].cpu().numpy())

    target_module = None
    for path in model_image_encoder_attr_candidates:
        try:
            m = model
            for part in path.split('.'):
                m = getattr(m, part)
            target_module = m
            break
        except Exception:
            continue
    if target_module is None:
        raise RuntimeError('Could not find a suitable image encoder module automatically. Please pass model_image_encoder_attr_candidates pointing to the patch embedding module or final encoder layer.')

    activations = None
    gradients = None
    def forward_hook(module, inp, out):
        nonlocal activations
        activations = out.detach()
    def backward_hook(module, grad_in, grad_out):
        nonlocal gradients
        gradients = grad_out[0].detach()

    target_module.register_forward_hook(forward_hook)
    target_module.register_backward_hook(backward_hook)

    score = logits[0, -1, target_token_index]
    model.zero_grad()
    score.backward(retain_graph=True)
    if activations is None or gradients is None:
        raise RuntimeError('Hooks did not capture activations/gradients. You may need to adjust the target module path.')

    acts = activations.cpu().numpy()[0]
    grads = gradients.cpu().numpy()[0]
    if acts.ndim == 2:
        acts_t = acts.T
        weights = np.mean(grads, axis=0)
        cam_patch = np.dot(weights, acts_t)
        n_patches = cam_patch.shape[0]
        side = int(np.sqrt(n_patches))
        cam_map = cam_patch.reshape(side, side)
    else:
        weights = np.mean(grads, axis=(1,2))
        cam_map = np.zeros(acts.shape[1:], dtype=np.float32)
        for i, w in enumerate(weights):
            cam_map += w * acts[i]
    cam_map = np.maximum(cam_map, 0)
    if cam_map.max() > 0:
        cam_map = cam_map / cam_map.max()
    from PIL import Image
    cam_img = Image.fromarray(np.uint8(255 * cam_map)).resize(pil_image.size, resample=Image.BILINEAR)
    cam_arr = np.array(cam_img) / 255.0
    return cam_arr

### Example usage (commented)

In [None]:
# Example usage (commented)
# history = extract_history_from_trainer(trainer)
# plot_metrics(history, title_prefix='Gemma Training')
# metrics, raw = evaluate_trainer(trainer, eval_dataset=val_dataset, task='regression')
# from PIL import Image
# pil = Image.fromarray((val_image*255).astype('uint8')).convert('RGB')
# heat = grad_cam_vit_patch(model, processor, pil)
# overlay = overlay_heatmap(np.array(pil)/255.0, heat)
# plt.imshow(overlay); plt.axis('off')