In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install git+https://github.com/huggingface/transformers.git@main
!pip install bert-score

In [None]:
# -*- coding: utf-8 -*-

"""
    task: caption
    version: 3
    preprocess: blip preprocessor
    base model: BLIP-2 caption
    freeze: yes
    architecture: fine-tuned BLIP-2 caption
    learning rate: 5e-5

"""

"""imports"""
import glob
import re
import os
import cv2
import numpy as np
import tensorflow as tf
import pandas as pd
from sklearn.preprocessing import MultiLabelBinarizer
from sklearn.metrics import f1_score
import time
import json
import nltk
import string
import warnings
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoProcessor, BlipForConditionalGeneration
import torch
from bert_score import score
warnings.filterwarnings("ignore")

"""save execution time"""
start_time = time.time()

"""files paths"""

path = "/content/drive/MyDrive/clef2023/"
train_path = '/content/train'
valid_path = "/content/valid"
test_path = "/content/test"

each_n_epoch = 5
version = 6
epochs = 5
im_size = (224, 224)
should_save_weights = True

"""create necessary folders"""
model_path = path + "runs/" + str(version)
os.makedirs(model_path + "/results", exist_ok=True)
os.makedirs(model_path + "/weights", exist_ok=True)
"""define limit for data loading"""
sample = False
train_limit = 60918
valid_limit = 10437
test_limit = 10473

dataset_batchsize = 4

"""Loading csvs"""

"""Read train concepts csv and apply limit"""
train_df = (
    pd.read_csv(
        path + "ImageCLEFmedical_Caption_2023_caption_prediction_train_labels.csv",
        delimiter="\t",
        index_col="ID",
    )
)
if sample:
    train_df = train_df.sample(train_limit).sort_values(by=["ID"])

print(f"2- read {len(train_df)} train images")

"""Read valid concepts csv and apply limit"""
valid_df = (
    pd.read_csv(
        path + "ImageCLEFmedical_Caption_2023_caption_prediction_valid_labels.csv",
        delimiter="\t",
        index_col="ID",
    )
)
if sample:
    valid_df = valid_df.sample(valid_limit).sort_values(by=["ID"])

print(f"3- read {len(valid_df)} valid images")

"""Save image names as list"""
train_img_names = train_df.index.values.tolist()
valid_img_names = valid_df.index.values.tolist()

"""Read test images from the path with glob"""
test_img_names = [
                     os.path.splitext(os.path.basename(x))[0] for x in glob.glob(test_path + "/*.jpg")
                 ]
if sample:
    test_img_names = test_img_names[:test_limit]
print(f"5- found {len(test_img_names)} test images")

