In [1]:
import sys
from transformers import ViTImageProcessor, RobertaTokenizer, Blip2Processor, Blip2ForConditionalGeneration, InstructBlipProcessor, InstructBlipForConditionalGeneration

from datasets import list_metrics

from diffusers import AudioLDMPipeline

import torch
from torch.utils.data import Dataset, DataLoader

from PIL import Image, ImageFile
ImageFile.LOAD_TRUNCATED_IMAGES = True

import pandas as pd


from peft import LoraConfig, get_peft_model

In [2]:
device = "cuda" if torch.cuda.is_available else "cpu"
print(device)

cuda


In [3]:
processor = Blip2Processor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded", device_map="auto", load_in_8bit=True, torch_dtype=torch.float32)
tokenizer = RobertaTokenizer.from_pretrained("FacebookAI/roberta-base")

The `load_in_4bit` and `load_in_8bit` arguments are deprecated and will be removed in the future versions. Please, pass a `BitsAndBytesConfig` object in `quantization_config` argument instead.


Loading checkpoint shards:   0%|          | 0/8 [00:00<?, ?it/s]

  return self.fget.__get__(instance, owner)()


In [4]:
# class CustomDataset(Dataset):
#     def __init__(self, dataframe):
#         self.dataframe = dataframe
        
#     def __len__(self):
#         return len(self.dataframe)
    
#     def __getitem__(self, idx):
#         # transform = transforms.Compose([
#         #     transforms.PILToTensor(),
#         # ])
#         image = Image.open(f"images/{idx}.png")
#         image_features = processor(image, return_tensors="pt").pixel_values

#         labels = tokenizer(self.dataframe["caption"][idx],return_tensors="pt",
#                                           max_length=46,
#                                           pad_to_max_length=True,
#                                           return_token_type_ids=True,
#                                           truncation=True).input_ids
#         return {'pixel_values':image_features.squeeze(0),'labels':self.dataframe["caption"][idx], "idx":idx}

class CustomDataset(Dataset):
    def __init__(self, dataset, processor):
        self.dataset = dataset
        self.processor = processor

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        image = Image.open(f"images/{idx}.png")
        encoding = self.processor(images=image, return_tensors="pt")
        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding["text"] = self.dataset["caption"][idx]
        return encoding

### Generate outputs to see base model results

In [5]:
image = Image.open("images/1.png")
inputs = processor(images=image, return_tensors="pt").to(device, torch.float32)
outputs = model.generate(**inputs, max_length = 30)



In [6]:
generated_text = processor.batch_decode(outputs, skip_special_tokens=True)[0].strip()
print(generated_text)

a close up of a guitar with strings and strings


### Add LoRA layer

In [7]:
config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="lora_only",
    target_modules=["q_proj", "k_proj"]
)

training_model = get_peft_model(model, config)
training_model.print_trainable_parameters()

trainable params: 5,406,720 || all params: 3,749,922,816 || trainable%: 0.1441821676150467


### Loading data

In [8]:
data = pd.read_csv("musiccaps-public.csv")

In [9]:
train_data = data[data["is_audioset_eval"] == False].reset_index(drop=True)
display(train_data)

Unnamed: 0,ytid,start_s,end_s,audioset_positive_labels,aspect_list,caption,author_id,is_balanced_subset,is_audioset_eval
0,-0SdAVK79lg,30,40,"/m/0155w,/m/01lyv,/m/0342h,/m/042v_gx,/m/04rlf...","['guitar song', 'piano backing', 'simple percu...",This song features an electric guitar as the m...,0,False,False
1,-1LrH01Ei1w,30,40,"/m/02p0sh1,/m/04rlf","['rubab instrument', 'repetitive melody on dif...",This song features a rubber instrument being p...,0,False,False
2,-4NLarMj4xU,30,40,"/m/04rlf,/t/dd00034","['pop', 'tinny wide hi hats', 'mellow piano me...",The Pop song features a soft female vocal sing...,4,False,False
3,-5f6hjZf9Yw,30,40,"/m/02w4v,/m/04rlf","['folk music', 'rubab', 'male voice', 'slow te...",This folk song features a male voice singing t...,0,False,False
4,-5xOcMJpTUk,70,80,"/m/018vs,/m/0342h,/m/042v_gx,/m/04rlf,/m/04szw...","['guitarist', 'male talking', 'twang sounds', ...",A male guitarist plays the guitar and speaks a...,1,False,False
...,...,...,...,...,...,...,...,...,...
2658,zqbHYVH6Wqo,30,40,"/m/085jw,/m/0l14l2","['traditional horn instruments', 'devotional',...",The song is an instrumental. The tempo is medi...,1,False,False
2659,zrb76mJOZQQ,3,13,"/m/0395lw,/m/0gy1t2s","['amateur recording', 'no music', 'sound of be...",This amateur recording features the sound of t...,0,False,False
2660,zu_1zpF--Zg,0,10,"/m/01xqw,/m/02fsn,/m/0d8_n,/m/0l14_3","['amateur recording', 'jazz/bossa-nova', 'upri...",This audio contains someone playing jazz chord...,6,False,False
2661,zw5dkiklbhE,15,25,"/m/01sm1g,/m/0l14md","['amateur recording', 'percussion', 'wooden bo...",This audio contains someone playing a wooden b...,6,False,False


In [10]:
val_data = data[data["is_audioset_eval"] == True].reset_index(drop=True)
display(val_data)

