<a href="https://colab.research.google.com/github/Mathijsgeelen/Sentence-extraction-annotation/blob/main/Extract_annotate.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Required installations

In [None]:
pip install layoutparser

In [None]:
pip install "layoutparser[effdet]"

In [None]:
pip install layoutparser torchvision && pip install "git+https://github.com/facebookresearch/detectron2.git@v0.5#egg=detectron2"

In [None]:
pip install "layoutparser[paddledetection]"

In [None]:
pip install "layoutparser[ocr]"

In [None]:
pip install vila

In [None]:
!apt-get install poppler-utils

In [None]:
pip install torch

In [None]:
! apt install tesseract-ocr
! apt install libtesseract-dev

In [None]:
! pip install Pillow
! pip install pytesseract

In [None]:
pip install fitz

In [None]:
pip install PyMuPDF

In [None]:
pip install nltk

In [None]:
pip install truecase

# Required imports

In [None]:
from google.colab import drive
drive.mount("/content/drive/")

In [None]:
import torch
import fitz
from torch.optim import Adam, AdamW
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
from sklearn.model_selection import train_test_split
from tqdm import tqdm, trange
from transformers import AutoModelForTokenClassification, AdamW
import random
import os
import tensorflow as tf
import numpy as np
from transformers import RobertaTokenizer, DebertaTokenizer, DebertaForTokenClassification, RobertaForTokenClassification, AutoTokenizer, AutoModelForTokenClassification, pipeline
import pandas as pd
import pytesseract
import layoutparser as lp 
import pdf2image
ocr_agent = lp.TesseractAgent(languages='eng')
from vila.pdftools.pdf_extractor import PDFExtractor
from vila.predictors import HierarchicalPDFPredictor
import truecase
import nltk
#download punkt
nltk.download('punkt')
import spacy
from spacy import displacy

In [None]:
seed = 1
torch.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
np.random.seed(seed)
random.seed(seed)
os.environ['PYTHONHASHSEED'] = str(seed)
tf.random.set_seed(seed)
TF_DETERMINISTIC_OPS=seed
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

tokenizer = DebertaTokenizer.from_pretrained("microsoft/deberta-base")

# Get the GPU device name.
device_name = tf.test.gpu_device_name()

# The device name should look like the following:
if device_name == '/device:GPU:0':
    print('Found GPU at: {}'.format(device_name))
else:
    raise SystemError('GPU device not found')

# File paths
Add the file path to the DeBERTa sentence boundary detection model which can be downloaded from https://github.com/Mathijsgeelen/Sentence-extraction-annotation/blob/main/DeBERTa_legal_SBD , or use your own SBD model. Upload the model to your google drive or manually and change the paths below.

Change file_name to the desired PDF file to be analysed

In [32]:
model_sbd = torch.load("/content/drive/My Drive/extract-annotate/DeBERTa_legal_SBD")
file_path = "/content/drive/My Drive/extract-annotate/pdf/legal.pdf"

# Word and block-level DLA example visualisation

In [None]:
pdf_extractor = PDFExtractor("pdfplumber")
page_tokens, page_images = pdf_extractor.load_tokens_and_image(file_path)

vision_model = lp.EfficientDetLayoutModel("lp://PubLayNet")
pdf_predictor = HierarchicalPDFPredictor.from_pretrained("allenai/hvila-row-layoutlm-finetuned-docbank")

ind = 0
blocks = vision_model.detect(page_images[ind])
page_tokens[ind].annotate(blocks=blocks)
pdf_data = page_tokens[ind].to_pagedata().to_dict()
predicted_tokens = pdf_predictor.predict(pdf_data)
lp.draw_box(page_images[ind], predicted_tokens, box_width=3, box_alpha=0.25) 

In [None]:
import pdf2image
image = pdf2image.convert_from_path(file_path)
ind = 0
image = image[ind]

model_layout = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config',
                                 extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.50],
                                 label_map={0: "Text", 1: "Title", 2: "List"})

layout = model_layout.detect(image)
#lp.draw_box(image, layout, box_width=3)
text_blocks = lp.Layout([b for b in layout if b.type=='Text' or b.type=="List"])

h, w = np.array(image).shape[:2]

left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image)

left_blocks = text_blocks.filter_by(left_interval, center=True)
left_blocks.sort(key = lambda b:b.coordinates[1])

right_blocks = [b for b in text_blocks if b not in left_blocks]
right_blocks.sort(key = lambda b:b.coordinates[1])

