In [8]:
!pip install datasets



In [9]:
import random
import json
from PIL import Image
import pandas as pd
import torch
import random
import datasets
from torch.utils.data import Dataset, DataLoader
from transformers import BlipProcessor, BlipForConditionalGeneration
import json
import pandas as pd
from sqlalchemy import create_engine

In [10]:
BASE_PATH = '/content/drive/MyDrive/social/'

images_path = f'{BASE_PATH}images/'
db_path = f'{BASE_PATH}metadata.db'
train_json_path = f'{BASE_PATH}train_images.json'

model_path = "Salesforce/blip-image-captioning-base"


In [11]:
class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], text=random.choice(item["text"]), padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

In [12]:
def train(model, dataloader, epochs = 10):
    optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

    device = "cuda" if torch.cuda.is_available() else "cpu"
    model.to(device)

    model.train()

    for epoch in range(epochs):
        print("Epoch:", epoch)
        for _, batch in enumerate(dataloader):
            input_ids = batch.pop("input_ids").to(device)
            pixel_values = batch.pop("pixel_values").to(device)

            outputs = model(input_ids=input_ids,
                            pixel_values=pixel_values,
                            labels=input_ids)

            loss = outputs.loss

            print("Loss:", loss.item())

            loss.backward()

            optimizer.step()
            optimizer.zero_grad()

In [13]:
# Load the processor and model
processor = BlipProcessor.from_pretrained(model_path)
model = BlipForConditionalGeneration.from_pretrained(model_path)

In [14]:
engine = create_engine(f'sqlite:///{db_path}')
metadata_df = pd.read_sql('ImageData', engine)

data_dict = {}
train_images_dict : dict[str, str]
with open(train_json_path, "r") as file:
    train_images_dict = json.load(file)

for id, name in train_images_dict.items():
    caption = metadata_df[metadata_df['id'] == int(id)][['caption']].values[0][0]
    image = Image.open(f'{images_path}{name}').convert('RGB')
    data_dict[id] = {'image': image, 'text': [caption]}


In [15]:
data_list = list(data_dict.values())
dataset = datasets.Dataset.from_list(data_list)


In [16]:
train_dataset = ImageCaptioningDataset(dataset, processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1)


In [None]:
train(model, train_dataloader)

Epoch: 0


We strongly recommend passing in an `attention_mask` since your input_ids may be padded. See https://huggingface.co/docs/transformers/troubleshooting#incorrect-output-when-padding-tokens-arent-masked.


Loss: 12.594154357910156
Loss: 11.109750747680664
Loss: 10.298810958862305
Loss: 10.175548553466797
Loss: 10.23611068725586
Loss: 10.290913581848145
Loss: 10.28325366973877
Loss: 10.26285457611084
Loss: 10.305441856384277
Loss: 10.262384414672852
Loss: 10.254898071289062
Loss: 10.274365425109863
Loss: 10.254914283752441
Loss: 10.245553970336914
Loss: 10.225752830505371
Loss: 10.207948684692383
Loss: 10.211370468139648
Loss: 10.169900894165039
Loss: 10.119539260864258
Loss: 10.121381759643555
Loss: 10.102998733520508
Loss: 9.758451461791992
Loss: 9.320759773254395
Loss: 8.70272159576416
Loss: 9.896804809570312
Loss: 8.753671646118164
Loss: 8.7097749710083
Loss: 8.376005172729492
Loss: 8.08980655670166
Loss: 7.891321659088135
Loss: 7.570362091064453
Loss: 7.4641852378845215
Loss: 7.191710948944092
Loss: 6.943793773651123
Loss: 6.822363376617432
Loss: 6.668732166290283
Loss: 6.511059761047363
Loss: 6.347950458526611
Loss: 6.316605091094971
Loss: 6.122056484222412
Loss: 5.8859405517578125


In [None]:
# Save the model and processor locally
processor.save_pretrained(f"{BASE_PATH}trained_models/blip_image_captioning_tuned_10e")
model.save_pretrained(f"{BASE_PATH}trained_models/blip_image_captioning_tuned_10e")