In [1]:
!pip install -q -U git+https://github.com/huggingface/transformers.git
!pip install -q -U datasets accelerate
!pip install -q -U accelerate
!pip install -q -U peft
!pip install -q -U bitsandbytes

# load Dataset

I already downloaded DocVQA (60+ GB) , processed data ( made it PaliGemma input ready ; removed unnecessary columns )

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

In [2]:
# Entry your HF_TOKEN (write permission token only) to authenticate

from huggingface_hub import notebook_login
notebook_login()

In [3]:
from datasets import load_dataset

train_ds = load_dataset("abhishekvidhate/DocQVA_small")

In [4]:
train_ds

In [5]:
train_small = train_ds['train_small']

In [6]:
train_small

In [7]:
train_small.to_pandas()

# Create Model inputd; process dataset to make it Pali Gemma ready i.e build DataCollator

In [8]:
from transformers import PaliGemmaProcessor

model_id = "google/paligemma-3b-pt-224"
processor = PaliGemmaProcessor.from_pretrained(model_id)

In [9]:
import torch
device = "cuda"

image_token = processor.tokenizer.convert_tokens_to_ids("<image>")
def collate_fn(examples):
  texts = ["answer " + example["question"] for example in examples]
  labels = ["; ".join(example['answers']) for example in examples]
  images = [example["image"].convert("RGB") for example in examples]
  tokens = processor(text=texts, images=images, suffix=labels,
                    return_tensors="pt", padding="longest",
                    tokenize_newline_separately=False)

  tokens = tokens.to(device)
  return tokens

# Load Model

I'm using loading model in 4bit configuration , so i can easily do LoRA and QLoRA

In [10]:
from transformers import BitsAndBytesConfig
from peft import get_peft_model, LoraConfig
from transformers import PaliGemmaForConditionalGeneration

bnb_config = BitsAndBytesConfig(
        load_in_4bit=True,
        bnb_4bit_quant_type="nf4",
        bnb_4bit_compute_type=torch.float16
)

lora_config = LoraConfig(
    r=4,
    target_modules=["self_attn.out_proj", "fc1", "fc2"],  # Targeting linear layers
#     target_modules=["q_proj", "k_proj"]
#     target_modules=["q_proj", "v_proj", "k_proj"],
    task_type="CAUSAL_LM",
)
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id, quantization_config=bnb_config, device_map="auto")
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Training Arguments for fine tuning

In [18]:
from transformers import TrainingArguments

args=TrainingArguments(
            num_train_epochs=2,
            remove_unused_columns=False,
            per_device_train_batch_size=2,
            gradient_accumulation_steps=4,
            warmup_steps=2,
            learning_rate=2e-5,
            weight_decay=1e-6,
            adam_beta2=0.999,
            logging_steps=100,
            optim="adamw_hf",
            save_strategy="steps",
            save_steps=500,
            push_to_hub=True,
            save_total_limit=1,
            output_dir="Abhishek-PaliGemma-FT",
            dataloader_pin_memory=False,
            report_to=[]  # Disable Wandb integration, no tracing to wandb
        )

In [19]:
from transformers import Trainer

trainer = Trainer(
        model=model,
        train_dataset=train_small ,
        data_collator=collate_fn,
        args=args
        )

In [20]:
trainer.train()

# Pushing Model to Huggingface for future inference

In [21]:
trainer.push_to_hub()

# Simple Inference of my FinedTuned Pali Gemma

will build Streamlit App deployed on cloud, so user won't have to run python code again and again, also with secret privacy safe

In [26]:
from transformers import AutoProcessor, PaliGemmaForConditionalGeneration

model_id = "abhishekvidhate/Abhishek-PaliGemma-FT"
model = PaliGemmaForConditionalGeneration.from_pretrained(model_id)
processor = AutoProcessor.from_pretrained("google/paligemma-3b-pt-224")

In [33]:
from PIL import Image
import requests


prompt = "What is behind the cat?"
image_file = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cat.png?download=true"
raw_image = Image.open(requests.get(image_file, stream=True).raw)

In [37]:
inputs = processor(prompt, raw_image.convert("RGB"), return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)

print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])


In [None]:
!pip install matplotlib 

In [44]:
from PIL import Image
import requests
from io import BytesIO
import matplotlib.pyplot as plt
import os

def load_image(image_path_or_url):
    if os.path.exists(image_path_or_url):
        # Case 1: Local Image File
        return Image.open(image_path_or_url)
    elif image_path_or_url.startswith('http'):
        # Case 2: Image URL
        response = requests.get(image_path_or_url)
        image_data = response.content
        return Image.open(BytesIO(image_data))
    else:
        raise ValueError("Unsupported image input. Please provide a valid local file path or image URL.")

# Example usage:
# image_path_or_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cat.png?download=true"
# or
image_path_or_url = "/kaggle/input/testing-photos/yuta.jpg"  # Replace with your local image path

raw_image = load_image(image_path_or_url)

# Now `raw_image` contains the loaded image in PIL format
# You can use `raw_image` as input to your model or processing pipeline

# Example:
prompt = "is this a boy or a girl?"

# Load and display the image
raw_image = load_image(image_path_or_url)
plt.imshow(raw_image)
plt.axis('off')  # Turn off axis labels
plt.show()

inputs = processor(prompt, raw_image.convert("RGB"), return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)
print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])


In [46]:
# Example usage:
image_path_or_url = "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/cat.png?download=true"
# or
# image_path_or_url = "/kaggle/input/testing-photos/yuta.jpg"  # Replace with your local image path

raw_image = load_image(image_path_or_url)

# Now `raw_image` contains the loaded image in PIL format
# You can use `raw_image` as input to your model or processing pipeline

# Example:
prompt = "what is cat doing?"

# Load and display the image
raw_image = load_image(image_path_or_url)
plt.imshow(raw_image)
plt.axis('off')  # Turn off axis labels
plt.show()

inputs = processor(prompt, raw_image.convert("RGB"), return_tensors="pt")
output = model.generate(**inputs, max_new_tokens=20)
print(processor.decode(output[0], skip_special_tokens=True)[len(prompt):])