In [1]:
!pip install -U transformers datasets peft accelerate bitsandbytes huggingface_hub safetensors

Collecting transformers
  Downloading transformers-4.51.2-py3-none-any.whl.metadata (38 kB)
Collecting datasets
  Downloading datasets-3.5.0-py3-none-any.whl.metadata (19 kB)
Collecting peft
  Downloading peft-0.15.1-py3-none-any.whl.metadata (13 kB)
Collecting accelerate
  Downloading accelerate-1.6.0-py3-none-any.whl.metadata (19 kB)
Collecting bitsandbytes
  Downloading bitsandbytes-0.45.5-py3-none-manylinux_2_24_x86_64.whl.metadata (5.0 kB)
Collecting huggingface_hub
  Downloading huggingface_hub-0.30.2-py3-none-any.whl.metadata (13 kB)
Collecting dill<0.3.9,>=0.3.0 (from datasets)
  Downloading dill-0.3.8-py3-none-any.whl.metadata (10 kB)
Collecting xxhash (from datasets)
  Downloading xxhash-3.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl.metadata (12 kB)
Collecting multiprocess<0.70.17 (from datasets)
  Downloading multiprocess-0.70.16-py311-none-any.whl.metadata (7.2 kB)
Collecting fsspec<=2024.12.0,>=2023.1.0 (from fsspec[http]<=2024.12.0,>=2023.1.0->datasets)

## Load the dataset I created from hugging face

In [38]:
from datasets import load_dataset

ds = load_dataset("SiyunHE/medical-pilagemma-lora", split="train")

print(ds[0])

ds = ds.train_test_split(test_size=0.1)
train_ds = ds["train"]
val_ds = ds["test"]
print(train_ds[0])
print(val_ds[0])

{'image': <PIL.Image.Image image mode=RGB size=4032x2412 at 0x7F9830162790>, 'question': "I've noticed my foot swelling up quite a bit, should I be concerned about this?", 'answer': "Swelling in the foot can be caused by a variety of conditions such as an injury, infection, arthritis, or even poor circulation. It's important to check for other symptoms like redness, warmth, or difficulty walking, and I would recommend elevating your foot and applying a cold compress to help with the swelling while scheduling an appointment to determine the underlying cause."}
{'image': <PIL.JpegImagePlugin.JpegImageFile image mode=RGB size=1500x1101 at 0x7F931392B690>, 'question': "I've noticed some inflammation in my eye and I'm worried it might be something serious. What could be causing this, and should I get it checked out soon?", 'answer': "It sounds like you might be experiencing conjunctivitis, which is the inflammation of the lining of your eye and can be caused by infections, allergies, or irr

## Load PaliGemma from hugging face

In [39]:
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForVision2Seq
import os

# Hugging Face Login
login("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")

from transformers import PaliGemmaProcessor
model_id = "google/paligemma-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(model_id)


## data embedding

In [40]:
import torch
device = "cuda"

image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
def collate_fn(examples):
  texts = ["answer " + example["question"] for example in examples]
  labels= [example['answer'] for example in examples]
  images = [example["image"].convert("RGB") for example in examples]
  tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest")

  tokens = tokens.to(torch.bfloat16).to(device)
  return tokens

## load the model directly.
Freeze the image encoder and the projector, and only fine-tune the decoder.

In [41]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, torch_dtype=torch.bfloat16).to(device)

for param in model.vision_tower.parameters():
    param.requires_grad = False

for param in model.multi_modal_projector.parameters():
    param.requires_grad = True


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

## Set up Lora

In [42]:
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.bfloat16
)

lora_config = LoraConfig(
    r=8,
    target_modules=["q_proj", "o_proj", "k_proj", "v_proj", "gate_proj", "up_proj", "down_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map={"":0})
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()
#trainable params: 11,298,816 || all params: 2,934,634,224 || trainable%: 0.38501616002417344


Loading checkpoint shards:   0%|          | 0/3 [00:00<?, ?it/s]

trainable params: 11,298,816 || all params: 2,934,765,296 || trainable%: 0.3850


In [None]:
# Hugging Face Login for my model space
login("SIYUN-UPLOAD-DATA-HF-TOKEN")

## Train the model

In [59]:
from transformers import TrainingArguments, Trainer
args=TrainingArguments(
            num_train_epochs=10,
            remove_unused_columns=False,
            per_device_train_batch_size=8,
            gradient_accumulation_steps=4,
            warmup_steps=2,
            learning_rate=2e-5,
            weight_decay=1e-6,
            adam_beta2=0.999,
            logging_steps=100,
            optim="adamw_torch",
            save_strategy="steps",
            save_steps=500,
            push_to_hub=True,
            hub_model_id="SiyunHE/medical-pilagemma-lora",
            save_total_limit=1,
            bf16=True,
            report_to=["tensorboard"],
            dataloader_pin_memory=False
        )


trainer = Trainer(
        model=model,
        train_dataset=train_ds,
        eval_dataset=val_ds,
        data_collator=collate_fn,
        args=args
        )


No label_names provided for model class `PeftModelForCausalLM`. Since `PeftModel` hides base models input arguments, if label_names is not given, label_names can't be set automatically within `Trainer`. Note that empty label_names list will be used instead.


## Start training

In [49]:
trainer.train()
metrics = trainer.evaluate()
print(metrics)

You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
Y

Step,Training Loss


You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
You are passing both `text` and `images` to `PaliGemmaProcessor`. The processor expects special image tokens in the text, as many tokens as there are images per each text. It is recommended to add `<image>` tokens in the very beginning of your text. For this call, we will infer how many images each text has and add special tokens.
Y

{'eval_loss': 2.654895067214966, 'eval_runtime': 0.6221, 'eval_samples_per_second': 12.859, 'eval_steps_per_second': 1.607, 'epoch': 6.888888888888889}


In [57]:
for log in trainer.state.log_history:
      print(log)

{'train_runtime': 53.4695, 'train_samples_per_second': 12.905, 'train_steps_per_second': 0.374, 'total_flos': 2553371829120096.0, 'train_loss': 8.127518463134766, 'epoch': 6.888888888888889, 'step': 20}
{'eval_loss': 2.654895067214966, 'eval_runtime': 0.6221, 'eval_samples_per_second': 12.859, 'eval_steps_per_second': 1.607, 'epoch': 6.888888888888889, 'step': 20}


reference: https://huggingface.co/blog/paligemma