In [1]:
import torch
import os
import json
import random
import numpy as np

from tqdm import tqdm
from torch.utils.data import Dataset, DataLoader
from PIL import Image
from transformers import AutoProcessor, Pix2StructForConditionalGeneration
from sklearn.model_selection import train_test_split

from train import train_model

In [2]:
def seed_everything(seed=42):
    random.seed(seed)
    np.random.seed(seed)
    torch.manual_seed(seed)
    torch.cuda.manual_seed_all(seed)
    torch.backends.cudnn.deterministic = True
    torch.backends.cudnn.benchmark = False

In [3]:
IMAGE_PATH = "data/images/"
QA_PATH = "data/qa/"

BATCH_SIZE = 4

DEVICE = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

RANDOM_SEED = 42

In [4]:
seed_everything(RANDOM_SEED)

In [5]:
processor = AutoProcessor.from_pretrained("google/matcha-chartqa")

In [6]:
class RealCQA(Dataset):
    def __init__(self, img_list) -> None:
        super().__init__()
        self.img_list = img_list
        
    def __len__(self):
        return len(self.img_list)
    

    def __getitem__(self, idx):
        item_id = self.img_list[idx][:-3]

        # Get image with following name
        image = Image.open(IMAGE_PATH + item_id + 'jpg')
        
        # Get corresponding json file
        with open(QA_PATH + item_id + 'json', encoding='utf8') as f:
            qa = json.load(f)
        
        # Since every image has a plethora of questions, select one from them randomly
        rnd_sample = np.random.randint(len(qa))

        # Take only question and corresponding answer from dict
        q, a = qa[rnd_sample]['question'], qa[rnd_sample]['answer']

        if isinstance(a, list):
            while isinstance(a[0], list):
                a = a[0]
            a = ', '.join([str(el) for el in a])
        
        elif isinstance(a, int) or isinstance(a, float):
            a = str(a)
        
        # Process images and correcponding questions
        inputs = processor(images=image, text=q, return_tensors="pt", max_patches=768).to(DEVICE)
        
        # Tokenize answers
        inputs['labels'] = processor.tokenizer.encode(a, return_tensors="pt", add_special_tokens=True, max_length=20, truncation=True, padding="max_length").to(DEVICE)
        
        return inputs

In [7]:
imgs_list = os.listdir(IMAGE_PATH)

train_imgs, test_imgs = train_test_split(imgs_list, test_size=0.3)

In [8]:
train_ds = RealCQA(train_imgs)
test_ds = RealCQA(test_imgs)

In [9]:
train_dl = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True)
test_dl = DataLoader(test_ds, batch_size=1, shuffle=False)

In [10]:
model = Pix2StructForConditionalGeneration.from_pretrained("google/matcha-chartqa").to(DEVICE)

In [11]:
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-5)

In [12]:
train_model(model=model,
            optimizer=optimizer,
            train_dl=train_dl,
            test_dl=test_dl,
            num_epochs=100,
            processor=processor,
            device=DEVICE,
            scheduler=None,
            neptune_tracking=True
            )



https://app.neptune.ai/bng215/Model-Collapse/e/TRAN-981


Epoch: 1: Train stage:   0%|          | 0/4947 [00:01<?, ?it/s]


Predictions: transmission loss r7(dB)
Ground-truth: Line chart
Line chart
transmission loss r7(dB)
Predictions: [Time, µs, 0.05]
Ground-truth: Line chart
Line chart
[Time, µs, 0.05]
Predictions: [PH, PH+3, PH+4, PH+5]
Ground-truth: Line chart
Line chart
[PH, PH+3, PH+4, PH+5]
Line chart
[Moo, ZrMoO, SnMoO]
Line chart
[N-doped NIO, Non-doped NIO]
Scatter chart
[D1, D21]
yes
75
yes
No
Vertical
0.01