# And finally combine the two list and add the index
# according to the order
text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])

lp.draw_box(image, text_blocks,
            box_width=3,
            show_element_id=True)





In [None]:
for block in text_blocks:
    segment_image = (block
                       .pad(left=10, right=10, top=10, bottom=10)
                       .crop_image(np.array(image)))
        # add padding in each image segment can help
        # improve robustness

    text = ocr_agent.detect(segment_image)
    block.set(text=text, inplace=True)

for txt in text_blocks.get_texts():
    print(txt, end='\n---\n')

# Functions for text extraction module

In [27]:
def extract_text_blocks(file_name, page_number, pdf_predictor, model_layout):

  ### transforms desired page from pdf file to an image.
  ### Function pdf_predictor is a pre-trained LayoutLM model for detecting words from an image
  ### Function model_layout is a pre-trained Mask-RCNN for text block detection from an image

  image = pdf2image.convert_from_path(file_name)
  ind = page_number
  image = image[ind]

  ### Extracts all the words from a pdf page and stores this in predicted_tokens
  ### Predicted tokens holds all words from left-to-right, top-to-bottom of a page
  pdf_extractor = PDFExtractor("pdfplumber")
  page_tokens, page_images = pdf_extractor.load_tokens_and_image(file_name)
  blocks = model_layout.detect(page_images[ind])
  page_tokens[ind].annotate(blocks=blocks)
  pdf_data = page_tokens[ind].to_pagedata().to_dict()
  predicted_tokens = pdf_predictor.predict(pdf_data)


  ### Extracts the text blocks from an image and stores this in text_blocks
  ### Text blocks is an iterable containing all the identified text blocks
  layout = model_layout.detect(image)
  #lp.draw_box(image, layout, box_width=3)
  text_blocks = lp.Layout([b for b in layout if b.type=='Text' or b.type=='List'])

  h, w = np.array(image).shape[:2]

  left_interval = lp.Interval(0, w/2*1.05, axis='x').put_on_canvas(image)

  left_blocks = text_blocks.filter_by(left_interval, center=True)
  left_blocks.sort(key = lambda b:b.coordinates[1])

  right_blocks = [b for b in text_blocks if b not in left_blocks]
  right_blocks.sort(key = lambda b:b.coordinates[1])

  # And finally combine the two list and add the index
  # according to the order
  text_blocks = lp.Layout([b.set(id = idx) for idx, b in enumerate(left_blocks + right_blocks)])

  lp.draw_box(image, text_blocks,
              box_width=3,
              show_element_id=True)
  
  return text_blocks, predicted_tokens, image





# Functions for SBD

In [28]:
def create_attention_masks(sent):
  # Create attention masks
  
  # Create the attention mask.
  #   - If a token ID is 0, then it's padding, set the mask to 0.
  #   - If a token ID is > 0, then it's a real token, set the mask to 1.
  attention_masks = [int(token_id > 0) for token_id in sent]
  
  # Store the attention mask for this sentence.

  return attention_masks

def create_sentences(tokenized_text, start_index):
  tokenized_sent = [101]
  end_index = start_index + 510

  if start_index < len(tokenized_text) - 510:

    tmp = tokenizer.convert_tokens_to_ids(tokenized_text[start_index:end_index])
    [tokenized_sent.append(token) for token in tmp]

  else:
    tmp = tokenizer.convert_tokens_to_ids(tokenized_text[start_index:])
    [tokenized_sent.append(token) for token in tmp]

  tokenized_sent.append(102)

  #if lenght is not long enough add padding
  for i in range(len(tokenized_sent), 512):
    tokenized_sent.append(0)
  attention_masks = create_attention_masks(tokenized_sent)

  return tokenized_sent, attention_masks


def create_tensors(seq_list, masks):
  train_inputs = torch.tensor(seq_list)

  train_masks = torch.tensor(masks)


  return train_inputs, train_masks

def run_tensors(start_index,text):
  tokenized_sent, attention_masks = create_sentences(text, start_index)
  train_inputs , train_masks = create_tensors(tokenized_sent, attention_masks)
  train_inputs = np.reshape(train_inputs,(1,512))
  train_masks = np.reshape(train_masks,(1,512))

  return train_inputs, train_masks

