In [1]:
import datasets
import pandas as pd
from pathlib import Path
from PIL import Image
from transformers import AutoTokenizer
from transformers import Blip2Processor, Blip2VisionModel, Blip2QFormerModel, Blip2QFormerConfig, Blip2ForConditionalGeneration
from transformers import AutoProcessor, Blip2ForConditionalGeneration
import os
import bitsandbytes as bnb

import torch

from peft import LoraConfig, get_peft_model, LoftQConfig
from transformers import AutoModelForCausalLM, BitsAndBytesConfig
from peft import (
    LoraConfig,
    PeftConfig,
    get_peft_model,
    prepare_model_for_kbit_training,
)
from transformers import (
    AutoConfig,
    AutoModelForCausalLM,
    AutoTokenizer,
    BitsAndBytesConfig,)

Could not find the bitsandbytes CUDA binary at PosixPath('/home/ai4103/.conda/envs/main/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cuda121.so')
Could not load bitsandbytes native library: /home/ai4103/.conda/envs/main/lib/python3.10/site-packages/bitsandbytes/libbitsandbytes_cpu.so: cannot open shared object file: No such file or directory
Traceback (most recent call last):
  File "/home/ai4103/.conda/envs/main/lib/python3.10/site-packages/bitsandbytes/cextension.py", line 109, in <module>
    lib = get_native_library()
  File "/home/ai4103/.conda/envs/main/lib/python3.10/site-packages/bitsandbytes/cextension.py", line 96, in get_native_library
    dll = ct.cdll.LoadLibrary(str(binary_path))
  File "/home/ai4103/.conda/envs/main/lib/python3.10/ctypes/__init__.py", line 452, in LoadLibrary
    return self._dlltype(name)
  File "/home/ai4103/.conda/envs/main/lib/python3.10/ctypes/__init__.py", line 374, in __init__
    self._handle = _dlopen(self._name, mode)
OSError: /hom

In [2]:
from peft import LoraConfig, get_peft_model,LoftQConfig

config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj"]
)

bnb_config = BitsAndBytesConfig(
    load_in_4bit=True,
    bnb_4bit_quant_type="nf4",
    bnb_4bit_compute_dtype=torch.bfloat16,
)

In [3]:
class BlipZSIC:
    def __init__(self, bnb_config: BitsAndBytesConfig, loraConfig: LoraConfig) -> None:
        self.base = Blip2ForConditionalGeneration.from_pretrained(
            "Salesforce/blip2-opt-2.7b",  # TODO put back in "load_in_8bit" for model
            device_map={"": 0},
            # trust_remote_code=True,
            # quantization_config=bnb_config
        )
        self.processor = Blip2Processor.from_pretrained(
            "Salesforce/blip2-opt-2.7b")
        
        self.tokenizer = AutoTokenizer.from_pretrained("airesearch/wangchanberta-base-att-spm-uncased")
        
        self.base.config.text_config.vocab_size = 25004
        self.base.language_model.resize_token_embeddings(len(self.tokenizer))
        self.base.config.eos_token_id = 6
        self.processor.tokenizer = self.tokenizer

        self.model = None

        self.loraConfig = loraConfig
        self.adapterList = []
        self.currentState = ''

    def compileModel(self)->None: # compile the model into qLora
        self.model = get_peft_model(self.base, self.loraConfig)
        self.model.print_trainable_parameters()

    def addAdapter(self,adapter:str): #pass in the adapter path to add
        self.adapterList.append(adapter)
        self.model.add_adapter(self.lora_config, adapter_name=adapter)

    def switchAdapter(self,adapterNum:int): 
        if adapterNum != 0: #left in for clarity
            try:
                if self.currentState != adapterNum:
                    self.model.set_adapter(self.adapterList[adapterNum])
                    self.currentState = adapterNum
                    print(f"switched to adapter {adapterNum}")
            except IndexError:
                print("index out of range, returning")
            
        else:
            self.model.disable_adapters()
            self.currentState = adapterNum
            print("adapters disabled")

    def forward(self,input_ids,pixel_values,modeltype = -1):
        if modeltype != -1 :
            self.switchAdapter(modeltype)
        
        return self.model(input_ids=input_ids,
                   pixel_values=pixel_values,
                   labels=input_ids)
    
    def predict(self, imgs):
        self.model.eval()
        
        pixel_values = self.processor(sample_img, return_tensors="pt").to(device).pixel_values

        outputs = self.model.generate(pixel_values=pixel_values)

        generated_caption = self.processor.batch_decode(outputs, skip_special_tokens=True)[0]
        
        return generated_caption
        
