# Fine-tune BLIP using Hugging Face `transformers` and `datasets` 🤗

This tutorial is largely based from the [GiT tutorial](https://colab.research.google.com/drive/1HLxgrG7xZJ9FvXckNG61J72FkyrbqKAA?usp=sharing) on how to fine-tune GiT on a custom image captioning dataset. Here we will use a dummy dataset of [football players](https://huggingface.co/datasets/ybelkada/football-dataset) ⚽ that is uploaded on the Hub. The images have been manually selected together with the captions.
Check the 🤗 [documentation](https://huggingface.co/docs/datasets/image_dataset) on how to create and upload your own image-text dataset.

## Set-up environment

In [1]:
!pip install git+https://github.com/huggingface/transformers.git@main

Collecting git+https://github.com/huggingface/transformers.git@main
  Cloning https://github.com/huggingface/transformers.git (to revision main) to /tmp/pip-req-build-bvdl2slm
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers.git /tmp/pip-req-build-bvdl2slm
  Resolved https://github.com/huggingface/transformers.git to commit 638d49983f36af910934b38771b4e55c835c1774
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h  Preparing metadata (pyproject.toml) ... [?25ldone
Collecting huggingface-hub<1.0,>=0.19.3 (from transformers==4.36.0.dev0)
  Obtaining dependency information for huggingface-hub<1.0,>=0.19.3 from https://files.pythonhosted.org/packages/05/09/1945ca6ba3ad8ad6e2872ba682ce8d68c5e63c8e55458ed8ab4885709f1d/huggingface_hub-0.19.4-py3-none-any.whl.metadata
  Downloading huggingface_hub-0.19.4-py3-none-any.whl.metadata (14 kB)
INFO: pip is looking at multiple versions of 

In [2]:
!pip install -q datasets

We also quickly upload some telemetry - this tells us which examples and software versions are getting used so we know where to prioritize our maintenance efforts. We don't collect (or care about) any personally identifiable information, but if you'd prefer not to be counted, feel free to skip this step or delete this cell entirely.

## Load the image captioning dataset

Let's load the image captioning dataset, you just need few lines of code for that.

Let's retrieve the caption of the first example:

And the corresponding image

## Create PyTorch Dataset

The lines below are entirely copied from the original notebook!

In [3]:
from datasets import load_dataset
from PIL import Image
from sklearn.model_selection import train_test_split



In [4]:
import pickle
with open('/kaggle/input/28231999/processed_data/processed_data.pkl', 'rb') as file:
    # Load the data from the file
    data = pickle.load(file)

# Now 'data' contains the deserialized object from the pickle file
# print(data)

In [5]:
print(data.keys())

dict_keys(['img_name', 'joint_points', 'text', 'simple_text', 'visibility_text', 'simple_visibility_text', 'normalized_text', 'normalized_simple_text'])


In [None]:
data['img_name'][0]

In [None]:
# images = [Image.open(f"/kaggle/input/28211999/all_joint_label_visible/{name}") for name in data['img_name']]

In [None]:
# from torch.utils.data import Dataset, DataLoader

# class ImageCaptioningDataset(Dataset):
#     def __init__(self, data_dict, processor, split_ratio=0.8):
#         self.processor = processor
#         self.images = [Image.open(f"/kaggle/input/28221999/all_joint_label_part_visible/{name}") for name in data_dict['img_name']]
#         self.text = [text for text in data_dict['text']]
#         self.simple_text = [text for text in data_dict['simple_text']]
#         if "visibility_text" in data_dict.keys():
#             self.visibility_text = [text for text in data_dict['visibility_text']]
#         if "simple_visibility_text"in data_dict.keys():
#             self.simple_visibility_text = [text for text in data_dict['simple_visibility_text']]
#         self.mode = "simple_text"
    
#     def change_mode(self, mode):
#         modes = ["simple_text", 'text', "simple_visibility_text","visibility_text"]
#         if mode in modes:
#             self.mode = mode
#         else:
#             print('invalid mode')
    
    
    
#     def __len__(self):
#         return len(self.text)

#     def __getitem__(self, idx):
#         image = self.images[idx]
#         text_list = getattr(self, self.mode)
#         text = text_list[idx]
# #         text = self.text[idx]
#         encoding = self.processor(images=image, text=text, padding="max_length", return_tensors="pt")
#         # remove batch dimension
#         encoding = {k:v.squeeze() for k,v in encoding.items()}
#         return encoding

In [None]:
from torch.utils.data import Dataset, DataLoader

class JointsVisibilityDataset(Dataset):
    def __init__(self, data_dict, processor, key_name = 'visibility_text', split='train', test_size = 0.2):
        self.processor = processor
        self.images = [Image.open(f"/kaggle/input/28231999/processed_data/{name}") for name in data_dict['img_name']]
        self.text = [text for text in data_dict[key_name]]
        self.test_size = test_size
        self.split = split
        if split == 'train':
            self.images, _, self.text, _ = train_test_split(self.images, self.text, test_size=self.test_size, random_state=42)
        elif split == 'test':
            _, self.images, _, self.text = train_test_split(self.images, self.text, test_size=self.test_size, random_state=42)
        
    def change_mode(self, mode):
        modes = ["simple_text", 'text', "simple_visibility_text","visibility_text"]
        if mode in modes:
            self.mode = mode
        else:
            print('invalid mode')
    
    
    
    def __len__(self):
        if self.split == 'train':
            return int(len(self.text)*(1-self.test_size))
        else:
            return int(len(self.text)*self.test_size)

    def __getitem__(self, idx):
        image = self.images[idx]
        text = self.text[idx]
        encoding = self.processor(images=image, text=text, padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

## Load model and processor

In [7]:
from transformers import AutoProcessor, BlipForConditionalGeneration

processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

preprocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

tokenizer_config.json:   0%|          | 0.00/506 [00:00<?, ?B/s]

vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

special_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

Now that we have loaded the processor, let's load the dataset and the dataloader:

In [8]:
train_dataset = JointsVisibilityDataset(data, processor)
test_dataset = JointsVisibilityDataset(data, processor,split='test')

In [9]:
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=1)

In [None]:
# train_dataset.change_mode('visibility_text')

## Train the model

Let's train the model! Run the simply the cell below for training the model

In [10]:
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(1):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
#     print(type(batch))
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids)

    loss = outputs.loss

    print("Loss:", loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

Epoch: 0


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Loss: 13.089844703674316
Loss: 12.081439971923828
Loss: 10.219390869140625
Loss: 10.159173011779785
Loss: 10.136489868164062
Loss: 10.136028289794922


KeyboardInterrupt: 

## Inference

Let's check the results on our train dataset

In [None]:
# # load image
# example = dataset[0]
# image = example["image"]
# image

In [None]:
# train_dataset[0]['pixel_values']

In [None]:
# # prepare image for the model
# inputs = processor(images=image, return_tensors="pt").to(device)
# pixel_values = inputs.pixel_values

# generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
# generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
# print(generated_caption)

In [46]:
def evaluate_accuracy():
    model.eval()
    truth_caption_list = []
    predicted_caption_list = []
    for i in range(10):
        inputs = processor(images=test_dataset.images[i], return_tensors="pt").to(device)
        pixel_values = inputs.pixel_values
        # pixel_values = train_dataset[0]['pixel_values'].to(device)

        generated_ids = model.generate(pixel_values=pixel_values, max_length=200)
        generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
        truth_caption = test_dataset.text[i]
#         print('g',generated_caption)
#         print('t', truth_caption)
        truth_caption = truth_caption.split()
        generated_caption = generated_caption.split()
#         print(generated_caption)
        for j in range(len(truth_caption)):
            truth_caption_list.append(int(truth_caption[j][-1]))
            predicted_caption_list += [int(value) for value in generated_caption if value.isdigit()]
    correct = sum(p == l for p, l in zip(truth_caption_list, predicted_caption_list))
    total = len(truth_caption_list )
    baseline = sum(truth_caption_list)/total
    print('finetuned', correct/total)
    print('baseline', baseline)

In [47]:
evaluate_accuracy()

0.64375
0.73125


In [None]:
train_dataset.images[0]

In [None]:
# prepare image for the model
inputs = processor(images=test_dataset.images[0], return_tensors="pt").to(device)
pixel_values = inputs.pixel_values
# pixel_values = train_dataset[0]['pixel_values'].to(device)

generated_ids = model.generate(pixel_values=pixel_values, max_length=200)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)

In [None]:
test_dataset.text[0]