In [None]:
import os
import pandas as pd
import httpx
from bs4 import BeautifulSoup

import os
import pandas as pd
from torchvision.io import read_image
from torch.utils.data import Dataset, DataLoader

from transformers import AutoProcessor, Blip2ForConditionalGeneration
import torch


In [None]:
# data read in 
df = pd.read_csv('Contrastive.csv')

In [None]:
# download the image files themselves (the original ZIP FILE IS 25G so we'll scrape a small subset lol)

base_url = "https://uploads8.wikiart.org/images/"
stopper = 0
for item in df.iterrows():
  item_name_spl = item[1]['painting'].split('_')
  painter = item_name_spl[0]
  painting = item_name_spl[1]
  scrape_url = base_url + painter + '/' + painting + '.jpg!Large.jpg'

  with open('imgs/'+item[1]['painting']+'.jpeg','wb') as outf:
    img = httpx.get(scrape_url)
    outf.write(img.content)
  
  if stopper == 25:
    break
  stopper += 1

In [None]:
# define a dataset class
class ArtemisDataset(Dataset):
    def __init__(self, processor, labels_fname="Contrastive.csv", img_dirname="imgs/"):
        self.img_labels = pd.read_csv(labels_fname).iloc[:24]
        self.img_dirname = img_dirname
        self.processor = processor

    def __len__(self):
        return len(self.img_labels)

    def __getitem__(self, idx):
        painting_name = self.img_labels.iloc[idx]['painting'] + '.jpeg'
        img_path = os.path.join(self.img_dirname, painting_name)
        image = read_image(img_path)
        label = self.img_labels.iloc[idx]['utterance']
        emotion = self.img_labels.iloc[idx]['emotion']

        encoding = self.processor(image, padding="max_length", return_tensors="pt")
        # remove batch dimension
        encoding = {k: v.squeeze() for k, v in encoding.items()}
        encoding["input_ids"] = self.processor.tokenizer(label)["input_ids"]
        encoding["attention_mask"] = self.processor.tokenizer(label)["attention_mask"]
        encoding["emotion"] = emotion
  
        return encoding

In [None]:
# load models
processor = AutoProcessor.from_pretrained("Salesforce/blip2-opt-2.7b")
model = Blip2ForConditionalGeneration.from_pretrained("ybelkada/blip2-opt-2.7b-fp16-sharded", device_map="auto", load_in_8bit=True)

In [None]:
# lora for fting
from peft import LoraConfig, get_peft_model
lora_config = LoraConfig(
    r=16,
    lora_alpha=32,
    lora_dropout=0.05,
    bias="none",
    target_modules=["q_proj", "k_proj"]
)
model = get_peft_model(model, lora_config)

In [None]:
train_dataset = ArtemisDataset(processor=processor)
#train_dataloader = DataLoader(train_dataset, shuffle=True, batch_size=1)#, collate_fn=collate_fn)

In [None]:
# fine-tuning (currently buggy -- image to be converted to a PIL image contains values outside the range [0, 1],)
optimizer = torch.optim.Adam(model.parameters(), lr=5e-4)
device = "cuda" if torch.cuda.is_available() else "cpu"

model.train()
for epoch in range(2):
  try: # this one is bc some images aren't jpeg (?)
    for idx, item in enumerate(train_dataset):
      print('hello')
      try: # this one for debugging what the problem w/ model()
        inputs = processor(images=item["pixel_values"], text=item["input_ids"], return_tensors="pt").to(device, torch.float16)
        outputs = model(**inputs)
      except Exception as e:
        print(e)
      loss = outputs.loss
      print('loss: ',loss.item())
      loss.backward()
      optimizer.step()
      optimizer.zero_grad()
  except:
    pass