In [1]:
!pip install wandb
import wandb
wandb.login()



[34m[1mwandb[0m: Using wandb-core as the SDK backend.  Please refer to https://wandb.me/wandb-core for more information.
[34m[1mwandb[0m: Currently logged in as: [33mahmadsait[0m ([33mahmadsait-king-abdullah-university-of-science-and-technology[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin


True

In [2]:
!pip install transformers datasets accelerate torchvision pandas



In [1]:
import pandas as pd
from PIL import Image
from datasets import load_dataset, Dataset
from transformers import BlipProcessor, BlipForConditionalGeneration, TrainingArguments, Trainer
import torch
import unicodedata

In [1]:
df = pd.read_csv("allam_enhancements_only.csv")

blip_df = df[["image_file", "allam_arabic"]].rename(columns={
    "image_file": "image",
    "allam_arabic": "caption"
})

blip_df["image"] = blip_df["image"].apply(lambda x: x.replace("YOUR/PATH/TO/WIKIART", "artelingo/dataset/wikiart"))

blip_df.to_csv("blip_caption_data.csv", index=False, encoding="utf-8-sig")

In [8]:
#filtering out bad samples from the dataset

df = pd.read_csv("blip_caption_data.csv", encoding="utf-8-sig")
df["image"] = df["image"].apply(lambda x: unicodedata.normalize("NFC", x))
dataset = Dataset.from_pandas(df)
from PIL import Image
from tqdm import tqdm
import os
import json

broken_indices = []
MAX_PIXELS = 89_478_485 # PIL's threshold

for i in tqdm(range(len(dataset)), desc="Checking images"):
    path = dataset[i]["image"]
    if not os.path.exists(path):
        broken_indices.append(i)
        continue
    try:
        with Image.open(path) as img:
            if img.width * img.height > MAX_PIXELS:
                print(f"[Skipped] {path} exceeds safe pixel size.")
                broken_indices.append(i)
                continue
            img = img.convert("RGB").copy()
    except Exception as e:
        print(f"Bad image at {i}: {path}, reason: {e}")
        broken_indices.append(i)

with open("broken_image_indices.json", "w") as f:
    json.dump(broken_indices, f)
    
with open("broken_image_indices.json") as f:
    broken_indices = set(json.load(f))

dataset = dataset.filter(lambda example, idx: idx not in broken_indices, with_indices=True)

Checking images:  10%|█         | 40183/386410 [04:44<38:12, 151.00it/s] 

[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_uriel-1955.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_uriel-1955.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_uriel-1955.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_uriel-1955.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_uriel-1955.jpg exceeds safe pixel size.


Checking images:  42%|████▏     | 163678/386410 [21:02<31:11, 119.00it/s]

[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_vir-heroicus-sublimis-1950.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_vir-heroicus-sublimis-1950.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_vir-heroicus-sublimis-1950.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_vir-heroicus-sublimis-1950.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_vir-heroicus-sublimis-1950.jpg exceeds safe pixel size.
[Skipped] artelingo/dataset/wikiart/Color_Field_Painting/barnett-newman_vir-heroicus-sublimis-1950.jpg exceeds safe pixel size.


Checking images: 100%|██████████| 386410/386410 [53:14<00:00, 120.95it/s] 


Filter:   0%|          | 0/386410 [00:00<?, ? examples/s]

In [15]:
df_filtered = dataset.to_pandas()

df_filtered.to_csv("filtered_blip_caption_data.csv", index=False, encoding="utf-8-sig")

In [4]:
df = pd.read_csv("filtered_blip_caption_data.csv", encoding="utf-8-sig")
df["image"] = df["image"].apply(lambda x: unicodedata.normalize("NFC", x))

In [5]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [6]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")
model = BlipForConditionalGeneration.from_pretrained("Salesforce/blip-image-captioning-base").to(device)

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [7]:
!nproc

3


In [8]:
from torch.utils.data import Dataset
from PIL import Image

class BlipOnTheFlyDataset(Dataset):
    def __init__(self, df, processor):
        self.data = df.reset_index(drop=True)
        self.processor = processor

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        row = self.data.loc[idx]
        image_path = row["image"]
        caption = row["caption"]
    
        try:
            image = Image.open(image_path).convert("RGB")
        except:
            image = Image.new("RGB", (224, 224))
    
        inputs = self.processor(
            images=image,
            text=caption,
            return_tensors="pt",
            padding="max_length",
            truncation=True
        )
    
        return {
            "pixel_values": inputs["pixel_values"].squeeze(0),
            "input_ids": inputs["input_ids"].squeeze(0),
            "attention_mask": inputs["attention_mask"].squeeze(0),
            "labels": inputs["input_ids"].squeeze(0),
        }

In [9]:
df = df.iloc[:332152].reset_index(drop=True)
dataset = BlipOnTheFlyDataset(df, processor)

In [12]:
def blip_data_collator(features):
    return {
        "pixel_values": torch.stack([f["pixel_values"] for f in features]),
        "input_ids": torch.stack([f["input_ids"] for f in features]),
        "attention_mask": torch.stack([f["attention_mask"] for f in features]),
        "labels": torch.stack([f["labels"] for f in features]),
    }



training_args = TrainingArguments(
    output_dir="./blip-finetuned-allam",         
    per_device_train_batch_size=32,
    num_train_epochs=5,
    save_strategy="epoch",                              
    logging_steps=50,
    save_total_limit=5,                        
    fp16=True,
    disable_tqdm=False,
    report_to="wandb",                      
    run_name="blip-allam"
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=dataset,
    tokenizer=processor.tokenizer,
    data_collator=blip_data_collator
)

trainer.train()
trainer.save_model("./blip-finetuned-allam")


  trainer = Trainer(


Step,Training Loss
50,0.7966
100,0.4251
150,0.4008
200,0.3879
250,0.3595
300,0.3332
350,0.3313
400,0.3167
450,0.2969
500,0.2951


It took 18 hours and 32 minutes to train BLIP on ALLAM captions

In [2]:
processor = BlipProcessor.from_pretrained("Salesforce/blip-image-captioning-base")

Using a slow image processor as `use_fast` is unset and a slow processor was saved with this model. `use_fast=True` will be the default behavior in v4.52, even if the model was saved with a slow processor. This will result in minor differences in outputs. You'll still be able to use a slow processor with `use_fast=False`.


In [10]:
from PIL import Image
from transformers import BlipProcessor, BlipForConditionalGeneration
import torch
import os

image_path = "../../../../ibex/ai/home/saitaa0b/wikiart/Ukiyo_e/utagawa-kuniyoshi_women-8.jpg"
image = Image.open(image_path).convert("RGB")

# listing saved checkpoint directories
checkpoint_dirs = [
    "checkpoint-10380",
    "checkpoint-20760",
    "checkpoint-31140",
    "checkpoint-41520",
    "checkpoint-51900"
]

# running inference for each checkpoint
for ckpt in checkpoint_dirs:
    model_path = os.path.join("blip-finetuned-allam", ckpt)
    model = BlipForConditionalGeneration.from_pretrained(model_path).to("cuda")
    model.eval()

    inputs = processor(images=image, return_tensors="pt").to("cuda")

    with torch.no_grad():
        output = model.generate(**inputs, max_length=128)
        caption = processor.tokenizer.decode(output[0], skip_special_tokens=True)
        print(f"[{ckpt}] Caption: {caption}")


[checkpoint-10380] Caption: السيدة تجلس على الشاطي ، وامامها المركب الصغير ، والسماء تزينها السحب البيضاء.
[checkpoint-20760] Caption: امراة ترتدي ملابس سوداء ، تجلس على متن قارب ، وتبدو عليها علامات الحزن والاسى.
[checkpoint-31140] Caption: امراة ترتدي ملابس سوداء تجلس على مركب صغير ، وتبدو وكانها تستمتع باللعب في المياه.
[checkpoint-41520] Caption: امراة ترتدي فستانا اسودا مزينا بنقوش بديعة ، تجلس على متن قارب خشبي صغير.
[checkpoint-51900] Caption: امراة ترتدي فستانا اسود مزينا بنقوش بديعة ، تجلس على متن قارب صيد صيد صيد صيد منيف.


epoch 3 clearly conveys emotion:

[checkpoint-10380] Caption: السيدة ترتدي ملابس ذات لون احمر ، وتحمل طفلها الصغير بين ذراعيها.

[checkpoint-20760] Caption: امراة ترتدي ملابس حمراء وتحمل طفلها الرضيع ، وتضع على راسها غطاء احمر.

[checkpoint-31140] Caption: امراة تحتضن طفلها الصغير ، وتبدو عليها علامات الحزن والاسى.

[checkpoint-41520] Caption: امراة ترتدي ملابس ذات لون احمر وتحمل طفلا صغيرا ، وتضع على راسها غطاء احمر.

[checkpoint-51900] Caption: امراة ترتدي ملابس حمراء وتحمل طفلا صغيرا ، وتقوم بارضاعه.
