Resources used to support this work:

* paper: https://arxiv.org/pdf/2309.10952.pdf
* model: https://huggingface.co/gpt2

* https://www.kaggle.com/code/aliabdin1/llm-04a-fine-tuning-llms
* https://huggingface.co/docs/transformers/tasks/language_modeling

In [2]:
import os
import re
import pandas as pd
from pydantic import BaseModel, root_validator
from typing import Optional, Dict
import numpy as np

import json

import tempfile
import transformers as tr
from difflib import SequenceMatcher
from functools import partial


data_loc = "/kaggle/input/sroie-datasetv2/SROIE2019/train"
os.environ['WANDB_DISABLED'] = "True"


In [4]:
schema = {
    "type": "object",
    "properties": {
        "TOTAL": {"type": "string"},
        "DATE": {"type": "string"},
        "ADDRESS": {"type": "string"},
        "COMPANY": {"type": "string"},
    },
    "required": ["TOTAL","DATE","ADDRESS","COMPANY"],
}

empty_schema = {
    "TOTAL": "",
   "DATE": "",
   "ADDRESS": "",
   "COMPANY": "",
}

In [None]:
# Setting up model and saving temp files 

tmpdir = tempfile.TemporaryDirectory()
local_training_root = tmpdir.name

model_checkpoint = "gpt2"

tokenizer = tr.AutoTokenizer.from_pretrained(
    model_checkpoint, cache_dir="../working/cache", additional_special_tokens = ['<Document>','</Document>','<Task>','</Task>','<Extraction>','</Extraction>']
)
tokenizer.pad_token = tokenizer.eos_token

# Loading model and resizing it for new tokens
checkpoint_name = "test-trainer"
local_checkpoint_path = os.path.join(local_training_root, checkpoint_name)
model = tr.AutoModelForCausalLM.from_pretrained(
    model_checkpoint, cache_dir="../working/cache"
)
model.resize_token_embeddings(len(tokenizer),pad_to_multiple_of=4)

In [5]:
# Methods to help load and pre-process data

def load_data_from_kaggle(data_loc):
    bbox_loc = data_loc+'/box'
    entities_loc = data_loc+'/entities'
    
    all_bboxes_files = os.listdir(bbox_loc)
    all_bboxes = []
    for bbox_file in all_bboxes_files:
        with open(os.path.join(bbox_loc,bbox_file),'r') as f:
            bbox = f.readlines()
            all_bboxes.append(bbox)
    
    bboxes, words = process_bboxes(all_bboxes)
        
    all_entities_files = os.listdir(entities_loc)
    all_entities = []
    for entities_file in all_entities_files:
        with open(os.path.join(entities_loc,entities_file),'r') as f:
            entities = f.read()
            all_entities.append(entities)
            
    file_ids = [f.replace('.txt','') for f in all_bboxes_files]
    
    return bboxes, words, all_entities, file_ids


def process_bboxes(all_bboxes):
    bboxes = []
    words = []
    for doc in all_bboxes:
        doc_bboxes = []
        doc_words = []
        for line in doc:
            line = line.replace(r'\n', '')
            m = re.match(r"(\d*),(\d*),(\d*),(\d*),(\d*),(\d*),(\d*),(\d*),(.*)", line)
            doc_bboxes.append(list(map(int, m.groups()[0:8])))
            doc_words.append(m.groups()[8])
        bboxes.append(doc_bboxes)
        words.append(doc_words)
    
    return bboxes, words

In [7]:
# Loading the data from kaggle, processing it and then saving a local json

bboxes, words, entities, ids = load_data_from_kaggle(data_loc)

data_json = []
for bbox, word, entity, file_id in zip(bboxes, words, entities, ids):
    data_json.append({'bbox':bbox, 'word':word, 'entity':json.loads(entity), 'key':file_id})
    
with open('sorie-datasetv3.json','w') as f:
    json.dump(data_json, f)

In [10]:
# Preprocessing util methods

def combine_bbox(bboxes):
    numpy_bbox_matrix = np.array(bboxes)
    return [
        numpy_bbox_matrix[:,0].min(),
        numpy_bbox_matrix[:, 1].min(),
        numpy_bbox_matrix[:, 2].max(),
        numpy_bbox_matrix[:, 3].max(),
    ]