Unnamed: 0,ytid,start_s,end_s,audioset_positive_labels,aspect_list,caption,author_id,is_balanced_subset,is_audioset_eval
0,-0Gj8-vB1q4,30,40,"/m/0140xf,/m/02cjck,/m/04rlf","['low quality', 'sustained strings melody', 's...",The low quality recording features a ballad so...,4,False,True
1,-0vPFx-wRRI,30,40,"/m/025_jnm,/m/04rlf","['amateur recording', 'finger snipping', 'male...",a male voice is singing a melody with changing...,6,False,True
2,-0xzrMun0Rs,30,40,"/m/01g90h,/m/04rlf","['backing track', 'jazzy', 'digital drums', 'p...",This song contains digital drums playing a sim...,6,False,True
3,-1OlgJWehn8,30,40,"/m/04rlf,/m/06bz3","['instrumental', 'white noise', 'female vocali...",This clip is three tracks playing consecutivel...,7,False,True
4,-1UWSisR2zo,30,40,"/m/04rlf,/m/0xzly","['live performance', 'poor audio quality', 'am...",A male singer sings this groovy melody. The so...,1,False,True
...,...,...,...,...,...,...,...,...,...
2853,zrrM6Qg2Dwg,30,40,"/m/04rlf,/m/0l156b","['steel pan music', 'happy mood', 'caribbean f...",This song features the main melody played on a...,0,True,True
2854,ztfegVzqeCI,30,40,"/m/015lz1,/m/01v1d8,/m/04rlf,/m/07kc_,/m/0l14qv","['female singer', 'synth bass', 'reverb', 'the...",A quirky drum machine and warm synth bass prov...,8,True,True
2855,zwfo7wnXdjs,30,40,"/m/02p0sh1,/m/04rlf,/m/06j64v","['instrumental music', 'arabic music', 'genera...",The song is an instrumental. The song is mediu...,1,True,True
2856,zx_vcwOsDO4,50,60,"/m/01glhc,/m/02sgy,/m/0342h,/m/03lty,/m/04rlf,...","['instrumental', 'no voice', 'electric guitar'...",The rock music is purely instrumental and feat...,2,True,True


In [11]:
train_data = CustomDataset(train_data, processor)
val_data = CustomDataset(val_data, processor)

In [12]:
def collate_fn(batch):
    # pad the input_ids and attention_mask
    processed_batch = {}
    for key in batch[0].keys():
        if key != "text":
            processed_batch[key] = torch.stack([example[key] for example in batch])
        else:
            text_inputs = processor.tokenizer(
                [example["text"] for example in batch], padding=True, return_tensors="pt"
            )
            processed_batch["input_ids"] = text_inputs["input_ids"]
            processed_batch["attention_mask"] = text_inputs["attention_mask"]
    return processed_batch

In [13]:
train_dataloader = DataLoader(train_data, shuffle=True, batch_size=3, collate_fn=collate_fn)

In [14]:
from transformers import Trainer, TrainingArguments
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')

In [None]:
optimizer = torch.optim.Adam(training_model.parameters(), lr=5e-4)
# loss = torch.nn.CosineEmbeddingLoss()
device = "cuda" if torch.cuda.is_available() else "cpu"

training_model.train()

for epoch in range(5):
  print("Epoch:", epoch)
  for idx, batch in enumerate(train_dataloader):
    input_ids = batch.pop("input_ids").to(device)
    pixel_values = batch.pop("pixel_values").to(device, torch.float32)
    outputs = training_model(input_ids=input_ids,
                    pixel_values=pixel_values,
                    labels=input_ids)
    
    loss = outputs.loss

    print("Loss:", loss.item())
      
    loss.backward()

    optimizer.step()
    optimizer.zero_grad()
  print(f"Epoch {epoch} Loss:", loss.item())

Epoch: 0
Loss: 2.2588915824890137
Loss: 1.9077609777450562
Loss: 2.089205503463745
Loss: 1.6575348377227783
Loss: 2.530764102935791
Loss: 1.530328631401062
Loss: 1.9788475036621094
Loss: 1.7268588542938232
Loss: 2.0601391792297363
Loss: 2.533592939376831
Loss: 1.9488040208816528
Loss: 1.663870096206665
Loss: 2.2682833671569824
Loss: 1.4550635814666748
Loss: 1.8368775844573975
Loss: 2.1807525157928467
Loss: 1.6045293807983398
Loss: 1.6954013109207153
Loss: 1.535036325454712
Loss: 1.625605821609497
Loss: 2.015024423599243
Loss: 1.7804661989212036
Loss: 2.141453504562378
Loss: 2.0120861530303955
Loss: 2.4234282970428467
Loss: 1.5865806341171265
Loss: 1.6444388628005981
Loss: 1.7841315269470215
Loss: 2.5273563861846924
Loss: 1.6347119808197021
Loss: 1.6874228715896606
Loss: 1.364664077758789
Loss: 2.335240125656128
Loss: 1.71576988697052
Loss: 2.1358940601348877
Loss: 2.124706745147705
Loss: 1.608621597290039
Loss: 2.097714424133301
Loss: 2.098231792449951
Loss: 2.0292553901672363
Loss: 2.

In [None]:
training_model.eval()

In [None]:
image = Image.open("images/1.png")
inputs = processor(images=image, return_tensors="pt").to(device, torch.float32)
pixel_values = inputs.pixel_values

generated_ids = training_model.generate(pixel_values=pixel_values, max_length=30)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
print(generated_caption)

In [None]:
torch.save(training_model.state_dict(), "model_v3.pt")