In [3]:
# Fine-tune a model with the COCO 2017 dataset
# This script prepares the dataset, configures the model, and performs fine-tuning.

from transformers import VisionEncoderDecoderModel, Trainer, TrainingArguments
from datasets import load_dataset
from transformers import ViTImageProcessor, GPT2TokenizerFast
import torch

# Load the COCO 2017 dataset (captions subset)
# COCO 2017 can be loaded directly with the `datasets` library
dataset = load_dataset("coco_captions", "2017")

# Preprocess the dataset for the model
def preprocess_data(example):
    # Load the image processor and tokenizer
    processor = ViTImageProcessor.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    tokenizer = GPT2TokenizerFast.from_pretrained("nlpconnect/vit-gpt2-image-captioning")
    
    # Process image
    inputs = processor(example["image"], return_tensors="pt")
    # Tokenize captions
    targets = tokenizer(example["caption"], padding="max_length", truncation=True, max_length=128, return_tensors="pt")
    
    # Return processed image and caption
    return {"pixel_values": inputs["pixel_values"].squeeze(0), "labels": targets["input_ids"].squeeze(0)}

# Apply preprocessing to training and validation datasets
train_dataset = dataset["train"].map(preprocess_data, batched=True)
val_dataset = dataset["validation"].map(preprocess_data, batched=True)

# Set the format of the dataset for PyTorch
train_dataset.set_format(type="torch", columns=["pixel_values", "labels"])
val_dataset.set_format(type="torch", columns=["pixel_values", "labels"])

# Specify the pre-trained model to fine-tune
model_name = "nlpconnect/vit-gpt2-image-captioning"
model = VisionEncoderDecoderModel.from_pretrained(model_name)

# Set training parameters
training_args = TrainingArguments(
    output_dir="./results",  # Path to save model checkpoints
    evaluation_strategy="epoch",  # Evaluate at the end of each epoch
    learning_rate=5e-5,  # Fine-tuning learning rate
    per_device_train_batch_size=16,  # Batch size during training
    per_device_eval_batch_size=8,  # Batch size during evaluation
    num_train_epochs=5,  # Total number of epochs
    weight_decay=0.01,  # Regularization term to avoid overfitting
    save_total_limit=2,  # Limit the number of saved checkpoints
    logging_dir="./logs",  # Directory for logging output
    report_to="none",  # Disable reporting to online services
)

# Define a trainer
trainer = Trainer(
    model=model,  # Pass the model to the trainer
    args=training_args,  # Provide training arguments
    train_dataset=train_dataset,  # Training data
    eval_dataset=val_dataset,  # Validation data
)

# Begin fine-tuning
trainer.train()

# Save the final fine-tuned model
model.save_pretrained("./fine_tuned_model")
print("a")