def convert_labels_entity(labels, words):
    entities = {}
    unique_labels = np.unique(labels)
    label_array = np.array(labels)
    word_array = np.array(words)
    for label in unique_labels:
        if label!="O":
            all_line = word_array[label_array==label]
            label_value = ' '.join(all_line)
            entities[label] = label_value
    return entities



def get_bins_from_list(cord_list, bins=100):
    return pd.cut(cord_list, bins=bins, labels=np.arange(1, bins+1)).to_list()


def spliting_criteria_len(text, tokenizer, max_length):
    token_lengths = len(tokenizer(text)['input_ids'])
    return token_lengths<=max_length



def assign_line_label(line: str, entity: dict):
    line_set = line.replace(",", "").strip().split()
    for column, value in entity.items():
        entity_values = value.replace(",", "").strip()
        entity_set = entity_values.split()
        
        
        matches_count = 0
        for l in line_set:
            if any(SequenceMatcher(a=l, b=b).ratio() > 0.8 for b in entity_set):
                matches_count += 1
            
            if (column.upper() == 'ADDRESS' and (matches_count / len(line_set)) >= 0.5) or \
               (column.upper() != 'ADDRESS' and (matches_count == len(line_set))) or \
               matches_count == len(entity_set):
                return column.upper()

    return "O"


def assign_labels(words: list, bboxes: list, entity: dict):
    max_area = {"TOTAL": (0, -1), "DATE": (0, -1)}  # Value, index
    already_labeled = {"TOTAL": False,
                       "DATE": False,
                       "ADDRESS": False,
                       "COMPANY": False,
                       "O": False
    }

    # Go through every line in $words and assign it a label
    labels = []
    for i, pair in enumerate(zip(words, bboxes)):
        line, bbox = pair
        label = assign_line_label(line=line, entity=entity)

        already_labeled[label] = True
        if (label == "ADDRESS" and already_labeled["TOTAL"]) or \
           (label == "COMPANY" and (already_labeled["DATE"] or already_labeled["TOTAL"])):
            label = "O"

        # Assign to the largest bounding box
        if label in ["TOTAL", "DATE"]:
            area = (bbox[4] - bbox[0]) + (bbox[5] - bbox[1])

            if max_area[label][0] < area:
                max_area[label] = (area, i)

            label = "O"

        labels.append(label)

    labels[max_area["DATE"][1]] = "DATE"
    labels[max_area["TOTAL"][1]] = "TOTAL"
    
    return labels

In [13]:
# Data models to store document level data and splits

class SorieDocument(BaseModel):
    key: str
    bbox: list
    word: list
    entity: Dict[str,str]
    bbox_quantized: list
    labels: list
        
    @root_validator(pre=True)
    def compute_bbox_quantized(cls, values):
        bbox_means = [[np.mean([b[0],b[2],b[4],b[6]]),np.mean([b[1],b[3],b[5],b[7]])] for b in values.get('bbox')]
        values['bbox_quantized'] = np.apply_along_axis(get_bins_from_list, arr=bbox_means, axis=0).tolist()
        return values
    
    @root_validator(pre=True)
    def compute_labels_embeddings(cls, values):
        values['labels']=assign_labels(words=values.get('word'), bboxes=values.get('bbox'), entity=values.get('entity'))
        return values


class SplitSorieDocument(BaseModel):
    key: str
    index: int
    bbox: list
    word: list
    bbox_quantized: list
    labels: list
    entity: dict
        
    @root_validator(pre=True)
    def compute_entities(cls, values):
        values['entity']=convert_labels_entity(labels=values.get('labels'), words=values.get('word'))
        return values

In [15]:
# Preprocess data for modelling. include schema genertion and data splitting

def create_string(txt, bbox):
    return f"{txt} {bbox[0]}|{bbox[1]}"


def create_word_bbox_text(words, bboxes):
    combined = [create_string(word, bbox) for bbox, word in zip(bboxes, words)]
    combined_rows = "\n".join(combined)
    return f"""<Document>\n{combined_rows}\n</Document>\n<Task>From the document, extract the text values and tags of the following entities:{json.dumps(empty_schema)}</Task>\n<Extraction>"""