def create_loader(train_inputs, train_masks):
  # Create the DataLoader for our training set.
  train_data = TensorDataset(train_inputs, train_masks)
  train_sampler = RandomSampler(train_data)
  train_dataloader = DataLoader(train_data, batch_size=1,sampler=train_sampler)

  return train_dataloader


def create_dict(preds, masks, input):
  
  total_count = 0
  match_count = 0
  mismatch = False
  pred_flat = np.argmax(preds, axis=2).flatten()
  masks_flat = masks.flatten ()
  input_tkns_flat = input.flatten()
  preds_dict["preds"].append(pred_flat)

  i = 0
  while i != len(pred_flat):
    #for i, mask_token in enumerate(masks_flat):
      #check if attention mask is an actual interesting token, i.e. not 0
    begin_pred = i
    end_pred = i


    while pred_flat[begin_pred] !=2 and begin_pred >= 1:
      begin_pred -=1
    while pred_flat[end_pred] != 3 and end_pred < len(pred_flat)-1:
      end_pred += 1


    inputs_pred = input_tkns_flat.tolist()[begin_pred:end_pred+1]
    tokens_pred = tokenizer.convert_ids_to_tokens(inputs_pred)
    text = tokenizer.convert_tokens_to_string(tokens_pred)
    
    preds_dict["inputs_pred"].append(inputs_pred)
    preds_dict["tokens_pred"].append(tokens_pred)
    preds_dict["pred_begin_end"].append([begin_pred,end_pred])
    preds_dict["text"].append(text)

    i = end_pred+1

      
def predict_boundary(model, start_index, text):
  device = torch.device("cuda")
  # VALIDATION on validation set
  model.eval()
  attention_ids, input_token_ids = [], []
  text , masks = run_tensors(start_index,text)
  predictions = []
  validation_dataloader = create_loader(text ,masks)
  for batch in validation_dataloader:
      batch = tuple(t.to(device) for t in batch)
      b_input_ids, b_input_mask = batch

      with torch.no_grad():
          output2 = model(b_input_ids, token_type_ids=None,
                          attention_mask=b_input_mask)

      logits = output2[0]
      logits = logits.detach().cpu().numpy()
      attention_ids = b_input_mask.to('cpu').numpy()
      input_token_ids = b_input_ids.to('cpu').numpy()

      pred_flat = np.argmax(logits, axis=2).flatten()

      

      return pred_flat
      




# Functions for extracting coordinates

