In [198]:
import os
import pandas as pd
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration, Trainer, TrainingArguments
from torch.utils.data import Dataset
from PIL import Image
from tqdm.auto import tqdm

In [199]:
# Ruta al archivo CSV que contiene las rutas de las imágenes y las descripciones
csv_path = "../train/data/output/divide_images_train/finetuning_modificado.csv"

# Cargar el procesador y el modelo preentrenado
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large")


In [200]:

class CustomDataset:
    def __init__(self, data, processor):
        self.data = data
        self.processor = processor

    def collate_fn(self, batch):
        inputs = []
        labels = []
        for item in batch:
            image_path, description = item
            image = Image.open(image_path).convert("RGB")

            encoding = self.processor(images=image, return_tensors="pt")

            # Imprime las claves del diccionario encoding
            print(encoding.keys())

            if 'input_ids' in encoding:
                inputs.append({
                    'input_ids': encoding['input_ids'].squeeze(),
                    'attention_mask': encoding['attention_mask'].squeeze(),
                    'labels': encoding['input_ids'].squeeze()
                })
            else:
                print(f"Warning: 'input_ids' not found in encoding for image: {image_path}")

            labels.append(description)

        return inputs, labels

# Ejemplo de cómo usar CustomDataset
data = [
    ("ruta/a/imagen1.jpg", "Descripción de la imagen 1"),
    ("ruta/a/imagen2.jpg", "Descripción de la imagen 2"),
    # Agrega más datos según sea necesario
]

processor = BlipProcessor.from_pretrained("modelo_preentrenado")

dataset = CustomDataset(data, processor)


OSError: modelo_preentrenado is not a local folder and is not a valid model identifier listed on 'https://huggingface.co/models'
If this is a private repository, make sure to pass a token having permission to this repo either by logging in with `huggingface-cli login` or by passing `token=<your_token>`

In [None]:

# Configurar argumentos para el entrenamiento
training_args = TrainingArguments(
    output_dir="./results",
    overwrite_output_dir=True,
    per_device_train_batch_size=4,
    save_steps=1000,
    logging_steps=100,
    eval_steps=500,
    save_total_limit=2,
    num_train_epochs=3,
    prediction_loss_only=True,
)

# Crear el Trainer para el finetuning
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    data_collator=dataset.collate_fn  # Utilizando el collate_fn personalizado
)

# Iniciar el entrenamiento del modelo
trainer.train()