In [None]:
!pip install -U pip
!pip install --upgrade --no-cache-dir git+https://github.com/huggingface/transformers.git
!pip install --upgrade --no-cache-dir torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
!pip install -U accelerate bitsandbytes einops

Collecting pip
  Downloading pip-25.1.1-py3-none-any.whl.metadata (3.6 kB)
Downloading pip-25.1.1-py3-none-any.whl (1.8 MB)
[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m1.8/1.8 MB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 24.1.2
    Uninstalling pip-24.1.2:
      Successfully uninstalled pip-24.1.2
Successfully installed pip-25.1.1
Collecting git+https://github.com/huggingface/transformers.git
  Cloning https://github.com/huggingface/transformers.git to /tmp/pip-req-build-j10tt1mv
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-j10tt1mv
  Resolved https://github.com/huggingface/transformers.git to commit ccf2ca162e33f381e454cdb74bf4b41a51ab976d
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml

In [None]:
# =============================================================================
# OPTIMIZED VERSION - Faster Training Startup
# =============================================================================

import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
import os
import json
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from transformers import TrainingArguments, Trainer
from peft import LoraConfig, get_peft_model
import multiprocessing as mp


In [None]:
# =============================================================================
# 1. DEVICE SETUP - Define device early and consistently
# =============================================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"🔧 Using device: {device}")
print(f"🔧 CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
    print(f"🔧 GPU: {torch.cuda.get_device_name()}")
    print(f"🔧 GPU Memory: {torch.cuda.get_device_properties(0).total_memory / 1024**3:.1f} GB")


🔧 Using device: cuda
🔧 CUDA available: True
🔧 GPU: NVIDIA A100-SXM4-40GB
🔧 GPU Memory: 39.6 GB


In [None]:
# =============================================================================
# 1. FASTER MODEL LOADING
# =============================================================================
def load_model_optimized():
    model_name = "google/medgemma-4b-it"

    print("Loading processor...")
    processor = AutoProcessor.from_pretrained(model_name)

    print("Loading model with optimizations...")
    # Use torch.bfloat16 for faster loading if available
    dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16

    model = AutoModelForImageTextToText.from_pretrained(
        model_name,
        #device_map="auto",
        torch_dtype=dtype,
        trust_remote_code=True,
        # Add these for faster loading:
        low_cpu_mem_usage=True,  # Reduces CPU memory usage during loading
        use_safetensors=True,    # Faster loading format
    )
    model = model.to(device)  # ✅ ADD THIS LINE
    return model, processor

In [None]:
# =============================================================================
# 2. OPTIMIZED DATASET - PRE-PROCESS DATA
# =============================================================================
class OptimizedRSNADataset(Dataset):
    def __init__(self, vis_root, ann_path, processor, device='cuda', preprocess=True):
        self.vis_root = vis_root
        self.processor = processor
        self.device = device

        # Load annotations
        with open(ann_path, 'r') as f:
            self.ann = json.load(f)  # First 100 only

        # PRE-PROCESS everything during initialization
        if preprocess:
            print("🔄 Pre-processing dataset (this happens once)...")
            self.processed_data = []
            self._preprocess_all_data()
        else:
            self.processed_data = None

    def _preprocess_all_data(self):
        """Pre-process all data once during initialization"""
        for i, info in enumerate(self.ann):
            try:
                processed_item = self._process_single_item(info)
                if processed_item:
                    self.processed_data.append(processed_item)

                if i % 10 == 0:
                    print(f"Processed {i+1}/{len(self.ann)} items...")

            except Exception as e:
                print(f"Error processing item {i}: {e}")
                continue

    def _process_single_item(self, info):
        """Process a single item and return tensors ready for training"""
        try:
            # Load and process image
            image_path = os.path.join(self.vis_root, info.get("key", info.get("input")))
            image = Image.open(image_path).convert("RGB")

            # Create prompt and answer
            prompt = "Describe this X-ray"
            answer = info.get("output", "No findings")

            # Build conversation
            messages = [
                {"role": "user", "content": [{"type": "image"}, {"type": "text", "text": prompt}]},
                {"role": "model", "content": answer}
            ]

            # Process with processor
            formatted_text = self.processor.apply_chat_template(
                messages, tokenize=False, add_generation_prompt=False
            )

            inputs = self.processor(
                text=formatted_text,
                images=image,
                return_tensors="pt",
                max_length=512,  # Reduced for speed
                truncation=True,
                padding=False
            )

            # Create labels
            input_ids = inputs['input_ids'].squeeze(0)
            labels = input_ids.clone()

            # Mask tokens before model reply
            model_token = self.processor.tokenizer.convert_tokens_to_ids("model")
            model_indices = (labels == model_token).nonzero(as_tuple=True)[0]

            if len(model_indices) > 0:
                cut_point = model_indices[0].item() + 1
                labels[:cut_point] = -100

            return {
                'input_ids': input_ids,
                'attention_mask': inputs['attention_mask'].squeeze(0),
                'pixel_values': inputs['pixel_values'].squeeze(0),
                'labels': labels
            }

        except Exception as e:
            return None

    def __len__(self):
        return len(self.processed_data) if self.processed_data else len(self.ann)

    def __getitem__(self, index):
        if self.processed_data:
            # Data is already processed - just return it
            item = self.processed_data[index]
            return {
                'input_ids': item['input_ids'].to(self.device),
                'attention_mask': item['attention_mask'].to(self.device),
                'pixel_values': item['pixel_values'].to(self.device),
                'labels': item['labels'].to(self.device)
            }
        else:
            # Fallback to on-the-fly processing
            return self._process_single_item(self.ann[index])

In [None]:
# =============================================================================
# 3. SIMPLIFIED COLLATE FUNCTION
# =============================================================================
def fast_collate_fn(batch):
    """Simplified, faster collate function"""
    # Filter None items
    batch = [item for item in batch if item is not None]
    if not batch:
        return {}

    # Simple padding - get max length
    max_len = max(item['input_ids'].size(0) for item in batch)
    max_len = max(max_len, 256)  # Minimum length

    batch_size = len(batch)
    device = batch[0]['input_ids'].device

    # Pre-allocate tensors
    input_ids = torch.full((batch_size, max_len), 0, dtype=torch.long, device=device)
    attention_mask = torch.zeros((batch_size, max_len), dtype=torch.long, device=device)
    labels = torch.full((batch_size, max_len), -100, dtype=torch.long, device=device)

    # Stack pixel values (assuming same shape)
    pixel_values = torch.stack([item['pixel_values'] for item in batch])

    # Fill tensors
    for i, item in enumerate(batch):
        seq_len = item['input_ids'].size(0)
        input_ids[i, :seq_len] = item['input_ids']
        attention_mask[i, :seq_len] = item['attention_mask']
        labels[i, :seq_len] = item['labels']

    return {
        'input_ids': input_ids,
        'attention_mask': attention_mask,
        'pixel_values': pixel_values,
        'labels': labels
    }


In [None]:
# =============================================================================
# 4. LORA CONFIGURATION
# =============================================================================
def setup_training_optimized():
  model, processor = load_model_optimized()
  lora_config = LoraConfig(
      r=8,
      lora_alpha=32,
      target_modules=[
          "model.language_model.layers.33.self_attn.q_proj",
          "model.language_model.layers.33.self_attn.v_proj",
          "model.vision_tower.vision_model.encoder.layers.26.self_attn.q_proj",
          "model.vision_tower.vision_model.encoder.layers.26.self_attn.v_proj",
          "lm_head"
      ],
      lora_dropout=0.05,
      bias="none",
      task_type="CAUSAL_LM"
  )

  print("Applying LoRA configuration...")
  model = get_peft_model(model, lora_config)

  # Count trainable parameters
  total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
  print(f"✅ Total trainable parameters: {total_params:,}")
  # Create datasets with preprocessing
  print("Creating training dataset...")
  train_dataset = OptimizedRSNADataset(
      vis_root="/content/drive/MyDrive/FilteredData/filtered_train_images",
      ann_path="/content/drive/MyDrive/FilteredData/augmented_same_size.json",
      processor=processor,
      preprocess=True  # Pre-process everything
  )

  print("Creating validation dataset...")
  val_dataset = OptimizedRSNADataset(
      vis_root="/content/drive/MyDrive/FilteredData/filtered_val_images",
      ann_path="/content/drive/MyDrive/FilteredData/converted_val_augmented_paraphrased.json",
      processor=processor,
      preprocess=True  # Pre-process everything
  )

  # Optimized training arguments
  training_args = TrainingArguments(
      output_dir="/content/drive/MyDrive/FilteredData/MEDGEMMA",
      num_train_epochs=1,
      per_device_train_batch_size=4,  # Slightly larger
      gradient_accumulation_steps=2,  # Reduced
      learning_rate=2e-5,
      weight_decay=0.01,
      logging_steps=10,
      save_steps=50,  # Less frequent saves
      eval_steps=50,
      eval_strategy="steps",
      save_strategy="steps",
      warmup_steps=50,  # Reduced warmup
      fp16=True,  # Enable for speed
      dataloader_pin_memory=False,  # Disable for speed
      dataloader_num_workers=0,  # Single worker for simplicity
      remove_unused_columns=False,
      report_to=None,
      save_total_limit=1,  # Keep only 1 checkpoint
      disable_tqdm=False,
      group_by_length=False,  # Disable for speed
  )

  # Create trainer
  trainer = Trainer(
      model=model,
      args=training_args,
      train_dataset=train_dataset,
      eval_dataset=val_dataset,
      data_collator=fast_collate_fn,
      tokenizer=processor.tokenizer,
  )

  return trainer, processor


In [None]:
# Mount Google Drive
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
# =============================================================================
# 5. MAIN EXECUTION
# =============================================================================
if __name__ == "__main__":
    print("🚀 Starting optimized training setup...")

    # Setup
    trainer, processor = setup_training_optimized()

    print("✅ Setup complete! Starting training...")

    # Train
    try:
        trainer.train()
        print("🎉 Training completed!")

        # Save
        trainer.save_model("/content/drive/MyDrive/FilteredData/FINAL_MEDGEMMA")
        processor.save_pretrained("/content/drive/MyDrive/FilteredData/FINAL_MEDGEMMA")

    except Exception as e:
        print(f"Training error: {e}")
        import traceback
        traceback.print_exc()

🚀 Starting optimized training setup...
Loading processor...


processor_config.json:   0%|          | 0.00/70.0 [00:00<?, ?B/s]

chat_template.jinja:   0%|          | 0.00/1.53k [00:00<?, ?B/s]

preprocessor_config.json:   0%|          | 0.00/570 [00:00<?, ?B/s]

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


tokenizer_config.json:   0%|          | 0.00/1.16M [00:00<?, ?B/s]

tokenizer.model:   0%|          | 0.00/4.69M [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/33.4M [00:00<?, ?B/s]

added_tokens.json:   0%|          | 0.00/35.0 [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/662 [00:00<?, ?B/s]

Loading model with optimizations...


config.json:   0%|          | 0.00/1.61k [00:00<?, ?B/s]

model.safetensors.index.json:   0%|          | 0.00/90.6k [00:00<?, ?B/s]

Fetching 2 files:   0%|          | 0/2 [00:00<?, ?it/s]

model-00002-of-00002.safetensors:   0%|          | 0.00/3.64G [00:00<?, ?B/s]

model-00001-of-00002.safetensors:   0%|          | 0.00/4.96G [00:00<?, ?B/s]

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

generation_config.json:   0%|          | 0.00/173 [00:00<?, ?B/s]

Applying LoRA configuration...




✅ Total trainable parameters: 2,220,544
Creating training dataset...
🔄 Pre-processing dataset (this happens once)...
Processed 1/630 items...
Processed 11/630 items...
Processed 21/630 items...
Processed 31/630 items...
Processed 41/630 items...
Processed 51/630 items...
Processed 61/630 items...
Processed 71/630 items...
Processed 81/630 items...
Processed 91/630 items...
Processed 101/630 items...
Processed 111/630 items...
Processed 121/630 items...
Processed 131/630 items...
Processed 141/630 items...
Processed 151/630 items...
Processed 161/630 items...
Processed 171/630 items...
Processed 181/630 items...
Processed 191/630 items...
Processed 201/630 items...
Processed 211/630 items...
Processed 221/630 items...
Processed 231/630 items...
Processed 241/630 items...
Processed 251/630 items...
Processed 261/630 items...
Processed 271/630 items...
Processed 281/630 items...
Processed 291/630 items...
Processed 301/630 items...
Processed 311/630 items...
Processed 321/630 items...
Pro

  trainer = Trainer(
No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


✅ Setup complete! Starting training...




<IPython.core.display.Javascript object>

[34m[1mwandb[0m: Logging into wandb.ai. (Learn how to deploy a W&B server locally: https://wandb.me/wandb-server)
[34m[1mwandb[0m: You can find your API key in your browser here: https://wandb.ai/authorize?ref=models
wandb: Paste an API key from your profile and hit enter:

 ··········


[34m[1mwandb[0m: No netrc file found, creating one.
[34m[1mwandb[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc
[34m[1mwandb[0m: Currently logged in as: [33mmoreenmohsen[0m ([33mmoreenmohsen-fci[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


Step,Training Loss,Validation Loss
50,8.0102,5.742097




🎉 Training completed!




In [None]:
#=============================================================================
# 6. ADDITIONAL SPEEDUP TIPS
# =============================================================================

# TIP 1: Monitor GPU memory
def print_gpu_memory():
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        cached = torch.cuda.memory_reserved() / 1024**3
        print(f"GPU Memory - Allocated: {allocated:.2f}GB, Cached: {cached:.2f}GB")

# TIP 2: Clear cache periodically
def clear_cache():
    torch.cuda.empty_cache()
    import gc
    gc.collect()

# TIP 3: Use smaller batches if OOM
# Reduce per_device_train_batch_size to 2 or 1 if you get OOM errors