# NER for Document Extraction Tool
This notebook contains script to train NER and run inference using the LayoutLM model "[LayoutLM: Pre-training of Text and Layout for Document Image Understanding](https://arxiv.org/abs/1912.13318)" by Yiheng Xu, Minghao Li, Lei Cui, Shaohan Huang, Furu Wei and Ming Zhou as implemented in Huggingface Transformers library. We take open-source receipt data from SROIE competition and form data from FUNSD repo. For more details, visit these links:
- LayoutLM Github repo [here](https://github.com/microsoft/unilm/tree/master/layoutlm).
- Read about the SROIE competition and dataset [here](https://rrc.cvc.uab.es/?ch=13).
- More about FUNSD [here] (https://guillaumejaume.github.io/FUNSD/).
- "Fine tune SROIE on LayoutLM" by ruifcruz [here](https://github.com/ruifcruz/sroie-on-layoutlm).
- Notebook is inspired from Neils Rogge transformer's [tutorials](https://github.com/NielsRogge/Transformers-Tutorials) and Urban Knuples kaggle [notebooks](https://www.kaggle.com/urbikn/layoutlm-using-the-sroie-dataset).

# 1. Pre-processing SROIE dataset
Before fine-tuning the model, we have to preprocess the SROIE dataset which can be downloaded from [here](https://drive.google.com/drive/folders/1ShItNWXyiY1tFDM5W02bceHuJjyeeJl2). The dataset contains multiple subfolders, because the competition is split up into three tasks **Text Localization, Optical character recognition (OCR)** and **Information Extraction (IE)** and some folders are meant for their specific task. For our purposes we're only interested in the last task, so we'll be using these two folders: 
- **0325updated.task1train(626p)** - contains receipt images (.jpg) and corresponding OCR'd bounding boxes and text (.txt)
- **0325updated.task2train(626p)** - contains labeled text (.txt) in a JSON format.

In [1]:
import os
import glob
import json
import random
from pathlib import Path
from difflib import SequenceMatcher
import math

import cv2
import pandas as pd
import numpy as np
from PIL import Image
from tqdm.notebook import tqdm
from IPython.display import display
import matplotlib
from matplotlib import pyplot, patches

## Preparing the dataset
The location of the SROIE dataset and the name of an example file used for demonstration purposes

In [13]:
sroie_folder_path = Path('/home/fsmlp/Downloads/SROIE2019')
example_file = Path('X51005757324.txt')

In [14]:
def normalize(points: list, width: int, height: int) -> list:
  x0, y0, x2, y2 = [int(p) for p in points]
  
  x0 = int(1000 * (x0 / width))
  x2 = int(1000 * (x2 / width))
  y0 = int(1000 * (y0 / height))
  y2 = int(1000 * (y2 / height))

  return [x0, y0, x2, y2]

In [24]:
input_folder = '/home/fsmlp/Downloads/SROIE2019/test_data/'
name='test'
# input_folder = '/home/fsmlp/Downloads/SROIE2019/0325updated.task2train(626p)/'
# name='train'

output_dir = '/home/fsmlp/Downloads/SROIE2019/datan'

In [25]:
csv = input_folder[:-1] + '-receiptResults.csv'
df_csv = pd.read_csv(csv, delimiter='\t')
rejected_list = []
for i in range(len(df_csv)):
    f = input_folder + df_csv['Filename'][i][:-3] + 'json'
    with open(f,'r') as fil:
        data = fil.read()
    js = json.loads(data) 
    width,height = js['analyzeResult']['readResults'][0]['width'], js['analyzeResult']['readResults'][0]['height']
    # rel_box = [w,h,w,h,w,h,w,h]

    vals = df_csv['MerchantAddress-elements'][i]
    bio_list = []
    if not (isinstance(vals,float) and math.isnan(vals)):
        vals = eval(vals)
        first = True
        for v in vals:
            bio={}
            i1 = v.split('/')[-3]
            i2 = v.split('/')[-1]
            # print(i1, i2)
            tok_det = js['analyzeResult']['readResults'][0]['lines'][int(i1)]['words'][int(i2)]
            bbox = tok_det['boundingBox']
            bio['x0'] = min(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
            bio['y0'] = min(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
            bio['x2'] = max(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
            bio['y2'] = max(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
            bio['line'] = tok_det["text"]
            if bio['x2'] < bio['x0'] or bio['y2'] < bio['y0']:
                print("Emptyyyyyyyyyyyyy")
            if first:
                bio['label'] = "B-Address"
                first=False
            else:
                bio['label'] = "I-Address" 
            bio['val'] = int(i1+i2)
            bio_list.append(bio)
    else:
        # print(df_csv['Filename'][i])
        rejected_list.append(df_csv['Filename'][i])
        continue
        # vals = eval(df_csv[t+'-elements'][i])
    
    vals = df_csv['Total-elements'][i]
    tot = True
    if (isinstance(vals,float) and math.isnan(vals)):
        new_vals = df_csv['Subtotal-elements'][i]
        if not (isinstance(vals,float) and math.isnan(vals)):
            vals=new_vals
        else:
            tot=False 
    if tot:
        vals = eval(vals)
        first = True
        for v in vals:
            bio={}
            i1 = v.split('/')[-3]
            i2 = v.split('/')[-1]
            # print(i1, i2)
            tok_det = js['analyzeResult']['readResults'][0]['lines'][int(i1)]['words'][int(i2)]
            bbox = tok_det['boundingBox']
            bio['x0'] = min(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
            bio['y0'] = min(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
            bio['x2'] = max(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
            bio['y2'] = max(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
            bio['line'] = tok_det["text"]
            if bio['x2'] < bio['x0'] or bio['y2'] < bio['y0']:
                print("Emptyyyyyyyyyyyyy")
            if first:
                bio['label'] = "B-Total"
                first=False
            else:
                bio['label'] = "I-Total" 
            bio['val'] = int(i1+i2)
            bio_list.append(bio)
    
    vals = df_csv['MerchantName-elements'][i]
    if not (isinstance(vals,float) and math.isnan(vals)):
        vals = eval(vals)
        first = True
        for v in vals:
            bio={}
            i1 = v.split('/')[-3]
            i2 = v.split('/')[-1]
            # print(i1, i2)
            tok_det = js['analyzeResult']['readResults'][0]['lines'][int(i1)]['words'][int(i2)]
            bbox = tok_det['boundingBox']
            bio['x0'] = min(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
            bio['y0'] = min(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
            bio['x2'] = max(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
            bio['y2'] = max(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
            bio['line'] = tok_det["text"]
            if bio['x2'] < bio['x0'] or bio['y2'] < bio['y0']:
                print("Emptyyyyyyyyyyyyy")
            if first:
                bio['label'] = "B-Company"
                first=False
            else:
                bio['label'] = "I-Company" 
            bio['val'] = int(i1+i2)
            bio_list.append(bio)
            
    vals = df_csv['TransactionDate-elements'][i]
    if not (isinstance(vals,float) and math.isnan(vals)):
        vals = eval(vals)
        first = True
        for v in vals:
            bio={}
            i1 = v.split('/')[-3]
            i2 = v.split('/')[-1]
            # print(i1, i2)
            tok_det = js['analyzeResult']['readResults'][0]['lines'][int(i1)]['words'][int(i2)]
            bbox = tok_det['boundingBox']
            bio['x0'] = min(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
            bio['y0'] = min(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
            bio['x2'] = max(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
            bio['y2'] = max(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
            bio['line'] = tok_det["text"]
            if bio['x2'] < bio['x0'] or bio['y2'] < bio['y0']:
                print("Emptyyyyyyyyyyyyy")
            if first:
                bio['label'] = "B-Date"
                first=False
            else:
                bio['label'] = "I-Date" 
            bio['val'] = int(i1+i2)
            bio_list.append(bio)
            
    df_temp = pd.DataFrame(bio_list)
    # print(df_temp)
    final_list = []
    for y,te in enumerate(js["analyzeResult"]['readResults'][0]["lines"]):
        for z,wo in enumerate(te['words']):
            search=int(str(y)+str(z))
            biol = df_temp['val'].tolist()
            final = {}
            try:
                a = biol.index(search)
                final['x0'] = df_temp['x0'][a]
                final['y0'] = df_temp['y0'][a]
                final['x2'] = df_temp['x2'][a]
                final['y2'] = df_temp['y2'][a]
                final['line'] = wo['text']
                final['label'] = df_temp['label'][a]
            except ValueError:
                ct=0
                bbox = wo['boundingBox']
                final['x0'] = min(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
                final['y0'] = min(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
                final['x2'] = max(int(bbox[0]),int(bbox[2]),int(bbox[4]),int(bbox[6]))
                final['y2'] = max(int(bbox[1]),int(bbox[3]),int(bbox[5]),int(bbox[7]))
                final['line'] = wo['text']
                if final['x2'] < final['x0'] or final['y2'] < final['y0']:
                    # print("Issue Hereeeeeeeeeeeeeeeeeeeeeeeeeee")
                    print(df_csv['Filename'][i],final["line"])
                    print()
                    # break
                final['label'] = 'O'
            final_list.append(final)
    tm = pd.DataFrame(final_list)[['x0','y0','x2','y2','line','label']]
    tm.to_csv(input_folder + df_csv['Filename'][i][:-3]+'csv',sep='\t', index=False)
    
    with open(f"{output_dir}/{name}.txt", "a", encoding="utf8") as file, \
         open(f"{output_dir}/{name}_box.txt", "a", encoding="utf8") as file_bbox, \
         open(f"{output_dir}/{name}_image.txt", "a", encoding="utf8") as file_image:
        for index, row in tm.iterrows():
            bbox = [int(p) for p in row[['x0', 'y0', 'x2', 'y2']]]
            normalized_bbox = normalize(bbox, width, height)

            file.write("{}\t{}\n".format(row['line'], row['label']))
            file_bbox.write("{}\t{} {} {} {}\n".format(row['line'], *normalized_bbox))
            file_image.write("{}\t{} {} {} {}\t{} {}\t{}\n".format(row['line'], *bbox, width, height, df_csv['Filename'][i][:-4]))

        # Write a second newline to separate dataset from others
        file.write("\n")
        file_bbox.write("\n")
        file_image.write("\n")
    # break        

### Reading the words and bounding boxes
So, the first step is reading the OCR data, where every line in the file includes a group of words and a bounding box which defines them. All we have to do is read the file, discard the unneeded points in the bounding box (because the model requires only the top-left and bottom-right points) and save it in Pandas Dataframe.

### Reading the entities file
Now we need to read the entities file to know what to label in our text.

In [7]:
def read_entities(path: Path):
  with open(path, 'r') as f:
    data = json.load(f)

  dataframe = pd.DataFrame([data])
  return dataframe


# Example usage
entities_file_path = sroie_folder_path / "0325updated.task2train(626p)" / example_file
print("== File content ==")
!head "{entities_file_path}"

entities = read_entities(path=entities_file_path)
print("\n\n== Dataframe ==")
entities

== File content ==
{
    "company": "MR. D.I.Y. (M) SDN BHD",
    "date": "25-03-18",
    "address": "LOT 1851-A & 1851-B, JALAN KPB 6, KAWASAN PERINDUSTRIAN BALAKONG, 43300 SERI KEMBANGAN, SELANGOR",
    "total": "50.80"
}

== Dataframe ==


Unnamed: 0,company,date,address,total
0,MR. D.I.Y. (M) SDN BHD,25-03-18,"LOT 1851-A & 1851-B, JALAN KPB 6, KAWASAN PERI...",50.8


### Assigning labels to words using the entities data
We have our words/lines and entities, now we just need to put them together by labeling our lines using the entities values. We'll be doing that by substring matching the entities values with the lines and if they don't match to a similarity check using pythons _difflib.SequenceMatcher_ and assigning anything above the 0.8 (80%) prediction match.

The **label "O"** will define all our words not labeled during the assignment step, because it's required for us to label everything.

In [None]:
# Assign a label to the line by checking the similarity
# of the line and all the entities
def assign_line_label(line: str, entities: pd.DataFrame):
    line_set = line.replace(",", "").strip().split()
    for i, column in enumerate(entities):
        entity_values = entities.iloc[0, i].replace(",", "").strip()
        entity_set = entity_values.split()

        matches_count = 0
        for l in line_set:
          if any(SequenceMatcher(a=l, b=b).ratio() > 0.8 for b in entity_set):
            matches_count += 1

        if matches_count == len(line_set) or matches_count == len(entity_set):
            return column.upper()

    return "O"


line = bbox.loc[0,"line"]
label = assign_line_label(line, entities)
print("Line:", line)
print("Assigned label:", label)

With a function which can handle the labeling of our lines, we'll create another function to label all our line in one DataFrame (so one receipt).

As simple as this could be, the problem arises when we get lines which would all pass the same match, like **TOTAL** for example; a receipt could have only one item on it and its price could be the same as the final total, so duplicate labels. Or maybe part of the address is also present at the end of the receipt.

To ignore such examples, I wrote simple hard-coded rules to assign *total* and *date* to only the largest bounding boxes it could find (based on its area) and to not allow the address to be assigned after date or total.

In [None]:
def assign_labels(words: pd.DataFrame, entities: pd.DataFrame):
    max_area = {"TOTAL": (0, -1), "DATE": (0, -1)}  # Value, index
    already_labeled = {"TOTAL": False,
                       "DATE": False,
                       "ADDRESS": False,
                       "COMPANY": False,
                       "O": False
    }

    # Go through every line in $words and assign it a label
    labels = []
    for i, line in enumerate(words['line']):
        label = assign_line_label(line, entities)

        already_labeled[label] = True
        if label == "ADDRESS" and (already_labeled["DATE"] or already_labeled["TOTAL"]):
            label = "O"

        # Assign to the largest bounding box
        if label in ["TOTAL", "DATE"]:
            x0_loc = words.columns.get_loc("x0")
            bbox = words.iloc[i, x0_loc:x0_loc+4].to_list()
            area = (bbox[2] - bbox[0]) + (bbox[3] - bbox[1])

            if max_area[label][0] < area:
                max_area[label] = (area, i)

            label = "O"

        labels.append(label)

    labels[max_area["DATE"][1]] = "DATE"
    labels[max_area["TOTAL"][1]] = "TOTAL"

    words["label"] = labels
    return words


# Example usage
bbox_labeled = assign_labels(bbox, entities)
bbox_labeled.head(25)

### Split words
For the last part we're splitting the lines into separate tokens with their own bounding boxes.

Splitting the bounding boxes based on word length is probably not the best approach, but it's good enough.

In [None]:
def split_line(line: pd.Series):
  line_copy = line.copy()

  line_str = line_copy.loc["line"]
  words = line_str.split(" ")

  # Filter unwanted tokens
  words = [word for word in words if len(word) >= 1]

  x0, y0, x2, y2 = line_copy.loc[['x0', 'y0', 'x2', 'y2']]
  bbox_width = x2 - x0
  

  new_lines = []
  for index, word in enumerate(words):
    x2 = x0 + int(bbox_width * len(word)/len(line_str))
    line_copy.at['x0', 'x2', 'line'] = [x0, x2, word]
    new_lines.append(line_copy.to_list())
    x0 = x2 + 5 

  return new_lines


# Example usage
new_lines = split_line(bbox_labeled.loc[0])
print("Original row:")
display(bbox_labeled.loc[0:0,:])

print("Splitted row:")
pd.DataFrame(new_lines, columns=bbox_labeled.columns)

### Putting it all together
We defined all our functions, now we just have to use them on every file and transform the dataset into a format which the script/model can parse.

In [None]:
from time import perf_counter
def dataset_creator(folder: Path, total=1000):
  bbox_folder = folder / '0325updated.task1train(626p)'
  entities_folder = folder / '0325updated.task2train(626p)'

  # Ignoring unwanted files which produced problems when I wanted to fine-tune the model with them included
  ignore = ['X51006619545.txt', 'X51006619785.txt', 'X51005663280(1).txt', 'X51005663280.txt'] 
  files = [file for file in bbox_folder.glob("*.txt") if file.name not in ignore]
  files = files[:total]

  data = []

  print("Reading dataset:")
  for file in tqdm(files, total=len(files)):
    bbox_file_path = file
    entities_file_path = entities_folder / file.name
    image_file_path = bbox_folder / file.with_suffix(".jpg")
  
    # Check if all the required files exist
    if not bbox_file_path.is_file() or not entities_file_path.is_file() or not image_file_path.is_file():
      continue
  
    # Read the files
    bbox = read_bbox_and_words(bbox_file_path)
    entities = read_entities(entities_file_path)
    image = Image.open(image_file_path)

    # Assign labels to lines in bbox using entities
    bbox_labeled = assign_labels(bbox, entities)
    del bbox

    # Split lines into separate tokens
    new_bbox_l = []
    for index, row in bbox_labeled.iterrows():
      new_bbox_l += split_line(row)
    new_bbox = pd.DataFrame(new_bbox_l, columns=bbox_labeled.columns, dtype=np.int16)
    del bbox_labeled


    # Do another label assignment to keep the labeling more precise 
    for index, row in new_bbox.iterrows():
      label = row['label']

      if label != "O":
        entity = entities.iloc[0, entities.columns.get_loc(label.lower())]
        if row['line'] not in entity:
          label = "O"
        else:
            # Not really IOB tagging, but it gives the best results
            label = "S-" + label
      
      new_bbox.at[index, 'label'] = label

    width, height = image.size
  
    data.append([new_bbox, width, height])
  return data

dataset = dataset_creator(sroie_folder_path)

## Writing the dataset into training and testing files
With our dataset transformed, we'll split the dataset into a trainable and testable set. I'm allocating 80% of the dataset to training and the other 20% to testing.

In [None]:
random.Random(4).shuffle(dataset)

# Set split point to be 80% of the dataset
split_point = int(len(dataset) * 0.8) 

dataset_train  = dataset[:split_point]
dataset_test = dataset[split_point:]
del(dataset)

### Defining the writing function
We'll use the same function to write into the train and test files

The normalization function is meant to normalize the bounding boxes points in a range [0,1000] using the width and height of the image of the receipt [\[source\]](https://huggingface.co/transformers/model_doc/layoutlm.html#overview).

In [None]:
def normalize(points: list, width: int, height: int) -> list:
  x0, y0, x2, y2 = [int(p) for p in points]
  
  x0 = int(1000 * (x0 / width))
  x2 = int(1000 * (x2 / width))
  y0 = int(1000 * (y0 / height))
  y2 = int(1000 * (y2 / height))

  return [x0, y0, x2, y2]


def write_dataset(dataset: list, output_dir: Path, name: str):
  print(f"Writing {name}ing dataset:")
  with open(output_dir / f"{name}.txt", "w+", encoding="utf8") as file, \
       open(output_dir / f"{name}_box.txt", "w+", encoding="utf8") as file_bbox, \
       open(output_dir / f"{name}_image.txt", "w+", encoding="utf8") as file_image:

      # Go through each dataset
      for datas in tqdm(dataset, total=len(dataset)):
        data, width, height = datas
        
        filename = data.iloc[0, data.columns.get_loc('filename')]

        # Go through every row in dataset
        for index, row in data.iterrows():
          bbox = [int(p) for p in row[['x0', 'y0', 'x2', 'y2']]]
          normalized_bbox = normalize(bbox, width, height)

          file.write("{}\t{}\n".format(row['line'], row['label']))
          file_bbox.write("{}\t{} {} {} {}\n".format(row['line'], *normalized_bbox))
          file_image.write("{}\t{} {} {} {}\t{} {}\t{}\n".format(row['line'], *bbox, width, height, filename))

        # Write a second newline to separate dataset from others
        file.write("\n")
        file_bbox.write("\n")
        file_image.write("\n")

In [None]:
dataset_directory = Path('/home/fsmlp/Downloads/SROIE2019','data')

dataset_directory.mkdir(parents=True, exist_ok=True)

write_dataset(dataset_train, dataset_directory, 'train')
write_dataset(dataset_test, dataset_directory, 'test')

# Creating the 'labels.txt' file to the the model what categories to predict.
labels = ['COMPANY', 'DATE', 'ADDRESS', 'TOTAL']
IOB_tags = ['S']
with open(dataset_directory / 'labels.txt', 'w') as f:
  for tag in IOB_tags:
    for label in labels:
      f.write(f"{tag}-{label}\n")
  f.write("O")

# 2. Fine tune LayoutLM
We downloaded and transformed our dataset into a trainable and testable set, now we can start the fine-tuning of the model.

## Download the model
First we're going to clone the LayoutLM Github project which contains the script to fine tune our model.

In [10]:
# ! rm -r unilm
! git clone -b remove_torch_save https://github.com/NielsRogge/unilm.git
! cd unilm/layoutlm
! pip install unilm/layoutlm

Cloning into 'unilm'...
remote: Enumerating objects: 3210, done.[K
remote: Counting objects: 100% (2344/2344), done.[K
remote: Compressing objects: 100% (1813/1813), done.[K
remote: Total 3210 (delta 1000), reused 1647 (delta 455), pack-reused 866[K
Receiving objects: 100% (3210/3210), 4.99 MiB | 6.88 MiB/s, done.
Resolving deltas: 100% (1480/1480), done.
Processing ./unilm/layoutlm
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
Building wheels for collected packages: layoutlm
  Building wheel for layoutlm (setup.py) ... [?25ldone
[?25h  Created wheel for layoutlm: filename=layoutlm-0.0-py3-none-any.whl size=11486 sha256=fb301d2

In [12]:
# ! rm -r transformers
! git clone https://github.com/huggingface/transformers.git
! cd transformers
! pip install ./transformers

Cloning into 'transformers'...
remote: Enumerating objects: 88571, done.[K
remote: Counting objects: 100% (282/282), done.[K
remote: Compressing objects: 100% (3/3), done.[K
remote: Total 88571 (delta 281), reused 279 (delta 279), pack-reused 88289[K
Receiving objects: 100% (88571/88571), 71.29 MiB | 14.76 MiB/s, done.
Resolving deltas: 100% (63821/63821), done.
Processing ./transformers
[33m  DEPRECATION: A future pip version will change local packages to be built in-place without first copying to a temporary directory. We recommend you use --use-feature=in-tree-build to test your packages with this new behavior before it becomes the default.
   pip 21.3 will remove support for this functionality. You can find discussion regarding this at https://github.com/pypa/pip/issues/7555.[0m
  Installing build dependencies ... [?25ldone
[?25h  Getting requirements to build wheel ... [?25ldone
[?25h    Preparing wheel metadata ... [?25ldone
Collecting tokenizers<0.11,>=0.10.1
  Using 

## Define a PyTorch dataset

First, we create a list containing the unique labels based on `data/labels.txt` (run this from the content directory):

In [None]:
! wget https://t3638486.p.clickup-attachments.com/t3638486/1678c4f6-0f92-484b-9e3f-0ea8275ad7cc/datan.zip

In [2]:
output_dir = '/home/fsmlp/Downloads/SROIE2019/datan'

In [3]:
from torch.nn import CrossEntropyLoss

def get_labels(path):
    with open(path, "r") as f:
        labels = f.read().splitlines()
    if "O" not in labels:
        labels = ["O"] + labels
    return labels

labels = get_labels(output_dir+"/labels.txt")
num_labels = len(labels)
label_map = {i: label for i, label in enumerate(labels)}
# Use cross entropy ignore index as padding label id so that only real label ids contribute to the loss later
pad_token_label_id = CrossEntropyLoss().ignore_index
# pad_token_label_id = -100

In [4]:
print(labels, pad_token_label_id)

['B-Company', 'I-Company', 'B-Address', 'I-Address', 'B-Total', 'I-Total', 'B-Date', 'I-Date', 'O'] -100


Next, we can create a PyTorch dataset and corresponding dataloader (both for training and evaluation):

In [5]:
from transformers import LayoutLMTokenizer
from layoutlm.data.funsd import FunsdDataset, InputFeatures
from torch.utils.data import DataLoader, RandomSampler, SequentialSampler

args = {'local_rank': -1,
        'overwrite_cache': True,
        'data_dir': output_dir,
        'model_name_or_path':'microsoft/layoutlm-base-uncased',
        'max_seq_length': 512,
        'model_type': 'layoutlm',}

# class to turn the keys of a dict into attributes (thanks Stackoverflow)
class AttrDict(dict):
    def __init__(self, *args, **kwargs):
        super(AttrDict, self).__init__(*args, **kwargs)
        self.__dict__ = self

args = AttrDict(args)

tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")

# the LayoutLM authors already defined a specific FunsdDataset, so we are going to use this here
train_dataset = FunsdDataset(args, tokenizer, labels, pad_token_label_id, mode="train")
train_sampler = RandomSampler(train_dataset)
train_dataloader = DataLoader(train_dataset,
                              sampler=train_sampler,
                              batch_size=1)

eval_dataset = FunsdDataset(args, tokenizer, labels, pad_token_label_id, mode="test")
eval_sampler = SequentialSampler(eval_dataset)
eval_dataloader = DataLoader(eval_dataset,
                             sampler=eval_sampler,
                            batch_size=1)

In [6]:
batch = next(iter(train_dataloader))
input_ids = batch[0][0]
tokenizer.decode(input_ids)

'[CLS] aso electrical trading sdn bhd 1000131 - k no 31g, jalan sepadu c 25 / c, section 25, taman industries, axis 40400 shah alam, selangor. tel : 03 - 51221701, 51313091 fax : 03 - 51215716 gst no : 000683900928 tax invoice bill to : receipt # : cs00087400 date : 27 / 09 / 2017 salesperson : cashier : user time : 10 : 51 : 00 ( gst ) ( gst ) item qty rsp rsp amount 107636 3 78. 00 82. 68 248. 04 sr : hager timer, 24hrs power reserve tot qty : 3 248. 04 ( excluded gst ) sub total : 234. 00 discount : 0. 00 total gst : 14. 04 rounding : 0. 01 total : 248. 05 cash : 248. 015 change : 0. 00 gst summary tax code % amount gst sr 6 234. 00 14. 04 total : 234. 00 14. 04 goods sold are not returnable, thank you. [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [P

## Define and fine-tune the model

As this is a sequence labeling task, we are going to load `LayoutLMForTokenClassification` (the base sized model) from the hub. We are going to fine-tune it on a downstream task, namely FUNSD.

In [7]:
from transformers import LayoutLMForTokenClassification
import torch

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = LayoutLMForTokenClassification.from_pretrained("microsoft/layoutlm-base-uncased", num_labels=num_labels)

model.to(device)


Some weights of the model checkpoint at microsoft/layoutlm-base-uncased were not used when initializing LayoutLMForTokenClassification: ['cls.predictions.bias', 'cls.predictions.transform.LayerNorm.bias', 'cls.predictions.decoder.bias', 'cls.predictions.transform.dense.weight', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.bias', 'cls.predictions.decoder.weight']
- This IS expected if you are initializing LayoutLMForTokenClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing LayoutLMForTokenClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of LayoutLMForTokenClassification were not initialized from the model checkpoint at microsoft

LayoutLMForTokenClassification(
  (layoutlm): LayoutLMModel(
    (embeddings): LayoutLMEmbeddings(
      (word_embeddings): Embedding(30522, 768, padding_idx=0)
      (position_embeddings): Embedding(512, 768)
      (x_position_embeddings): Embedding(1024, 768)
      (y_position_embeddings): Embedding(1024, 768)
      (h_position_embeddings): Embedding(1024, 768)
      (w_position_embeddings): Embedding(1024, 768)
      (token_type_embeddings): Embedding(2, 768)
      (LayerNorm): LayerNorm((768,), eps=1e-12, elementwise_affine=True)
      (dropout): Dropout(p=0.1, inplace=False)
    )
    (encoder): LayoutLMEncoder(
      (layer): ModuleList(
        (0): LayoutLMLayer(
          (attention): LayoutLMAttention(
            (self): LayoutLMSelfAttention(
              (query): Linear(in_features=768, out_features=768, bias=True)
              (key): Linear(in_features=768, out_features=768, bias=True)
              (value): Linear(in_features=768, out_features=768, bias=True)
         

## Training

In [8]:
from transformers import AdamW
from tqdm import tqdm

global_step = 0
num_train_epochs = 10
t_total = len(train_dataloader) * num_train_epochs # total number of training steps 
# optimizer = AdamW(model.parameters(), lr=5e-5)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.001, steps_per_epoch=len(train_dataloader), epochs=num_train_epochs)

#put the model in training mode
model.train()
for epoch in range(num_train_epochs):
  for batch in tqdm(train_dataloader, desc="Training"):
      input_ids = batch[0].to(device)
      bbox = batch[4].to(device)
      print(bbox[:, :, 2] - bbox[:, :, 0])
      attention_mask = batch[1].to(device)
      token_type_ids = batch[2].to(device)
      labels = batch[3].to(device)

      # forward pass
      outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
                      labels=labels)
      loss = outputs.loss

      # print loss every 100 steps
      if global_step % 10 == 0:
        print(f"Loss after {global_step} steps: {loss.item()}")

      # backward pass to get the gradients 
      loss.backward()

      #print("Gradients on classification head:")
      #print(model.classifier.weight.grad[6,:].sum())

      # update
      optimizer.step()
      optimizer.zero_grad()
      global_step += 1

Training:   0%|          | 1/624 [00:00<01:41,  6.12it/s]

tensor([[  0, 206, 206, 206, 206,  86,  86,  86,  86,  73,  73,  80,  80, 131,
         131, 131, 131,  49,  49,  63,  37,  37,  91,  91, 137, 137, 137,  88,
          88,  88,  88,  91,  91,  88, 111, 111, 167, 167,  64,  64,  58,  58,
         157, 157, 157, 157, 134, 134, 134, 134, 134, 157, 157, 157, 157, 143,
         143, 143, 143, 235, 235, 235, 235, 235, 235,  59, 143, 143, 143,  26,
         202,  88,  75,  66,  66,  38,  41,  41, 169, 169, 169, 169, 169,  44,
          98,  98, 195, 195, 195, 195, 195, 153, 153, 153, 153, 104, 127,  68,
          68,  65,  65, 114, 119, 119,  64,  64,  18,  21,  41,  41,  87,  41,
         114, 114, 114, 114, 114,  91,  91, 115, 115,  95,  95, 153, 103, 103,
         103, 103, 145, 145, 145, 131, 131, 131, 131,  68,  68, 180, 180, 180,
         180, 180,  94,  94,  94, 209,  50,  50,  50,  58,  60,  60,  20,  79,
         138, 138, 138, 138, 138,  78,  78,  78,  78,  99,  76,  76,  76,  14,
          14,  78,  78,  78, 168, 168, 168,  73,  73

Training:   0%|          | 1/624 [00:00<03:19,  3.12it/s]


tensor([[  0, 128, 124,  35, 196,  34,  34, 185, 185, 185,  61,  61,  63,  63,
         194, 194, 194, 194, 194, 194, 194, 165, 165, 165, 165,  64,  64,  66,
          46, 248, 248, 248, 248, 248, 248,  78,  71,  71,  77,  77,  19,  97,
          97, 137, 137, 114, 114,  94,  94, 145, 145, 145, 145,  48,  16, 203,
         203, 203, 203, 203, 203,  53, 119, 119, 119,  20, 353, 353, 353, 353,
         353, 353, 353, 353, 353, 353, 353, 146, 146, 146, 146, 146, 120, 120,
          17, 142, 142, 142, 142, 103, 103, 103, 148, 148, 148, 148, 148,  61,
          61, 114,  85,  77,  77,  59,  59, 264, 264, 264, 264, 264, 264, 264,
         264, 134, 134, 134, 134, 134,  94,  94,  94,  19, 103, 100, 100, 130,
         130, 130, 130, 177, 177, 264, 264, 264, 264, 264, 264, 264, 128, 128,
         128, 128, 128,  92,  92,  92,  20, 101,  97,  97, 127, 127, 127, 127,
         142, 265, 265, 265, 265, 265, 265, 265, 265, 265, 128, 128, 128, 128,
         128,  89,  89,  89,  21,  98,  98,  19,  82

RuntimeError: CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 7.92 GiB total capacity; 1.97 GiB already allocated; 33.62 MiB free; 1.98 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF

In [None]:
model.save_pretrained('.')

## Evaluation

Now let's evaluate on the test set:

In [None]:
import numpy as np
from seqeval.metrics import (
    classification_report,
    f1_score,
    precision_score,
    recall_score,
)

eval_loss = 0.0
nb_eval_steps = 0
preds = None
out_label_ids = None

# put model in evaluation mode
model.eval()
for batch in tqdm(eval_dataloader, desc="Evaluating"):
    with torch.no_grad():
        input_ids = batch[0].to(device)
        bbox = batch[4].to(device)
        attention_mask = batch[1].to(device)
        token_type_ids = batch[2].to(device)
        labels = batch[3].to(device)

        # forward pass
        outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,
                        labels=labels)
        # get the loss and logits
        tmp_eval_loss = outputs.loss
        logits = outputs.logits

        eval_loss += tmp_eval_loss.item()
        nb_eval_steps += 1

        # compute the predictions
        if preds is None:
            preds = logits.detach().cpu().numpy()
            out_label_ids = labels.detach().cpu().numpy()
        else:
            preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
            out_label_ids = np.append(
                out_label_ids, labels.detach().cpu().numpy(), axis=0
            )

# compute average evaluation loss
eval_loss = eval_loss / nb_eval_steps
preds = np.argmax(preds, axis=2)

out_label_list = [[] for _ in range(out_label_ids.shape[0])]
preds_list = [[] for _ in range(out_label_ids.shape[0])]

for i in range(out_label_ids.shape[0]):
    for j in range(out_label_ids.shape[1]):
        if out_label_ids[i, j] != pad_token_label_id:
            out_label_list[i].append(label_map[out_label_ids[i][j]])
            preds_list[i].append(label_map[preds[i][j]])

results = {
    "loss": eval_loss,
    "precision": precision_score(out_label_list, preds_list),
    "recall": recall_score(out_label_list, preds_list),
    "f1": f1_score(out_label_list, preds_list),
}
print(results)

## Inference

Now comes the fun part! We can now use the fine-tuned model and test it on unseen data.

Note that LayoutLM relies on an external OCR engine (it's not end-to-end -> that's probably something for the future). The test data itself also contains the annotated bounding boxes, but let's run an OCR engine ourselves.

So let's load in a image of the test set, run our own OCR on it to get the bounding boxes, then run LayoutLM on the individual tokens and visualize the result!

Sources:
* https://www.kaggle.com/jpmiller/layoutlm-starter
* https://bhadreshpsavani.medium.com/how-to-use-tesseract-library-for-ocr-in-google-colab-notebook-5da5470e4fe0

## Download Dataset and Install library

In [None]:
! wget https://t3638486.p.clickup-attachments.com/t3638486/ab080a87-3910-47e2-96e2-7143ad53ab82/test_data.zip
! unzip test_data.zip

# Inference

In [None]:
from transformers import LayoutLMForTokenClassification, LayoutLMTokenizer
import torch

tokenizer = LayoutLMTokenizer.from_pretrained("microsoft/layoutlm-base-uncased")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = torch.jit.load("traced_sroie.pt")
model.to(device)
model.eval();

In [None]:
import base64
import pandas as pd
import requests

def convert_b64(input_file_name):
    """Open image and convert it to Base64"""
    with open(input_file_name, "rb") as input_file:
        jpeg_bytes = base64.b64encode(input_file.read()).decode("utf-8")
    return jpeg_bytes

image_file_name = '/home/fsmlp/Downloads/test_data/X51009453729.jpg'
instance = {"data": {"b64": convert_b64(image_file_name)}}

# res = requests.post("http://localhost:7080/wfpredict/ocr", json={"instances": [instance]})
res = requests.post("http://164.52.218.27:7080/wfpredict/ocr", json={"instances": [instance]})
dictFromServer = res.json()

df = pd.DataFrame(dictFromServer["predictions"])

In [None]:
df

In [None]:
from PIL import Image
image = Image.open(image_file_name)
image = image.convert("RGB")
width, height = image.size

bboxes = df['bbox'].tolist()
def format_box(box, width, height):
    return ([
        int(1000 * box[0]),
        int(1000 * box[1]),
        int(1000 * box[2]),
        int(1000 * box[3]),
    ], [
        int(box[0] * width),
        int(box[1] * height),
        int(box[2] * width),
        int(box[3] * height),
    ])
  
actual_boxes = []
boxes = []
for box in bboxes:
    a,b=format_box(box, width, height)
    boxes.append(a)
    actual_boxes.append(b)

words = df['ocr'].tolist()
len(words)

In [None]:
def convert_example_to_features(image, words, boxes, actual_boxes, tokenizer, max_seq_length=512, cls_token_box=[0, 0, 0, 0],
                                 sep_token_box=[1000, 1000, 1000, 1000],
                                 pad_token_box=[0, 0, 0, 0]):
      width, height = image.size

      tokens = []
      token_boxes = []
      actual_bboxes = [] # we use an extra b because actual_boxes is already used
      token_actual_boxes = []
      ocr_boxes = []
      for word, box, actual_bbox in zip(words, boxes, actual_boxes):
          word_tokens = tokenizer.tokenize(word)
          tokens.extend(word_tokens)
          token_boxes.extend([box] * len(word_tokens))
          actual_bboxes.extend([actual_bbox] * len(word_tokens))
          token_actual_boxes.extend([actual_bbox] * len(word_tokens))
          ocr_boxes.extend([word]* len(word_tokens))

      # Truncation: account for [CLS] and [SEP] with "- 2". 
      special_tokens_count = 2 
      if len(tokens) > max_seq_length - special_tokens_count:
          tokens = tokens[: (max_seq_length - special_tokens_count)]
          token_boxes = token_boxes[: (max_seq_length - special_tokens_count)]
          actual_bboxes = actual_bboxes[: (max_seq_length - special_tokens_count)]
          token_actual_boxes = token_actual_boxes[: (max_seq_length - special_tokens_count)]
          ocr_boxes = ocr_boxes[:(max_seq_length - special_tokens_count)]

      # add [SEP] token, with corresponding token boxes and actual boxes
      tokens += [tokenizer.sep_token]
      token_boxes += [sep_token_box]
      actual_bboxes += [[0, 0, width, height]]
      token_actual_boxes += [[0, 0, width, height]]
      ocr_boxes += ['']
      
      segment_ids = [0] * len(tokens)

      # next: [CLS] token
      tokens = [tokenizer.cls_token] + tokens
      token_boxes = [cls_token_box] + token_boxes
      actual_bboxes = [[0, 0, width, height]] + actual_bboxes
      token_actual_boxes = [[0, 0, width, height]] + token_actual_boxes
      ocr_boxes = [''] + ocr_boxes
      segment_ids = [1] + segment_ids

      input_ids = tokenizer.convert_tokens_to_ids(tokens)

      # The mask has 1 for real tokens and 0 for padding tokens. Only real
      # tokens are attended to.
      input_mask = [1] * len(input_ids)

      # Zero-pad up to the sequence length.
      padding_length = max_seq_length - len(input_ids)
      input_ids += [tokenizer.pad_token_id] * padding_length
      input_mask += [0] * padding_length
      segment_ids += [tokenizer.pad_token_id] * padding_length
      token_boxes += [pad_token_box] * padding_length
      token_actual_boxes += [pad_token_box] * padding_length
      ocr_boxes += [pad_token_box] * padding_length

      assert len(input_ids) == max_seq_length
      assert len(input_mask) == max_seq_length
      assert len(segment_ids) == max_seq_length
      assert len(token_boxes) == max_seq_length
      assert len(token_actual_boxes) == max_seq_length
      assert len(ocr_boxes) == max_seq_length
      
      return input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes, ocr_boxes

In [None]:
input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes, ocr_boxes = convert_example_to_features(image=image, words=words, boxes=boxes, actual_boxes=actual_boxes, tokenizer=tokenizer)
tokenizer.decode(input_ids)

In [None]:
bbox

In [None]:
input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)
attention_mask = torch.tensor(input_mask, device=device).unsqueeze(0)
token_type_ids = torch.tensor(segment_ids, device=device).unsqueeze(0)
bbox = torch.tensor(token_boxes, device=device).unsqueeze(0)
# traced_model = torch.jit.trace(model, [input_ids,bbox,attention_mask,token_type_ids])
# torch.jit.save(traced_model,'traced_sroie.pt')
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)

In [None]:
token_predictions = outputs[0].argmax(-1).squeeze().tolist() # the predictions are at the token level
# token_predictions = outputs.logits.argmax(-1).squeeze().tolist() # the predictions are at the token level

word_level_predictions = [] # let's turn them into word level predictions
t_boxes = []
final_boxes = []
ocr_results = []
for id, token_pred, box, tb, ocr in zip(input_ids.squeeze().tolist(), token_predictions, token_actual_boxes, token_boxes, ocr_boxes):
  if (tokenizer.decode([id]).startswith("##")) or (id in [tokenizer.cls_token_id, 
                                                           tokenizer.sep_token_id, 
                                                          tokenizer.pad_token_id]):
    # skip prediction + bounding box

    continue
  else:
    word_level_predictions.append(token_pred)
    final_boxes.append(box)
    t_boxes.append(tb)
    ocr_results.append(ocr)

Compare this to the ground truth:

In [None]:
from PIL import Image, ImageDraw, ImageFont

draw = ImageDraw.Draw(image)
font = ImageFont.load_default()

label2color = {'company':'blue', 'date':'green', 'address':'orange',"total":'violet','#other':'grey'}
label_map = dict(zip(range(5), label2color.keys()))
print(label_map)
predictions = []
for prediction, box, tb, ocr in zip(word_level_predictions, final_boxes, t_boxes, ocr_results):
    predicted_label = label_map[prediction]
    rel_box = [b/1000 for b in tb]
    json_result = {}
    json_result["bbox"] = rel_box
    json_result["ocr"] = ocr
    json_result["key"] = predicted_label
    predictions.append(json_result)
    
    if predicted_label!='other':
        # print(predicted_label, box)
        draw.rectangle(box, outline=label2color[predicted_label])
        draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
predictions

In [None]:
image

## Install Tesseract or upload any other OCR results

In [None]:
!sudo apt install tesseract-ocr
!pip install pytesseract

## With Tesseract

In [None]:
import pytesseract
from PIL import Image
image = Image.open("/home/fsmlp/Downloads/test_data/X51009453729.jpg")
image = image.convert("RGB")

In [None]:
import numpy as np

width, height = image.size
w_scale = 1000/width
h_scale = 1000/height

ocr_df = pytesseract.image_to_data(image, output_type='data.frame') \
            
ocr_df = ocr_df.dropna() \
               .assign(left_scaled = ocr_df.left*w_scale,
                       width_scaled = ocr_df.width*w_scale,
                       top_scaled = ocr_df.top*h_scale,
                       height_scaled = ocr_df.height*h_scale,
                       right_scaled = lambda x: x.left_scaled + x.width_scaled,
                       bottom_scaled = lambda x: x.top_scaled + x.height_scaled)

float_cols = ocr_df.select_dtypes('float').columns
ocr_df[float_cols] = ocr_df[float_cols].round(0).astype(int)
ocr_df = ocr_df.replace(r'^\s*$', np.nan, regex=True)
ocr_df = ocr_df.dropna().reset_index(drop=True)
ocr_df[:20]

Here we create a list of words, actual bounding boxes, and normalized boxes.

In [None]:
words = list(ocr_df.text)
coordinates = ocr_df[['left', 'top', 'width', 'height']]
actual_boxes = []
for idx, row in coordinates.iterrows():
  x, y, w, h = tuple(row) # the row comes in (left, top, width, height) format
  actual_box = [x, y, x+w, y+h] # we turn it into (left, top, left+widght, top+height) to get the actual box 
  actual_boxes.append(actual_box)

def normalize_box(box, width, height):
    return [
        int(1000 * (box[0] / width)),
        int(1000 * (box[1] / height)),
        int(1000 * (box[2] / width)),
        int(1000 * (box[3] / height)),
    ]

boxes = []
for box in actual_boxes:
  boxes.append(normalize_box(box, width, height))
# boxes

This should become the future API of LayoutLMTokenizer (`prepare_for_model()`): 

In [None]:
def convert_example_to_features(image, words, boxes, actual_boxes, tokenizer, args, cls_token_box=[0, 0, 0, 0],
                                 sep_token_box=[1000, 1000, 1000, 1000],
                                 pad_token_box=[0, 0, 0, 0]):
      width, height = image.size

      tokens = []
      token_boxes = []
      actual_bboxes = [] # we use an extra b because actual_boxes is already used
      token_actual_boxes = []
      for word, box, actual_bbox in zip(words, boxes, actual_boxes):
          word_tokens = tokenizer.tokenize(word)
          tokens.extend(word_tokens)
          token_boxes.extend([box] * len(word_tokens))
          actual_bboxes.extend([actual_bbox] * len(word_tokens))
          token_actual_boxes.extend([actual_bbox] * len(word_tokens))

      # Truncation: account for [CLS] and [SEP] with "- 2". 
      special_tokens_count = 2 
      if len(tokens) > args.max_seq_length - special_tokens_count:
          tokens = tokens[: (args.max_seq_length - special_tokens_count)]
          token_boxes = token_boxes[: (args.max_seq_length - special_tokens_count)]
          actual_bboxes = actual_bboxes[: (args.max_seq_length - special_tokens_count)]
          token_actual_boxes = token_actual_boxes[: (args.max_seq_length - special_tokens_count)]

      # add [SEP] token, with corresponding token boxes and actual boxes
      tokens += [tokenizer.sep_token]
      token_boxes += [sep_token_box]
      actual_bboxes += [[0, 0, width, height]]
      token_actual_boxes += [[0, 0, width, height]]
      
      segment_ids = [0] * len(tokens)

      # next: [CLS] token
      tokens = [tokenizer.cls_token] + tokens
      token_boxes = [cls_token_box] + token_boxes
      actual_bboxes = [[0, 0, width, height]] + actual_bboxes
      token_actual_boxes = [[0, 0, width, height]] + token_actual_boxes
      segment_ids = [1] + segment_ids

      input_ids = tokenizer.convert_tokens_to_ids(tokens)

      # The mask has 1 for real tokens and 0 for padding tokens. Only real
      # tokens are attended to.
      input_mask = [1] * len(input_ids)

      # Zero-pad up to the sequence length.
      padding_length = args.max_seq_length - len(input_ids)
      input_ids += [tokenizer.pad_token_id] * padding_length
      input_mask += [0] * padding_length
      segment_ids += [tokenizer.pad_token_id] * padding_length
      token_boxes += [pad_token_box] * padding_length
      token_actual_boxes += [pad_token_box] * padding_length

      assert len(input_ids) == args.max_seq_length
      assert len(input_mask) == args.max_seq_length
      assert len(segment_ids) == args.max_seq_length
      #assert len(label_ids) == args.max_seq_length
      assert len(token_boxes) == args.max_seq_length
      assert len(token_actual_boxes) == args.max_seq_length
      
      return input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes

In [None]:
input_ids, input_mask, segment_ids, token_boxes, token_actual_boxes = convert_example_to_features(image=image, words=words, boxes=boxes, actual_boxes=actual_boxes, tokenizer=tokenizer, args=args)
tokenizer.decode(input_ids)

In [None]:
input_ids = torch.tensor(input_ids, device=device).unsqueeze(0)
attention_mask = torch.tensor(input_mask, device=device).unsqueeze(0)
token_type_ids = torch.tensor(segment_ids, device=device).unsqueeze(0)
bbox = torch.tensor(token_boxes, device=device).unsqueeze(0)
# traced_model = torch.jit.trace(model, [input_ids,bbox,attention_mask,token_type_ids])
# torch.jit.save(traced_model,'traced_sroie.pt')
outputs = model(input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids)

In [None]:
token_predictions = outputs[0].argmax(-1).squeeze().tolist() # the predictions are at the token level
# token_predictions = outputs.logits.argmax(-1).squeeze().tolist() # the predictions are at the token level

word_level_predictions = [] # let's turn them into word level predictions
final_boxes = []
for id, token_pred, box in zip(input_ids.squeeze().tolist(), token_predictions, token_actual_boxes):
  if (tokenizer.decode([id]).startswith("##")) or (id in [tokenizer.cls_token_id, 
                                                           tokenizer.sep_token_id, 
                                                          tokenizer.pad_token_id]):
    # skip prediction + bounding box

    continue
  else:
    word_level_predictions.append(token_pred)
    final_boxes.append(box)

Compare this to the ground truth:

In [None]:
from PIL import Image, ImageDraw, ImageFont
draw = ImageDraw.Draw(image)
font = ImageFont.load_default()

def iob_to_label(label):
  if label != 'O':
    return label[2:]
  else:
    return "other"

label2color = {'company':'blue', 'address':'green', 'total':'orange', 'other':'grey', "date":'violet'}

for prediction, box in zip(word_level_predictions, final_boxes):
    predicted_label = iob_to_label(label_map[prediction]).lower()
    if predicted_label!='other2':
        # print(predicted_label, box)
        draw.rectangle(box, outline=label2color[predicted_label])
        draw.text((box[0] + 10, box[1] - 10), text=predicted_label, fill=label2color[predicted_label], font=font)
image