In [1]:
import warnings
warnings.filterwarnings("ignore")

In [2]:
import torch
import os
import json
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from transformers.optimization import Adafactor, get_cosine_schedule_with_warmup
from sklearn.model_selection import train_test_split

from train import train_model

In [3]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [4]:
IMAGE_PATH = "data/images/"
QA_PATH = "data/merged_qa.json"

BATCH_SIZE = 4

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

RANDOM_SEED = 42

In [5]:
seed_everything(RANDOM_SEED)

In [6]:
processor = AutoProcessor.from_pretrained("google/matcha-chartqa")

In [7]:
class RealCQA(Dataset):
    def __init__(self, img_list, qa_json) -> None:
        super().__init__()
        self.img_list = img_list
        self.qa_json = qa_json
        
    def __len__(self):
        return len(self.img_list)
    

    def __getitem__(self, idx):
        item_id = self.img_list[idx][:-4]

        # Get image with following name
        image = Image.open(IMAGE_PATH + item_id + '.jpg')
        
        # Get corresponding information from json file
        qa = self.qa_json[item_id]
        
        # Since every image has a plethora of questions, select one from them randomly
        rnd_sample = np.random.randint(len(qa))

        # Take only question and corresponding answer from dict
        q, a = qa[rnd_sample]['question'], qa[rnd_sample]['answer']

        if isinstance(a, list):
            while isinstance(a[0], list):
                a = a[0]
            a = ', '.join([str(el) for el in a])
        
        # Process images and correcponding questions
        inputs = processor(images=image, text=q, return_tensors="pt", max_patches=768).to(DEVICE)
        
        # Tokenize answers
        inputs['labels'] = str(a)
        
        return inputs

In [8]:
imgs_list = os.listdir(IMAGE_PATH)

train_imgs, test_imgs = train_test_split(imgs_list, test_size=0.15)

In [9]:
with open(QA_PATH, "r") as f:
    qa_json = json.load(f)

In [10]:
train_ds = RealCQA(train_imgs, qa_json)
test_ds = RealCQA(test_imgs, qa_json)

In [11]:
def collator(batch):
  new_batch = {"flattened_patches":[], "attention_mask":[]}
  
  labels = [item['labels'] for item in batch]
  new_batch["labels"] = processor.tokenizer.batch_encode_plus(labels, return_tensors="pt", add_special_tokens=True, max_length=20, truncation=True, padding="max_length").to(DEVICE)['input_ids']
  
  for item in batch:
    new_batch["flattened_patches"].append(item["flattened_patches"])
    new_batch["attention_mask"].append(item["attention_mask"])
  
  new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
  new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

  return new_batch

In [12]:
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True, collate_fn=collator)
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False, collate_fn=collator)

In [13]:
matcha_chartqa = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-chartqa").to(DEVICE)

In [14]:
optimizer = torch.optim.AdamW(matcha_chartqa.parameters(), lr=1e-4)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=120, num_training_steps=600)

In [15]:
train_model(model=matcha_chartqa,
            optimizer=optimizer,
            train_dl=train_dl,
            test_dl=test_dl,
            num_epochs=1,
            processor=processor,
            device=DEVICE,
            scheduler=scheduler,
            neptune_tracking=True,
            model_name="matcha-chartqa"
            )

https://app.neptune.ai/bng215/Model-Collapse/e/TRAN-1004


Epoch: 1: Val stage: 100%|██████████| 4240/4240 [18:37<00:00,  3.79it/s]s]
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [18:23<00:00,  3.84it/s]/s]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [19:34<00:00,  3.61it/s]/s]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [20:24<00:00,  3.46it/s]it/s]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [18:47<00:00,  3.76it/s]s/it]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [18:27<00:00,  3.83it/s]it/s]   
Epoch: 1: Train stage:  11%|█         | 656/6007 [2:05:16<17:01:49, 11.46s/it]  


KeyboardInterrupt: 

In [17]:
matcha_base = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-base").to(DEVICE)

config.json:   0%|          | 0.00/4.89k [00:00<?, ?B/s]

pytorch_model.bin:   0%|          | 0.00/1.13G [00:00<?, ?B/s]

In [18]:
optimizer = torch.optim.AdamW(matcha_base.parameters(), lr=1e-4)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=120, num_training_steps=600)

In [19]:
train_model(model=matcha_base,
            optimizer=optimizer,
            train_dl=train_dl,
            test_dl=test_dl,
            num_epochs=1,
            processor=processor,
            device=DEVICE,
            scheduler=scheduler,
            neptune_tracking=True,
            model_name="matcha-base"
            )

https://app.neptune.ai/bng215/Model-Collapse/e/TRAN-1005


Epoch: 1: Val stage: 100%|██████████| 4240/4240 [19:39<00:00,  3.59it/s]s]
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [19:37<00:00,  3.60it/s]it]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [19:26<00:00,  3.63it/s]/s]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [20:02<00:00,  3.53it/s]s/it]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [19:30<00:00,  3.62it/s]it/s]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [19:51<00:00,  3.56it/s]s/it]   
Epoch: 1: Val stage:  18%|█▊        | 780/4240 [03:56<17:28,  3.30it/s]1s/it]   
Epoch: 1: Train stage:  12%|█▏        | 699/6007 [2:14:32<17:01:40, 11.55s/it]


KeyboardInterrupt: 

In [13]:
matcha_plotqa = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-plotqa-v1").to(DEVICE)

pytorch_model.bin:  17%|#6        | 189M/1.13G [00:00<?, ?B/s]

In [14]:
optimizer = torch.optim.AdamW(matcha_plotqa.parameters(), lr=1e-4)
scheduler = get_cosine_schedule_with_warmup(optimizer, num_warmup_steps=120, num_training_steps=600)

In [16]:
train_model(model=matcha_plotqa,
            optimizer=optimizer,
            train_dl=train_dl,
            test_dl=test_dl,
            num_epochs=1,
            processor=processor,
            device=DEVICE,
            scheduler=scheduler,
            neptune_tracking=True,
            model_name="matcha-plotqa"
            )

https://app.neptune.ai/bng215/Model-Collapse/e/TRAN-1006


Epoch: 1: Val stage: 100%|██████████| 4240/4240 [19:48<00:00,  3.57it/s]t]
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [20:09<00:00,  3.51it/s]/s]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [18:24<00:00,  3.84it/s]it]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [18:41<00:00,  3.78it/s]it/s]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [18:53<00:00,  3.74it/s]it/s]   
Epoch: 1: Val stage: 100%|██████████| 4240/4240 [18:56<00:00,  3.73it/s]it/s]   
Epoch: 1: Val stage:  88%|████████▊ | 3723/4240 [16:17<02:15,  3.81it/s]it/s]   
Epoch: 1: Train stage:  12%|█▏        | 699/6007 [2:23:22<18:08:46, 12.31s/it]


KeyboardInterrupt: 