In [15]:
import os
import base64
from io import BytesIO
from PIL import Image

In [16]:
dataset = []

for folder in os.listdir('data'):
    if folder == '.DS_Store':
        continue
    
    text = open('data/' + folder + '/annotation.txt', 'r').read()

    image_text = open('data/' + folder + '/image.txt', 'r').read()
    image = Image.open(BytesIO(base64.b64decode(image_text)))

    dataset.append({'text': text, 'image': image})

In [17]:
print(len(dataset))

1475


In [21]:
from torch.utils.data import Dataset, DataLoader

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=item["text"], padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

In [22]:
from transformers import AutoProcessor, BlipForConditionalGeneration

processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

In [23]:
train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=2)

In [25]:
import torch

optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(5):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids)
    
    loss = outputs.loss

    print("Loss:", loss.item())

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

Epoch: 0
Loss: 1.448843002319336
Loss: 1.406710147857666
Loss: 1.4019355773925781
Loss: 1.455841064453125
Loss: 1.4141844511032104
Loss: 1.440647006034851
Loss: 1.4473156929016113
Loss: 1.5178241729736328
Loss: 1.432844877243042
Loss: 1.486398458480835
Loss: 1.450452208518982
Loss: 1.445716381072998
Loss: 1.447212815284729
Loss: 1.4390712976455688
Loss: 1.4186346530914307
Loss: 1.5286283493041992
Loss: 1.4437127113342285
Loss: 1.4639393091201782
Loss: 1.4389865398406982
Loss: 1.4348111152648926
Loss: 1.425684928894043
Loss: 1.4234188795089722
Loss: 1.4554870128631592
Loss: 1.4485057592391968
Loss: 1.5299510955810547
Loss: 1.4246214628219604
Loss: 1.4001436233520508
Loss: 1.5191121101379395
Loss: 1.4624799489974976
Loss: 1.3950421810150146
Loss: 1.436233401298523
Loss: 1.5251874923706055
Loss: 1.3959378004074097
Loss: 1.4688321352005005
Loss: 1.4066022634506226
Loss: 1.4662952423095703
Loss: 1.4665287733078003
Loss: 1.4250178337097168
Loss: 1.436220645904541
Loss: 1.4365336894989014
Los

In [26]:
model.save_pretrained("blip-finetuned")