# Finetuning BLIP LLM for image captioning on spatial relations in a tennis match

import packages

In [2]:
from datasets import load_dataset
from torch.utils.data import Dataset, DataLoader
from PIL import Image, ImageOps
import torch
from transformers import AutoProcessor, BlipForConditionalGeneration
from huggingface_hub import upload_folder, HfApi, login
import io
import matplotlib.pyplot as plt


Load the image captioning dataset from hugging face.
Also train and test split for the data.

In [4]:
dataset = load_dataset("DinoDave/SpatialRelationsTennis_masked")

train_test_split = dataset['train'].train_test_split(test_size=0.1)

# Separate train and test sets
train_dataset_raw = train_test_split['train']
test_dataset_raw = train_test_split['test']

print("Number of training examples:", len(train_dataset_raw))
print("Number of testing examples:", len(test_dataset_raw))

False

Sanity check for the downloaded data

In [None]:
train_dataset_raw[100]["text"]

In [None]:
train_dataset_raw[100]["image"]

Class for the data set

In [None]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor, resize_to=(640, 640)):
        self.dataset = dataset
        self.processor = processor
        self.resize_to = resize_to

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"]

        # Resize the image
        if self.resize_to:
            image = image.resize(self.resize_to, Image.LANCZOS)

        encoding = self.processor(images=image, text=item["text"], padding="max_length", return_tensors="pt")
        # Remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        return encoding

Load the model and the processor

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"

processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-large")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-large").to(device)


Check for GPU and empty cache

In [None]:
device

In [None]:
torch.cuda.is_available()

In [None]:
torch.cuda.empty_cache()

### Train the model

function for evaluating and creating the test-measures

In [None]:
def evaluate(model, dataloader, device):
    model.eval()
    total_loss = 0
    num_batches = 0

    with torch.no_grad():
        for batch in dataloader:
            input_ids = batch.pop("input_ids").to(device)
            pixel_values = batch.pop("pixel_values").to(device)
            attention_mask = batch.pop("attention_mask").to(device)

            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            labels=input_ids,
                            attention_mask=attention_mask
                            )

            loss = outputs.loss
            total_loss += loss.item()
            num_batches += 1

    average_loss = total_loss / num_batches
    return average_loss

actual training of the model

In [None]:
accumulation_steps = 2  # number of steps to accumulate gradients

learning_rate = 3e-5
number_epochs = 20

optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

training_loss = []
test_loss = []


for epoch in range(number_epochs):
  print("Epoch:", epoch)
  model.train()
  for idx, batch in enumerate(train_dataloader):
    
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)
    attention_mask = batch.pop("attention_mask").to(device)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids,
                    attention_mask=attention_mask
                    )

    loss = outputs.loss

    loss_cpu = loss.cpu()

    training_loss.append(loss_cpu.detach().numpy())
    print("Loss:", loss.item())

    loss.backward()

    if (idx + 1) % accumulation_steps == 0:
      optimizer.step()
      optimizer.zero_grad()

  # Evaluate on test dataset after each epoch
  test_loss_item = evaluate(model, test_dataloader, device)
  print(f"Test Loss after epoch {epoch}: {test_loss_item}")
  test_loss.append(test_loss_item)

model.eval()

print("Finetuning process done!")

Show graphs for training and testing loss

In [None]:
#training_loss_cpu = [loss.cpu().item() for loss in training_loss]

plt.figure(figsize=(10, 6))
plt.plot(training_loss, marker='o', linestyle='-', color='b', label='Training Loss')
plt.xlabel('Data points')
plt.ylabel('Loss')
plt.title('Training Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()

In [None]:
plt.figure(figsize=(10, 6))
plt.plot(test_loss, marker='o', linestyle='-', color='g', label='Test Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.title('Test Loss over Epochs')
plt.legend()
plt.grid(True)
plt.show()

Show testing images with generated caption

In [None]:
# Create a figure with subplots (one column, multiple rows)
fig, axes = plt.subplots(len(test_dataset), 1, figsize=(5, 5 * len(test_dataset)))

for ax, id in zip(axes, range(len(test_dataset))):
    image_raw = test_dataset_raw[id]["image"]
    #
    image = image_raw.resize((640, 640), Image.Resampling.LANCZOS)
    inputs = processor(image, return_tensors="pt").to(device)

    out = model.generate(**inputs, max_length=50)

    ax.imshow(image)
    ax.set_title(processor.decode(out[0], skip_special_tokens=True))  # Display the filename as the title
    ax.axis('off')  # Hide the axes

plt.tight_layout()
plt.show()

Uploading the model to hugging face (as backup)

In [None]:
# Authenticate with Hugging Face Hub
login(token="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")

model.save_pretrained("./fine_tuned_model")
processor.save_pretrained("./fine_tuned_model")

# Define repository name and organization (if applicable)
repo_name = "BLIP_finetuned_spatial_relations"
organization = "DinoDave"  # Set to None if not uploading to an organization

# Upload the folder to Hugging Face Hub
upload_folder(
    repo_id=f"{organization}/{repo_name}" if organization else repo_name,
    folder_path="./fine_tuned_model",
    commit_message="Initial commit of fine-tuned model",
    ignore_patterns=["*.pyc", "__pycache__/*"],
    create_pr=False  # Set to True if you want to create a pull request instead of committing directly
)