In [1]:
import io
import os
import torch
from PIL import Image
from torchvision import transforms
from torch.utils.data import Dataset
from transformers import PaliGemmaProcessor
from datasets import load_dataset, Dataset

# Path to JSON file
train_json_path = "Train data comprised 60% of the dataset.json"

dataset = load_dataset("json", data_files={"train": train_json_path})

def transform_format(example):
    return {
        "id": example["id"],
        "image": example["image"],
        "Question": example["conversations"][0]["value"],
        "Answer": example["conversations"][1]["value"],
    }

dataset = dataset.map(transform_format)

def load_images_as_pil(dataset):
    updated_entries = []
    for example in dataset:
        image_path = f"./Demoface/{example['image']}"
        try:
            pil_image = Image.open(image_path).convert("RGB")
        except Exception as e:
            print(f"Error loading image {image_path}: {e}")
            pil_image = None

        updated_entries.append({
            "image": pil_image,
            "Question": example["Question"],
            "Answer": example["Answer"],
        })
    return updated_entries

dataset = load_images_as_pil(dataset["train"])

def pil_image_to_bytes(image):
    buf = io.BytesIO()
    image.save(buf, format='JPEG')
    return buf.getvalue()

def bytes_to_pil_image(byte_data):
    return Image.open(io.BytesIO(byte_data))

data_dict = {
    "image": [pil_image_to_bytes(item["image"]) for item in dataset],
    "Question": [item["Question"] for item in dataset],
    "Answer": [item["Answer"] for item in dataset]
}

dataset = Dataset.from_dict(data_dict)

print(dataset)

  from .autonotebook import tqdm as notebook_tqdm


Dataset({
    features: ['image', 'Question', 'Answer'],
    num_rows: 18144
})


In [2]:
import io
import torch
from PIL import Image
from transformers import (
    PaliGemmaForConditionalGeneration, AutoProcessor,
    TrainingArguments, Trainer, TrainerCallback, TrainerState, TrainerControl,
    BitsAndBytesConfig
)
from peft import get_peft_model, LoraConfig
from datasets import load_dataset
from tqdm.auto import tqdm

def bytes_to_pil_image(byte_data):
    return Image.open(io.BytesIO(byte_data))

def collate_fn(examples):
    texts = [f"<image> <bos> answer {example['Question']} <eos>" for example in examples]
    labels = [f"{example['Answer']} <eos>" for example in examples]
    images = [bytes_to_pil_image(example["image"]).convert("RGB") for example in examples]

    tokens = processor(
        text=texts,
        images=images,
        suffix=labels,
        return_tensors="pt",
        padding="longest",
        truncation=True,
    )
    
    return tokens.to(torch.bfloat16)


In [None]:
    
# Load model with quantization
bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_compute_dtype=torch.float16,
    bnb_4bit_use_double_quant=True,
)
checkpoint_dir = "./paligemma-3b-checkpoint"

checkpoint_path =checkpoint_dir 
model = PaliGemmaForConditionalGeneration.from_pretrained(
    checkpoint_path,
    quantization_config=bnb_config,
    device_map="auto",
)
processor = PaliGemmaProcessor.from_pretrained("google/paligemma-3b-pt-224")
# LoRA configuration
lora_config = LoraConfig(
    r=8, lora_alpha=32, target_modules=["q_proj", "v_proj"],
    lora_dropout=0.1, bias="none", task_type="CAUSAL_LM"
)
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Training arguments
args = TrainingArguments(
    num_train_epochs=75,
    remove_unused_columns=False,
    per_device_train_batch_size=8,
    gradient_accumulation_steps=2,
    warmup_steps=2,
    learning_rate=1e-4,
    weight_decay=1e-6,
    adam_beta2=0.999,
    logging_steps=150,
    optim="adamw_hf",
    save_strategy="steps",
    save_steps=1000,
    save_total_limit=1,
    output_dir="paligemma_lora_finetuned_45_rawpng",
    fp16=True,
    dataloader_pin_memory=True,
    push_to_hub=False,  # Set to True if pushing to Hugging Face Hub
    lr_scheduler_type='cosine',  # Add cosine scheduler
    warmup_ratio=0.3,  # Adjust warmup ratio
)


import csv
from tqdm import tqdm
from transformers import TrainerCallback

class ProgressBarCallback(TrainerCallback):
    def __init__(self, csv_filename="Paligemma_loss.csv"):
        self.progress_bar = None
        self.csv_filename = csv_filename

        # Create CSV file and write header
        with open(self.csv_filename, mode='w', newline='') as file:
            writer = csv.writer(file)
            writer.writerow(["step", "loss"])  # Column names

    def on_train_begin(self, args, state, control, **kwargs):
        if self.progress_bar is None:
            self.progress_bar = tqdm(total=state.max_steps, desc="Training", dynamic_ncols=True)

    def on_step_end(self, args, state, control, **kwargs):
       # print(len(state.log_history), state.global_step , (args.logging_steps+1))
        # Check if step is a logging step (e.g., step 10, 20, 30 if logging_steps=10)
        if len(state.log_history) > 0 and state.global_step % (args.logging_steps+1) == 0:
            last_log = state.log_history[-1]  # Get latest log entry
            step = last_log.get("step", state.global_step)  # Get step number
            loss = last_log.get("loss", None)  # Get loss if available

            if loss is not None:
                with open(self.csv_filename, mode='a', newline='') as file:
                    writer = csv.writer(file)
                    writer.writerow([step, loss])  # Write only at logging steps

 

        if self.progress_bar is not None:
            self.progress_bar.update(1)
            self.progress_bar.refresh()  # Force update to prevent misalignment

    def on_train_end(self, args, state, control, **kwargs):
        if self.progress_bar is not None:
            self.progress_bar.close()


# Trainer with callback
trainer = Trainer(
    model=model,
    train_dataset=dataset,
    data_collator=collate_fn,
    args=args,
    callbacks=[ProgressBarCallback()],
)

# Train
trainer.train()


In [None]:
import os

save_path = "./Paligemma_fine_tuned_75"  # Define your save path

# Ensure the directory exists
os.makedirs(save_path, exist_ok=True)

# Save model and processor
trainer.save_model(save_path)
processor.save_pretrained(save_path)

print(f"✅ Model saved to {save_path}")