### Imports

In [1]:
import os
import sys
import json 
import glob
import gc
import torch

import pandas as pd
import numpy as np

from PIL import Image
from pathlib import Path

In [2]:
sys.path.insert(0, '/kaggle/input/omegaconf')
from omegaconf import OmegaConf

In [3]:
from torch.utils.data import DataLoader
from accelerate import Accelerator
from transformers import GenerationConfig
from tqdm.auto import tqdm



# Config

In [4]:
%%writefile mga_cfg_deplot.yaml
model:
    backbone_path: /kaggle/input/google-deplot
    max_length: 8
    max_patches: 3072
    patch_size: 16
    len_tokenizer: ???
    pad_token_id: ???
    decoder_start_token_id: ???
    bos_token_id: ???
        
predict_params:
    bs: 1
    checkpoint_path: /kaggle/input/mga-r-final-matcha-v2-ft-i2/mga_model_fold_0.pth.tar

competition_dataset:
    data_dir: /kaggle/input/benetech-making-graphs-accessible

Writing mga_cfg_deplot.yaml


In [5]:
cfg = OmegaConf.load('mga_cfg_deplot.yaml')
cfg

{'model': {'backbone_path': '/kaggle/input/google-deplot', 'max_length': 8, 'max_patches': 3072, 'patch_size': 16, 'len_tokenizer': '???', 'pad_token_id': '???', 'decoder_start_token_id': '???', 'bos_token_id': '???'}, 'predict_params': {'bs': 1, 'checkpoint_path': '/kaggle/input/mga-r-final-matcha-v2-ft-i2/mga_model_fold_0.pth.tar'}, 'competition_dataset': {'data_dir': '/kaggle/input/benetech-making-graphs-accessible'}}

# Dataset

In [6]:
# Reference: https://www.kaggle.com/code/nbroad/donut-train-benetech

import json
import os
import pdb
from copy import deepcopy

import torch
from PIL import Image
from tokenizers import AddedToken
from torch.utils.data import Dataset
from transformers import Pix2StructProcessor

# -- token map --#
TOKEN_MAP = {
    "line": "[<lines>]",
    "vertical_bar": "[<vertical_bar>]",
    "scatter": "[<scatter>]",
    "dot": "[<dot>]",
    "horizontal_bar": "[<horizontal_bar>]",
    "histogram": "[<histogram>]",

    "c_start": "[<c_start>]",
    "c_end": "[<c_end>]",
    "x_start": "[<x_start>]",
    "x_end": "[<x_end>]",
    "y_start": "[<y_start>]",
    "y_end": "[<y_end>]",

    "p_start": "[<p_start>]",
    "p_end": "[<p_end>]",

    "bos_token": "[<mga>]",
}


# -----

def is_nan(val):
    return val != val


def get_processor(cfg):
    """
    load the processor
    """
    processor_path = cfg.model.backbone_path
    print(f"loading processor from {processor_path}")
    processor = Pix2StructProcessor.from_pretrained(processor_path)
    processor.image_processor.is_vqa = False
    processor.image_processor.patch_size = {
        "height": cfg.model.patch_size,
        "width": cfg.model.patch_size
    }

    # NEW TOKENS
    print("adding new tokens...")
    new_tokens = []
    for _, this_tok in TOKEN_MAP.items():
        new_tokens.append(this_tok)
    new_tokens = sorted(new_tokens)

    tokens_to_add = []
    for this_tok in new_tokens:
        tokens_to_add.append(AddedToken(this_tok, lstrip=False, rstrip=False))

    processor.tokenizer.add_tokens(tokens_to_add)

    return processor