In [33]:
def extract_boundaries(tokenizer, file_name, page_number, pdf_predictor, model_layout,model_sbd):
  ### Uses pre-trained DeBERTa for SBD and loops over these predictions to extract the sentence boundaries
  ### The text blocks are fed block by block into the SBD model, first being transformed in the proper format and then
  ### for each token the class is predicted.
  ### All predicted sentence boundaries are saved in preds_dict, which is a dictionary containing all the boundaries

  ### First 2 words prior to and after the sentence boundary are extracted and stored in preds_dict
  ### preds_dict is a dictionary ultimately containing all the extracted sentence boundaries.

  preds_dict = {"boundaries":{},"start_end":[]}
  entry = 0

  # Extract text_blocks and tokens from image, takes as input the file_path, the desired page number, pre-trained word-level extractor
  # and pre-trained block level extractor.
  text_blocks, predicted_tokens, image = extract_text_blocks(file_name, page_number, pdf_predictor, model_layout)

  # Extract text from blocks using OCR
  for block in text_blocks:
    segment_image = (block
                       .pad(left=10, right=10, top=10, bottom=10)
                       .crop_image(np.array(image)))
        # add padding in each image segment can help
        # improve robustness

    text = ocr_agent.detect(segment_image)
    block.set(text=text, inplace=True)

  # Loop over the extracted text from the text block.
  # Predict sentence boundaries using predict_boundary
  # Once a sentence boundary is encountered, extract the first 2 words prior to and 2 words following the sentence boundary and store this
  # in the preds_dict. If no words prior or following, store what's possible.
  for txt in text_blocks.get_texts():
    text_tokenized = tokenizer.tokenize(txt)
    preds = predict_boundary(model_sbd,0, text_tokenized)
    end = 0
    start = 0
    s = 1
    e = 1
    while len(tokenizer.convert_tokens_to_string(text_tokenized[:s]).replace("\n"," ").replace("\x0c","").split(" ")) < 4 and s < len(text_tokenized)-1:
      s+=1
    while len(tokenizer.convert_tokens_to_string(text_tokenized[:s]).replace("\n"," ").replace("\x0c","").split(" ")) < 5 and s < len(text_tokenized)-1:
      s+=1
    while len(tokenizer.convert_tokens_to_string(text_tokenized[-e:]).replace("\n"," ").replace("\x0c","").split(" ")) < 4 and e < len(text_tokenized)-1:
      e+=1
    sent = tokenizer.convert_tokens_to_string(text_tokenized[-e:]).replace("\n"," ").replace("\x0c","").split(" ")
    while len(tokenizer.convert_tokens_to_string(text_tokenized[-e:]).replace("\n"," ").replace("\x0c","").split(" ")) < 5+sent.count("") and e < len(text_tokenized)-1:
      e+=1

    start_sent = tokenizer.convert_tokens_to_string(text_tokenized[:s-1]).replace("\n"," ").replace("\x0c","").rstrip().lstrip()
    end_sent = tokenizer.convert_tokens_to_string(text_tokenized[-e+1:]).replace("\n"," ").replace("\x0c","").rstrip().lstrip()
    preds_dict["start_end"].append([start_sent,end_sent])

    for i,j in enumerate(preds):
      if j == 3:
        end = i
        left = len(text_tokenized) - end
        tokens = []
        t = 1
        while len(tokenizer.convert_tokens_to_string(text_tokenized[end-t:end]).replace("\n", " ").replace("\x0c","").rstrip().lstrip().split(" ")) < 3 and end-t > 1:
          t += 1
        while len(tokenizer.convert_tokens_to_string(text_tokenized[end-t:end]).replace("\n", " ").replace("\x0c","").rstrip().lstrip().split(" ")) < 4 and end-t > 1:
          t += 1

        text = tokenizer.convert_tokens_to_string(text_tokenized[end-t+1:end]).replace("\n", " ").replace("\x0c","").replace("  "," ").replace("   "," ").lstrip().strip()
        tokens.append(text)

        t = 1
        while len(tokenizer.convert_tokens_to_string(text_tokenized[end:end+t]).replace("\n"," ").replace("\x0c","").rstrip().split(" ")) < 3 and left-t > 1:
          t += 1
        while len(tokenizer.convert_tokens_to_string(text_tokenized[end:end+t]).replace("\n"," ").replace("\x0c","").rstrip().split(" ")) < 4 and left-t > 1:
          t += 1
        text = tokenizer.convert_tokens_to_string(text_tokenized[end:end+t-1]).replace("\n"," ").replace("\x0c","").replace("  "," ").replace("   "," ").lstrip().strip()
        tokens.append(text)
        preds_dict["boundaries"][entry] = tokens

        entry += 1

      elif j == 2:
        if preds[i-1] != 3:
          end = i-1
          left = len(text_tokenized) - end
          tokens = []
          t = 1
          while len(tokenizer.convert_tokens_to_string(text_tokenized[end-t:end]).replace("\n", " ").replace("\x0c"," ").rstrip().split(" ")) < 3 and end-t > 1:
            t += 1
          while len(tokenizer.convert_tokens_to_string(text_tokenized[end-t:end]).replace("\n", " ").replace("\x0c"," ").rstrip().split(" ")) < 4 and end-t > 1:
            t += 1
          
          text = tokenizer.convert_tokens_to_string(text_tokenized[end-t+1:end]).replace("\n", " ").replace("\x0c"," ").replace("  "," ").replace("   "," ").lstrip().strip()
          tokens.append(text)

          t = 1
          while len(tokenizer.convert_tokens_to_string(text_tokenized[end:end+t]).replace("\n"," ").replace("\x0c"," ").rstrip().split(" ")) < 3 and left-t > 1:
            t += 1
          while len(tokenizer.convert_tokens_to_string(text_tokenized[end:end+t]).replace("\n"," ").replace("\x0c"," ").rstrip().split(" ")) < 4 and left-t > 1:
            t += 1
          text = tokenizer.convert_tokens_to_string(text_tokenized[end:end+t-1]).replace("\n"," ").replace("\x0c"," ").replace("  "," ").replace("   "," ").lstrip().strip()
          tokens.append(text)

          preds_dict["boundaries"][entry] = tokens
          entry += 1
    
    text_tokenized2 = list(filter(lambda a: a != 'Č', text_tokenized))
    text_tokenized2 = list(filter(lambda a: a != 'Ċ', text_tokenized2)) 

    # Slight addition to the pre-trained SBD model. If a text blocks ends with ';', als predict a sentence boundary
    # This is in line with Savelka et al. (2017) their documentation, but not correctly captured by the SBD model
    if len(text_tokenized2) >= 2:
      if text_tokenized2[-1] == ";" or text_tokenized2[-2] == ";":
        t = 1
        tokens = []
        while len(tokenizer.convert_tokens_to_string(text_tokenized2[-1-t:]).split(" ")) < 3 and abs(-1-t) < len(text_tokenized2)-1:
          t += 1
        while len(tokenizer.convert_tokens_to_string(text_tokenized2[-1-t:]).split(" ")) < 4 and abs(-1-t) < len(text_tokenized2)-1:
          t += 1
        #####
        text = tokenizer.convert_tokens_to_string(text_tokenized2[-1-t+1:]).replace("\n", " ").replace("\x0c"," ").replace("  "," ").replace("   "," ").lstrip().strip()
        tokens.append(text)
        tokens.append(" ")
        preds_dict["boundaries"][entry] = tokens
        entry+=1

    for i in range(len(text_tokenized)-2):
      if tokenizer.convert_tokens_to_string(text_tokenized[i]) == ";":
        if tokenizer.convert_tokens_to_string(text_tokenized[i+1]) == "\x0c" or tokenizer.convert_tokens_to_string(text_tokenized[i+1]) == "\n":
          t = 1
          tokens = []
          while len(tokenizer.convert_tokens_to_string(text_tokenized[i-t:i+1]).split(" ")) < 3 and t+1 < len(text_tokenized)-1:
            t+=1
          while len(tokenizer.convert_tokens_to_string(text_tokenized[i-t:i+1]).split(" ")) < 4 and t+1 < len(text_tokenized)-1:
            t+=1
          ###
          
          text = tokenizer.convert_tokens_to_string(text_tokenized[i-t:i+1]).replace("\n", " ").replace("\x0c"," ").replace("  "," ").replace("   "," ").lstrip().strip()
          tokens.append(text)
          tokens.append(" ")
          preds_dict["boundaries"][entry] = tokens
          entry+=1
        elif tokenizer.convert_tokens_to_string(text_tokenized[i+1]) == " or" and (tokenizer.convert_tokens_to_string(text_tokenized[i+2]) == "\x0c" or tokenizer.convert_tokens_to_string(text_tokenized[i+2]) == "\n"):
          t = 1
          tokens = []
          while len(tokenizer.convert_tokens_to_string(text_tokenized[i-t:i+2]).split(" ")) < 3 and t+1 < len(text_tokenized)-1:
            t+=1
          while len(tokenizer.convert_tokens_to_string(text_tokenized[i-t:i+2]).split(" ")) < 4 and t+1 < len(text_tokenized)-1:
            t+=1
          ##
          text = tokenizer.convert_tokens_to_string(text_tokenized[i-t+1:i+2]).replace("\n", " ").replace("\x0c"," ").replace("  "," ").replace("   "," ").lstrip().strip()
          tokens.append(text)
          tokens.append(" ")
          preds_dict["boundaries"][entry] = tokens
          entry+=1
      if tokenizer.convert_tokens_to_string(text_tokenized[i]) == "." and tokenizer.convert_tokens_to_string(text_tokenized[i+1]) == "\n" and tokenizer.convert_tokens_to_string(text_tokenized[i+2]) == "\x0c":
        t = 1
        tokens = []
        while len(tokenizer.convert_tokens_to_string(text_tokenized[i-t:i+1]).split(" ")) < 3 and t+1 < len(text_tokenized)-1:
          t+=1
        while len(tokenizer.convert_tokens_to_string(text_tokenized[i-t:i+1]).split(" ")) < 4 and t+1 < len(text_tokenized)-1:
          t+=1
        ##
        text = tokenizer.convert_tokens_to_string(text_tokenized[i-t+1:i+1]).replace("\n", " ").replace("\x0c"," ").replace("  "," ").replace("   "," ").lstrip().strip()
        tokens.append(text)
        tokens.append(" ")
        preds_dict["boundaries"][entry] = tokens
        entry+=1

  
  return preds_dict, predicted_tokens



