# **Benetech | Inference**

In [1]:
import pandas as pd
import polars as pol

import torch
import pytorch_lightning as pl
from transformers import Pix2StructForConditionalGeneration, AutoProcessor, Pix2StructConfig, get_linear_schedule_with_warmup 
from transformers import AdamW
from transformers.optimization import get_cosine_schedule_with_warmup, Adafactor

from PIL import Image
import os
import json
from pathlib import Path
from typing import List, Dict, Union

import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
from torch.utils.data import Dataset as TorchDataset

import albumentations as A
from albumentations.pytorch import ToTensorV2
import cv2

import re
import warnings
warnings.simplefilter("ignore")

In [2]:
from pathlib import Path

PROMPT_TOKEN = "<|BOS|>"
X_START = "<x_start>"
X_END = "<x_end>"
Y_START = "<y_start>"
Y_END = "<y_end>"

SEPARATOR_TOKENS = [
    PROMPT_TOKEN,
    X_START,
    X_END,
    Y_START,
    Y_END,
]

LINE_TOKEN =  "<line>" 
VERTICAL_BAR_TOKEN = "<vertical_bar>"
HORIZONTAL_BAR_TOKEN = "<horizontal_bar>"
SCATTER_TOKEN = "<scatter>"
DOT_TOKEN = "<dot>"

CHART_TYPE_TOKENS = [
    LINE_TOKEN,
    VERTICAL_BAR_TOKEN,
    HORIZONTAL_BAR_TOKEN,
    SCATTER_TOKEN,
    DOT_TOKEN,
]

NEW_TOKENS = SEPARATOR_TOKENS + CHART_TYPE_TOKENS

class Config:
    # General
    debug = False
    num_proc = 2
    num_workers = 2
    gpus = 2

    # Data
    data_dir = Path('/kaggle/input/benetech-making-graphs-accessible/train')
    images_path = data_dir/'images'
    train_json_files = list(data_dir.glob('annotations/*.json'))

    # Training
    epochs = 5
    val_check_interval = 1.0
    check_val_every_n_epoch = 1
    gradient_clip_val = 2.0
    lr = 2e-5
    lr_scheduler_type = "cosine"
    num_warmup_steps = 100
    seed = 42
    output_path = "output"
    log_steps = 200
    batch_size = 2
    use_wandb = True
    image_height = 512
    image_width  = 512
    max_length = 1024


In [3]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Pix2StructForConditionalGeneration.from_pretrained("/kaggle/input/benetech-matcha-models/matchav16/kaggle/working/matcha16").to(device)
processor = AutoProcessor.from_pretrained("/kaggle/input/benetech-matcha-models/matchav16/kaggle/working/matcha16", is_vqa=False)

## **Inference**

In [4]:
def display_matcha_output(matcha_output, visualize=True):
    # '<|BOS|><vertical_bar><x_start> Group 1;Group 2<x_end><y_start> 3.6;8.4<y_end>'
    chart_match = re.search("<(\w+)>", matcha_output)
    chart = chart_match.group(1) if chart_match else None
    
    
    # Extracting x values
    x_match = re.search("<x_start>(.*?)<x_end>", matcha_output)
    x = x_match.group(1).strip() if x_match else None
        
        
    # Extracting y values
    y_match = re.search("<y_start>(.*?)<y_end>", matcha_output)
    y = y_match.group(1).strip() if y_match else None
    
    x = x.split(';') if x is not None else ['0;0']
    y = y.split(';') if y is not None else ['0;0']

    return chart, x, y

In [6]:
def matcha(path, model, processor, device, visualize=True):
    with torch.no_grad():
        image = Image.open(path)
        if visualize:
            display(image)

        inputs = processor(images=image,
                          return_tensors="pt").to(device)

        predictions = model.generate(**inputs, max_new_tokens=512,
                                early_stopping=True, use_cache=True,
                                eos_token_id=processor.tokenizer.eos_token_id,
                                pad_token_id=processor.tokenizer.pad_token_id,
                                bos_token_id=processor.tokenizer.bos_token_id,
                                decoder_start_token_id=processor.tokenizer.bos_token_id,
                                temperature=.9,  
                                top_k=1,
                                top_p=.4,)
        return processor.decode(
            predictions[0], skip_special_tokens=True
        )

