In [None]:
!pip install -q transformers peft datasets evaluate bitsandbytes torchvision

In [None]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration, TrainingArguments, Trainer, EarlyStoppingCallback
from peft import get_peft_model, LoraConfig, TaskType
from datasets import load_dataset
import evaluate
import torch

In [None]:
model_id = "Salesforce/blip2-flan-t5-xl"
processor = Blip2Processor.from_pretrained(model_id)
model = Blip2ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    load_in_8bit=True
)

In [None]:
peft_config = LoraConfig(
    r=16,
    lora_alpha=32,
    target_modules=["q", "v"],
    task_type=TaskType.SEQ_2_SEQ_LM
)
model = get_peft_model(model, peft_config)

In [None]:
dataset = load_dataset("nlphuji/flickr30k")
split = dataset["train"].train_test_split(test_size=0.1)
train_dataset = split["train"]
eval_dataset = split["test"]

In [None]:
def preprocess(example):
    image = example["image"]
    caption = example["sentence"]
    inputs = processor(images=image, text="Describe this image", return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    inputs["labels"] = processor.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=128).input_ids
    return {k: v.squeeze(0) for k, v in inputs.items()}

train_dataset = train_dataset.map(preprocess)
eval_dataset = eval_dataset.map(preprocess)

In [None]:
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")

def compute_metrics(pred):
    predictions, labels = pred
    decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
    bleu_result = bleu.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": bleu_result["bleu"], "rougeL": rouge_result["rougeL"]}


In [None]:
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    gradient_accumulation_steps=4,
    num_train_epochs=10,
    eval_steps=200,
    evaluation_strategy="steps",
    save_strategy="steps",
    logging_steps=50,
    save_steps=200,
    load_best_model_at_end=True,
    metric_for_best_model="bleu",
    greater_is_better=True,
    output_dir="./blip2-xl-finetuned",
    report_to="tensorboard",
    save_total_limit=2,
    fp16=True,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    compute_metrics=compute_metrics,
    callbacks=[EarlyStoppingCallback(early_stopping_patience=3)],
)

trainer.train()

In [None]:
from IPython.display import display
import matplotlib.pyplot as plt

def show_examples(model, processor, dataset, num_images=3):
    model.eval()
    for i in range(num_images):
        example = dataset[i]
        image = example["image"]
        gt_caption = example["sentence"]

        inputs = processor(images=image, text="Describe this image", return_tensors="pt").to(model.device)
        output = model.generate(**inputs, max_new_tokens=50)
        caption = processor.tokenizer.decode(output[0], skip_special_tokens=True)

        display(image)
        print(f"💬 Ground Truth: {gt_caption}")
        print(f"🤖 Model Output: {caption}")
        print("-" * 60)


In [None]:
print("До обучения:")
pretrained_model = Blip2ForConditionalGeneration.from_pretrained(
    model_id,
    device_map="auto",
    load_in_8bit=True
)
show_examples(pretrained_model, processor, eval_dataset)

In [None]:
print("После обучения:")
show_examples(model, processor, eval_dataset)