def extract_indexes_boundaries(preds_dict, predicted_tokens):

  ### Extract the index from the predicted boundaries in preds_dict by looking it up in the predicted_tokens list
  ### Works by string matching. If predicted boundary tokens exist in the predicted_tokens, then extract desired index

  indexes = []
  for text in range(len(preds_dict["boundaries"])):
    for i in range(len(predicted_tokens)-3):
      if preds_dict["boundaries"][text][1] != " " and preds_dict["boundaries"][text][1] !="":
        if len(preds_dict["boundaries"][text][0].split()) >= 2:
          if predicted_tokens[i].text == preds_dict["boundaries"][text][0].split()[0].replace(" ","") and predicted_tokens[i+1].text == preds_dict["boundaries"][text][0].split()[1].replace(" ","") and predicted_tokens[i+2].text == preds_dict["boundaries"][text][1].split()[0].replace(" ",""):
            indexes.append(i+1)
          elif get_count(preds_dict["boundaries"][text][0].split()[0].replace(" ",""),preds_dict["boundaries"][text][0].split()[1].replace(" ",""), predicted_tokens) == 1:
            if predicted_tokens[i].text == preds_dict["boundaries"][text][0].split()[0].replace(" ","") and predicted_tokens[i+1].text == preds_dict["boundaries"][text][0].split()[1].replace(" ",""):
              indexes.append(i+1)
      else:
        if len(preds_dict["boundaries"][text][0].split()) == 2:
          if predicted_tokens[i].text == preds_dict["boundaries"][text][0].split()[0].replace(" ","") and predicted_tokens[i+1].text == preds_dict["boundaries"][text][0].split()[1].replace(" ",""):
            indexes.append(i+1)
          elif predicted_tokens[i+2].text == preds_dict["boundaries"][text][0].split()[-2].replace(" ","") and predicted_tokens[i+3].text == preds_dict["boundaries"][text][0].split()[-1].replace(" ",""):
            indexes.append(i+3)
  
  return list(set(indexes))