def split_doc_into_splits(doc, spliting_criteria):
    split_docs = []
    
    words = doc.word
    bboxes_quantized = doc.bbox_quantized

    start_index = 0
    split_index = 0
    for index, _ in enumerate(words):
        input_text = create_word_bbox_text(words[start_index:index+1], bboxes_quantized[start_index:index+1])
        if index==len(words)-1:
            split_docs.append(
                SplitSorieDocument(
                    word=words[start_index:index+1],
                    bbox=doc.bbox[start_index:index+1],
                    index=split_index,
                    key=doc.key,
                    bbox_quantized = bboxes_quantized[start_index:index+1],
                    labels=doc.labels[start_index:index+1],
                )
            )
            start_index = index
            split_index+=1
        else:
            if spliting_criteria(input_text):
                continue
            else:
                split_docs.append(
                SplitSorieDocument(
                    word=words[start_index:index],
                    bbox=doc.bbox[start_index:index],
                    index=split_index,
                    key=doc.key,
                    bbox_quantized = bboxes_quantized[start_index:index],
                    labels=doc.labels[start_index:index],
                )
            )
                start_index = index
                split_index+=1

    return split_docs

In [None]:
all_docs = [SorieDocument.parse_obj(doc) for doc in data_json]
all_splits = [split for doc in all_docs for split in split_doc_into_splits(doc, spliting_criteria=partial(spliting_criteria_len,tokenizer=tokenizer, max_length=256))]

In [19]:
from torch.utils.data import Dataset

class SROIE_Dataset(Dataset):
    def __init__(self, split_list, tokenizer):
        self.split_list = split_list
        self.tokenizer = tokenizer
        
    def __len__(self):
        return len(self.split_list)
    
    def __getitem__(self,index):
        input_text = create_word_bbox_text(self.split_list[index].word, self.split_list[index].bbox_quantized)
        outputs = json.dumps(self.split_list[index].entity)+'</Extraction>'
        encoding = tokenizer('\n'.join([input_text,outputs]), return_tensors="pt", truncation=True, padding=True,)
        return {i:v[0] for i,v in encoding.items()}

In [20]:
dataset = SROIE_Dataset(all_splits, tokenizer)

In [21]:
training_args = tr.TrainingArguments(
    local_checkpoint_path,
    num_train_epochs=3,  # default number of epochs to train is 3
    per_device_train_batch_size=4,
    logging_steps=25,
    eval_steps= 50,
    optim="adamw_torch",
    report_to=None,
)

data_collator = tr.DataCollatorForLanguageModeling(tokenizer=tokenizer,mlm=False)

trainer = tr.Trainer(
    model,
    training_args,
    train_dataset=dataset,
    tokenizer=tokenizer,
    data_collator=data_collator,
)

Using the `WANDB_DISABLED` environment variable is deprecated and will be removed in v5. Use the --report_to flag to control the integrations used for logging result (for instance --report_to none).


In [22]:
trainer.train()

You're using a GPT2TokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.


ValueError: You should supply an encoding or a list of encodings to this method that includes input_ids, but you provided []

In [23]:
index = 11
item = tokenizer(create_word_bbox_text(all_splits[index].word,all_splits[index].bbox_quantized),return_tensors="pt", truncation=True, padding=True,)

In [None]:
result_example = tokenizer.batch_decode(trainer.model.generate(input_ids=item["input_ids"].to("cuda"),attention_mask= item["attention_mask"].to("cuda"),max_new_tokens=100, do_sample=True, top_k=10, top_p=0.95), skip_special_tokens=True)

In [None]:
print(create_word_bbox_text(all_splits[index].word,all_splits[index].bbox_quantized))

In [None]:
json.dumps(all_splits[index].entity)

In [None]:
print(result_example[0])

## Post Processing

In [None]:
import re
def find_results(text, schema):
    reg_str = r"<Extraction>(.*?)</Extraction>"
    text = text.replace("\n","")
    extracted_result = re.findall(reg_str, text)

    json_extracted = []
    for e in extracted_result:
        try:
            json_object = json.loads(e)
            validate(json_object,schema)
        except:
            continue
        json_extracted.append(json_object)
    return json_extracted[0] if json_extracted else {}

In [None]:
a = find_results(result_example[0], schema)