# Donut with additional hyperparameter tuning

In [None]:
import re
from pathlib import Path
from typing import List
from functools import partial

from transformers import (
    DonutProcessor,
    VisionEncoderDecoderConfig,
    VisionEncoderDecoderModel,
)
import torch
from torch.utils.data import DataLoader
import numpy as np
import pandas as pd
from datasets import Dataset
from datasets import Image as ds_img
from tqdm.notebook import tqdm

import warnings
warnings.filterwarnings("ignore")

In [None]:
class CFG:
    test_grayscale = True
    debug_clean = False
    batch_size = 4
    image_path = "/kaggle/input/benetech-making-graphs-accessible/test/images"
    max_length = 512
    model_dir = "/kaggle/input/benetech-donut"

    
BOS_TOKEN = "<|BOS|>"
X_START = "<x_start>"
X_END = "<x_end>"
Y_START = "<y_start>"
Y_END = "<y_end>"

PLACEHOLDER_DATA_SERIES = "0;0"
PLACEHOLDER_CHART_TYPE = "line"

In [None]:
from typing import List, Tuple
def clean_preds(x: List[str], y: List[str]) -> Tuple[List[str], List[str]]:
    def clean(value):
        value = re.sub(r"[^\d.\-eE]", "", value)
        if value.count(".") > 1:
            value = value.replace(".", "", 1)
        if value.count("-") > 1:
            value = value.replace("-", "")
        if value.count("e") > 1:
            value = value.replace("e", "", 1)
        return value

    def clean_list(str_list):
        cleaned_list = [clean(val) for val in str_list if val.strip()]
        return cleaned_list

    x_cleaned = clean_list(x)
    y_cleaned = clean_list(y)

    return x_cleaned, y_cleaned
    

def string2preds(pred_string: str) -> Tuple[str, List[str], List[str]]:
    chart_type_mapping = {
        "<dot>": "dot",
        "<horizontal_bar>": "horizontal_bar",
        "<vertical_bar>": "vertical_bar",
        "<scatter>": "scatter",
        "<line>": "line",
    }

    for token, chart_type in chart_type_mapping.items():
        if token in pred_string:
            break
    else:
        return "vertical_bar", [], []

    pattern = r"{}{}(.*?){}{}(.*?){}".format(
        X_START, X_END, Y_START, Y_END, chart_type_mapping[token]
    )
    match = re.search(pattern, pred_string)
    if not match:
        return chart_type, [], []

    x_values = match.group(1).split(";")
    y_values = match.group(2).split(";")

    x_cleaned, y_cleaned = clean_preds(x_values, y_values)

    return chart_type, x_cleaned, y_cleaned

In [None]:
image_dir = Path(CFG.image_path)
images = list(image_dir.glob("*.jpg"))

ds = Dataset.from_dict(
    {"image_path": [str(x) for x in images], "id": [x.stem for x in images]}
).cast_column("image_path", ds_img())

def preprocess(examples, processor):
    pixel_values = []

    for sample in examples["image_path"]:
        arr = np.array(sample)
        
        # There are some grayscale images that were making this fail
        # This prevents that.
        if len(arr.shape) == 2:
            print("Changing grayscale to 3 channel format")
            print(arr.shape)
            arr = np.stack([arr]*3, axis=-1)
        
        pixel_values.append(processor(arr, random_padding=True).pixel_values)
        
        
    return {
        "pixel_values": torch.tensor(np.vstack(pixel_values)),
    }

model = VisionEncoderDecoderModel.from_pretrained(CFG.model_dir)
model.eval()

device = torch.device("cuda:0")

model.to(device)
decoder_start_token_id = model.config.decoder_start_token_id
processor = DonutProcessor.from_pretrained(CFG.model_dir)

ids = ds["id"]
ds.set_transform(partial(preprocess, processor=processor))

data_loader = DataLoader(
    ds, batch_size=CFG.batch_size, shuffle=False
)


from tqdm import tqdm

all_generations = []
with tqdm(total=len(data_loader)) as progress_bar:
    for batch in data_loader:
        pixel_values = batch["pixel_values"].to(device)
        batch_size = pixel_values.shape[0]
        decoder_input_ids = torch.full(
            (batch_size, 1),
            decoder_start_token_id,
            device=pixel_values.device,
        )

        outputs = model.generate(
            pixel_values,
            decoder_input_ids=decoder_input_ids,
            max_length=CFG.max_length,
            early_stopping=True,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id,
            use_cache=True,
            num_beams=2,
            temperature=.9,
            top_k=1,
            top_p=.4,
            return_dict_in_generate=True,
        )

        all_generations.extend(processor.batch_decode(outputs.sequences))
        progress_bar.update(1)

chart_types, x_preds, y_preds = [], [], []
for gen in all_generations:
    try:
        chart_type, x, y = string2preds(gen)
        new_chart_type = chart_type
        x_str = ";".join(map(str, x))
        y_str = ";".join(map(str, y))
    except Exception as e:
        print("Failed to convert to string:", gen)
        print(e)
        new_chart_type = PLACEHOLDER_CHART_TYPE
        x_str = PLACEHOLDER_DATA_SERIES
        y_str = PLACEHOLDER_DATA_SERIES

    if len(x_str) == 0:
        x_str = PLACEHOLDER_DATA_SERIES
    if len(y_str) == 0:
        y_str = PLACEHOLDER_DATA_SERIES

    chart_types.append(new_chart_type)
    x_preds.append(x_str)
    y_preds.append(y_str)
    
        

sub_df = pd.DataFrame(
    data={
        "id": [f"{id_}_x" for id_ in ids] + [f"{id_}_y" for id_ in ids],
        "data_series": x_preds + y_preds,
        "chart_type": chart_types * 2,
    }
)

sub_df.to_csv("submission.csv", index=False)

In [None]:
display(sub_df)

In [None]:
model = VisionEncoderDecoderModel.from_pretrained(CFG.model_dir)

model.eval()