class MGADataset(Dataset):
    """Dataset class for MGA dataset
    """

    def __init__(self, cfg, graph_ids, transform=None):

        self.cfg = cfg
        self.data_dir = cfg.competition_dataset.data_dir.rstrip("/")
        self.image_dir = os.path.join(self.data_dir, "test", "images")

        self.graph_ids = deepcopy(graph_ids)
        self.transform = transform

        # load processor
        self.load_processor()

    def load_processor(self):
        self.processor = get_processor(self.cfg)

    def load_image(self, graph_id):
        image_path = os.path.join(self.image_dir, f"{graph_id}.jpg")
        image = Image.open(image_path)
        if image.mode != 'RGB':
            image = image.convert('RGB')

        return image

    def __len__(self):
        return len(self.graph_ids)

    def __getitem__(self, index):
        graph_id = self.graph_ids[index]
        image = self.load_image(graph_id)

        p_img = self.processor(
            images=image,
            max_patches=self.cfg.model.max_patches,
            add_special_tokens=True,
        )

        r = {}
        r['id'] = graph_id
        r['image'] = image
        r['flattened_patches'] = p_img['flattened_patches']
        r['attention_mask'] = p_img['attention_mask']

        return r

# Data Loader

In [7]:
import pdb
from copy import deepcopy
from dataclasses import dataclass

import numpy as np
import torch
from transformers import DataCollatorWithPadding


@dataclass
class MGACollator(DataCollatorWithPadding):
    """
    data collector for mga task
    """

    tokenizer = None
    padding = True
    max_length = None
    pad_to_multiple_of = None
    return_tensors = "pt"

    def __call__(self, features):
        batch = dict()

        # graph ids
        batch["id"] = [feature["id"] for feature in features]
        batch["images"] = [feature["image"] for feature in features]

        # image features ---
        flattened_patches = [feature["flattened_patches"] for feature in features]
        attention_mask = [feature["attention_mask"] for feature in features]

        flattened_patches = np.concatenate(flattened_patches, axis=0)
        attention_mask = np.concatenate(attention_mask, axis=0)

        batch["flattened_patches"] = flattened_patches
        batch["attention_mask"] = attention_mask

        # casting ---
        tensor_keys = ["flattened_patches", "attention_mask"]
        
        for key in tensor_keys:
            if key != "flattened_patches":
                batch[key] = torch.tensor(batch[key], dtype=torch.int64)
            else:
                batch[key] = torch.tensor(batch[key], dtype=torch.float32)

        return batch


# Model

In [8]:
import sys
from pathlib import Path

import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.checkpoint
from transformers import (AutoConfig, AutoModel, Pix2StructConfig,
                          Pix2StructForConditionalGeneration)


def get_model(cfg):
    model = Pix2StructForConditionalGeneration.from_pretrained(cfg.model.backbone_path)
    return model


class MGAModel(nn.Module):
    """
    The MGA model
    """

    def __init__(self, cfg):
        print("initializing the MGA model...")

        super(MGAModel, self).__init__()
        self.cfg = cfg

        backbone_config = Pix2StructConfig.from_pretrained(cfg.model.backbone_path)
        backbone_config.text_config.max_length = cfg.model.max_length
        backbone_config.text_config.is_decoder = True

        backbone_config.text_config.pad_token_id = cfg.model.pad_token_id
        backbone_config.text_config.decoder_start_token_id = cfg.model.decoder_start_token_id
        backbone_config.text_config.bos_token_id = cfg.model.bos_token_id

        # backbone_config.decoder.max_length = cfg.model.max_length

        self.backbone = Pix2StructForConditionalGeneration.from_pretrained(
            cfg.model.backbone_path,
            config=backbone_config,
        )

        # resize model embeddings
        print("resizing model embeddings...")
        print(f"tokenizer length = {cfg.model.len_tokenizer}")
        self.backbone.decoder.resize_token_embeddings(cfg.model.len_tokenizer)

        self.loss_fn = nn.CrossEntropyLoss(
            ignore_index=-100,
            reduction="mean",
        )

    def forward(
            self,
            flattened_patches,
            attention_mask,
            labels,
    ):

        outputs = self.backbone(
            flattened_patches=flattened_patches,
            attention_mask=attention_mask,
            labels=labels,
        )

        loss_main = outputs.loss

        # logits = outputs.logits  # [:, 0]  # (bs, vocab)
        # labels_cls = labels[:, 0]  # (bs, )
        # loss = self.loss_fn(logits, labels)

        # logits_num = outputs.logits[:, 1]  # (bs, vocab)
        # labels_num = labels[:, 1]  # (bs, )
        # loss_num = self.loss_fn(logits_num, labels_num)

        loss = loss_main  # + 0.10 * loss_cls

        loss_dict = {
            "loss_main": loss_main,
            "loss_cls": loss,
        }

        return loss, loss_dict