class ImageCaptioningDataset(Dataset):
    def __init__(self, folder_path, image_list, dataset_df, processor):
        self.folder_path = folder_path
        self.image_list = image_list
        self.processor = processor
        self.dataset_df = dataset_df

    def __len__(self):
        return len(self.image_list)

    def __getitem__(self, idx):
        img_name = self.image_list[idx]
        img_path = self.folder_path + "/" + img_name + ".jpg"
        img = Image.open(img_path)
        img = img.convert('RGB')
        img = img.resize(im_size)
        caption = self.dataset_df.loc[img_name]["caption"]
        encoding = self.processor(images=img, text=caption, padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding['input_ids'] = encoding['input_ids'][:512]
        return encoding

"""Load processors and models"""
processor = AutoProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base")

"""Generate Dataset and Dataloaders"""
train_dataset = ImageCaptioningDataset(folder_path=train_path, image_list=train_img_names, dataset_df=train_df, processor=processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=dataset_batchsize)

valid_dataset = ImageCaptioningDataset(folder_path=valid_path, image_list=valid_img_names, dataset_df=valid_df, processor=processor)
valid_dataloader = DataLoader(valid_dataset, shuffle=True, batch_size=dataset_batchsize)

device = "cuda" if torch.cuda.is_available() else "cpu"

"""Predicting"""
def pred_img_caption(image):
    inputs = processor(images=image, return_tensors="pt").to(device)
    pixel_values = inputs.pixel_values

    generated_ids = model.generate(pixel_values=pixel_values, max_length=50)
    generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
    return generated_caption

epoch_number = []
bert_scores = []

def predict_and_save(model, epoch):
    """pred test dataset"""
    test_captions = []
    for name in test_img_names:
        im_path = test_path + "/" + name + ".jpg"
        img = Image.open(im_path)
        cap = pred_img_caption(img)
        test_captions.append(cap)

    filename = f"e-{epoch}-test-caption.csv"
    df = pd.DataFrame(
        data={
            "ID": test_img_names,
            "caption": test_captions,
        }
    )

    df.to_csv(
        model_path + "/results/" + filename, index=False, header=None, sep="\t"
    )

    """pred valid dataset"""
    valid_captions = []
    for name in valid_img_names:
        im_path = valid_path + "/" + name + ".jpg"
        img = Image.open(im_path)
        cap = pred_img_caption(img)
        valid_captions.append(cap)

    """compute bert score for valid data"""
    reference = valid_df['caption'].tolist()
    P, R, F1 = score(valid_captions, reference, lang='en', verbose=False)
    score_v = (torch.sum(F1) / valid_limit).item()
    epoch_number.append(epoch)
    bert_scores.append(score_v)


    """Saving valid results"""
    filename = f"e-{epoch}-valid-caption.csv"
    df = pd.DataFrame(
        data={
            "ID": valid_img_names,
            "caption": valid_captions,
        }
    )

    df.to_csv(
        model_path + "/results/" + filename, index=False, header=None, sep="\t"
    )




"""Training"""
optimizer = torch.optim.AdamW(model.parameters(), lr=5e-5)

device = "cuda" if torch.cuda.is_available() else "cpu"
model.to(device)

model.train()

for epoch in range(epochs):
  print("Epoch:", epoch+1)
  train_loss = 0.0
  valid_loss = 0.0
  batch_num = 1
  for idx, batch in enumerate(train_dataloader):
    print("Batch: ", str(batch_num))
    batch_num +=1
    torch.cuda.empty_cache()
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device)
    labels = batch.pop("labels").to(device)

    outputs = model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=labels)

    loss = outputs.loss
    train_loss += loss.item() * input_ids.size(0)

    loss.backward()

    optimizer.step()
    optimizer.zero_grad()

  # Compute the average loss for the training dataset
  train_loss /= len(train_dataloader.dataset)

  # Evaluate your model on the validation dataset
  model.eval()
  with torch.no_grad():
        for idx, batch in enumerate(valid_dataloader):
            input_ids = batch.pop("input_ids").to(device)
            pixel_values = batch.pop("pixel_values").to(device)

            outputs = model(input_ids=input_ids,
                      pixel_values=pixel_values,
                      labels=input_ids)

            loss = outputs.loss
            valid_loss += loss.item() * input_ids.size(0)

  # Compute the average loss for the validation dataset
  valid_loss /= len(valid_dataloader.dataset)

  # Print the training and validation losses for each epoch
  print('Epoch [{}/{}], Train Loss: {:.4f}, Valid Loss: {:.4f}'
        .format(epoch + 1, epochs, train_loss, valid_loss))

  if epoch == 0 or (epoch + 1) % each_n_epoch == 0:
      predict_and_save(model, epoch + 1)
      os.makedirs(model_path + "/weights/epoch" + str(epoch+1) , exist_ok=True)
      model.save_pretrained(model_path + "/weights/epoch" + str(epoch+1))




with open(model_path + "/results/running_time.json", "w+") as f:
    json.dump({
        "task": "caption",
        "total_running_time_mins": (time.time() - start_time) / 60,
        "version": version,
    }, f)

2- read 60918 train images
3- read 10437 valid images
5- found 10473 test images


Downloading (…)rocessor_config.json:   0%|          | 0.00/287 [00:00<?, ?B/s]

Downloading (…)okenizer_config.json:   0%|          | 0.00/438 [00:00<?, ?B/s]

Downloading (…)solve/main/vocab.txt:   0%|          | 0.00/232k [00:00<?, ?B/s]

Downloading (…)/main/tokenizer.json:   0%|          | 0.00/711k [00:00<?, ?B/s]

Downloading (…)cial_tokens_map.json:   0%|          | 0.00/125 [00:00<?, ?B/s]

Downloading (…)lve/main/config.json:   0%|          | 0.00/4.56k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/990M [00:00<?, ?B/s]

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Batch:  1319
Batch:  1320
Batch:  1321
Batch:  1322
Batch:  1323
Batch:  1324
Batch:  1325
Batch:  1326
Batch:  1327
Batch:  1328
Batch:  1329
Batch:  1330
Batch:  1331
Batch:  1332
Batch:  1333
Batch:  1334
Batch:  1335
Batch:  1336
Batch:  1337
Batch:  1338
Batch:  1339
Batch:  1340
Batch:  1341
Batch:  1342
Batch:  1343
Batch:  1344
Batch:  1345
Batch:  1346
Batch:  1347
Batch:  1348
Batch:  1349
Batch:  1350
Batch:  1351
Batch:  1352
Batch:  1353
Batch:  1354
Batch:  1355
Batch:  1356
Batch:  1357
Batch:  1358
Batch:  1359
Batch:  1360
Batch:  1361
Batch:  1362
Batch:  1363
Batch:  1364
Batch:  1365
Batch:  1366
Batch:  1367
Batch:  1368
Batch:  1369
Batch:  1370
Batch:  1371
Batch:  1372
Batch:  1373
Batch:  1374
Batch:  1375
Batch:  1376
Batch:  1377
Batch:  1378
Batch:  1379
Batch:  1380
Batch:  1381
Batch:  1382
Batch:  1383
Batch:  1384
Batch:  1385
Batch:  1386
Batch:  1387
Batch:  1388
Batch:  1389
Batch:  1390