def extract_coordinates(predicted_tokens,x):

  # Transform the index into a set of x-y coordinates
  # x0,y0 is top left corner, x1,y1 bottom right
  x0 = predicted_tokens[x].block.x_1
  x1 = predicted_tokens[x].block.x_2
  y0 = predicted_tokens[x].block.y_1
  y1 = predicted_tokens[x].block.y_2

  return x0,y0,x1,y1

def get_count(word1,word2,predicted_tokens):
  count = 0
  for i in range(len(predicted_tokens)-1):
    if predicted_tokens[i].text == word1:
      if predicted_tokens[i+1].text == word2:
        count += 1

  return count


def extract_indexes_boundaries(preds_dict, predicted_tokens):

  ### Extract the index from the predicted boundaries in preds_dict by looking it up in the predicted_tokens list
  ### Works by string matching. If predicted boundary tokens exist in the predicted_tokens, then extract desired index

  indexes = []
  for text in range(len(preds_dict["boundaries"])):
    for i in range(len(predicted_tokens)-3):
      if preds_dict["boundaries"][text][1] != " " and preds_dict["boundaries"][text][1] !="":
        if len(preds_dict["boundaries"][text][0].split()) == 3:
          if predicted_tokens[i].text == preds_dict["boundaries"][text][0].split()[0].replace(" ","") and predicted_tokens[i+1].text == preds_dict["boundaries"][text][0].split()[1].replace(" ","") and predicted_tokens[i+2].text == preds_dict["boundaries"][text][0].split()[2].replace(" ",""):
            indexes.append(i+2)


        if len(preds_dict["boundaries"][text][0].split()) == 2:
          if predicted_tokens[i].text == preds_dict["boundaries"][text][0].split()[0].replace(" ","") and predicted_tokens[i+1].text == preds_dict["boundaries"][text][0].split()[1].replace(" ",""):
            indexes.append(i+1)
          elif get_count(preds_dict["boundaries"][text][0].split()[0].replace(" ",""),preds_dict["boundaries"][text][0].split()[1].replace(" ",""), predicted_tokens) == 1:
            if predicted_tokens[i].text == preds_dict["boundaries"][text][0].split()[0].replace(" ","") and predicted_tokens[i+1].text == preds_dict["boundaries"][text][0].split()[1].replace(" ",""):
              indexes.append(i+1)
      else:
        if len(preds_dict["boundaries"][text][0].split()) == 2:
          if predicted_tokens[i].text == preds_dict["boundaries"][text][0].split()[0].replace(" ","") and predicted_tokens[i+1].text == preds_dict["boundaries"][text][0].split()[1].replace(" ",""):
            indexes.append(i+1)
          elif predicted_tokens[i+2].text == preds_dict["boundaries"][text][0].split()[-2].replace(" ","") and predicted_tokens[i+3].text == preds_dict["boundaries"][text][0].split()[-1].replace(" ",""):
            indexes.append(i+3)
  
  return list(set(indexes))


