In [1]:
import torch
import os
import json
import requests
import numpy as np

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoProcessor, Pix2StructForConditionalGeneration

In [2]:
IMAGE_PATH = "data/images/"
QA_PATH = "data/qa/"

BATCH_SIZE = 4

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

In [3]:
processor = AutoProcessor.from_pretrained("google/matcha-chartqa")

In [4]:
class RealCQA(Dataset):
    def __init__(self, img_list) -> None:
        super().__init__()
        self.img_list = img_list
        
    def __len__(self):
        return len(self.img_list)
    

    def __getitem__(self, idx):
        item_id = self.img_list[idx][:-3]

        # Get image with following name
        image = Image.open(IMAGE_PATH + item_id + 'jpg')
        
        # Get corresponding json file
        with open(QA_PATH + item_id + 'json', encoding='utf8') as f:
            qa = json.load(f)
        
        # Since every image has a plethora of questions, select one from them randomly
        rnd_sample = np.random.randint(len(qa))

        # Take only question and corresponding answer from dict
        q, a = qa[rnd_sample]['question'], qa[rnd_sample]['answer']

        if isinstance(a, list):
            while isinstance(a[0], list):
                a = a[0]
            a = ', '.join([str(el) for el in a])
        
        elif isinstance(a, int) or isinstance(a, float):
            a = str(a)
        
        # Process images and correcponding questions
        inputs = processor(images=image, text=q, return_tensors="pt", max_patches=768).to(device)
        
        # Tokenize answers
        inputs['labels'] = processor.tokenizer.encode(a, return_tensors="pt", add_special_tokens=True, max_length=20, truncation=True, padding="max_length").to(device)
        
        return inputs

In [5]:
imgs_list = os.listdir(IMAGE_PATH)

In [6]:
ds = RealCQA(imgs_list)

In [8]:
dataloader = DataLoader(ds, batch_size=BATCH_SIZE, shuffle=True)

In [9]:
batch = next(iter(dataloader))

In [10]:
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-chartqa").to(device)

In [None]:
EPOCHS = 5000

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

model.to(device)

model.train()

for epoch in range(EPOCHS):
    print("Epoch:", epoch)
    for idx, batch in enumerate(dataloader):
      labels = batch.pop("labels").to(device).squeeze(1)
      flattened_patches = batch.pop("flattened_patches").to(device).squeeze(1)
      attention_mask = batch.pop("attention_mask").to(device).squeeze(1)

      outputs = model(flattened_patches=flattened_patches,
                      attention_mask=attention_mask,
                      labels=labels)
      
      loss = outputs.loss

      print("Loss:", loss.item())

      loss.backward()

      optimizer.step()
      optimizer.zero_grad()

    model.eval()

    predictions = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask)        
    print("Predictions:", processor.batch_decode(predictions, skip_special_tokens=True))
    print("Ground-truth:", processor.batch_decode(labels, skip_special_tokens=True))

    model.train()

In [18]:
batch = next(iter(dataloader))

In [20]:
labels = batch.pop("labels").to(device).squeeze(1)
flattened_patches = batch.pop("flattened_patches").to(device).squeeze(1)
attention_mask = batch.pop("attention_mask").to(device).squeeze(1)

In [21]:
model.eval()
predictions = model.generate(flattened_patches=flattened_patches, attention_mask=attention_mask)        
print("Predictions:", processor.batch_decode(predictions, skip_special_tokens=True))
print("Ground-truth:", processor.batch_decode(labels, skip_special_tokens=True))



Predictions: ['no', 'no']
Ground-truth: ['yes', 'no']
