# <b><span style='color:#F1A424'>|</span> Benetech: <span style='color:#F1A424'>Pix2Struct</span><span style='color:#ABABAB'> [Inference]</span></b> 

***


### <b><span style='color:#F1A424'>Table of Contents</span></b> <a class='anchor' id='top'></a>
<div style=" background-color:#3b3745; padding: 13px 13px; border-radius: 8px; color: white">
<li> <a href="#introduction">Introduction</a></li>
<li> <a href="#install_libraries">Install libraries</a></li>
<li><a href="#import_libraries">Import Libraries</a></li>
<li><a href="#configuration">Configuration</a></li>
<li><a href="#utils">Utils</a></li>
<li><a href="#pre_processing">Load Data</a></li>
<li><a href="#model">Model</a></li>
<li><a href="#dataset">Dataset</a></li>
<li><a href="#collate">Collate Function</a></li>
<li><a href="#dataloader">DataLoader</a></li>
<li><a href="#inference">Inference</a></li>
<li><a href="#submission">Submission</a></li>
<li><a href="#q_and_a">Quality Assurance</a></li>
</div>


# <b><span style='color:#F1A424'>|</span> Introduction</b><a class='anchor' id='introduction'></a> [↑](#top) 

***

### <b><span style='color:#F1A424'>Useful References</span></b>