def get_end_index(predicted_tokens, words):
  if len(words) >= 3:
    for i in range(len(predicted_tokens)):
      if predicted_tokens[i].text == words[-1]:
        if predicted_tokens[i-1].text == words[-2]:
          if predicted_tokens[i-2].text == words[-3]:
            if len(words) > 3:
              if predicted_tokens[i-3].text == words[-4]:
                return i
            else:
              return i


def get_start_index(predicted_tokens, words):
  if len(words) >= 3:
    for i in range(len(predicted_tokens)-3):
      if predicted_tokens[i].text == words[0]:
        if predicted_tokens[i+1].text == words[1]:
          if predicted_tokens[i+2].text == words[2]:
            if len(words) > 3:
              if predicted_tokens[i+3].text == words[3]:
                return i
            else:
              return i

def extract_indexes_end(sentence, predicted_tokens):
  indexes = []
  for i in reversed(range(4,len(predicted_tokens))):
    if len(sentence.split(" ")) == 4:
      if predicted_tokens[i].text == sentence.split(" ")[-1].replace(" ","") and predicted_tokens[i-1].text == sentence.split(" ")[-2].replace(" ","") and predicted_tokens[i-2].text == sentence.split(" ")[-3].replace(" ","") and predicted_tokens[i-3].text == sentence.split(" ")[-4].replace(" ",""):
        indexes.append(i)
    elif len(sentence.split(" ")) == 3:
      if predicted_tokens[i].text == sentence.split(" ")[-1].replace(" ","") and predicted_tokens[i-1].text == sentence.split(" ")[-2].replace(" ","") and predicted_tokens[i-2].text == sentence.split(" ")[-3].replace(" ",""):
        indexes.append(i)


  return list(set(indexes))


def extract_indexes_start(sentence,predicted_tokens):
  indexes = []
  for i in range(len(predicted_tokens)-3):
      if predicted_tokens[i].text == sentence.split(" ")[0].replace(" ","") and predicted_tokens[i+1].text == sentence.split(" ")[1].replace(" ","") and predicted_tokens[i+2].text == sentence.split(" ")[2].replace(" ","") and predicted_tokens[i+3].text == sentence.split(" ")[3].replace(" ",""):
        indexes.append(i)
  return list(set(indexes))

def extract_start_end_indexes(preds_dict, predicted_tokens):
  dic = {}
  count = 0
  for pair in preds_dict["start_end"]:
    if len(pair[0].split(" ")) == 4:
      start = extract_indexes_start(pair[0], predicted_tokens)
      end = extract_indexes_end(pair[1], predicted_tokens)
      stop = False
      i = 0
      if start != [] and end != []:
        if len(start) > 1:
          while stop != True and i != len(start):
            if start[i] not in dic.keys():
              stop = True
            else:
              i+=1

        if len(end) > 1:
          for j in end:
            if j < start[i]:
              end.remove(j)
        dic[start[i]] = end

  return dic


def get_all_sentences_indexes(preds_dict, predicted_tokens):
  start_end = extract_start_end_indexes(preds_dict, predicted_tokens)
  start = []
  end = []
  boundaries = []
  if start_end != None:
    if len(start_end) >= 1:
      for key,value in start_end.items():
        if key != [] and value != []:
          start.append(key)
          end.append(value[0])
      inbetween = sorted(extract_indexes_boundaries(preds_dict, predicted_tokens))
      for token in start:
        if token in inbetween:
          inbetween.remove(token)
      for token in end:
        if token in inbetween:
          inbetween.remove(token)
          
      for token in range(len(start)):
        tokens = [start[token],end[token]]
        for tok in inbetween:
          if tok > start[token] and tok < end[token]:
            tokens.append(tok)
        tokens = sorted(tokens)
        for i in range(len(tokens)-1):
          boundaries.append([tokens[i],tokens[i+1]])
          tokens[i+1] += 1
  boundaries.sort(key=lambda x: x[0])
    
  return boundaries


# Function for NER

