In [None]:
!pip install -q transformers peft datasets evaluate bitsandbytes torchvision


In [None]:
from transformers import Blip2Processor, Blip2ForConditionalGeneration, TrainingArguments, Trainer
from peft import get_peft_model, LoraConfig, TaskType
from datasets import Dataset
from PIL import Image
import pandas as pd
import evaluate
import torch
import os
from IPython.display import display
import matplotlib.pyplot as plt

In [None]:
CAPTION_FILE = "/mnt/data/100_ex.csv"
IMAGE_DIR = "data/dataset/images" 

In [None]:
df = pd.read_csv(CAPTION_FILE)
df["image"] = df["image"].apply(lambda x: os.path.join(IMAGE_DIR, x))
dataset = Dataset.from_pandas(df[["image", "caption"]])
split = dataset.train_test_split(test_size=0.1)
train_dataset, eval_dataset = split["train"], split["test"]

In [None]:
processor = Blip2Processor.from_pretrained("Salesforce/blip2-flan-t5-xl")

def preprocess(example):
    image = Image.open(example["image"]).convert("RGB")
    caption = example["caption"]
    inputs = processor(images=image, text="Describe this image", return_tensors="pt", padding="max_length", truncation=True, max_length=128)
    inputs["labels"] = processor.tokenizer(caption, return_tensors="pt", padding="max_length", truncation=True, max_length=128).input_ids
    return {k: v.squeeze(0) for k, v in inputs.items()}

train_dataset = train_dataset.map(preprocess)
eval_dataset = eval_dataset.map(preprocess)

In [None]:
model_id = "Salesforce/blip2-flan-t5-xl"
model = Blip2ForConditionalGeneration.from_pretrained(model_id, device_map="auto", load_in_8bit=True)
peft_config = LoraConfig(r=8, lora_alpha=16, target_modules=["q", "v"], task_type=TaskType.SEQ_2_SEQ_LM)
model = get_peft_model(model, peft_config)

# 📊 Метрики
bleu = evaluate.load("bleu")
rouge = evaluate.load("rouge")

def compute_metrics(pred):
    predictions, labels = pred
    decoded_preds = processor.tokenizer.batch_decode(predictions, skip_special_tokens=True)
    decoded_labels = processor.tokenizer.batch_decode(labels, skip_special_tokens=True)
    bleu_result = bleu.compute(predictions=decoded_preds, references=[[l] for l in decoded_labels])
    rouge_result = rouge.compute(predictions=decoded_preds, references=decoded_labels)
    return {"bleu": bleu_result["bleu"], "rougeL": rouge_result["rougeL"]}

In [None]:
print("🔍 Метрики до обучения:")
pre_model = Blip2ForConditionalGeneration.from_pretrained(model_id, device_map="auto", load_in_8bit=True)

def get_metrics(model, dataset, processor):
    inputs = [processor(images=Image.open(example["image"]).convert("RGB"), text="Describe this image", return_tensors="pt").to(model.device)
              for example in dataset]
    outputs = [model.generate(**inp, max_new_tokens=50)[0] for inp in inputs]
    preds = processor.tokenizer.batch_decode(outputs, skip_special_tokens=True)
    targets = [example["caption"] for example in dataset]
    bleu_result = bleu.compute(predictions=preds, references=[[t] for t in targets])
    rouge_result = rouge.compute(predictions=preds, references=targets)
    return {"bleu": bleu_result["bleu"], "rougeL": rouge_result["rougeL"]}

before_metrics = get_metrics(pre_model, eval_dataset, processor)
print("До обучения:", before_metrics)

In [None]:
training_args = TrainingArguments(
    per_device_train_batch_size=1,
    num_train_epochs=1,
    logging_steps=10,
    evaluation_strategy="no",
    save_strategy="no",
    output_dir="./debug-out",
    fp16=True,
)

In [None]:
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
)
trainer.train()

In [None]:
print("Метрики после обучения:")
after_metrics = get_metrics(model, eval_dataset, processor)
print("После обучения:", after_metrics)