# Notebook for sentence-level model
(Section 5.2 of the report)

Contains the following token-level model specific components:
- Code for dataset creation
- Post-processing algorithm
- Evaluation code

![](../Misc/sentence_level_model.jpg)

### Import libraries

In [1]:
import os
import sys
from tqdm.notebook import tqdm
from multiprocessing import cpu_count, Pool, Manager
import numpy as np
import pandas as pd
import pickle
import json
import gc
from sklearn.model_selection import train_test_split
from datasets import Dataset, DatasetDict, concatenate_datasets, load_from_disk
import importlib

# Preprocessing
sys.path.append('../DataPreprocessing')
from read_into_dicts import DocReader

# Model
from sentence_transformers import SentenceTransformer
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader
from torch.optim import Adam
from transformers import Trainer, TrainingArguments, AutoTokenizer, EarlyStoppingCallback

# WANDB
import wandb
import random
import string
import yaml

2024-05-29 21:52:36.124859: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-05-29 21:52:38.593565: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Define Parameter

In [2]:
# Load configuration
with open('config.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file)

# GENERAL:
DEBUG = config["general"]["DEBUG"]
MODEL = config["general"]["MODEL"]
DATA_SIZE = config["general"]["DATA_SIZE"]
DOMAIN = config["general"]['DATA_DICT_FILE_PATH']["DOMAIN"]
DATA_DICT_FILE_PATH = config["general"]['DATA_DICT_FILE_PATH'][DOMAIN]
DATA_PATH = config["general"]["DATA_PATH"]
SPLITS_JSON = config["general"]["SPLITS_JSON"]
TRAIN_SIZE = config["general"]["TRAIN_SIZE"]
VAL_SIZE = config["general"]["VAL_SIZE"]
USE_MULTIPROCESSING = config["general"]["USE_MULTIPROCESSING"]
SPLIT_BASE = config["general"]["SPLIT_BASE"]
SENTENCE_SPLITTER_MODEL = config["general"]["SENTENCE_SPLITTER_MODEL"]
IS_HPO_RUN = config["hpo"]["IS_HPO_RUN"]

# SENTENCE LEVEL MODEL
DATASET = config["sentence_level_model"]["DATASET"]
PRETRAINED_MODEL = config["sentence_level_model"]["PRETRAINED_MODEL"]
CHECKPOINT = config["sentence_level_model"]["CHECKPOINT"]
EPOCHS = config["sentence_level_model"]["EPOCHS"]
EARLY_STOPPING_PATIENCE = config["sentence_level_model"]["EARLY_STOPPING_PATIENCE"]
BATCH_SIZE = config["sentence_level_model"]["BATCH_SIZE"]
USE_POS_ENCODING = config["sentence_level_model"]["USE_POS_ENCODING"]
ALPHA = config["sentence_level_model"]["ALPHA"]
HIER_LABELS_LEVELS = config["sentence_level_model"]["HIER_LABELS_LEVELS"]  # also acts as flag if detailed models shall be used (i.e. model with auxiliary objective trained)
DETAILED_LABEL_WEIGHTS = config["sentence_level_model"]["DETAILED_LABEL_WEIGHTS"]
NUMBER_OF_LEVELS = len(HIER_LABELS_LEVELS)

tokenizer = AutoTokenizer.from_pretrained(MODEL)
sentence_model = SentenceTransformer(config["sentence_level_model"]["SENTENCE_TRANSFORMER_MODEL"], device="cuda" if DATASET else "cpu")
EMBEDDING_DIMENSIONS = sentence_model.get_sentence_embedding_dimension()
SENTENCE_MODEL_IS_BINARY = config["sentence_level_model"]["SENTENCE_MODEL_IS_BINARY"]
SENTENCE_MODEL_ARCHITECTURE = config["sentence_level_model"]["SENTENCE_MODEL_ARCHITECTURE"]

label_to_index = {"O": 0, "B": 1, "I": 2} # {"O": 0, "B": 1, "I": 2} # 22-03 dataset with {"B": 0, "I": 1, "O": 2}
index_to_label = {v: k for k, v in label_to_index.items()} # reverse of label_to_index

print(f"{DATA_DICT_FILE_PATH=}")
print(f"{USE_MULTIPROCESSING=}")
print(f"{SPLIT_BASE=}")
print(f"{SENTENCE_SPLITTER_MODEL=}")
print(f"{DATASET=}")
print(f"{USE_MULTIPROCESSING=}")
print(f"{PRETRAINED_MODEL=}")

DATA_DICT_FILE_PATH='preprocessing/data_dicts/28245V231219_AML/data_dict_roberta-base-10-04.pkl'
USE_MULTIPROCESSING=True
SPLIT_BASE='blocks'
SENTENCE_SPLITTER_MODEL='transformer'
DATASET='Model/datasets/sentence-level/dataset_dict_1149-docs_roberta-base-model_transformer-07-05.hf'
USE_MULTIPROCESSING=True
PRETRAINED_MODEL='oiHQRN'


In [3]:
sentence_model

SentenceTransformer(
  (0): Transformer({'max_seq_length': 128, 'do_lower_case': False}) with Transformer model: BertModel 
  (1): Pooling({'word_embedding_dimension': 384, 'pooling_mode_cls_token': False, 'pooling_mode_mean_tokens': True, 'pooling_mode_max_tokens': False, 'pooling_mode_mean_sqrt_len_tokens': False})
  (2): Normalize()
)

WANDB

In [3]:
# Generate a random tag of 6 characters
unique_tag = ''.join(random.choices(string.ascii_letters + string.digits, k=6))
os.environ['WANDB_DIR'] = f"{DATA_PATH}/Model"
os.environ['WANDB_CACHE_DIR'] = f"{DATA_PATH}/Model"

wandb_init = {
    "project": config["wandb"]["WANDB_PROJECT"],    
    "tags": [unique_tag, f"MODEL={MODEL}", f"DATA_SIZE={DATA_SIZE}"],
    "group": config["wandb"]["WANDB_GROUP"],
    "name": f'{config["wandb"]["WANDB_GROUP"]}-{SENTENCE_MODEL_ARCHITECTURE}-SENTENCE-{"BINARY" if SENTENCE_MODEL_IS_BINARY else "BIO"}-{HIER_LABELS_LEVELS}-{MODEL}'
}
wandb.login()
wandb.init(**wandb_init)
config_dict = {
    "model": MODEL,
    "data_size": DATA_SIZE,
    "train_size": TRAIN_SIZE,
    "val_size": VAL_SIZE,
    "test_size": 1 - TRAIN_SIZE - VAL_SIZE,
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "hierarchical_labels_levels": HIER_LABELS_LEVELS
}
wandb.config.update(config_dict)


[34m[1mwandb[0m: Currently logged in as: [33mtlh45[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Load Full Data

In [3]:
docReader = DocReader(MODEL, tokenizer)
create_original_snippets = True
add_full_page = False

if DATA_DICT_FILE_PATH:
    print(f"Load data_dict from {DATA_PATH}/{DATA_DICT_FILE_PATH}.")
    with open(f'{DATA_PATH}/{DATA_DICT_FILE_PATH}', 'rb') as handle:
        data_dict = pickle.load(handle)
    print("Done.")
else:
    data_dict = docReader.preprocess_folder(preprocess=False, folder_path=f'{DATA_PATH}/28245V231219', num_workers=min(70, cpu_count()), data_size=DATA_SIZE, chunksize=1
                                        , extract_title=True, extract_doc_long_id=True, refine_regions=True, create_original_snippets=create_original_snippets, add_full_page=add_full_page
                                        , data_dict_folder=f"{DATA_PATH}/preprocessing/data_dicts", file_name_additional_suffix="-22-02")

Load data_dict from /home/tlh45/rds/hpc-work/preprocessing/data_dicts/28245V231219_AML/data_dict_roberta-base-10-04.pkl.
Done.


Load Cyber data

In [None]:
with open(f'{DATA_PATH}/{config["general"]["DATA_DICT_FILE_PATH"]["CYBER_I"]}', 'rb') as handle:
    data_dict_cyber_i = pickle.load(handle)
    
with open(f'{DATA_PATH}/{config["general"]["DATA_DICT_FILE_PATH"]["CYBER_II"]}', 'rb') as handle:
    data_dict_cyber_ii = pickle.load(handle)

In [None]:
import custom_sentence_tokenizer
importlib.reload(custom_sentence_tokenizer)

doc_id = '22206856'
page_id = '12'

print(f"Sentence Splitter Model chosen: {SENTENCE_SPLITTER_MODEL}")
sentenizer = custom_sentence_tokenizer.Sentenizer(SENTENCE_SPLITTER_MODEL, f"{DATA_PATH}/Model/punkt_tokenizer.pkl", "cuda" if DATASET else "cpu")


# Below shows the two different sentence splitting options
if SPLIT_BASE != "blocks":
    text = data_dict[doc_id][page_id]['full_text']
    sentences = sentenizer.tokenize_into_sentences(text)
    for sentence in sentences:
        print("------------- SENTENCE -------------")
        print(sentence)
        print("\n")


if SPLIT_BASE == "blocks":
    blocks = data_dict[doc_id][page_id]['blocks']
    sentences = []
    # Iterate over each block and tokenize its text into sentences
    for block in blocks:
        block_text = block['text'] 
        block_sentences = sentenizer.tokenize_into_sentences(block_text)
        sentences.extend(block_sentences) 

    for sentence in sentences:
        print("------------- SENTENCE -------------")
        print(sentence)
        print("\n")

### Data Preprocessing

In [5]:
import detailed_labels_handler

tag_statistics, mapping_dicts = detailed_labels_handler.extract_and_analyze_tags(data_dict, HIER_LABELS_LEVELS)

# Calculate unique tag counts
unique_level_tags = {level: len(tags) for level, tags in tag_statistics.items()}
total_unique_tags = sum(len(tags) for tags in tag_statistics.values())

if not HIER_LABELS_LEVELS:
    print("Note: no HIER_LABELS_LEVELS provided.")
# Print statistics and unique counts
for level, tags in tag_statistics.items():
    print(f"{level.capitalize()} Tags:", tags)
    print(f"Unique {level.capitalize()} Tags:", unique_level_tags[level])

# Print full tag count if available
if 'full_tags' in tag_statistics:
    print("Full Tags Count:", tag_statistics['full_tags'])
print("Total Unique Tags:", total_unique_tags)

# Print mapping dictionaries
for level, mapping_dict in mapping_dicts.items():
    print(f"{level.capitalize()} Mapping Dictionary:", mapping_dict)

# Example encoded labels to decode
encoded_labels = [[0, 0, 0, 0], [0, 0, 1, 1]]
decoded_labels = detailed_labels_handler.decode_detailed_labels(encoded_labels, mapping_dicts, HIER_LABELS_LEVELS)
print("Decoded Labels:", decoded_labels)

Level_0 Tags: Counter({'aml': 43085})
Unique Level_0 Tags: 1
Level_1 Tags: Counter({'cdd': 7645, 'customeridentification': 5187, 'definitions': 4547, 'str': 3991, 'pcp': 3426, 'riskassessment': 2341, 'penalties': 2140, 'assetfreeze': 2122, 'tfs': 2092, 'other': 2057, 'competentauthority': 2010, 'recordkeeping': 1516, 'mlro': 977, 'wiretransfer': 869, 'activitybasedsanction': 774, 'intlcooperation': 755, 'bo': 505, 'sanctionrights': 98, 'internationalcooperation': 33})
Unique Level_1 Tags: 19
Level_2 Tags: Counter({'other': 6436, 'verification': 3261, 'firms': 2650, 'measures': 1503, 'enhanced': 1327, 'information': 1107, 'amlprogram': 1103, 'monitoring': 1064, 'firm': 748, 'regulation': 696, 'counterparty': 686, 'fiu': 638, 'prohibition': 583, 'training': 565, 'cdd': 560, 'simplified': 557, 'reporting': 507, 'lawenforcement': 504, 'grouplevel': 490, 'identificationdocs': 483, 'correspondentbank': 444, 'highlevel': 431, 'moneylaundering': 400, 'kyc': 398, 'sanctionspermit': 394, 'obliga

In [6]:
# Simple function to calculate sentence embeddings using the sentence transformer models
# The sentence transformer models correspond to the first-level transformer network in the hierarchical sentence-level model
# We calculate these embeddings as part of the dataset, so that they can easily be re-used and don't have to be recalculated each time
def calculate_sentence_embeddings(sentences):
    return sentence_model.encode(sentences, show_progress_bar=False)

Code to create dataset for the sentence-level model

In [6]:
# Note: the dataset creation should best happen on a cpu due to multiprocessing
from functools import partial
import unicodedata

def save_documents_to_files(data_dict, temp_folder):
    '''
    Helper function to save each document in the data_dict to a separate file in the temp_folder
    '''
    if not os.path.exists(temp_folder):
        os.makedirs(temp_folder, exist_ok=True)
    for doc_id, doc_data in tqdm(data_dict.items(), desc="Saving documents"):
        file_path = os.path.join(temp_folder, f"{doc_id}.pkl")
        with open(file_path, 'wb') as f:
            pickle.dump(doc_data, f)

def label_sentences_with_token_overlap(sentences, regions, threshold=0.9, short_sentence_length=20):
    '''
    Function that labels sentences based on token overlap with regions (more inaccurate, but used as fallback)
    '''
    bio_labels = []
    detailed_labels = []
    region_info = [(region['text'].split(), region.get('tags', None)) for region in regions]
    
    def sentence_in_region(sentence_tokens, region_tokens):
        '''
        Helper function to check if a sentence is contained within a region based on token overlap
        '''
        matched_token_count = 0
        region_index = 0
        for token in sentence_tokens:
            if token in region_tokens[region_index:]:
                matched_token_count += 1
                region_index = region_tokens.index(token, region_index) + 1
        match_ratio = matched_token_count / len(sentence_tokens)
        return match_ratio >= threshold

    # Iterate over each sentence and check if it is contained within a region
    for sentence in sentences:
        tokenized_sentence = sentence.split() # simply tokenize by whitespace
        bio_label = 'O'
        detailed_label = 'N/A'
        sentence_matched = False

        # For each sentence we go through all regions and check if the sentence is contained within a region
        for region_tokens, region_tags in region_info:
            if sentence_in_region(tokenized_sentence, region_tokens):
                sentence_matched = True
                if not bio_labels or bio_labels[-1] == 'O':  # If the previous sentence was outside a region
                    bio_label = 'B'
                else:
                    bio_label = 'I'
                if HIER_LABELS_LEVELS:
                    detailed_label = ';'.join(region_tags)  # Join tags if there are multiple
                break  # Stop checking once a match is found

        if not sentence_matched and HIER_LABELS_LEVELS:
            detailed_label = 'N/A'

        bio_labels.append(bio_label)
        if HIER_LABELS_LEVELS:
            detailed_labels.append(detailed_label)
            
    numerical_bio_labels = [label_to_index[label] for label in bio_labels]
    
    if HIER_LABELS_LEVELS:
        return numerical_bio_labels, detailed_labels
    else:
        return numerical_bio_labels

def is_unwanted_character(c):
    '''
    Helper function to check if a character is unwanted (e.g. control characters, formatting characters, etc.)
    '''
    unwanted_categories = ['Cc', 'Cf', 'Co', 'Cs', 'Cn', 'Zl', 'Zp', 'Sm', 'Po']
    if unicodedata.category(c) in unwanted_categories:
        return True
    unwanted_chars = ['\xad', '\ufeff', '[', ']', '\\', '/', '�'] 
    if c in unwanted_chars:
        return True
    return False

def label_sentences_by_indices(sentences, regions, tokenized_page_text, doc_id, page_index):
    '''
    Function similar to the above, but labels sentences based on indices of regions in the tokenized page text.
    This is more accurate and hence used as the primary method for labeling sentences.
    For more details, refer to the report.
    '''
    try:
        bio_labels = []
        detailed_labels = []
        sentence_boundaries = [] 

        # STEP 1: Map sentences to indices in page
        current_string = ""
        sentence_idx = 0
        current_token_list = []
        idx = 0 

        while idx < len(tokenized_page_text):
            token = tokenized_page_text[idx]
            if sentence_idx >= len(sentences):
                break
            
            sentence = sentences[sentence_idx]
            current_token_list.append(token)
            current_string = tokenizer.convert_tokens_to_string(current_token_list)
            
            clean_current_string = "".join(current_string.split())
            clean_current_string = "".join([char for char in clean_current_string if not is_unwanted_character(char)])
            
            clean_sentence = "".join(sentence.split())
            clean_sentence = "".join([char for char in clean_sentence if not is_unwanted_character(char)])[2:]

            # Match exactly for single-token sentences
            if len(clean_sentence) <= 1:
                if sentence_idx == len(sentences) - 1:
                    end_idx = len(tokenized_page_text) - 1
                else:
                    end_idx = idx
                sentence_boundaries.append((idx, end_idx))
                idx -= 1 # move back to allow last sentnce to continue matchng
                current_token_list = []
                sentence_idx += 1

            # Match by start or end for multi-token sentences
            elif clean_current_string.startswith(clean_sentence) or clean_current_string.endswith(clean_sentence) or clean_sentence in clean_current_string:
                if sentence_idx == len(sentences) - 1:
                    end_idx = len(tokenized_page_text) - 1
                else:
                    end_idx = idx
                sentence_boundaries.append((idx - len(current_token_list) + 1, end_idx))
                current_token_list = []
                sentence_idx += 1
            idx += 1 
        assert len(sentence_boundaries) == len(sentences), f"Number of sentences and boundaries do not match in {doc_id=}, {page_index=}: {len(sentences)} vs {len(sentence_boundaries)}"
    
    # If we have a mismatch in the number of sentences and boundaries, fall back to overlap matching
    except Exception as e:
        print(f"Error in {doc_id=}, {page_index=}: {e}. Falling back to overlap matching.")
        return label_sentences_with_token_overlap(sentences, regions) if HIER_LABELS_LEVELS else (label_sentences_with_token_overlap(sentences, regions), [])
    
    # STEP 2: label sentences based on region indices
    # This logic is described in more detail in Section 5.2.2 of the report
    for (start_idx, end_idx), sentence in zip(sentence_boundaries, sentences):
        label = 'O'  # Default label
        detailed_label = "N/A"
        for region in regions:
            region_start_idx = region['start_idx_in_page']
            region_end_idx = region['end_idx_in_page']
            # Check if sentence starts inside the region or overlaps with it
            if start_idx <= region_start_idx and end_idx > region_start_idx:
                label = 'B'
                detailed_label = region.get('tags', "N/A")
                break

            # Check if sentence is fully within the region
            elif start_idx > region_start_idx and end_idx <= region_end_idx:
                label = 'I'
                detailed_label = region.get('tags', "N/A")
                break
    
        bio_labels.append(label)
        if HIER_LABELS_LEVELS:
            if isinstance(detailed_label, list):
                detailed_label = ';'.join(detailed_label)
            detailed_labels.append(detailed_label)

    # Convert labels to numerical form
    # label_to_index = {"B": 0, "I": 1, "O": 2} --> use from above
    numerical_bio_labels = [label_to_index[label] for label in bio_labels]
    
    if HIER_LABELS_LEVELS:
        return numerical_bio_labels, detailed_labels
    else:
        return numerical_bio_labels

def worker_create_training_samples(doc_id, temp_folder, save_folder):
    '''
    Worker function to create training samples for a single document
    '''
    file_path = os.path.join(temp_folder, f"{doc_id}.pkl")
    with open(file_path, 'rb') as f:
        page_data = pickle.load(f)
    gc.collect()
    return create_training_samples_per_document((doc_id, page_data), save_folder)

def create_training_samples_per_document(doc_data, save_folder=None):
    '''
    Inner function to create training samples for a single document. Called by the worker function.
    '''
    doc_id, pages = doc_data
    page_samples = []
    for page_index, page_content in pages.items():
        if page_index in ["title", "doc_long_id"]:
            continue
        
        # Tokenie into sentence (default: use blocks for higher accuracy)
        if SPLIT_BASE == "full_text":
            page_text = page_content['full_text']
            sentences = sentenizer.tokenize_into_sentences(page_text)
        elif SPLIT_BASE == "blocks":
            blocks = page_content['blocks']
            sentences = []
            # Iterate over each block and tokenize its text into sentences
            for block in blocks:
                block_text = block['text']  # Extract the text for the current block
                block_sentences = sentenizer.tokenize_into_sentences(block_text)
                sentences.extend(block_sentences)
        else:
            raise
        
        if not sentences:
            continue
            
        # Get sentence embeddings (i.e. run through first transformer network in the hierarchical model)
        sentence_embeddings = calculate_sentence_embeddings(sentences)
        
        tokenized_page_text = page_content['tokenized_full_text']
        refined_regions = page_content['refined_regions']

        # Get labels for each sentence based on region indices
        bio_labels, detailed_labels = label_sentences_by_indices(sentences, refined_regions, tokenized_page_text, doc_id, page_index) if HIER_LABELS_LEVELS else (label_sentences_by_indices(sentences, refined_regions, tokenized_page_text, doc_id, page_index), [])
        
        # Create sample for each page
        page_sample = {
            'sentences': sentences,
            'embeddings': [embedding.tolist() for embedding in sentence_embeddings],
            'labels': bio_labels,
            'detailed_labels': detailed_labels_handler.encode_detailed_labels(detailed_labels, mapping_dicts, NUMBER_OF_LEVELS),
            'metadata': {'doc_id': doc_id, 'page_id': page_index},
        }
        page_samples.append(page_sample)
    
    if not DEBUG:
        temp_file = os.path.join(save_folder, f"{doc_id}_page_samples.pkl")
        with open(temp_file, 'wb') as f:
            pickle.dump(page_samples, f)
    gc.collect()


def process_documents_per_document(doc_ids_list, num_docs_dict, temp_folder, save_folder, num_workers=None, chunk_size=1):    
    '''
    Process documents in parallel using multiprocessing.
    Calls the worker function to create training samples for each document.
    '''
    if USE_MULTIPROCESSING:
        if num_workers is None:
            num_workers = min(32, cpu_count())
        print(f'{num_docs_dict-len(doc_ids_list)} valid documents found. {len(doc_ids_list)} out of {num_docs_dict} documents require processing.')
        worker_func = partial(worker_create_training_samples, temp_folder=temp_folder, save_folder=save_folder)

        if doc_ids_list:
            with Pool(num_workers) as pool:
                list(tqdm(pool.imap_unordered(worker_func, doc_ids_list, chunksize=chunk_size), initial=num_docs_dict-len(doc_ids_list), total=num_docs_dict, desc="Process documents"))
    else:
        for doc_id in tqdm(doc_ids_list):
            worker_create_training_samples(doc_id, temp_folder=temp_folder, save_folder=save_folder)
    return save_folder

def is_valid_data_file(file_path):
    '''
    Helper function to check if a data file is valid (i.e. contains data)
    '''
    try:
        with open(file_path, 'rb') as f:
            data = pickle.load(f)
            return len(data) > 0
    except Exception as e:
        print(f"Error loading {file_path}: {e}")
        return False

def collect_results(temp_folder, valid_doc_ids):
    '''
    Function to collect results from the processed documents and store them in a dictionary. (necessary for memory reasons)
    '''
    all_training_samples = {}
    for filename in tqdm(os.listdir(temp_folder)):
        doc_id = filename.split('.')[0].split('_')[0]
        if doc_id in valid_doc_ids:
            file_path = os.path.join(temp_folder, filename)
            if is_valid_data_file(file_path):
                with open(file_path, 'rb') as f:
                    document_sample = pickle.load(f)
                    if doc_id not in all_training_samples.keys():
                        all_training_samples[doc_id] = [document_sample]
                    else:
                        all_training_samples[doc_id].extend(document_sample)

    print(f"Processed {len(all_training_samples)} items.")
    return all_training_samples

# for debugging
def create_samples_singleprocessing(data_dict):
    all_training_samples = []
    for doc_id, pages in tqdm(data_dict.items(), desc="Processing"):
        if DEBUG and doc_id != "19865597":  
            continue
        for page_index, page_content in pages.items():
            if DEBUG and page_index != "2":
                continue

            samples = create_training_samples_per_document((doc_id, {page_index: page_content}))
    return all_training_samples

def filter_samples_based_on_splits(all_samples, splits):
    '''
    Filters samples into training, validation, and test sets based on document ID splits.
    '''
    
    train_samples, val_samples, test_samples = [], [], []
    
    train_ids = set(splits['train_ids'])
    val_ids = set(splits['val_ids'])
    test_ids = set(splits['test_ids'])

    for doc_id, samples in all_samples.items():
        if doc_id in train_ids:
            train_samples.extend(samples)
        elif doc_id in val_ids:
            val_samples.extend(samples)
        elif doc_id in test_ids:
            test_samples.extend(samples)

    return train_samples, val_samples, test_samples

def create_dataset_from_samples(dataset_name: str, folder_path: str):
    '''
    Function to create a Hugging Face Dataset from the samples stored in the folder_path.
    '''
    def create_df(samples):
        flattened_samples = [page for document in samples for page in document]

        df_samples = pd.DataFrame({
            'sentences': [page['sentences'] for page in flattened_samples],
            'embeddings': [page['embeddings'] for page in flattened_samples],
            'bio_labels': [page['labels'] for page in flattened_samples],
            'detailed_labels': [page['detailed_labels'] for page in flattened_samples],
            'metadata': [page['metadata'] for page in flattened_samples],
        })
        return df_samples

    def save_dataset_to_file(dataset, dataset_name):
        dataset_path = f'{DATA_PATH}/Model/datasets/sentence-level/{dataset_name}_dataset.hf'
        dataset.save_to_disk(dataset_path)
        return dataset_path

    # Process samples in batches
    batch_files = os.listdir(folder_path)
    num_batches = len(batch_files)
    print(f"Total number of batches: {num_batches}")

    dataset = None
    for i, file in enumerate(batch_files):
        print(f"Processing batch {i + 1}/{num_batches}")
        with open(f"{folder_path}/{file}", "rb") as batch_file:
            batch_samples = pickle.load(batch_file)
        
        # Convert samples to pandas DataFrame
        df_batch = create_df(batch_samples)
        del batch_samples
        gc.collect()

        # Convert DataFrame to Hugging Face Dataset
        batch_dataset = Dataset.from_pandas(df_batch)
        del df_batch
        gc.collect()

        # Append to the main dataset or save to file
        if dataset is None:
            dataset = batch_dataset
        else:
            dataset = concatenate_datasets([dataset, batch_dataset])

        # Save memory by deleting the batch dataset
        del batch_dataset
        gc.collect()

    # Save final dataset to file
    dataset_path = save_dataset_to_file(dataset, dataset_name)

    return dataset_path

def save_samples_in_batches(samples, batch_size, folder):
    '''
    Helper function to save samples in batches to a folder
    '''
    os.makedirs(folder, exist_ok=True)
    num_samples = len(samples)
    num_batches = (num_samples + batch_size - 1) // batch_size

    for i in range(num_batches):
        start_idx = i * batch_size
        end_idx = min((i + 1) * batch_size, num_samples)
        batch_samples = samples[start_idx:end_idx]
        file_path = f"{folder}/batch_{i}.pkl"
        print(f"Saving batch {i + 1}/{num_batches} to {file_path}")
        with open(file_path, 'wb') as f:
            pickle.dump(batch_samples, f)


def combine_datasets_into_dict(train_dataset_path, val_dataset_path, test_dataset_path):
    '''
    Function to combine the training, validation, and test datasets into a single DatasetDict.
    '''
    train_dataset = load_from_disk(train_dataset_path)
    val_dataset = load_from_disk(val_dataset_path)
    test_dataset = load_from_disk(test_dataset_path)
    
    dataset_dict = DatasetDict({
        'train': train_dataset,
        'validation': val_dataset,
        'test': test_dataset
    })
    return dataset_dict


# Main entry point for the dataset creation

DEBUG = False
dataset_dict = None

if DATASET is not None:
    print(f"Load Datadict from {DATA_PATH}/{DATASET}.")
    dataset_dict = load_from_disk(f"{DATA_PATH}/{DATASET}")
    print("Datadict successfully loaded.")
else:
    if DEBUG:
        all_training_samples = create_samples_singleprocessing(data_dict)
    else:
        temp_folder = f"{DATA_PATH}/preprocessing/temp_data_dict_singles"
        save_folder = f"{DATA_PATH}/preprocessing/dataset_files"
        os.makedirs(save_folder, exist_ok=True) 

        valid_doc_ids = set(data_dict.keys())

        # Initialize a set for documents that need to be processed
        docs_to_process = set()

        # Check each document to see if it needs processing
        print('Check which documents need processing.')
        print(f'You might want to empty the {save_folder=} to force preprocessing from scratch.')
        for doc_id in tqdm(data_dict.keys()):
            file_path = os.path.join(save_folder, f"{doc_id}_page_samples.pkl")
            # Check if file exists and contains valid data
            if not os.path.exists(file_path) or not is_valid_data_file(file_path):
                docs_to_process.add(doc_id)
        
        doc_ids_list = list(docs_to_process)
        num_docs_dict = len(data_dict.keys())
        print(f"{len(doc_ids_list)} out of {num_docs_dict} documents need to be processed.")

        print(f"Save data_dict to individual files for memory reduction. (Folder: {temp_folder})")
        save_documents_to_files(data_dict, temp_folder)

        gc.collect()

        # 8 and 1 works quite okay
        process_documents_per_document(doc_ids_list, num_docs_dict, temp_folder, save_folder, num_workers=110, chunk_size=1) #100

        all_training_samples_document_level_dict = collect_results(save_folder, valid_doc_ids)

        print(f"Total number of documents collected: {len(all_training_samples_document_level_dict)}")
        print(f"Load split information from {SPLITS_JSON}.")
        with open(SPLITS_JSON, 'r') as file:
            document_splits = json.load(file)

        # below we keep triggering the garbage collector to avoid memory issues
        train_samples, val_samples, test_samples = filter_samples_based_on_splits(all_training_samples_document_level_dict, document_splits)
        
        del all_training_samples_document_level_dict
        gc.collect()
        
        batch_size = 200 # documents
        save_samples_in_batches(train_samples, batch_size=batch_size, folder = f'{DATA_PATH}/Model/datasets/sentence-level/train_samples_{DATA_SIZE}-docs_{batch_size}_samples')
        del train_samples
        gc.collect()
        save_samples_in_batches(val_samples, batch_size=batch_size, folder = f'{DATA_PATH}/Model/datasets/sentence-level/val_samples_{DATA_SIZE}-docs_{batch_size}_samples')
        del val_samples
        gc.collect()
        save_samples_in_batches(test_samples, batch_size=batch_size, folder = f'{DATA_PATH}/Model/datasets/sentence-level/test_samples_{DATA_SIZE}-docs_{batch_size}_samples')
        del test_samples
        gc.collect()

        print("Create single datasets.")
        train_dataset_path = create_dataset_from_samples("train", f'{DATA_PATH}/Model/datasets/sentence-level/train_samples_{DATA_SIZE}-docs_{batch_size}_samples')
        validation_dataset_path = create_dataset_from_samples("validation", f'{DATA_PATH}/Model/datasets/sentence-level/val_samples_{DATA_SIZE}-docs_{batch_size}_samples')
        test_dataset_path = create_dataset_from_samples("test", f'{DATA_PATH}/Model/datasets/sentence-level/test_samples_{DATA_SIZE}-docs_{batch_size}_samples')

        print("Create dataset_dict.")
        dataset_dict = combine_datasets_into_dict(train_dataset_path, validation_dataset_path, test_dataset_path)

        # Save the dataset
        file_path = f'{DATA_PATH}/Model/datasets/sentence-level/dataset_dict_{DATA_SIZE}-docs_{MODEL.replace("/", "-")}-model_{SENTENCE_SPLITTER_MODEL}-07-05.hf'
        print(f'Saving dataset_dict to {file_path}')
        dataset_dict.save_to_disk(file_path)
        print("Dataset successfully saved.")

Load Datadict from /home/tlh45/rds/hpc-work/Model/datasets/sentence-level/dataset_dict_1149-docs_roberta-base-model_transformer-07-05.hf.
Datadict successfully loaded.


##### Manually analyse preprocessed data

In [11]:
def check_dataset_of_id(doc_id:str):
    for name in ['train', 'validation', 'test']:
        for metadata_entry in dataset_dict[name]['metadata']:
            if doc_id == metadata_entry['doc_id']:
                print(f" Document {doc_id} is in the {name} dataset.")
                break

doc_id = "27363191"
check_dataset_of_id(doc_id)

 Document 27363191 is in the test dataset.


Function to display sample for a specific page in the dataset

In [12]:
def display_sentence_labels(dataset, doc_id, page_id, mapping_dicts):
    page_id = str(page_id)

    for i, sample in tqdm(enumerate(dataset), total=len(dataset)):
        metadata = sample['metadata']
        if metadata['doc_id'] == doc_id and metadata['page_id'] == page_id:
            for sentence, bio_label, detailed_label in zip(sample['sentences'], sample['bio_labels'], sample['detailed_labels']):
                label = index_to_label[bio_label]
                filtered_detailed_labels = [detailed_label[index] for index in HIER_LABELS_LEVELS if index < len(detailed_label)]
                detailed_label = detailed_labels_handler.decode_detailed_labels([filtered_detailed_labels], mapping_dicts, HIER_LABELS_LEVELS)
                print("################################ Sentence: ################################")
                print(f"\n{sentence}\n")
                print(f"######## Label: {label} | Detailed Label: {detailed_label} ########n")
            break

doc_id = "27363191"
page_id = "3"
display_sentence_labels(dataset_dict['test'], doc_id, page_id, mapping_dicts)

  0%|          | 0/6908 [00:00<?, ?it/s]

################################ Sentence: ################################

3

######## Label: O | Detailed Label:  ########n
################################ Sentence: ################################

1. 14

######## Label: O | Detailed Label:  ########n
################################ Sentence: ################################

" provision of correspondent banking " to reflect our interpretation that the intent is to refer to customers of RFIs which are themselves Banks and are using the RFI to provide correspondent banking services to them.

######## Label: O | Detailed Label:  ########n
################################ Sentence: ################################

1. 19

######## Label: O | Detailed Label:  ########n
################################ Sentence: ################################

We fully support the sentiment expressed here, but would note our concerns about the lack of " Safe Harbour " provisions within the Legislative framework in the context of the desire to " sha

### Load pre-trained model

In [8]:
PRETRAINED_MODEL

'oiHQRN'

In [9]:
import importlib
import sentence_level_models as SLM
importlib.reload(SLM)

if PRETRAINED_MODEL:   
    # paths
    model_load_path = f"{DATA_PATH}/Model/results/saved_model/{PRETRAINED_MODEL}/torch_model.pth"
    config_save_path = os.path.join(os.path.dirname(model_load_path), "config.json")
    
    # Load configuration dictionary from the JSON file
    with open(config_save_path, 'r') as config_file:
        loaded_config = json.load(config_file)

    model_type = loaded_config.pop("model_type")

    # Use loaded configuration to recreate the model instance
    if model_type in ["BiLSTM", "BiLSTM_CRF"]:
        loaded_model = SLM.SentenceTaggingBiLSTM(**loaded_config)
    elif model_type == "Transformer":
        loaded_config.pop("bidirectional")
        loaded_model = SLM.STATO(**loaded_config, positional_encoding=USE_POS_ENCODING)

    # Get model weights from saved file
    loaded_model.load_state_dict(torch.load(model_load_path))
    loaded_model.eval()
    
    print(f"Model successfully loaded from {model_load_path}.")
    
    model = loaded_model



Model successfully loaded from /home/tlh45/rds/hpc-work/Model/results/saved_model/oiHQRN/torch_model.pth.


#### Code to make predictions using the sentence-level model

In [None]:
from torch.nn.functional import softmax
from detailed_labels_handler import decode_detailed_labels

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model.to(device)

def predict_page(sentences, sentence_embeddings=None, prob_threshold = None):
    '''
    Function to predict the labels for a page of sentences using the sentence-level model.
    '''
    
    # Calculate sentence embeddings
    if sentence_embeddings is None:
        sentence_embeddings = calculate_sentence_embeddings(sentences)
    sentence_tensors = [torch.tensor(embedding, dtype=torch.float) for embedding in sentence_embeddings]
    
    if not sentence_tensors:
        print("Warning: No sentences or sentence embeddings found. Skipping...")
        
        
    # Convert list of embeddings into a single tensor
    inputs = torch.stack(sentence_tensors).to(device)
    
    inputs = inputs.unsqueeze(1)  # Add batch dimension

    # Get predictions
    model.eval()
    with torch.no_grad():
        logits, *model_outputs = model(inputs)
        detailed_logits = model_outputs[0] if model_outputs else None
        
    # Handle shape of logits to properly extract probabilities
    if logits.dim() == 3 and logits.shape[1] == 1:  # Shape [seq_len, 1, num_classes]
        bio_probabilities = softmax(logits, dim=-1).squeeze(1)  # Remove batch dimension, now [seq_len, num_classes]
    else:
        bio_probabilities = softmax(logits, dim=-1)
        
    bio_probability_lists = [probs.tolist() for probs in bio_probabilities]
    
    # Get BIO labels (main objective) based on threshold
    predicted_bio_tags = []
    for probs in bio_probabilities:
        if prob_threshold and probs[1] > prob_threshold: 
            predicted_bio_tags.append(index_to_label[1])
        else:
            predicted_label_index = probs.argmax().item()
            predicted_bio_tags.append(index_to_label[predicted_label_index])
    
    # Get detailed labels (auxiliary objective) if available
    detailed_predictions = []
    if detailed_logits is not None and HIER_LABELS_LEVELS:
        encoded_detailed_labels_per_sentence = [[] for _ in range(len(sentences))]
        for level_logits in detailed_logits:
            level_probabilities = softmax(level_logits, dim=-1)
            predicted_detailed_labels = torch.argmax(level_probabilities, dim=-1).squeeze()
            if predicted_detailed_labels.dim() == 0:
                # in case of only one senetnce --> scalar
                encoded_detailed_labels_per_sentence[0].append(predicted_detailed_labels.item())
            else:
                # Otherwise, iterate as normal
                for sentence_idx, label in enumerate(predicted_detailed_labels):
                    encoded_detailed_labels_per_sentence[sentence_idx].append(label.item())

        # decode detailed labels using helper function
        for encoded_labels in encoded_detailed_labels_per_sentence:
            decoded_label = decode_detailed_labels([encoded_labels], mapping_dicts, HIER_LABELS_LEVELS)
            detailed_predictions.append(decoded_label)
    else:
        detailed_predictions = ["N/A" for _ in sentences]

    # zip predictions data together
    sentence_tag_details = list(zip(sentences, predicted_bio_tags, detailed_predictions, bio_probability_lists))

    return sentence_tag_details

# Exemplary usage, using a page from the data_dict (not available here due to confidentiaity reasons)
doc_id = "27367085"
page_index = "2"

# Split page into  sentences using either fulltext or blocks
if SPLIT_BASE == "full_text":
    page_text = data_dict[doc_id][page_index]['full_text']
    sentences = sentenizer.tokenize_into_sentences(page_text)
elif SPLIT_BASE == "blocks":
    blocks = data_dict[doc_id][page_index]['blocks']
    sentences = []
    # Iterate over each block and tokenize its text into sentences
    for block in blocks:
        block_text = block['text']  # Extract the text for the current block
        block_sentences = sentenizer.tokenize_into_sentences(block_text)
        sentences.extend(block_sentences)
else:
    raise

# Get predictions for the page
sentence_tag_details = predict_page(sentences, prob_threshold=0.4)

# Create snippets (cell below must be run first)
snippets = create_snippets(sentence_tag_details, look_ahead=0)

# Print
for snippet in snippets:
    print("-------------------")
    print(snippet)
print("-" * 100)
# Print each sentence with its predicted BIO tag and detailed labels
for sentence, bio_tag, detailed_tags_str, bio_probs in sentence_tag_details:
    print(f"Sentence: {sentence[:100]}...")
    print(f"BIO Tag: {bio_tag} ({bio_probs})")
    print(f"Detailed Tags: aml-{detailed_tags_str}")
    print("-" * 100)

In [13]:
# testing decoding of detailed labels
decode_detailed_labels([[5, 35]], mapping_dicts=mapping_dicts)

'customeridentification'

### Convert into snippets and evaluate

#### Custom post-processing function for the sentence-level model to create snippets

In [11]:
def create_snippets(sentence_tag_pairs, look_ahead=0):
    snippets = []
    current_snippet = []
    buffer = []

    def get_upcoming_tags(index, look_ahead, pairs):
        '''
        Helper to get the upcoming tags for a given index and look-ahead window
        '''
        return [pair[1] for pair in pairs[index + 1:index + 1 + look_ahead]]

    for index, (sentence, tag, *_) in enumerate(sentence_tag_pairs):
        # Handle either binary or BIO case
        # "B" is the beginning of a snippet
        # "I" is the continuation of a snippet
        # "O" is outside of a snippet
        # For "O", we check if the buffer contains a snippet and if the upcoming tags do not contain "B" or "I"
        if SENTENCE_MODEL_IS_BINARY:
            if tag == "B":
                current_snippet.append(sentence)
            elif tag == "O":
                if current_snippet:
                    buffer.append(sentence)
                    upcoming_tags = get_upcoming_tags(index, look_ahead, sentence_tag_pairs)
                    if 'B' not in upcoming_tags:
                        current_snippet.extend(buffer)
                        snippets.append(' '.join(current_snippet))
                        current_snippet = []
                        buffer = []
                    else:
                        current_snippet.extend(buffer)
                        buffer = []
        else:
            if tag == "B":
                if current_snippet:
                    snippets.append(' '.join(current_snippet))
                    current_snippet = []
                current_snippet.extend(buffer)
                buffer = []
                current_snippet.append(sentence)
            elif tag == "I":
                current_snippet.extend(buffer)
                buffer = []
                current_snippet.append(sentence)
            elif tag == "O":
                if current_snippet:
                    buffer.append(sentence)
                    upcoming_tags = get_upcoming_tags(index, look_ahead, sentence_tag_pairs)
                    if 'B' not in upcoming_tags and 'I' not in upcoming_tags:
                        current_snippet.extend(buffer)
                        snippets.append(' '.join(current_snippet))
                        current_snippet = []
                        buffer = []
                    else:
                        current_snippet.extend(buffer)
                        buffer = []

    # Handle any remaining snippets
    if current_snippet:
        current_snippet.extend(buffer)
        snippets.append(' '.join(current_snippet))

    return snippets


Collect doc_ids and pages of test_set

In [12]:
test_docs_pages = {}

for sample in tqdm(dataset_dict['test']):
    # Extract document ID and page index from each sample
    doc_id = sample['metadata']['doc_id']
    page_index = sample['metadata']['page_id'] 
    if doc_id not in test_docs_pages:
        test_docs_pages[doc_id] = [page_index]
    else:
        # If the document ID already in the dictionary, append current page to its list if not already present
        if page_index not in test_docs_pages[doc_id]:
            test_docs_pages[doc_id].append(page_index)
# sort pages
for doc_id in test_docs_pages.keys():
    test_docs_pages[doc_id] = sorted(test_docs_pages[doc_id], key=lambda x: int(x))

print(test_docs_pages)

  0%|          | 0/6908 [00:00<?, ?it/s]

{'20541584': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81'], '20277318': ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '10', '11', '12', '13', '14', '15', '16', '17', '18', '19', '20', '21', '22', '23', '24', '25', '26', '27', '28', '29', '30', '31', '32', '33', '34', '35', '36', '37', '38', '39', '40', '41', '42', '43', '44', '45', '46', '47', '48', '49', '50', '51', '52', '53', '54', '55', '56', '57', '58', '59', '60', '61', '62', '63', '64', '65', '66', '67', '68', '69', '70', '71', '72', '73', '74', '75', '76', '77', '78', '79', '80', '81', '82', '8

Make predictions for test data

In [13]:
predicted_snippets_path = f'{DATA_PATH}/predicted_snippets/predicted_snippets_sentence_level_{DOMAIN}_17-05.pkl'

In [14]:
import time

# Initialise new dictionary for storing predicted snippets
predicted_snippets_dict = {}
total_time = 0
number_of_documents = len(test_docs_pages)
number_of_pages = 0

for doc_id in tqdm(test_docs_pages):
    if doc_id not in predicted_snippets_dict:
        predicted_snippets_dict[doc_id] = {}
    
    for page_index in test_docs_pages[doc_id]:
        
        # 1. Split the page into sentences
        if SPLIT_BASE == "full_text":
            page_text = data_dict[doc_id][page_index]['full_text']
            sentences = sentenizer.tokenize_into_sentences(page_text)
        elif SPLIT_BASE == "blocks":
            blocks = data_dict[doc_id][page_index]['blocks']
            sentences = []
            for block in blocks:
                block_text = block['text']
                block_sentences = sentenizer.tokenize_into_sentences(block_text)
                sentences.extend(block_sentences)
        else:
            raise ValueError("Invalid SPLIT_BASE value")
        
        # print(doc_id, page_index)
        start_time = time.time()
        sentence_tag_details = predict_page(sentences, prob_threshold=0.4)
        end_time = time.time()
        total_time += end_time - start_time
        snippets = create_snippets(sentence_tag_details, look_ahead=1)
        
        # Create snippets in structure expected by the evaluation script and similar to the data_dict
        predicted_snippets = []
        for snippet in snippets:
            snippet_text = snippet
            snippet_tokenized_text = tokenizer.tokenize(snippet_text)
            dict_to_add = {'text': snippet_text, 'tokenized_text': snippet_tokenized_text}
            predicted_snippets.append(dict_to_add)
        
        # Add predicted snippets for the current page to the new dictionary
        predicted_snippets_dict[doc_id][page_index] = predicted_snippets
        number_of_pages += 1

with open(predicted_snippets_path, 'wb') as f:
    pickle.dump(predicted_snippets_dict, f)
    
print(f"Total time taken: {total_time:.2f} seconds")
print(f"Predicted snippets saved to {predicted_snippets_path}.")

  0%|          | 0/173 [00:00<?, ?it/s]

Token indices sequence length is longer than the specified maximum sequence length for this model (612 > 512). Running this sequence through the model will result in indexing errors


Total time taken: 129.26 seconds
Predicted snippets saved to /home/tlh45/rds/hpc-work/predicted_snippets/predicted_snippets_sentence_level_AML_17-05.pkl.


#### Update data_dict with predictions

In [17]:
# Load from pickle file
with open(predicted_snippets_path, 'rb') as f:
    loaded_predicted_snippets_dict = pickle.load(f)

# update data_dict with predicted snippets
for doc_id in tqdm(loaded_predicted_snippets_dict):
    if doc_id in data_dict:
        for page_index in loaded_predicted_snippets_dict[doc_id]:
            if page_index in data_dict[doc_id]:
                
                # Update the specific page with predicted snippets
                data_dict[doc_id][page_index]['predicted_snippets'] = loaded_predicted_snippets_dict[doc_id][page_index]
                
                # Change gold data to match tokenization approach of sentence level model!!!!
                # Senetnce level removes formatting!!!!
                
                # 1. get sentences for full text:
                blocks = data_dict[doc_id][page_index]['blocks']
                sentences = []
                for block in blocks:
                    block_text = block['text']
                    block_sentences = sentenizer.tokenize_into_sentences(block_text)
                    sentences.extend(block_sentences)
                # concat sentences to obtain full text
                full_text = ' '.join(sentences)
                
                # 2. tokenize full text
                tokenized_text = tokenizer.tokenize(full_text)
                
                data_dict[doc_id][page_index]['full_text'] = full_text
                data_dict[doc_id][page_index]['tokenized_text'] = tokenized_text
                
                # 3. chnage refined regions to match sentence level tokenization (important step for evalaution consistency!)
                refined_regions_original = list(data_dict[doc_id][page_index]['refined_regions'])  # Convert tuple to list
                for idx, refined_region_original in enumerate(refined_regions_original):
                    refined_region_original_text = refined_region_original['text']
                    sentences = sentenizer.tokenize_into_sentences(refined_region_original_text)
                    reconstructed_text = ' '.join(sentences)
                    reconstructed_tokens = tokenizer.tokenize(reconstructed_text)
                    new_dict = {'text': reconstructed_text, 'tokenized_text': reconstructed_tokens, 'tags': refined_region_original['tags']}
                    refined_regions_original[idx] = new_dict
                data_dict[doc_id][page_index]['refined_regions'] = tuple(refined_regions_original)  # Convert back to tuple
            else:
                # Handle cases where the page_index does not exist in data_dict[doc_id]
                print(f"Page index {page_index} for document ID {doc_id} was not found in data_dict.")
    else:
        # Handle cases where the doc_id does not exist in data_dict
        print(f"Document ID {doc_id} was not found in data_dict.")

  0%|          | 0/173 [00:00<?, ?it/s]

#### Evaluate snippets vs. ground-truth

In [18]:
import importlib
import evaluator
importlib.reload(evaluator)

metrics_config = {
    'iou': True,
    'bleu': True,
    'jaccard': True,
    'precision': True,
    'recall': True,
    'f1': True,
    'precision_region_lvl': True,
    'recall_region_lvl': True,
    'f1_region_lvl': True,
    'edit_distance': True,
    'rouge-1-f': True,
    'rouge-2-f': True,
    'rouge-l-f': True,
    'pk': False,
    'windowdiff': False,
    'cohen_kappa': False
}

aggregated_metrics_snippets = evaluator.evaluate_snippets_parallel(data_dict, "predicted_snippets", "refined_regions", metrics_config, DEBUG=False)
aggregated_metrics_snippets['inference_time'] = total_time
aggregated_metrics_snippets['pages'] = number_of_pages
aggregated_metrics_snippets['batch_size'] = "All sentences of a page"
aggregated_metrics_snippets['GPU'] = torch.cuda.get_device_properties(0).name
print(aggregated_metrics_snippets)

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6154113052658
Average Recall Score: 0.7528642310241872
Average F1 Score: 0.6772337103370286
Average Iou Score: 0.6049309109738842
Average Bleu Score: 0.5151440009778114
Average Jaccard Score: 0.6051036822247481
Average Edit_distance Score: 138.73980649619904
Average Precision_region_lvl Score: 0.7012859916489036
Average Recall_region_lvl Score: 0.7951105926751747
Average F1_region_lvl Score: 0.6608763024807144
Average Rouge-1-f Score: 0.7215226353820612
Average Rouge-2-f Score: 0.6673863751875584
Average Rouge-l-f Score: 0.7175806180757638
{'precision': 0.6154113052658, 'recall': 0.7528642310241872, 'f1': 0.6772337103370286, 'iou': 0.6049309109738842, 'bleu': 0.5151440009778114, 'jaccard': 0.6051036822247481, 'edit_distance': 138.73980649619904, 'precision_region_lvl': 0.7012859916489036, 'recall_region_lvl': 0.7951105926751747, 'f1_region_lvl': 0.6608763024807144, 'rouge-1-f': 0.7215226353820612, 'rouge-2-f': 0.6673863751875584, 'rouge-l-f': 0.71758061807576

In [19]:
DOMAIN

'AML'

In [20]:
# Save to file
with open(f"Evaluation/{DOMAIN}/sentence_eval_final_{PRETRAINED_MODEL}_different_tokenization.json", 'w') as f:
    json.dump(aggregated_metrics_snippets, f, indent=4)

#### Evaluate similarity between blocks and refined regions
Note: this is unrelated to the sentence-level model and simply evaluates the similarity between blocks and regions

In [5]:
import importlib
import evaluator
importlib.reload(evaluator)

metrics_config = {
    'iou': True,
    'bleu': True,
    'jaccard': True,
    'precision': True,
    'recall': True,
    'f1': True,
    'precision_region_lvl': True,
    'recall_region_lvl': True,
    'f1_region_lvl': True,
    'edit_distance': True,
    'rouge-1-f': True,
    'rouge-2-f': True,
    'rouge-l-f': True,
    'pk': False,
    'windowdiff': False,
    'cohen_kappa': False
}

aggregated_metrics_snippets = evaluator.evaluate_snippets_parallel(data_dict, "blocks", "refined_regions", metrics_config, DEBUG=False)

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.8613154066652959
Average Recall Score: 0.818381861647683
Average F1 Score: 0.8392999372804058
Average Iou Score: 0.807934991304034
Average Bleu Score: 0.7315791964983502
Average Jaccard Score: 0.8081440359699109
Average Edit_distance Score: 81.82154553689638
Average Precision_region_lvl Score: 0.9091115198953504
Average Recall_region_lvl Score: 0.846457969011016
Average F1_region_lvl Score: 0.8371476073282121
Average Rouge-1-f Score: 0.8747242314469701
Average Rouge-2-f Score: 0.8557768136541352
Average Rouge-l-f Score: 0.874457730429435