# Inference

In [9]:
def post_process(pred_string, token_map, delimiter="|"):
    # get chart type ---
    chart_options = [
        "horizontal_bar",
        "dot",
        "scatter",
        "vertical_bar",
        "line",
        "histogram",
    ]

    chart_type = "line"  # default type

    for ct in chart_options:
        if token_map[ct] in pred_string:
            chart_type = ct
            break

    if chart_type == "histogram":
        chart_type = "vertical_bar"

    # get x series ---
    x_start_tok = token_map["x_start"]
    x_end_tok = token_map["x_end"]

    try:
        x = pred_string.split(x_start_tok)[1].split(x_end_tok)[0].split(delimiter)
        x = [elem.strip() for elem in x if len(elem.strip()) > 0]
    except IndexError:
        x = []

    # get y series ---
    y_start_tok = token_map["y_start"]
    y_end_tok = token_map["y_end"]

    try:
        y = pred_string.split(y_start_tok)[1].split(y_end_tok)[0].split(delimiter)
        y = [elem.strip() for elem in y if len(elem.strip()) > 0]
    except IndexError:
        y = []

    return chart_type, x, y


In [10]:
def run_evaluation(model, infer_dl, tokenizer, token_map, max_len):

    # # config for text generation ---
    conf_g = {
        "max_new_tokens": max_len, # 1024, # 512,
        "do_sample": False,
        "top_k": 1,
        "use_cache": True,
    }

    generation_config = GenerationConfig(**conf_g)

    # put model in eval mode ---
    model.eval()

    all_ids = []
    all_texts = []

    progress_bar = tqdm(range(len(infer_dl)))
    for batch in infer_dl:
        with torch.no_grad():
            batch_ids = batch["id"]
            generated_ids = model.backbone.generate(
                flattened_patches=batch['flattened_patches'],
                attention_mask=batch['attention_mask'],
                generation_config=generation_config,
            )
            generated_texts = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)

            all_ids.extend(batch_ids)
            all_texts.extend(generated_texts)

        progress_bar.update(1)
    progress_bar.close()

    # prepare output dataframe ---
    preds = []
    extended_preds = []
    for this_id, this_text in zip(all_ids, all_texts):
        id_x = f"{this_id}_x"
        id_y = f"{this_id}_y"
        pred_chart, pred_x, pred_y = post_process(this_text, token_map)

        preds.append([id_x, pred_x, pred_chart])
        preds.append([id_y, pred_y, pred_chart])

        extended_preds.append([id_x, pred_x, pred_chart, this_text])
        extended_preds.append([id_y, pred_y, pred_chart, this_text])

    pred_df = pd.DataFrame(preds)
    pred_df.columns = ["id", "data_series", "chart_type"]

    return pred_df

### Inference

In [11]:
image_dir = Path('/kaggle/input/benetech-making-graphs-accessible/test/images')
image_paths = list(image_dir.glob("*.jpg"))
image_ids = [str(pth).split("/")[-1].split(".jpg")[0] for pth in image_paths]
print(image_ids[:5])

infer_ds = MGADataset(cfg, image_ids)

tokenizer = infer_ds.processor.tokenizer
cfg.model.len_tokenizer = len(tokenizer)

cfg.model.pad_token_id = tokenizer.pad_token_id
BOS_TOKEN = TOKEN_MAP["bos_token"]

cfg.model.decoder_start_token_id = tokenizer.convert_tokens_to_ids([BOS_TOKEN])[0]
cfg.model.bos_token_id = tokenizer.convert_tokens_to_ids([BOS_TOKEN])[0]

collate_fn = MGACollator(tokenizer=tokenizer)

infer_dl = DataLoader(
    infer_ds,
    batch_size=cfg.predict_params.bs,
    collate_fn=collate_fn,
    shuffle=False,
)