Blip = BlipZSIC(bnb_config,config)

Loading checkpoint shards:   0%|          | 0/2 [00:00<?, ?it/s]

In [4]:
Blip.compileModel()

trainable params: 5,242,880 || all params: 3,685,236,736 || trainable%: 0.1423


In [5]:
from torch.utils.data import Dataset, DataLoader

class ImageCaptioningDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        encoding = self.processor(images=item["image"], padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding["text"] = item["text"]
        return encoding

def collate_fn(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key != "text":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            text_inputs = Blip.processor.tokenizer(
                [example["text"] for example in batch], padding=True, return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
    return processed_batch

# Data and Dataloader setup


In [6]:
images_path = list(Path("/project/lt900331-ai24nr/blip/Image").glob("*.jpg"))

labels = pd.DataFrame(
    {"image": images_path,
    "caption": ["น่าจะเป็นดาวแหละ"] * len(images_path)
    }
)
# labels = pd.read_csv("/root/Datasets/preprocess.csv")

In [7]:
images = [Image.open(str(images_path[0].parent / path)) for path in labels['image']]

dataset = datasets.Dataset.from_dict({"image": images, "text": labels['caption']})
dataset = dataset.train_test_split(test_size=0.1, seed=42)

train_dataset = ImageCaptioningDataset(dataset['train'], Blip.processor)
train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=4, collate_fn=collate_fn)

test_dataset = ImageCaptioningDataset(dataset['test'], Blip.processor)
test_dataloader = DataLoader(test_dataset, shuffle=True, batch_size=4, collate_fn=collate_fn)

# Training

In [8]:
import torch
from tqdm import tqdm
optimizer = torch.optim.Adam(Blip.model.parameters(), lr=5e-4)

device = "cuda" if torch.cuda.is_available() else "cpu"

Blip.model.train()

for epoch in range(10):
    training_loss = 0
    print("Epoch:", epoch)
    for idx, batch in enumerate(tqdm(train_dataloader)):
        input_ids = batch.pop("input_ids").to(device)
        pixel_values = batch.pop("pixel_values").to(device)
        outputs = Blip.forward(input_ids=input_ids,
                        pixel_values=pixel_values,
        )#put in a 4th value as an int to select adapters

        loss = outputs.loss
        training_loss += loss.item()
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()
        
    print(training_loss / len(train_dataloader))

Epoch: 0


100%|██████████| 1/1 [00:02<00:00,  2.46s/it]


10.002299308776855
Epoch: 1


100%|██████████| 1/1 [00:00<00:00,  3.50it/s]


9.557210922241211
Epoch: 2


100%|██████████| 1/1 [00:00<00:00,  3.54it/s]


8.596506118774414
Epoch: 3


100%|██████████| 1/1 [00:00<00:00,  3.56it/s]


7.382016181945801
Epoch: 4


100%|██████████| 1/1 [00:00<00:00,  3.56it/s]


6.052756309509277
Epoch: 5


100%|██████████| 1/1 [00:00<00:00,  3.56it/s]


4.780925750732422
Epoch: 6


100%|██████████| 1/1 [00:00<00:00,  3.56it/s]


3.195636510848999
Epoch: 7


100%|██████████| 1/1 [00:00<00:00,  3.55it/s]


2.720263719558716
Epoch: 8


100%|██████████| 1/1 [00:00<00:00,  3.55it/s]


1.5085463523864746
Epoch: 9


100%|██████████| 1/1 [00:00<00:00,  3.57it/s]

0.9145419001579285





In [9]:
# sample image in dataset['test']
sample_img = dataset['test']['image'][0]

In [10]:
Blip.predict(sample_img)

'จน่าจะเป็นดาวแหละ น่าจะเป็นดาวแหละ น่าจะเป็นดาวแหละ น่าจะเป็นดาว'