In [7]:
def matcha_inference(image_path, visualize):
    matcha_output=matcha(image_path, model, processor, device, visualize)
    return display_matcha_output(matcha_output, visualize)

In [8]:
import os
from tqdm import tqdm
import math

IMAGE_FOLDER = "/kaggle/input/benetech-making-graphs-accessible/test/images"
all_ids = []
all_values = []
all_chart_types = []

for image_name in tqdm(os.listdir(IMAGE_FOLDER)):
    if ".jpg" not in image_name:
        continue
    try:
        image_path = os.path.join(IMAGE_FOLDER, image_name)
        chart, x_values, y_values,  = matcha_inference(image_path, False)
        if chart in ["horizontal_bar", "vertical_bar", "line", "dot", "scatter"]:
            formatted_y_values = []
            for v in y_values:
                try:
                    float(v)
                    if math.isnan(float(v)):
                        formatted_y_values.append(0)
                    else:
                        formatted_y_values.append(v)
                except:
                    formatted_y_values.append(0)
                    
            y_values = formatted_y_values
        
        if chart in ["dot", "scatter"]:
            formatted_x_values = []
            for v in x_values:
                try:
                    float(v)
                    if math.isnan(float(v)):
                        formatted_x_values.append(0)
                    else:
                        formatted_x_values.append(v)
                except:
                    formatted_x_values.append(0)
                    
            x_values = formatted_x_values
            
        
        length = min(len(x_values), len(y_values))
        x_values = ";".join([str(v).strip() for v in x_values][:length])
        y_values = ";".join([str(v).strip() for v in y_values][:length])
        chart_type = chart
    except Exception as e:
        print("Exception", e)
        chart_type = "scatter"
        x_values = "0;0"
        y_values = "0;0"
    
    image_id = image_name.split(".")[0]
        
    all_ids.append(image_id + "_x")
    all_values.append(x_values)
    
    all_ids.append(image_id + "_y")
    all_values.append(y_values)
    
    all_chart_types.extend([chart_type, chart_type])

 20%|██        | 1/5 [00:03<00:14,  3.71s/it]

<|BOS|><line><x_start> 0;6;12;18;24<x_end><y_start> 0;1;1;2;1<y_end>


 40%|████      | 2/5 [00:07<00:11,  3.76s/it]

<|BOS|><vertical_bar><x_start> 21-Feb;22-Feb;23-Feb;24-Feb;25-Feb;26-Feb;27-Feb;28-Feb;29-Feb;01-Mar;02-Mar;03-Mar;04-Mar;05-Mar;06-Mar;07-Mar;08-Mar;09-Mar;10-Mar<x_end><y_start> 90000;150000;170000;180000;130000;99000;-0.0000;40000;60000;60000;50000;40000;60000;80000;80000;100000;130000;100000<y_end>


 60%|██████    | 3/5 [00:16<00:12,  6.09s/it]

<|BOS|><scatter><x_start> 5;6;7;8;9;10;11;12;13;14;15;16;17;18;19;20;21;22;23;24;25;26;27;28;29;30;31;32;33;34;35;36;37;38;39;40;41;42;43;44;45;46;47;48;49;50;51;52;53;54;55;56;57;58;59;60;61;62;63;64;65;66;67;68;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;69;


 80%|████████  | 4/5 [00:16<00:03,  3.91s/it]

<|BOS|><vertical_bar><x_start> Group 1;Group 2<x_end><y_start> 3.6;8.4<y_end>


100%|██████████| 5/5 [00:18<00:00,  3.66s/it]

<|BOS|><line><x_start> 0.0;0.4;0.8;1.2;1.6;2.0;2.4;2.8<x_end><y_start> 7.0;6.0;7.0;5.0;6.0;6.0;6.0;5.0<y_end>





In [9]:
submission_df = pd.DataFrame({
    "id": all_ids,
    "data_series": all_values,
    "chart_type": all_chart_types
})

submission_df.to_csv("submission.csv", index=False)