In [None]:
from datasets import load_dataset
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
from transformers import TrainingArguments
from transformers import Trainer

ds = load_dataset("jmhessel/newyorker_caption_contest", "explanation")

processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

In [None]:
from datasets import DatasetDict

ds = DatasetDict({
    "train": ds["train"].select(range(1000)),
    "validation": ds["validation"].select(range(60)),
    "test": ds["test"].select(range(50)),
})


# Part $1:$ Data Analysis and Preparation:

In [None]:
ds['train']['image'][0].resize((100, 100), Image.Resampling.LANCZOS)

In [None]:
WIDTH, HEIGHT = 100, 100

def preprocess_function(examples):
    # Load and process images
    images = [image_file.resize((WIDTH, HEIGHT), Image.Resampling.LANCZOS) for image_file in examples['image']]
    # Tokenize captions
    captions = [caption for caption in examples['caption_choices']]
    inputs = processor(images=images, text=captions, return_tensors="pt", padding='max_length', truncation=True)

    # Convert tensors to lists (map expects lists)
    return {
        "input_ids": inputs["input_ids"],
        "attention_mask": inputs["attention_mask"],
        "pixel_values": inputs["pixel_values"],
        "labels": inputs["input_ids"]
    }

# Apply the preprocessing function to the dataset
processed_dataset = ds.map(preprocess_function, batched=True, batch_size=50, remove_columns=ds["train"].column_names)


# Part $2:$ Exploratory Data Analysis (EDA)

# Part $3:$ Model Development

In [None]:

training_args = TrainingArguments(
    output_dir="./blip-caption-generator",
    per_device_train_batch_size=4,
    num_train_epochs=3,
    learning_rate=5e-5,
    logging_dir="./logs",
    logging_steps=10,
    save_steps=50,
    eval_steps=50,
    save_total_limit=2,
)


In [None]:
processed_dataset['train']

In [None]:

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=processed_dataset["train"],
    eval_dataset=processed_dataset["validation"],
)


In [None]:
trainer.train()


In [None]:
def keep_selected_columns(example):
    return {key: example[key] for key in ["input_ids", "attention_mask", "pixel_values"]}

test = processed_dataset.map(keep_selected_columns, remove_columns=processed_dataset["train"].column_names)
test

In [None]:
len(test["train"]["input_ids"]), len(test["train"]["pixel_values"])

In [None]:
inputs["pixel_values"]