In [1]:
from transformers import DonutProcessor, AutoModel, VisionEncoderDecoderModel, AutoConfig, AutoProcessor
from PIL import Image
import torch, os, re
from pprint import pprint


# Generation code from https://github.com/vis-nlp/UniChart
model_name = "ahmed-masry/unichart-chart2text-statista-960" #"ahmed-masry/unichart-base-960"

base_model = VisionEncoderDecoderModel.from_pretrained("ahmed-masry/unichart-base-960")
base_processor = AutoProcessor.from_pretrained("ahmed-masry/unichart-base-960")

statista_model = VisionEncoderDecoderModel.from_pretrained("ahmed-masry/unichart-chart2text-statista-960" )
statista_processor = AutoProcessor.from_pretrained("ahmed-masry/unichart-chart2text-statista-960" )

pew_model = VisionEncoderDecoderModel.from_pretrained("ahmed-masry/unichart-chart2text-pew-960" )
pew_processor = AutoProcessor.from_pretrained("ahmed-masry/unichart-chart2text-pew-960" )
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def generate_summary(image_path, model, processor, input_prompt="<summarize_chart> <s_answer>"):
    model.to(device)
    image = Image.open(image_path).convert("RGB")
    decoder_input_ids = processor.tokenizer(input_prompt, add_special_tokens=False, return_tensors="pt").input_ids
    pixel_values = processor(image, return_tensors="pt").pixel_values

    outputs = model.generate(
        pixel_values.to(device),
        decoder_input_ids=decoder_input_ids.to(device),
        max_length=model.decoder.config.max_position_embeddings,
        early_stopping=True,
        pad_token_id=processor.tokenizer.pad_token_id,
        eos_token_id=processor.tokenizer.eos_token_id,
        use_cache=True,
        num_beams=4,
        bad_words_ids=[[processor.tokenizer.unk_token_id]],
        return_dict_in_generate=True,
    )
    sequence = processor.batch_decode(outputs.sequences)[0]
    sequence = sequence.replace(processor.tokenizer.eos_token, "").replace(processor.tokenizer.pad_token, "")
    sequence = sequence.split("<s_answer>")[1].strip()
    pprint(sequence)

  from .autonotebook import tqdm as notebook_tqdm


In [12]:
generate_summary("./example_imgs/lines.png", base_model, base_processor)
generate_summary("./example_imgs/lines.png", statista_model, statista_processor)
generate_summary("./example_imgs/lines.png", pew_model, pew_processor)

('The chart shows the average monthly hours of sunshine in Seattle vs. Number '
 'of Bikes that Cross Fremont Bridge for the months of January, March, May, '
 'July, September, and November. The chart indicates that the average monthly '
 'hours of sunshine in Seattle vs. Number of Bikes in July were significantly '
 'lower than the average hours of sunshine for the months of January, March, '
 'May, July, September, and November. The data suggests that the average '
 'monthly hours of sunshine in Seattle vs. Number of Bikes were significantly '
 'lower than the average hours of sunshine for the months of July.')
('The short-time presents data on September 4, 2013, shows the average monthly '
 'hours of sunshine in Seattle vs. Member of the Sunshine Coast. The average '
 'monthly hours of sunshine in July was 39.52 hours, which was the highest '
 'monthly average hour of sunshine. The data is calculated by Statista based '
 "on the ASTA Research Project's history, published by Statista

In [3]:
generate_summary("../imgs/line_plot_example.png", base_model, base_processor)

[W NNPACK.cpp:51] Could not initialize NNPACK! Reason: Unsupported hardware.


('The chart shows the average monthly hours of sunshine in Seattle vs. Seattle '
 'vs. Surface. The chart reveals that the average monthly hours of sunshine in '
 'Seattle vs. Seattle vs. Surface. The chart reveals that the average monthly '
 'hours of sunshine in Seattle vs. Seattle vs. Seattle vs. Surface. Surface. '
 'The chart reveals that the average monthly hours of sunshine in Seattle vs. '
 "Surface for men's Surface for men's Surface visits Surface visits Surface "
 'visits Surface visits Surface visits Surface visits Surfaces Surfaces '
 'Surfaces Surfaces Surfaces Surfaces Surfaces Surfaces Surfaces Surfaces '
 'Surfaces Surfaces Surfaces Surfaces Surfaces Surfaces Surfaces Surfaces '
 'Surfaces Surfaces Surfaces and '
 'Surfacesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesatesat

In [13]:
generate_summary("./example_imgs/scatter.png", base_model, base_processor)

('The chart shows the points from a 2d Gaussian Distribution in Random '
 'Gaussian X from 0.0 to 8. The blue and the organic data shows the points '
 'from a 2d Gaussian X from 0.0 to 8. The chart shows that the points from a '
 '2d Gaussian X decreased from -4.0 to -2.0, while the points from a 2d '
 'Gaussian X decreased from -2.0 to -2.0. Overall, the chart shows that the '
 'points from a 2d Gaussian X decreased from -4.0 to -2.0, while the points '
 'from a 2d Gaussian X decreased from -2.0 to -2.0.')


In [14]:
generate_summary("./example_imgs/line_subplots.png", base_model, base_processor)

("The chart displays the number of Anoscombe's quarterbacks in the United "
 'States over a twenty-year period from 2013 to 2015. The number of Anoscombe '
 'started at 5.4 in 2013, increased to 5.7 in 2014, then to 5.7 in 2015, and '
 'further increased to 6.9 in 2016. The number of Anoscombe then decreased to '
 '6.4 in 2017, and further decreased to 6.4 in 2018. Therefore, the chart '
 "depicts a fluctuating trend of Anoscombe's quarterbacks in the United States "
 'during the twenty-year period, with an overall increasing trend from 2013 to '
 '2015.')