- [Pix2Struct HuggingFace Demo](https://github.com/huggingface/notebooks/blob/main/examples/image_captioning_pix2struct.ipynb)
- [Issues in training discussion](https://github.com/huggingface/transformers/issues/22903)
- [Pix2Struct Niels Rogge Demo](https://github.com/NielsRogge/Transformers-Tutorials/blob/master/Pix2Struct/Fine_tune_Pix2Struct_on_key_value_pair_dataset_(PyTorch_Lightning).ipynb)

# <b><span style='color:#F1A424'>|</span> Install Libraries</b><a class='anchor' id='install_libraries'></a> [↑](#top) 

***

Check the issue mentioned in the *Useful References*. We need a `transformers` version which has the error fixed. In the future, the latest environments will include the fix.

In [1]:
!pip uninstall transformers -y
!python -m pip install --no-index --find-links=/kaggle/input/benetech-pip transformers

Found existing installation: transformers 4.28.1
Uninstalling transformers-4.28.1:
  Successfully uninstalled transformers-4.28.1
[0mLooking in links: /kaggle/input/benetech-pip
Processing /kaggle/input/benetech-pip/transformers-4.30.0.dev0.zip
  Installing build dependencies ... [?25l- \ | / done
[?25h  Getting requirements to build wheel ... [?25l- \ done
[?25h  Preparing metadata (pyproject.toml) ... [?25l- \ done
Processing /kaggle/input/benetech-pip/huggingface_hub-0.14.1-py3-none-any.whl
Building wheels for collected packages: transformers
  Building wheel for transformers (pyproject.toml) ... [?25l- \ | / - \ | / - \ | / - \ | / - \ | / - done
[?25h  Created wheel for transformers: filename=transformers-4.30.0.dev0-py3-none-any.whl size=7120338 sha256=cd73a6fa626a1484d585f373693f5fbd010e5b1e27cdf34a44bc12bec1aded3e
  Stored in directory: /root/.cache/pip/wheels/c9/06/f4/a315c5665163a83bde1d7b42bfc743ec

In [2]:
import datasets
import transformers

print(f"datasets version: {datasets.__version__}") # should be 2.12.0 
print(f"transformers version: {transformers.__version__}") # should be 4.29.0.dev0 or higher

datasets version: 2.1.0
transformers version: 4.30.0.dev0


# <b><span style='color:#F1A424'>|</span> Import Libraries</b><a class='anchor' id='import_libraries'></a> [↑](#top) 

***

Import all the required libraries for this notebook.

In [3]:
import ast
import cv2
import json
import matplotlib.pyplot as plt
import multiprocessing
import numpy as np
import os
import pandas as pd
import pyarrow
import random
import re
import torch
import wandb


from collections import Counter
from datasets import load_dataset, concatenate_datasets
from datasets import Dataset as HFDataset, DatasetDict
from datasets import Image as ds_img
from glob import glob
from itertools import chain
from nltk import edit_distance
from pathlib import Path
from PIL import Image
from torch.utils.data import Dataset, DataLoader
from tqdm import tqdm
from transformers import Pix2StructForConditionalGeneration, AutoProcessor
from typing import List, Dict, Union, Tuple, Any



# <b><span style='color:#F1A424'>|</span> Configuration</b><a class='anchor' id='configuration'></a> [↑](#top) 

***

Central repository for this notebook's hyperparameters.

In [4]:
class config:
    BATCH_SIZE = 4
    DEBUG = False
    DEVICE = 'cuda' if torch.cuda.is_available() else 'cpu'
    GPUS = 2
    NUM_PROCESS = 2
    NUM_WORKERS = multiprocessing.cpu_count()
    VERBOSE = True

    
class paths:
    BEST_MODEL = "/kaggle/input/matcha-amp"
    TEST_FOLDER = "/kaggle/input/benetech-making-graphs-accessible/train"
    TEST_IMAGES_FOLDER = "/kaggle/input/benetech-making-graphs-accessible/test/images"

# <b><span style='color:#F1A424'>|</span> Utils</b><a class='anchor' id='utils'></a> [↑](#top) 

***

Utility functions used throughout the notebook.

In [5]:
X_START = "<s_x_values>"
X_END = "</s_x_values>"
Y_START = "<s_y_values>"
Y_END = "</s_y_values>"
CHART_START = "<s_chart>"
CHART_END = "</s_chart>"


def get_text(text: str, start_token: str, end_token: str, exception: str):
    """
    This functions retrieves text data between two tokens. If the tokens are not present it defaults to
    an exception string.
    """
    pattern = f"{start_token}(.*?){end_token}"
    matches = re.findall(pattern, text)
    if matches:
        data = matches[0]
    else:
        data = exception
    return data


def get_prediction(preds: List[Dict]):
    """
    This function extracts the relevant information from the model's predictions. Predictions are
    lists of strings, where the relevant data is enclosed by special tokens.
    """
    output = []
    for pred in preds:
        pred_dictionary = {}
        prediction = pred["prediction"]
        pred_dictionary["id"] = pred["id"]
        pred_dictionary["chart_type"] = get_text(text=prediction, start_token=CHART_START,
                                                 end_token=CHART_END, exception="line")
        pred_dictionary["data_series_x"] = get_text(text=prediction, start_token=X_START,
                                                  end_token=X_END, exception="0;0")
        pred_dictionary["data_series_y"] = get_text(text=prediction, start_token=Y_START,
                                                  end_token=Y_END, exception="0;0")
        output.append(pred_dictionary)
    return output


def format_submission(preds: List[Dict]):
    """
    Since we need to have one prediction per axis we need to format our predictions.
    """
    output = []
    for prediction in predictions:
        prediction_x, prediction_y = {}, {}
        prediction_x["id"] = prediction["id"] + "_x"
        prediction_y["id"] = prediction["id"] + "_y"
        prediction_x["data_series"] = prediction["data_series_x"]
        prediction_y["data_series"] = prediction["data_series_y"]
        prediction_x["chart_type"] = prediction["chart_type"]
        prediction_y["chart_type"] = prediction["chart_type"]
        output.append(prediction_x)
        output.append(prediction_y)
    return output

# <b><span style='color:#F1A424'>|</span> Load Data</b><a class='anchor' id='load_data'></a> [↑](#top) 

***

In [6]:
image_dir = Path(paths.TEST_IMAGES_FOLDER)
images = list(image_dir.glob("*.jpg"))

ds = HFDataset.from_dict({
    "image_path": [str(x) for x in images],
    "id": [x.stem for x in images]
}).cast_column("image_path", ds_img())

# <b><span style='color:#F1A424'>|</span> Model</b><a class='anchor' id='model'></a> [↑](#top) 

***

In [7]:
processor = AutoProcessor.from_pretrained(paths.BEST_MODEL, is_vqa=False)
model = Pix2StructForConditionalGeneration.from_pretrained(paths.BEST_MODEL, is_vqa=False)

# <b><span style='color:#F1A424'>|</span> Dataset</b><a class='anchor' id='dataset'></a> [↑](#top) 

***

We will create a `CustomDataset` class. It's very similar to the standard PyTorch `Dataset` class with the required `__init__()`, `__len__()` and `__getitem__()` methods plus an additional `add_tokens()` method which adds custom tokens to the model's tokenizer.

In [8]:
class CustomDataset(Dataset):
    def __init__(
        self,
        dataset: DatasetDict,
        max_patches: int = 512,
        max_length: int = 512,
        new_tokens: list = []
        ):
        """
        Initialize CustomDataset instance.
        :param dataset (DatasetDict): HuggingFace DatasetDict instance
        :param max_patches (int): Maximum number of patches to extract
        :param max_length (int):     
        """
        super().__init__()

        self.dataset = dataset
        self.max_patches = max_patches
        self.max_length = max_length
    
    def __len__(self) -> int:
        return len(self.dataset)

    def __getitem__(self, idx: int):
        item = self.dataset[idx]
        encoding = processor(
            images=item["image_path"],
            max_patches=self.max_patches,
            add_special_tokens=True,
            return_tensors="pt"
        )
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        encoding["id"] = item["id"]
        return encoding

### <b><span style='color:#F1A424'>Create inference dataset</span></b>

The [preprocess](https://github.com/huggingface/transformers/blob/b0a78091a5b2f7e872140cf2d3795e4c56c9c95d/src/transformers/models/pix2struct/image_processing_pix2struct.py#L323) encoding consists of three parts:
- `flattened_patches`: image patches.
- `attention_mask`: attention mask. Tensor with 1s and 0s.
- `id`: id of the image.

In [9]:
inference_dataset = CustomDataset(dataset=ds)

# === Let's check one sample ===
encoding = inference_dataset[0]
# decoded_text = processor.decode(encoding["labels"]) # uncomment to show <pad>
print(f"Encoding keys: {encoding.keys()} \n") 
print(f"Unique tokens in tokenizer: {len(processor.tokenizer)} \n")

Encoding keys: dict_keys(['flattened_patches', 'attention_mask', 'id']) 

Unique tokens in tokenizer: 50350 



# <b><span style='color:#F1A424'>|</span> Collate Function</b><a class='anchor' id='collate'></a> [↑](#top) 

***

The `collate` function in Hugging Face refers to the process of combining and organizing a batch of individual examples into a single batch tensor or data structure. It is commonly used in natural language processing (NLP) tasks such as text classification or language modeling.

Hugging Face provides a `DataCollator` class that implements the `collate` function. It takes care of tasks like padding sequences to a common length, creating attention masks, and handling any additional processing specific to the task or model being used.

By using the `collate` function, you can efficiently preprocess and prepare your data for training or inference with Hugging Face's models and libraries.

In [10]:
def collator(batch):
    new_batch = {"flattened_patches":[], "attention_mask":[], "id":[]}
    
    for item in batch:
        new_batch["flattened_patches"].append(item["flattened_patches"])
        new_batch["attention_mask"].append(item["attention_mask"])
        new_batch["id"].append(item["id"])

    new_batch["flattened_patches"] = torch.stack(new_batch["flattened_patches"])
    new_batch["attention_mask"] = torch.stack(new_batch["attention_mask"])

    return new_batch

# <b><span style='color:#F1A424'>|</span> DataLoader</b><a class='anchor' id='dataloader'></a> [↑](#top) 

***

In [11]:
inference_dataloader = DataLoader(inference_dataset, batch_size=config.BATCH_SIZE,
                                  shuffle=False, num_workers=config.NUM_WORKERS, collate_fn=collator)

# === Let's check one sample ===
batch = next(iter(inference_dataloader))
encoding = batch

# === Iterate over each element in the dictionary and print shape ===
for k,v in encoding.items():
    try:
        print(f"{k} shape: {v.shape} \n")
    except:
        print(f"{k} shape: {len(v)} \n")

flattened_patches shape: torch.Size([4, 512, 770]) 

attention_mask shape: torch.Size([4, 512]) 

id shape: 4 



# <b><span style='color:#F1A424'>|</span> Inference</b><a class='anchor' id='inference'></a> [↑](#top) 

***

In [12]:
device = config.DEVICE
model.to(device)
model.eval()
predictions = []
with torch.no_grad():
    for step, batch in enumerate(tqdm(inference_dataloader, unit="test_batch")):
        ids = batch.pop("id")
        flattened_patches = batch.pop("flattened_patches").to(device)
        attention_mask = batch.pop("attention_mask").to(device)
        prediction = model.generate(
            flattened_patches=flattened_patches,
            attention_mask=attention_mask,
            max_new_tokens=512,
            pad_token_id=processor.tokenizer.pad_token_id,
            eos_token_id=processor.tokenizer.eos_token_id
        )
        preds = processor.batch_decode(prediction, skip_special_tokens=True)
        preds = [{'id': id_value, 'prediction': pred_value} for id_value, pred_value in zip(ids, preds)]
        predictions += get_prediction(preds)
        print(f'Step: {step} | Prediction: {preds}'), print("\n")

 50%|█████     | 1/2 [00:43<00:43, 43.81s/test_batch]

Step: 0 | Prediction: [{'id': '000b92c3b098', 'prediction': '<s_chart> line</s_chart><s_x_values> 0;6;12;18;24</s_x_values><s_y_values> 0.0000;-1.3232;-1.2132;-1.9432;-1.2432</s_y_values>'}, {'id': '01b45b831589', 'prediction': '<s_chart> vertical_bar</s_chart><s_x_values> 21-Feb ; 21-Mar ; 23-Apr ; 23-May ; 24-Jun ; 22-Jul ; 26-Aug ; 27-Sep ; 27-Oct ; 2-Nov ; 2-Dec ; 0-Dec ; 0-Nov ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Nov ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Nov ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Dec ; 0-Dec ; 0-'}, {'id': '00f5404753cf', 'prediction': '<s_chart> scatter</s_chart><s_x_values> 0.6 ; 0.6 ; 0.6 ; 0.7 ; 0.8 ; 0.8 ; 0.9 ; 1.0 ; 1.0 ; 1.1 ; 1.2 ; 1.2 ; 1.3 ; 1.4 ; 1.4 ; 1.5 ; 1.6 ; 1.6 ; 1.7 ; 1.7 ; 1.8 ; 1.8 ; 1.8 ; 1.9 ; 2.0 ; 2.0 ; 2.0 ; 2.1 ; 2.2 ; 2.2 ; 2.3 ; 2.4 ; 2.4 ; 2.5 ; 2.6 ; 2.6 ; 2.7 ; 2.7 ; 2.8 ; 2.8 ; 2.9 ; 3.0 ; 3.0 ; 3.1 ; 3.2 ; 3.3 ; 3.4 ; 3.5 ; 3.5 ; 3.6 ; 3.7 ; 3.7 ; 3.8 ; 3.8 ; 3.9 ; 4.0 ; 4.0 ; 4.1 ; 4.2 ; 4.3 ; 4.4 ; 4.5 ; 4.6 ; 4.6 ; 4

100%|██████████| 2/2 [00:46<00:00, 23.21s/test_batch]

Step: 1 | Prediction: [{'id': '007a18eb4e09', 'prediction': '<s_chart> line</s_chart><s_x_values> 0;0.4;0.8;1.2;1.6;2.0;2.4;2.8</s_x_values><s_y_values> 0.0132;0.0132;0.0132;0.0132;0.0132;0.0132;0.0132;0.0132</s_y_values>'}]







# <b><span style='color:#F1A424'>|</span> Submission</b><a class='anchor' id='submission'></a> [↑](#top) 

***

In [13]:
submission = pd.DataFrame(format_submission(predictions))
submission["chart_type"] = submission["chart_type"].apply(lambda x: re.sub(r"\s", "", x))
submission.to_csv("submission.csv", index=False)
submission

Unnamed: 0,id,data_series,chart_type
0,000b92c3b098_x,0;6;12;18;24,line
1,000b92c3b098_y,0.0000;-1.3232;-1.2132;-1.9432;-1.2432,line
2,01b45b831589_x,0;0,vertical_bar
3,01b45b831589_y,0;0,vertical_bar
4,00f5404753cf_x,0;0,scatter
5,00f5404753cf_y,0;0,scatter
6,00dcf883a459_x,Group 1;Group 2,vertical_bar
7,00dcf883a459_y,3.5;8.3,vertical_bar
8,007a18eb4e09_x,0;0.4;0.8;1.2;1.6;2.0;2.4;2.8,line
9,007a18eb4e09_y,0.0132;0.0132;0.0132;0.0132;0.0132;0.0132;0.0...,line


# <b><span style='color:#F1A424'>|</span> Quality Assurance</b><a class='anchor' id='q_and_a'></a> [↑](#top) 

***

In [14]:
try:
    assert len(submission) == len(glob(paths.TEST_IMAGES_FOLDER + "/*.jpg"))*2
except:
    raise ValueError("The number of rows in submission/2 does not match the number of images.")
    
try:
    assert submission.id.nunique() == len(glob(paths.TEST_IMAGES_FOLDER + "/*.jpg"))*2
except:
    raise ValueError("The number of unique IDs/2 does not match the number of images.")
    
try:
    assert submission.columns.tolist() == ['id', 'data_series', 'chart_type']
except:
    raise ValueError("Wrong column names.")