Simplified
Communication
Line chart
0.5
no
No
1.363961841841517
0.215277778
Line chart
[Ag (40 nm) + ZnO (18 nm)]
Vertical Boxplots
4
Line chart
[Leaf height (cm) | Millenium | Maximum]
Line chart
7
Vertical Bar chart
[Runs, Conversion]
Line chart
[Flood risk reduction, Flood damage function of Condition 1, Flood damage function of Condition
no
No
no
No




yes
No
yes
No
Temperature of explosive graphitization
10000/7, 1E-5
no
Yes
Line chart
[NeQ-50-5, NeQ-100-5]
Line chart
[Wavelength (nm)]
Line chart
[VITO CASE, VITO CORE]
yes
No




no
No
Line chart
[L/500, L/660]
Line chart
[Neat, Neat]
Line chart
[Chart, Type of chart]
no
Yes
Line chart
[1 - Specificity, 1 - Specificity]
yes
Yes




no
29
Scatter chart
[PCA factor 1, PCA factor 2]
no
No
Line chart
[Anatase, Anatase_]
Line chart
[Plasma frequency, Plasma frequency]
Vertical Bar chart
[Wild-type cells IgG, Wild-type cells 1ug Ab, As
Vertical Boxplots
400
Line chart
[Fe,0,@APPTES]
no
No
Linear
No
Vertical Bar chart
[Hypertension, Female]
0.0
1600
Vertical Boxplots
Ashcroft score
Line chart
[Nine, T16]
no
No
no
Yes
no
1.6
no
No
Line chart
[Cubic-C, A, Sub-C]
no
Yes
Line chart
[Indent 1 on Matic, Indent 2 on Matic]
3.903225806451538
0.214588889
3
2
Line chart
[Energy [keV], Illual]




no
25
Line chart
[Ti, Cross-section depth (um)]
no
No
Line chart
0.100
no
No




0.8831635710005992
0.433333333




no
Yes




0.2476242193863698
0.05
Line chart
100
Line chart
[1, 20]
Line chart
[Balancing, Bounce, OFFS way]
Line chart
[As-blended, Heat-annealed]
Line chart
[Time, The forecasted value]




yes
15
no
Yes
no
No
Vertical Bar chart
[iso31, Alks]
Line chart
[Ti-21, Ti-6]
yes
No




no
No
Line chart
[Length-for-age BOYS, Length-for-age BOYS]
Vertical
0.6
Available, Used
10
Line chart
[1, 2]
no
Yes
Line chart
[Temperature (°C), Temperature (°C)]
Average Surface Roughness (um)
1.05
Line chart
[Distance from core haplotype (Mbp)]
Scatter chart
[Charter, Measure Haines, Charter 2, Charter 3]
Line chart
[Cholera, OCC]
Vertical Bar chart
[Medium, High]
Vertical
76
6
4
Vertical Boxplots
[C, M, U]
Line chart
[1, 2, 3]
no
Yes
Line chart
[SELFA2, SELFA3]




no
2.4
Vertical
4
Line chart
[Composite 1, Composite 2]
Vertical Bar chart
[Stunting, Underweight, Wasting]
Line chart
[PTG-L8 B, PTG-X10 B]
Linear
No
no
No
Line chart
[B, C]
Linear
No
Vertical Bar chart
100
Line chart
[Multiples, Singletons]
Line chart
[20, 30]
Line chart
[BS882, PS89]
Line chart
[2e1, 211]
Line chart
[0, 10000, 20000]
Line chart
[1, 2]




t_{1}, t_{2}, t_{3}, 
2500
Line chart
0.5
Line chart
[Bastine, Week 12, Week 24]
yes
No
yes
Yes
Scatter chart
0
Scatter chart
[Worries (CW8 sum score), Worries (CW8 sum score
yes
Yes
Line chart
[Leaqua, Miro]
Line chart
[Follow-up in years, Cum-Survival]
Vertical Bar chart
[Anaemia, Mild anaemia]
Line chart
[PPTIN of S1, PPTIN of S2]
Vertical
4
Line chart
[BC-RU, BC-RU+BC [valDin]],[BC-RU
Line chart
[n+1, 10]
Line chart
[HEA-H, HEA-H]
Line chart
[1-Specificity, 1-Specificity]
Vertical Bar chart
[FMRI, TCD]
Line chart
[Brazil, Australia, Algeria, Mexico, United Kingdom]
Line chart
[Age at puberty (AGECL - Corrected)]
Line chart
[10, 40]
no
No
no
Yes
\Delta AUC (%)
141
yes
Yes
Line chart
[Loadings, Loadings]
Line chart
[DAM, DAM vs. Vs. Spread]
Horizontal Bar chart
[UP,DOWN]




no
No
Line chart
[Mother died in 1st year, Mother did not die in 1st year]
yes
Yes
Line chart
[AS, YL]
Line chart
Engineering Strain
Line chart
17
no
Yes
Line chart
[JSON, JSON (GZip)]


KeyboardInterrupt: 