In [2]:
# Install necessary libraries
!pip install -q datasets
!pip install -q git+https://github.com/huggingface/transformers
!pip install -q bitsandbytes sentencepiece accelerate loralib
!pip install -q -U git+https://github.com/huggingface/peft.git



  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
# Import required libraries
import torch
from datasets import load_dataset
from transformers import IdeficsForVisionText2Text, AutoProcessor, Trainer, TrainingArguments, BitsAndBytesConfig
from peft import LoraConfig, get_peft_model
from PIL import Image
import torchvision.transforms as transforms



In [4]:
# Setup device
device = "cuda" if torch.cuda.is_available() else "cpu"

# Load model and processor
model_name = "HuggingFaceM4/idefics-9b"



In [None]:
quant_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_use_double_quant=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.float16,
    llm_int8_skip_modules=["lm_head", "embed_tokens"]
)

processor = AutoProcessor.from_pretrained(model_name)
model = IdeficsForVisionText2Text.from_pretrained(model_name, quantization_config=quant_config, device_map="auto")



In [None]:
# Function for inference
def generate_caption(model, processor, prompt, max_tokens=50):
    tokenizer = processor.tokenizer
    eos_id = tokenizer.convert_tokens_to_ids("</s>")
    banned_tokens = tokenizer(["<image>", "<fake_token_around_image>"], add_special_tokens=False).input_ids

    inputs = processor(prompt, return_tensors="pt").to(device)
    output_ids = model.generate(
        **inputs,
        eos_token_id=[eos_id],
        bad_words_ids=banned_tokens,
        max_new_tokens=max_tokens,
        early_stopping=True
    )
    output_text = processor.batch_decode(output_ids, skip_special_tokens=True)[0]
    print(output_text)



In [None]:
# Preprocessing helper for images
def ensure_rgb(image):
    if image.mode == "RGB":
        return image
    rgba = image.convert("RGBA")
    background = Image.new("RGBA", rgba.size, (255, 255, 255))
    return Image.alpha_composite(background, rgba).convert("RGB")

# Dataset transformation
def preprocess_batch(batch):
    size = processor.image_processor.image_size
    mean = processor.image_processor.image_mean
    std = processor.image_processor.image_std

    transform = transforms.Compose([
        ensure_rgb,
        transforms.RandomResizedCrop((size, size), scale=(0.9, 1.0), interpolation=transforms.InterpolationMode.BICUBIC),
        transforms.ToTensor(),
        transforms.Normalize(mean=mean, std=std),
    ])

    prompts = []
    for img_url, caption in zip(batch["image_url"], batch["caption"]):
        clean_caption = caption.split(".")[0]
        prompts.append([img_url, f"Describe this image: {clean_caption}</s>"])

    inputs = processor(prompts, transform=transform, return_tensors="pt").to(device)
    inputs["labels"] = inputs["input_ids"]
    return inputs

# Load and prepare the dataset
dataset = load_dataset("nlphuji/flickr30k", split="train[:1%]").train_test_split(test_size=0.1)
train_set, val_set = dataset["train"], dataset["test"]

train_set.set_transform(preprocess_batch)
val_set.set_transform(preprocess_batch)



In [None]:
# Apply LoRA configuration
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    target_modules=["q_proj", "k_proj", "v_proj"],
    bias="none"
)


In [None]:
model = get_peft_model(model, lora_config)
model.print_trainable_parameters()

# Define training arguments
training_args = TrainingArguments(
    output_dir="idefics-flickr30k-lora",
    learning_rate=2e-4,
    fp16=True,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    gradient_accumulation_steps=8,
    max_steps=25,
    evaluation_strategy="steps",
    save_strategy="steps",
    eval_steps=10,
    save_steps=25,
    logging_steps=5,
    remove_unused_columns=False,
    save_total_limit=3,
    push_to_hub=False,
    report_to="none",
    label_names=["labels"],
    optim="paged_adamw_8bit",
)



In [None]:
# Create Trainer and start fine-tuning
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_set,
    eval_dataset=val_set
)

trainer.train()