In [None]:
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.patches as patches
import json
from PIL import Image
from IPython.display import Image as Image2
from transformers import AutoProcessor, Pix2StructForConditionalGeneration, Pix2StructProcessor
from huggingface_hub import login
import torch
import vl_convert as vlc

device = "cuda" if torch.cuda.is_available() else "cpu"

In [None]:
login(token = "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")

image = Image.open("dataset_vega/test/82.png")

In [None]:
plt.imshow(image)

plt.xticks([])
plt.yticks([])

In [None]:
processor = AutoProcessor.from_pretrained("google/matcha-base")
processor.image_processor.is_vqa = False

model = Pix2StructForConditionalGeneration.from_pretrained("martinsinnona/visdecode_2024_7")

model.eval()
inputs = processor(images=image, return_tensors="pt", max_patches=1024).to(device)

generated_ids = model.generate(flattened_patches=inputs.flattened_patches, attention_mask=inputs.attention_mask, max_length=200)
generated_caption = processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
generated_caption

In [None]:
def get_mark_type(str):
    
    start = str.find("<mark>")
    end = str.find("</mark>")
    
    if start != -1 and end != -1: return str[start+6:end]
    return ""

def get_var_types(str):
    
    start1 = str.find("<type>")
    end1 = str.find("</type>")
    
    if start1 != -1 and end1 != -1: 
        
        start2 = str.find("<type>", end1+1)
        end2 = str.find("</type>", end1+1)
        
        return str[start1+6:end1], str[start2+6:end2]
    
    return "",""

In [None]:
model = Pix2StructForConditionalGeneration.from_pretrained('google/deplot')
processor = Pix2StructProcessor.from_pretrained('google/deplot')

inputs = processor(images=image, text="Generate underlying data table of the figure below:", return_tensors="pt")
predictions = model.generate(**inputs, max_new_tokens=512)

generated_data = processor.decode(predictions[0], skip_special_tokens=True).replace("<0x0A>","\n")

In [None]:
def get_vega_from_xml(str_xml):

    res = {"encoding":{"x":{},"y":{}}}

    mark_type = get_mark_type(str_xml)
    var_type_x, var_type_y = get_var_types(str_xml)

    res["mark"] = mark_type

    res["encoding"]["x"]["type"] = var_type_x
    res["encoding"]["y"]["type"] = var_type_y

    return res

In [None]:
def get_vega_from_data(str_data):

    res = {"encoding":{"x":{},"y":{}},"data":{"values":[]}}
    data = str_data.split("\n")[1:]

    var_names = data[0].split("|")

    var_names_x = var_names[0].strip()
    var_names_y = var_names[1].strip()

    res["encoding"]["x"]["field"] = var_names_x
    res["encoding"]["y"]["field"] = var_names_y

    for line in data[1:]:

        aux = line.split("|")

        aux[0] = aux[0].strip()
        aux[1] = aux[1].strip()

        res["data"]["values"].append({var_names_x: aux[0], var_names_y: int(aux[1])})

    return res

In [None]:
def merge_vegas(vega1, vega2):

    print("----------------------")

    res = vega1

    for key in vega2.keys():

        print(key)

        if key in res.keys():
            res[key] = merge_vegas(res[key], vega2[key])
        else:
            res[key] = vega2[key]

    return res

In [None]:
def merge_dicts(dict1, dict2):
    
    merged = dict1.copy()  # Start with dict1's keys and values

    for key, value in dict2.items():
        if key in merged:
            if isinstance(merged[key], dict) and isinstance(value, dict):
                # If both corresponding values are dictionaries, merge them recursively
                merged[key] = merge_dicts(merged[key], value)
            else:
                # If they are not both dictionaries, you can handle it as needed (e.g., add to list)
                if merged[key] != value:
                    if not isinstance(merged[key], list):
                        merged[key] = [merged[key]]
                    merged[key].append(value)
        else:
            # If key is not in merged, simply add it
            merged[key] = value

    return merged

In [None]:
def draw_vega(dict, scale = 1):

    spec = json.dumps(dict, indent = 4)
    png_data = vlc.vegalite_to_png(vl_spec = spec, scale = scale)
    
    return Image2(png_data, retina = True)

In [None]:
data_dict = get_vega_from_data(generated_data)
data_dict

In [None]:
matcha_dict = get_vega_from_xml(generated_caption)
matcha_dict

In [None]:
vega_dict = merge_dicts(matcha_dict, data_dict)
vega_dict

In [None]:
plt.imshow(image)

plt.xticks([])
plt.yticks([])

In [None]:
draw_vega(vega_dict, scale = 3)    

In [None]:
vlc.vegalite_to_svg(vega_dict)