In [30]:
def NER(all_indexes,predicted_tokens):
  ner = spacy.load("en_core_web_sm")
  dic = {}

  for index in all_indexes:
    sent = ""
    words = []
    for ind in range(index[0],index[1]+1):
      words.append(predicted_tokens[ind].text)
      sent += predicted_tokens[ind].text
      sent += " "
    sent = sent.rstrip()
    ner_results = ner(sent)
    if len(ner_results) >= 1:
      for word in ner_results.ents:
        text = word.text
        label = word.label_
        for i in range(len(text.split(" "))):
          if text.split(" ")[i] in words:
            index_dict = words.index(text.split(" ")[i]) + index[0]
            dic[index_dict] = label 
          elif text.split(" ")[i] + "." in words:
            index_dict = words.index(text.split(" ")[i]+".") + index[0]
            dic[index_dict] = label 
          elif text.split(" ")[i] + "?" in words:
            index_dict = words.index(text.split(" ")[i]+"?") + index[0]
            dic[index_dict] = label 
          elif text.split(" ")[i] + "!" in words:
            index_dict = words.index(text.split(" ")[i]+"!") + index[0]
            dic[index_dict] = label 

            

  return dic



# Add highlights

In [None]:
def add_higlights(file_path, model_sbd):
  ### Takes as input a file path and adds highlights in that file based on the predicted sentence boundaries.
  ### Loops over the extracted coordinates and based on these coordinates highlights text in the file


  ### Model for block-level DLA is pre-trained Mask-RCNN on the PubLayNet dataset
  model_layout = lp.Detectron2LayoutModel('lp://PubLayNet/mask_rcnn_X_101_32x8d_FPN_3x/config',
                                  extra_config=["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.5],
                                  label_map={0: "Text", 1: "Title", 2: "List"})
  
  ### Model for word-level DLA is pre-trained LayoutLM on the DocBank dataset
  pdf_predictor = HierarchicalPDFPredictor.from_pretrained("allenai/hvila-block-layoutlm-finetuned-docbank")
  
  ### Loops over the file page for page and extracts sentence boundary coordinates after which it highlights these.
  pdf_file = fitz.open(file_path)
  for p in range(len(pdf_file)):
    print(p)
    page = pdf_file[p]
    preds_dict, predicted_tokens = extract_boundaries(tokenizer, file_name, p, pdf_predictor, model_layout,model_sbd)
    indices = extract_indexes_boundaries(preds_dict, predicted_tokens)
    all_indexes = get_all_sentences_indexes(preds_dict, predicted_tokens)
    ner = NER(all_indexes, predicted_tokens)

    for i in all_indexes:
      x0,y0,x1,y1 = extract_coordinates(predicted_tokens,i[1])
      ul = (x0,y0)
      ur = (x1, y0)
      ll = (x0,y1)
      lr = (x1,y1)
      q = fitz.Quad(ul, ur, ll, lr)
      page.add_highlight_annot(q)
      
    for i in indices:
      x0,y0,x1,y1 = extract_coordinates(predicted_tokens,i)
      ul = (x0,y0)
      ur = (x1, y0)
      ll = (x0,y1)
      lr = (x1,y1)
      q = fitz.Quad(ul, ur, ll, lr)
      page.add_highlight_annot(q)

    for key,value in ner.items():
      x0,y0,x1,y1 = extract_coordinates(predicted_tokens,key)
      ul = (x0,y0)
      ur = (x1, y0)
      ll = (x0,y1)
      lr = (x1,y1)
      q = fitz.Quad(ul, ur, ll, lr)
      highlight = page.add_highlight_annot(q)
      if value == "GPE": #place
        highlight.set_colors(stroke=[0.2,1,1]) # light blue
        highlight.update()
        highlight.update()
      elif value == "DATE":
        highlight.set_colors(stroke=[0.2, 0.7, 0.4]) # green
        highlight.update()
      elif value == "PER":
        highlight.set_colors(stroke=[0.3, 0, 0.7]) # purple
        highlight.update()
      elif value == "ORG":
        highlight.set_colors(stroke=[0.7, 0.7, 0.8]) # light grey
        highlight.update()
      elif value == "ORDINAL":
        highlight.set_colors(stroke=[0.8, 0.5, 0.0]) # light orange
        highlight.update()
      else:
        highlight.set_colors(stroke=[1.0, 0.9, 1.0]) # light pink
        highlight.update()


  ### File is saved containing the exact same layout as before, but with added highlights
  pdf_file.save("highlighted.pdf")
    
add_higlights(file_path,model_sbd)