cfg_dict = OmegaConf.to_container(cfg, resolve=True)

model = MGAModel(cfg)

print("=="*50)
checkpoint_path = cfg.predict_params.checkpoint_path
print(f"loading model from checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path)
model.load_state_dict(ckpt['state_dict'])
del ckpt
gc.collect()
print("loaded!")

print("accelerator setup...")
accelerator = Accelerator()  # cpu = True

model, infer_dl = accelerator.prepare(model, infer_dl)

model.eval()
print("model is in eval mode")
print("=="*50)

pred_df = run_evaluation(model, infer_dl, tokenizer, TOKEN_MAP,  max_len=cfg.model.max_length)

pred_df["data_series"] = pred_df["data_series"].apply(lambda x: ";".join(x))
pred_df.head()

['000b92c3b098', '01b45b831589', '00f5404753cf', '00dcf883a459', '007a18eb4e09']
loading processor from /kaggle/input/google-deplot
adding new tokens...
initializing the MGA model...
resizing model embeddings...
tokenizer length = 50359
loading model from checkpoint: /kaggle/input/mga-r-final-matcha-v2-ft-i2/mga_model_fold_0.pth.tar
loaded!
accelerator setup...
model is in eval mode


  0%|          | 0/5 [00:00<?, ?it/s]

Unnamed: 0,id,data_series,chart_type
0,000b92c3b098_x,,line
1,000b92c3b098_y,,line
2,01b45b831589_x,,vertical_bar
3,01b45b831589_y,,vertical_bar
4,00f5404753cf_x,,scatter


In [12]:
del model
gc.collect()
torch.cuda.empty_cache()

# SPLIT

In [13]:
scatter_ids = pred_df[pred_df['chart_type'] == 'scatter']['id'].values.tolist()
scatter_ids = [this_id.split('_x')[0].split('_y')[0] for this_id in scatter_ids]
scatter_ids = list(set(scatter_ids))

other_ids = pred_df[pred_df['chart_type'] != 'scatter']['id'].values.tolist()
other_ids = [this_id.split('_x')[0].split('_y')[0] for this_id in other_ids]
other_ids = list(set(other_ids))

In [14]:
pred_df_v0 = deepcopy(pred_df)

# Pass 2 - Other

In [15]:
%%writefile mga_cfg_deplot.yaml
model:
    backbone_path: /kaggle/input/google-deplot
    max_length: 1024
    max_patches: 4096
    patch_size: 16
    len_tokenizer: ???
    pad_token_id: ???
    decoder_start_token_id: ???
    bos_token_id: ???
        
predict_params:
    bs: 1
    checkpoint_path: /kaggle/input/mga-r-rest-v2/mga_model_fold_0.pth.tar

competition_dataset:
    data_dir: /kaggle/input/benetech-making-graphs-accessible

Overwriting mga_cfg_deplot.yaml


In [16]:
cfg = OmegaConf.load('mga_cfg_deplot.yaml')
cfg

{'model': {'backbone_path': '/kaggle/input/google-deplot', 'max_length': 1024, 'max_patches': 4096, 'patch_size': 16, 'len_tokenizer': '???', 'pad_token_id': '???', 'decoder_start_token_id': '???', 'bos_token_id': '???'}, 'predict_params': {'bs': 1, 'checkpoint_path': '/kaggle/input/mga-r-rest-v2/mga_model_fold_0.pth.tar'}, 'competition_dataset': {'data_dir': '/kaggle/input/benetech-making-graphs-accessible'}}

In [17]:
image_ids = deepcopy(other_ids)
infer_ds = MGADataset(cfg, image_ids)

tokenizer = infer_ds.processor.tokenizer
cfg.model.len_tokenizer = len(tokenizer)

cfg.model.pad_token_id = tokenizer.pad_token_id
BOS_TOKEN = TOKEN_MAP["bos_token"]

cfg.model.decoder_start_token_id = tokenizer.convert_tokens_to_ids([BOS_TOKEN])[0]
cfg.model.bos_token_id = tokenizer.convert_tokens_to_ids([BOS_TOKEN])[0]

collate_fn = MGACollator(tokenizer=tokenizer)

infer_dl = DataLoader(
    infer_ds,
    batch_size=cfg.predict_params.bs,
    collate_fn=collate_fn,
    shuffle=False,
)

cfg_dict = OmegaConf.to_container(cfg, resolve=True)

model = MGAModel(cfg)

print("=="*50)
checkpoint_path = cfg.predict_params.checkpoint_path
print(f"loading model from checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path)
model.load_state_dict(ckpt['state_dict'])
del ckpt
gc.collect()
print("loaded!")

print("accelerator setup...")
accelerator = Accelerator()  # cpu = True

model, infer_dl = accelerator.prepare(model, infer_dl)

model.eval()
print("model is in eval mode")
print("=="*50)

pred_df = run_evaluation(model, infer_dl, tokenizer, TOKEN_MAP,  max_len=cfg.model.max_length)

pred_df["data_series"] = pred_df["data_series"].apply(lambda x: ";".join(x))
pred_df.head()

loading processor from /kaggle/input/google-deplot
adding new tokens...
initializing the MGA model...
resizing model embeddings...
tokenizer length = 50359
loading model from checkpoint: /kaggle/input/mga-r-rest-v2/mga_model_fold_0.pth.tar
loaded!
accelerator setup...
model is in eval mode


  0%|          | 0/4 [00:00<?, ?it/s]

Unnamed: 0,id,data_series,chart_type
0,000b92c3b098_x,0;6;12;18;24,line
1,000b92c3b098_y,1.03e-02;-6.82e-01;-1.38e+00;-2.08e+00;-2.74e+00,line
2,00dcf883a459_x,Group 1;Group 2,vertical_bar
3,00dcf883a459_y,3.60e+00;8.40e+00,vertical_bar
4,007a18eb4e09_x,0.0;0.4;0.8;1.2;1.6;2.0;2.4,line


In [18]:
pred_df_v1_p1 = deepcopy(pred_df)

In [19]:
del model
gc.collect()
torch.cuda.empty_cache()

# Pass 2 - Scatter

In [20]:
%%writefile mga_cfg_deplot.yaml
model:
    backbone_path: /kaggle/input/google-deplot
    max_length: 1280
    max_patches: 3072
    patch_size: 16
    len_tokenizer: ???
    pad_token_id: ???
    decoder_start_token_id: ???
    bos_token_id: ???
        
predict_params:
    bs: 1
    checkpoint_path: /kaggle/input/mga-r-scatter-matcha/mga_model_fold_0.pth.tar

competition_dataset:
    data_dir: /kaggle/input/benetech-making-graphs-accessible

Overwriting mga_cfg_deplot.yaml


In [21]:
cfg = OmegaConf.load('mga_cfg_deplot.yaml')
cfg

{'model': {'backbone_path': '/kaggle/input/google-deplot', 'max_length': 1280, 'max_patches': 3072, 'patch_size': 16, 'len_tokenizer': '???', 'pad_token_id': '???', 'decoder_start_token_id': '???', 'bos_token_id': '???'}, 'predict_params': {'bs': 1, 'checkpoint_path': '/kaggle/input/mga-r-scatter-matcha/mga_model_fold_0.pth.tar'}, 'competition_dataset': {'data_dir': '/kaggle/input/benetech-making-graphs-accessible'}}

In [22]:
image_ids = deepcopy(scatter_ids)
infer_ds = MGADataset(cfg, image_ids)

tokenizer = infer_ds.processor.tokenizer
cfg.model.len_tokenizer = len(tokenizer)

cfg.model.pad_token_id = tokenizer.pad_token_id
BOS_TOKEN = TOKEN_MAP["bos_token"]

cfg.model.decoder_start_token_id = tokenizer.convert_tokens_to_ids([BOS_TOKEN])[0]
cfg.model.bos_token_id = tokenizer.convert_tokens_to_ids([BOS_TOKEN])[0]

collate_fn = MGACollator(tokenizer=tokenizer)

infer_dl = DataLoader(
    infer_ds,
    batch_size=cfg.predict_params.bs,
    collate_fn=collate_fn,
    shuffle=False,
)

cfg_dict = OmegaConf.to_container(cfg, resolve=True)

model = MGAModel(cfg)

print("=="*50)
checkpoint_path = cfg.predict_params.checkpoint_path
print(f"loading model from checkpoint: {checkpoint_path}")
ckpt = torch.load(checkpoint_path)
model.load_state_dict(ckpt['state_dict'])
del ckpt
gc.collect()
print("loaded!")

print("accelerator setup...")
accelerator = Accelerator()  # cpu = True

model, infer_dl = accelerator.prepare(model, infer_dl)

model.eval()
print("model is in eval mode")
print("=="*50)

pred_df = run_evaluation(model, infer_dl, tokenizer, TOKEN_MAP,  max_len=cfg.model.max_length)

pred_df["data_series"] = pred_df["data_series"].apply(lambda x: ";".join(x))
pred_df.head()

loading processor from /kaggle/input/google-deplot
adding new tokens...
initializing the MGA model...
resizing model embeddings...
tokenizer length = 50359
loading model from checkpoint: /kaggle/input/mga-r-scatter-matcha/mga_model_fold_0.pth.tar
loaded!
accelerator setup...
model is in eval mode


  0%|          | 0/1 [00:00<?, ?it/s]

Unnamed: 0,id,data_series,chart_type
0,00f5404753cf_x,5.00e+00;5.00e+00;5.00e+00;6.00e+00;6.00e+00;6...,scatter
1,00f5404753cf_y,1.10e+01;1.20e+01;1.40e+01;1.20e+01;1.30e+01;1...,scatter


In [23]:
pred_df_v1_p2 = deepcopy(pred_df)

# Merge

In [24]:
pred_df = pd.concat([pred_df_v1_p1, pred_df_v1_p2])
pred_df = pred_df.reset_index(drop=True)
pred_df.head()

Unnamed: 0,id,data_series,chart_type
0,000b92c3b098_x,0;6;12;18;24,line
1,000b92c3b098_y,1.03e-02;-6.82e-01;-1.38e+00;-2.08e+00;-2.74e+00,line
2,00dcf883a459_x,Group 1;Group 2,vertical_bar
3,00dcf883a459_y,3.60e+00;8.40e+00,vertical_bar
4,007a18eb4e09_x,0.0;0.4;0.8;1.2;1.6;2.0;2.4,line


In [25]:
pred_df_final = pd.merge(pred_df_v0[['id', 'chart_type']], pred_df[['id', 'data_series']], on='id', how='left')
pred_df_final['data_series'] = pred_df_final['data_series'].fillna('')

In [26]:
pred_df_final = pred_df_final[["id", "data_series", "chart_type"]].copy()
pred_df_final = pred_df_final.reset_index(drop=True)

In [27]:
pred_df_final.to_csv("submission.csv", index=False)

In [28]:
pred_df_final.head(10)

Unnamed: 0,id,data_series,chart_type
0,000b92c3b098_x,0;6;12;18;24,line
1,000b92c3b098_y,1.03e-02;-6.82e-01;-1.38e+00;-2.08e+00;-2.74e+00,line
2,01b45b831589_x,21-Feb;22-Feb;23-Feb;24-Feb;25-Feb;26-Feb;27-F...,vertical_bar
3,01b45b831589_y,8.83e+04;1.50e+05;1.71e+05;1.76e+05;1.37e+05;9...,vertical_bar
4,00f5404753cf_x,5.00e+00;5.00e+00;5.00e+00;6.00e+00;6.00e+00;6...,scatter
5,00f5404753cf_y,1.10e+01;1.20e+01;1.40e+01;1.20e+01;1.30e+01;1...,scatter
6,00dcf883a459_x,Group 1;Group 2,vertical_bar
7,00dcf883a459_y,3.60e+00;8.40e+00,vertical_bar
8,007a18eb4e09_x,0.0;0.4;0.8;1.2;1.6;2.0;2.4,line
9,007a18eb4e09_y,1.33e-02;1.33e-02;1.33e-02;1.33e-02;1.33e-02;1...,line
