### Import modules

In [2]:
from common.partition import get_random_partition
import random
import json
from PIL import Image
import pandas as pd
import torch
import datasets
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration

### Define custom dataset class

In [2]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=random.choice(item["text"]), padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

### Define training function

`train()` handles the training loop 

In [3]:
def train(model, dataloader, epochs = 3):
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    model.train()

    for epoch in range(epochs):
        print("Epoch:", epoch)
        for _, batch in enumerate(dataloader):
            input_ids = batch.pop("input_ids").to(device)
            pixel_values = batch.pop("pixel_values").to(device)

            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            labels=input_ids)

            loss = outputs.loss

            print("Loss:", loss.item())

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

### Parameters

- **sample_index**: index of the sample that's gonna be used for training
- **subsample_ratio**: specifies the proportion between the partition samples and the sample of images used
- **model_path**: path to pretrained model

In [3]:
sample_index = 0
subsample_ratio = 0.25
model_path = '../../../trained_models/blip_image_captioning_base'

### Load processor and model

In [5]:
# Load the processor and model
processor = BlipProcessor.from_pretrained(model_path)
model = BlipForConditionalGeneration.from_pretrained(model_path)

### Load sample and build subsamples from it

In [4]:
dataset_path = '../../../datasets/flickr8/'
captions_csv = pd.read_csv(f'{dataset_path}captions.txt')

with open(f'{dataset_path}random_samples_7.json', 'rb') as file:
    random_samples = json.load(file)
    
images = random_samples[sample_index]
samples = get_random_partition(images, subsample_ratio)

print(f'Samples: {len(samples)}')
[len(samples[i]) for i, _ in enumerate(samples)]

Samples: 4


[141, 141, 141, 143]

### Train the model with each subsample

In [None]:
for i, sample in enumerate(samples):
    print(f'Start on sample {i + 1} out of {len(samples)}')
    data_dict = {}

    for image in sample:
        data_dict[image] = None

    for value in captions_csv.values:
        image_name = value[0]
        image_caption = value[1]

        if image_name in data_dict and data_dict[image_name] != None:
            data_dict[image_name]['text'].append(image_caption)
        elif image_name in data_dict:
            image = Image.open(f'{dataset_path}/Images/{image_name}').convert('RGB')
            data_dict[image_name] = {'image': image, 'text': [image_caption]}

    data_list = list(data_dict.values())
    print('Done processing images')

    dataset = datasets.Dataset.from_list(data_list)
    train_dataset = ImageCaptioningDataset(dataset, processor)
    train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1)
    print('Done building the training dataset')

    print('Start training')
    train(model, train_dataloader)
    print(f'Done training with sample {i + 1} out of {len(samples)}')

### Save fine tuned model

In [None]:
from datetime import date

date_id = date.isoformat(date.today())
model.save_pretrained(f"../../../trained_models/blip_image_captioning_{date_id}")
processor.save_pretrained(f"../../../trained_models/blip_image_captioning_{date_id}")