# Notebook for token-level model
(Section 5.3 of the report)

Contains the following token-level model specific components:
- Code for dataset creation
- Post-processing algorithm
- Evaluation code

![](../Misc/token_level_model.jpg)

### Import libraries

In [2]:
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"

import sys
from tqdm.notebook import tqdm
import pickle
import json
from sklearn.model_selection import train_test_split
import pandas as pd

sys.path.append('../DataPreprocessing')
from read_into_dicts import DocReader

# Model
from transformers import Trainer, TrainingArguments, AutoTokenizer, EarlyStoppingCallback, AutoConfig, AutoModelForTokenClassification
import torch
from datasets import Dataset, DatasetDict
import gc
from multiprocessing import Pool
from token_level_model import TokenClassificationWithDetailedLabels

# WANDB
import wandb
import random
import string
import yaml

2024-06-01 15:40:42.877073: I tensorflow/core/util/port.cc:113] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-01 15:40:45.174767: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.


### Define parameters

In [3]:
# Load configuration
with open('config.yaml', 'r') as config_file:
    config = yaml.safe_load(config_file)

# GENERAL:
DEBUG = config["general"]["DEBUG"]
MODEL = config["general"]["MODEL"]
DATA_SIZE = config["general"]["DATA_SIZE"]
DOMAIN = config["general"]['DATA_DICT_FILE_PATH']['DOMAIN']
DATA_DICT_FILE_PATH = config["general"]['DATA_DICT_FILE_PATH'][DOMAIN]
DATA_PATH = config["general"]["DATA_PATH"]
SPLITS_JSON = config["general"]["SPLITS_JSON"]
TRAIN_SIZE = config["general"]["TRAIN_SIZE"]
VAL_SIZE = config["general"]["VAL_SIZE"]


# MODEL CONFIGS
DEFAULT_MODEL = config["token_level_model"]["DEFAULT_MODEL"]
RETOKENIZATION_NEEDED = config["token_level_model"]["RETOKENIZATION_NEEDED"]
DATASET = config["token_level_model"]["DATASET"]
PRETRAINED_MODEL = config["token_level_model"]['PRETRAINED_MODEL']
CHECKPOINT = config["token_level_model"]["CHECKPOINT"]
EPOCHS = config["token_level_model"]["EPOCHS"]
EARLY_STOPPING_PATIENCE = config["token_level_model"]["EARLY_STOPPING_PATIENCE"]
BATCH_SIZE = config["token_level_model"]["BATCH_SIZE"]
OVERLAP = config["token_level_model"]["SLIDING_WINDOW_OVERLAP"]

# Detailed labels
ALPHA = config["token_level_model"]["ALPHA"]
HIER_LABELS_LEVELS = config["token_level_model"]["HIER_LABELS_LEVELS"]  # also acts as flag if detailed models shall be used (i.e. model with auxiliary objective trained)
DETAILED_LABEL_WEIGHTS = config["token_level_model"]["DETAILED_LABEL_WEIGHTS"]
NUMBER_OF_LEVELS = len(HIER_LABELS_LEVELS)

tokenizer = AutoTokenizer.from_pretrained(MODEL)

# define special tokens
cls_token = tokenizer.cls_token if hasattr(tokenizer, 'cls_token') else '[CLS]'
sep_token = tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else '[SEP]'
pad_token = tokenizer.pad_token if hasattr(tokenizer, 'pad_token') else '[PAD]'

# choose mapping - NOTE: Must also be adpated in dtaaset creation code!
label_to_index = {'O': 0, 'B': 1, 'I': 2, 'E': 3, '-100': -100}   # BIOE approach!
# label_to_index = {'O': 0, 'B': 1, 'I': 2, '-100': -100}     # BIO approach!

index_to_label = {v: k for k, v in label_to_index.items()} # reverse of label_to_index
NUMBER_BIO_LABELS = len(label_to_index) -1 # -1 for -100 above

print(f"{MODEL=}")
print(f"{DOMAIN=}")
print(f"{DATA_DICT_FILE_PATH=}")
print(f"{DATASET=}")
print(f"{OVERLAP=}")
print(f"{NUMBER_BIO_LABELS=}")
print(f"{EPOCHS=}")
print(f"{PRETRAINED_MODEL=}")

# Generate a random tag of 6 characters or choose defined one
unique_tag = config["general"]["UNIQUE_TAG"] if config["general"]["UNIQUE_TAG"] else ''.join(random.choices(string.ascii_letters + string.digits, k=6))      

MODEL='roberta-base'
DOMAIN='AML'
DATA_DICT_FILE_PATH='path/to/aml_data_dict.pkl'
DATASET='Model/datasets/token-level/<huggingface_dataset_dict_for_token-level_model.hf>'
OVERLAP=256
NUMBER_BIO_LABELS=4
EPOCHS=15
PRETRAINED_MODEL='or6jqv'


### Set wandb settings

In [3]:
os.environ['WANDB_DIR'] = f"{DATA_PATH}/Model"
os.environ['WANDB_CACHE_DIR'] = f"{DATA_PATH}/Model"

wandb_init = {
    "project": config["wandb"]["WANDB_PROJECT"],    
    "tags": [unique_tag, f"MODEL={MODEL}", f"DATA_SIZE={DATA_SIZE}"],
    "group": config["wandb"]["WANDB_GROUP"],
    "name": f'{config["wandb"]["WANDB_GROUP"]}-TOKEN_LEVEL-{MODEL}'
}
wandb.login()
run = wandb.init(**wandb_init)
        
run_id = run.id
config_dict = {
    "model": MODEL,
    "data_size": DATA_SIZE,
    "train_size": TRAIN_SIZE,
    "val_size": VAL_SIZE,
    "test_size": 1 - TRAIN_SIZE - VAL_SIZE,
    "epochs": EPOCHS,
    "batch_size": BATCH_SIZE,
    "sliding_window_overlap": OVERLAP,
    "hierarchical_labels_levels": HIER_LABELS_LEVELS,
    "run_id": run_id,
    "dataset": str(DATASET)
}
wandb.config.update(config_dict, allow_val_change=True)


[34m[1mwandb[0m: Currently logged in as: [33mtlh45[0m. Use [1m`wandb login --relogin`[0m to force relogin


### Load Full Data

In [3]:
docReader = DocReader(MODEL, tokenizer)
create_original_snippets = True
add_full_page = True

if DATA_DICT_FILE_PATH:
    print(f"Load data_dict from {DATA_PATH}/{DATA_DICT_FILE_PATH}.")
    with open(f'{DATA_PATH}/{DATA_DICT_FILE_PATH}', 'rb') as handle:
        data_dict = pickle.load(handle)
    print("Done.")
else:
    data_dict = docReader.preprocess_folder(preprocess=False, folder_path= f'{DATA_PATH}/28245V231219'
                                        , data_size=DATA_SIZE, num_workers=76, chunksize=1
                                        , extract_title=True, extract_doc_long_id=True, refine_regions=True, create_original_snippets=create_original_snippets, add_full_page=add_full_page
                                        , data_dict_folder=f"{DATA_PATH}/preprocessing/data_dicts/28245V231219_AML", file_name_additional_suffix="-06-04")

Load data_dict from /home/tlh45/rds/hpc-work/preprocessing/data_dicts/28245V231219_AML/data_dict_roberta-base-10-04.pkl.
Done.


In [7]:
# save data_dict file
if not DATA_DICT_FILE_PATH:
    with open(f'{DATA_PATH}/preprocessing/data_dicts/44704v240404_CYBER/data_dict_{MODEL}-06-04.pkl', 'wb') as handle:
        pickle.dump(data_dict, handle, protocol=pickle.HIGHEST_PROTOCOL)

#### Logic to create training, validation, and test data doc ids --> create centralised data split file
(Implemented here, but used across all approaches)

In [8]:
if False: # uncomment to run this code (only needed once to create splits)
    
    def extract_metadata(data_dict):
        '''
        Extract metadata from data_dict
        '''
        metadata = {}
        for doc_id, doc_data in data_dict.items():
            num_pages = len([page for page in doc_data if page not in ['title', 'doc_long_id']])
            metadata[doc_id] = num_pages
        return metadata

    def stratify_and_split(metadata, random_state=42, train_size=TRAIN_SIZE, val_size=VAL_SIZE):
        '''
        Calculate stratified splits based on metadata and split IDs in data_dict into train, validation and test sets
        '''
        # Convert metadata to lists
        doc_ids = list(metadata.keys())
        num_pages = list(metadata.values())

        test_size = 1.0 - train_size - val_size
        
        # Stratify based on number of pages
        quartiles = pd.qcut(num_pages, 4, labels=False)
        
        # First split to separate out training
        train_ids, test_val_ids, _, test_val_quartiles = train_test_split(doc_ids, quartiles, test_size=test_size+val_size, stratify=quartiles, random_state=random_state)
        # Adjust test_size for second split based on remaining documents and create second split
        test_size_adjusted = test_size / (test_size + val_size)
        val_ids, test_ids, _, _ = train_test_split(test_val_ids, test_val_quartiles, test_size=test_size_adjusted, stratify=test_val_quartiles, random_state=random_state)
        
        # Log sizes
        sizes = {'train': len(train_ids), 'validation': len(val_ids), 'test': len(test_ids)}

        return train_ids, val_ids, test_ids, quartiles, sizes, random_state

    def save_splits_as_json(train_ids, val_ids, test_ids, quartiles, sizes, random_state):
        '''
        Write information about splits to a json file
        '''
        splits = {
            'train_ids': train_ids,
            'val_ids': val_ids,
            'test_ids': test_ids,
            'quartiles_used': quartiles.tolist(),
            'sizes': sizes,
            'random_state': random_state
        }
        total = len(train_ids)+len(val_ids)+len(test_ids)
        with open(f"document_splits_for_datasets-{total}.json", 'w') as f:
            json.dump(splits, f, indent=4)

    metadata = extract_metadata(data_dict)
    train_ids, val_ids, test_ids, quartiles, sizes, random_state = stratify_and_split(metadata)
    save_splits_as_json(train_ids, val_ids, test_ids, quartiles, sizes, random_state)

### Data Preprocessing

#### Load detailed labels

In [4]:
# Extract and analyze tags
import detailed_labels_handler
import importlib
importlib.reload(detailed_labels_handler)


tag_statistics, mapping_dicts = detailed_labels_handler.extract_and_analyze_tags(data_dict, HIER_LABELS_LEVELS)

# Calculate unique tag counts
unique_level_tags = {level: len(tags) for level, tags in tag_statistics.items()}
total_unique_tags = sum(len(tags) for tags in tag_statistics.values())

    
if not HIER_LABELS_LEVELS:
    print("Note: no HIER_LABELS_LEVELS provided.")
# Print statistics and unique counts
for level, tags in tag_statistics.items():
    print(f"{level.capitalize()} Tags:", tags)
    print(f"Unique {level.capitalize()} Tags:", unique_level_tags[level])

# Print full tag count if available
if 'full_tags' in tag_statistics:
    print("Full Tags Count:", tag_statistics['full_tags'])
print("Total Unique Tags:", total_unique_tags)

# Print mapping dictionaries
for level, mapping_dict in mapping_dicts.items():
    print(f"{level.capitalize()} Mapping Dictionary:", mapping_dict)

# Example below shows how detauled labels are a list of lists, with each sublist containing the detailed labels for a token
encoded_labels = [[0, 0, 0, 0], [0, 0, 1, 1]]
decoded_labels = detailed_labels_handler.decode_detailed_labels(encoded_labels, mapping_dicts, HIER_LABELS_LEVELS)
print("Decoded Labels:", decoded_labels)

Note: no HIER_LABELS_LEVELS provided.
Total Unique Tags: 0
Decoded Labels: []


### Dataset creation (Section 5.3.2)

In [5]:
# define output folder here
output_folder = f'{DATA_PATH}/preprocessing/token_level_temp_files/temp_token_level_singles_{str(MODEL).replace("/", "-")}'
skipped_regions = []

def tokenize_and_create_labels(tokenized_page_text, refined_regions, max_length=512, overlap=50, doc_id=None, page_index=None):
    '''
    Labelling function to create BIO(E) labels for tokenized text using indeces identified during data pre-processing
    '''
    
    # Initialise labels for main and auxiliary objective with "outside region" labels
    labels = ['O'] * len(tokenized_page_text)
    detailed_labels = ['N/A'] * len(tokenized_page_text)

    # Region labeling logic
    for region_index, refined_region in enumerate(refined_regions):
        if RETOKENIZATION_NEEDED:
            region_text = refined_region['text']
            region_tokens = tokenizer.tokenize(region_text)
        else:
            region_tokens = refined_region['tokenized_text'] # simply pull tokens from data_dict
            
        if len(region_tokens) == 0:
            continue #jump to next region if region is empty
        
        region_tags = ';'.join(refined_region.get('tags', [])) # extract detailed labels ("tags"), if available
        
        # Extract IDs as identified during pre-processing
        start_idx = refined_region['start_idx_in_page']
        end_idx = refined_region['end_idx_in_page']
        if end_idx >= len(labels) or start_idx == -1 or end_idx == -1: # skip regions that could not be identified during preprocessing
            print("WARNING:")
            print(f"{end_idx=}, len of labels: {len(labels)}")
            print(f"Skipping {doc_id=}, {page_index=}.")
            skipped_regions.append((doc_id, page_index, region_index))
            continue
        
        # Simply assign labels
        labels[start_idx] = 'B'
        for i in range(start_idx + 1, end_idx):
            if i < len(labels):
                labels[i] = 'I'
        labels[end_idx] = 'E' # need to mark here whether BIOE ("E") or BIO ("I")!

        # For auxiliary objective, we mark the inside region labels with the region tags
        if HIER_LABELS_LEVELS:
            for idx in range(start_idx, end_idx+1):
                detailed_labels[idx] = region_tags if region_tags else "N/A"

    # Sliding window approach: split now labelled full page into sequences with overlaps
    # Necessary step due to context window of BERT-based models
    start = 0
    page_length = len(tokenized_page_text)
    while start < page_length:
        end = min(start + max_length - 2, page_length) # account for special start and end tokens
        
        # Adjust end to avoid splitting a word!
        while end < page_length and tokenized_page_text[end].startswith('##') and end - start < max_length - 2:
            end += 1

        # Define tokens belonging to this window/sequence
        seq_tokens = tokenized_page_text[start:end]
        seq_labels = labels[start:end]
        seq_detailed_labels = detailed_labels[start:end]
        yield (seq_tokens, seq_labels, seq_detailed_labels)

        # Update start for next window considering overlap
        start = max(start + overlap, end - overlap) if end < page_length else end


def finalize_sequences(sequences, tokenizer, max_length=512):
    '''
    Function to finalise all windows/sequences by adding special tokens and padding to max_length
    '''
    
    # Define special tokens according to tokenizer used (i.e. special tokens are different between models such as BERT and RoBERTa)
    # --> pull from corresponding tokenizer
    cls_token = tokenizer.cls_token if hasattr(tokenizer, 'cls_token') else '[CLS]'
    sep_token = tokenizer.sep_token if hasattr(tokenizer, 'sep_token') else '[SEP]'
    pad_token = tokenizer.pad_token if hasattr(tokenizer, 'pad_token') else '[PAD]'

    # Iterate over all sequences
    for tokens, bio_labels, detailed_labels in sequences:
        # Initial len before adding special tokens or padding
        original_len = len(tokens)

        # Add special tokens
        tokens = [cls_token] + list(tokens) + [sep_token] # beginning and end tokens
        bio_labels = ['-100'] + list(bio_labels) + ['-100']  # add label '-100' for special tokens
        detailed_labels = ['-100'] + list(detailed_labels) + ['-100']  # add label '-100' for special tokens
        if len(tokens) > max_length:
            print(f"WARNING: length of tokens ({len(tokens)}) exceed max length of {max_length}!")

        # Padding
        padding_length = max_length - len(tokens)
        tokens.extend([pad_token] * padding_length)
        bio_labels.extend(['-100'] * padding_length)  # add label '-100' for padding tokens
        detailed_labels.extend(['-100'] * padding_length)  # add label '-100' for padding tokens

        # Convert tokens to IDs
        input_ids = tokenizer.convert_tokens_to_ids(tokens)
        attention_mask = [1 if token != pad_token else 0 for token in tokens] # create attention mask: 0 for padding tokens, 1 for real tokens

        # Convert labels to numerical values
        numerical_bio_labels = [-100] * max_length  # initialise all labels to -100
        numerical_bio_labels[1:original_len+1] = [label_to_index[label] for label in bio_labels[1:original_len+1]]  # align labels with tokens (given special tokens)

        numerical_detailed_labels = [[-1]*NUMBER_OF_LEVELS] * max_length  # initialise all labels to -100
        numerical_detailed_labels[1:original_len+1] = detailed_labels_handler.encode_detailed_labels(detailed_labels[1:original_len+1], mapping_dicts, NUMBER_OF_LEVELS) # encoding via mapping dictionaries

        yield (input_ids, attention_mask, numerical_bio_labels, numerical_detailed_labels)
        
def process_documents_batch(batch_args):
    '''
    Process batch of documents
    '''
    results_files = []
    for args in batch_args:
        results_file = process_document(args)
        results_files.append(results_file)
    return results_files

def process_document(args):
    '''
    Process a single document
    '''
    doc_id, pages, tokenizer, max_length, overlap = args # unload arguments

    results = []
    for page_index, page_content in pages.items(): # go through all pages of the document
        if page_index in ['title', 'doc_long_id']:
            continue
        
        if RETOKENIZATION_NEEDED:
            tokenized_page_text = tokenizer.tokenize(page_content['full_text'])
        else:
            tokenized_page_text = page_content['tokenized_full_text']
        refined_regions = page_content['refined_regions']

        seq_index = 0
        # Tokenize and create labels for each window/sequence
        for seq in tokenize_and_create_labels(tokenized_page_text, refined_regions, max_length=max_length, overlap=overlap, doc_id=doc_id, page_index=page_index):
            # Dinalise each window/sequence on the page
            for finalized_seq in finalize_sequences([seq], tokenizer, max_length=max_length):
                input_ids, attention_mask, numerical_bio_labels, numerical_detailed_labels = finalized_seq
                metadata = {'doc_id': doc_id, 'page_id': page_index, 'seq_index': seq_index} # add metadata
                # Construct sample
                sample = {
                    'input_ids': input_ids
                    , 'attention_mask': attention_mask
                    , 'bio_labels': numerical_bio_labels
                    , 'detailed_labels': numerical_detailed_labels
                    , 'metadata': metadata
                }
                # Add to results
                results.append(sample)
                seq_index += 1

    # Write into temporary file (necessary for more efficient memory handling)
    os.makedirs(output_folder, exist_ok=True)
    results_file = os.path.join(output_folder, f"doc_{doc_id}_results.pkl")
    with open(results_file, 'wb') as file:
        pickle.dump(results, file)

    return results_file

def load_processed_data(processed_files, document_splits):
    '''
    Load processed data and split into train, validation and test sets according to split definitions
    '''
    train_ids = set(document_splits['train_ids'])
    val_ids = set(document_splits['val_ids'])
    test_ids = set(document_splits['test_ids'])
        
    train_samples = []
    val_samples = []
    test_samples = []

    for results_file in tqdm(processed_files):
        with open(results_file, 'rb') as file:
            doc_results = pickle.load(file)
            for sample in doc_results:
                if sample['metadata']['doc_id'] in train_ids:
                    train_samples.append(sample)
                if sample['metadata']['doc_id'] in val_ids:
                    val_samples.append(sample)
                if sample['metadata']['doc_id'] in test_ids:
                    test_samples.append(sample)

    return train_samples, val_samples, test_samples


# MAIN ENTRY POINT
# Process and tokenize the data
# Create or load dataset
if DATASET is not None:
    print(f"Load Datadict from {DATA_PATH}/{DATASET}.")
    dataset_dict = DatasetDict.load_from_disk(f"{DATA_PATH}/{DATASET}")
    print("Datadict successfully loaded.")
else:
    # Create arguments
    args = [(doc_id, pages, tokenizer, 512, OVERLAP) for doc_id, pages in data_dict.items()]

    # Debug: analyse process for a single document
    if DEBUG:
        for arg in tqdm(args):
            if arg[0] == "17222128":
                results = process_document(arg)
    else:
        print("Check which files need processing.")
        os.makedirs(output_folder, exist_ok=True)
        existing_files = set(os.listdir(output_folder))
        already_processed_doc_ids = {file_name.split('_')[1] for file_name in existing_files if file_name.endswith('.pkl')}

        # Filter args to exclude already processed documents
        args = [arg for arg in args if str(arg[0]) not in already_processed_doc_ids]

        if args:
            del data_dict
            gc.collect()

            print("Initiate Dataset Creation.")
            batch_size = 10  # Number of documents per batch

            with Pool(1) as pool: # allows multiprocessing, but not needed here as process using pre-determined labels is already fast enough
                total_batches = (len(args) + batch_size - 1) // batch_size  # calculate total number of batches
                pbar = tqdm(total=total_batches, unit="batch", desc="Process documents in batches")  # progress bar

                for i in range(0, len(args), batch_size):
                    batch_args = [args[i:i + batch_size]]
                    for batch_output_files in pool.imap_unordered(process_documents_batch, batch_args):
                        pbar.update(1)  # update pb per completed batch
                    
                    # force garbage collection
                    del batch_args
                    gc.collect() 
                pbar.close()

            # force garbage collection
            del args
            gc.collect()
        else:
            print(f"NOTE: No preprocessing necessary. You might want to delete the {output_folder=} to force new preprocessing.")
        print(f"Load processed data from {output_folder}.")
        
        # Split dataset into different datasets
        with open(SPLITS_JSON, 'r') as f:
            document_splits = json.load(f)
            
        output_files = set(os.listdir(output_folder))
        output_files = [f"{output_folder}/{file}" for file in output_files]
        train_samples, val_samples, test_samples = load_processed_data(output_files, document_splits)

        # Create final dataset
        print("Create datasets.")
        train_dataset = Dataset.from_list(train_samples)
        val_dataset = Dataset.from_list(val_samples)
        test_dataset = Dataset.from_list(test_samples)

        # Combining splits into a DatasetDict
        dataset_dict = DatasetDict({
            'train': train_dataset,
            'validation': val_dataset,
            'test': test_dataset
        })

        # Save to disk
        dataset_path = f'{DATA_PATH}/Model/datasets/token-level/dataset_dict_{DATA_SIZE}-docs_{MODEL.replace("/", "-")}-model_win{OVERLAP}_BIO-24-04.hf'
        print(f"Save dataset_dict to {dataset_path}.")
        dataset_dict.save_to_disk(dataset_path)

Load Datadict from /home/tlh45/rds/hpc-work/Model/datasets/token-level/dataset_dict_1149-docs_roberta-base-model_win256-24-04.hf.
Datadict successfully loaded.


##### Manually analyse preprocessed data

In [6]:
# Load JSON splits
with open(SPLITS_JSON, 'r') as file:
    document_splits = json.load(file)
    
# Convert JSON lists to sets for easier comparison
json_train_ids = set(document_splits['train_ids'])
json_val_ids = set(document_splits['val_ids'])
json_test_ids = set(document_splits['test_ids'])

Assertion for correct document splits:
(same as in model script)

In [6]:
def extract_doc_ids_from_dataset(dataset):
    return {sample['metadata']['doc_id'] for sample in dataset}

print("Extract document IDs from train.")
train_doc_ids = set(extract_doc_ids_from_dataset(dataset_dict['train']))
print("Extract document IDs from validation.")
val_doc_ids = set(extract_doc_ids_from_dataset(dataset_dict['validation']))
print("Extract document IDs from test.")
test_doc_ids = set(extract_doc_ids_from_dataset(dataset_dict['test']))

# Validate match
assert train_doc_ids == json_train_ids, "Mismatch in train dataset document IDs"
assert val_doc_ids == json_val_ids, "Mismatch in validation dataset document IDs"
assert test_doc_ids == json_test_ids, "Mismatch in test dataset document IDs"

print("Validation successful: Datasets contain the correct document IDs according to the JSON splits.")

Extract document IDs from train.
Extract document IDs from validation.
Extract document IDs from test.
Validation successful: Datasets contain the correct document IDs according to the JSON splits.


### Create mapping between doc_id, page_index and sample in dataset:

In [6]:
train_dataset_index_map = {}
for idx, sample in enumerate(tqdm(dataset_dict['train'])):
    doc_id = sample['metadata']['doc_id']
    page_id = sample['metadata']['page_id']
    if (doc_id, page_id) not in train_dataset_index_map:
        train_dataset_index_map[(doc_id, page_id)] = idx

  0%|          | 0/74450 [00:00<?, ?it/s]

### Print individual samples

In [7]:
# We keep one example from the dataset here
doc_id = '18753674'
page_index = '71'

target_index = train_dataset_index_map.get((doc_id, page_index)) # use above mapping for fast accesses
if target_index is not None:
    sample = dataset_dict['train'][target_index]
    input_ids = tokenizer.convert_ids_to_tokens(sample['input_ids'])
    attention_mask = sample['attention_mask']
    bio_labels = sample['bio_labels']
    detailed_labels = sample['detailed_labels']

    # Print a header
    print(f"{'Token':<15} {'Attention Mask':<15} {'BIO Label':<10} {'Detailed Labels':<30}")
    print('-' * 70)

    # Print each token and its corresponding details
    for token, mask, bio, detail in zip(input_ids, attention_mask, bio_labels, detailed_labels):
        detail_str = str(detail) if isinstance(detail, list) else detail
        print(f"{token:<15} {mask:<15} {bio:<10} {detail_str:<30}")


Token           Attention Mask  BIO Label  Detailed Labels               
----------------------------------------------------------------------
<s>             1               -100       [-1, -1, -1, -1]              
A               1               0          [-1, -1, -1, -1]              
Ġ108            1               0          [-1, -1, -1, -1]              
Ċ               1               0          [-1, -1, -1, -1]              
SR              1               0          [-1, -1, -1, -1]              
O               1               0          [-1, -1, -1, -1]              
.               1               0          [-1, -1, -1, -1]              
Ġ6              1               0          [-1, -1, -1, -1]              
Ġ               1               0          [-1, -1, -1, -1]              
Ġ               1               0          [-1, -1, -1, -1]              
Ġ               1               0          [-1, -1, -1, -1]              
Ġ               1               0        

### Setup Model

In [None]:
from torch.nn.utils.rnn import pad_sequence
def token_classification_collate_fn(self, batch):
        '''
        Custom collate function for token-level model.
        Ensure proper/uniform padding for all elements in the batch.
        '''
        
        # MAIN OBJECTIVE
        # pad input_ids, attention_mask and bio_labels (i.e. main objective labels)
        input_ids = pad_sequence([torch.tensor(sample['input_ids'], dtype=torch.long) for sample in batch], batch_first=True, padding_value=self.tokenizer.convert_tokens_to_ids(self.tokenizer.pad_token if hasattr(self.tokenizer, 'pad_token') else '[PAD]')) # if possible, we ensure that the padding value is the same as the one used by the tokenizer
        attention_mask = pad_sequence([torch.tensor(sample['attention_mask'], dtype=torch.long) for sample in batch], batch_first=True, padding_value=0) # attention mask consists of 0s and 1s and is thefore padded with 0s
        bio_labels_padded = pad_sequence([torch.tensor(sample['bio_labels'], dtype=torch.long) for sample in batch], batch_first=True, padding_value=-100)  # -100 is used as the padding value for the main objective (BIO(E) labels)

        # put batch elements together in a dictionary
        batch_dict = {"input_ids": input_ids, "attention_mask": attention_mask, "bio_labels": bio_labels_padded}
        
        # AUXILIARY OBJECTIVE
        # pad detailed_labels if available
        if self.HIER_LABELS_LEVELS:
            batch_detailed_labels = []
            NUMBER_OF_LEVELS = len(self.HIER_LABELS_LEVELS)
            
            for sample in batch:
                detailed_labels = sample.get('detailed_labels', [])

                # filtered detailed labels for page
                filtered_token_labels_list = []

                for token_labels in detailed_labels:
                    # Filter the token's labels based on HIER_LABELS_LEVELS
                    # This step is necessary as the dataset contains all levels of labels, but the model only needs the specified levels
                    filtered_token_labels = [token_labels[idx] for idx in self.HIER_LABELS_LEVELS if idx < len(token_labels)]
                    filtered_token_labels_list.append(filtered_token_labels)
                
                # Pad detailed labels to ensure uniform length
                token_detailed_labels_padded = [labels + [-1] * (NUMBER_OF_LEVELS - len(labels)) for labels in filtered_token_labels_list]

                # Convert to tensor
                if token_detailed_labels_padded:
                    detailed_labels_tensor = torch.tensor(token_detailed_labels_padded, dtype=torch.long) 
                else:
                    # Create placeholder tensor if there are no detailed labels
                    detailed_labels_tensor = torch.full((1, NUMBER_OF_LEVELS), -1, dtype=torch.long)
                
                batch_detailed_labels.append(detailed_labels_tensor)
            
            detailed_labels_padded_uniform = pad_sequence(batch_detailed_labels, batch_first=True, padding_value=-1) # using -1 as padding value for detailed labels for consistency

            # Add to bacth dict
            batch_dict["detailed_labels"] = detailed_labels_padded_uniform

        return batch_dict

In [9]:
import importlib
import evaluator
importlib.reload(evaluator)


# Definition of custom trainer and training arguments here (same as in script, but needed for working with models)
class CustomTokenTrainer(Trainer):
    def compute_loss(self, model, inputs, return_outputs=False):
        bio_labels = inputs.pop("bio_labels", None)
        detailed_labels = inputs.pop("detailed_labels", None)
        outputs = model(**inputs, bio_labels=bio_labels, detailed_labels=detailed_labels)
        loss = outputs[0]
        return (loss, outputs) if return_outputs else loss
    
training_args = TrainingArguments(
    output_dir=f'{DATA_PATH}/Model/results/{unique_tag}',
    num_train_epochs=EPOCHS,
    per_device_train_batch_size=BATCH_SIZE,
    per_device_eval_batch_size=BATCH_SIZE,
    warmup_steps=500,
    weight_decay=0.01,
    logging_dir=f'{DATA_PATH}/Model/logs',
    logging_steps=50,
    evaluation_strategy="epoch",
    save_strategy="epoch",
    save_steps=500, 
    save_total_limit=5,
    load_best_model_at_end=True, 
    metric_for_best_model='loss',
    greater_is_better=False, 
    report_to="wandb",
    fp16=True,
)

## Load pretrained model

In [None]:
from token_level_model import TokenClassificationWithDetailedLabels

# If pre-trained model is defined, load it from disk
if PRETRAINED_MODEL:
    model_path = f"{DATA_PATH}/Model/results/saved_model/{PRETRAINED_MODEL}/torch_model.pth"
    config_path = os.path.join(os.path.dirname(model_path), "config.json")

    # Load configuration
    with open(config_path, 'r') as config_file:
        loaded_config = json.load(config_file)

    # Recreate model architecture based on loaded configuration
    if DEFAULT_MODEL:
        token_model = AutoModelForTokenClassification.from_pretrained(
            MODEL,
            num_labels=NUMBER_BIO_LABELS
        )
    else:
        token_model = TokenClassificationWithDetailedLabels(
            model_name_or_path=MODEL,
            num_labels=loaded_config["num_labels"],
            num_detailed_labels_per_level=loaded_config["num_detailed_labels_per_level"],
            detailed_label_weights=loaded_config["detailed_label_weights"],
            alpha=loaded_config["alpha"]
        )

    # Load weights
    token_model.load_state_dict(torch.load(model_path))
    token_model.eval()
    token_model.to("cuda")
    print(f"{loaded_config['model_architecture']} model successfully loaded.")
    print(f"File path used: {model_path}")

#### Evaluate detailed labels predictions on test set using trainer class

In [10]:
trainer = CustomTokenTrainer(
            model=token_model, 
            args=training_args,
            train_dataset=dataset_dict['train'],
            eval_dataset=dataset_dict['validation'],
            data_collator=token_classification_collate_fn,
            compute_metrics=evaluator.compute_metrics_wrapper(index_to_label, HIER_LABELS_LEVELS, mapping_dicts, calc_windowdiff=True),
            callbacks=[EarlyStoppingCallback(early_stopping_patience=EARLY_STOPPING_PATIENCE)]
        )


test_results = trainer.evaluate(dataset_dict["test"])

print("Test Set Evaluation Results:")
for key, value in test_results.items():
    print(f"{key}: {value}")


[34m[1mwandb[0m: Currently logged in as: [33mtlh45[0m. Use [1m`wandb login --relogin`[0m to force relogin


Test Set Evaluation Results:
eval_loss: 0.8003551363945007
eval_normal_labels_metrics: {'cross_entropy_loss': 0.41854533553123474, 'precision': 0.7450200849816477, 'recall': 0.9255944872151243, 'f1': 0.8255482682840314, 'accuracy': 0.8858901939988688, 'cohen_kappa': 0.7401486005630533, 'model_pk': 0.006727384516473018, 'model_pk_value_segeval': 0.10541919016688624, 'model_windowdiff': 0.6402827258188656, 'model_windowdiff_value_segeval': 0.15752762437901768, 'random_baseline_accuracy': 0.5002875396186099, 'random_baseline_precision': 0.2920038226533108, 'random_baseline_recall': 0.5005517425931951, 'random_baseline_f1': 0.3688398105627545, 'random_baseline_pk': 0.375576440068596, 'random_pk_value_segeval': 0.4999439034742256, 'random_baseline_windowdiff': 63.23811077902266, 'random_windowdiff_value_segeval': 0.4999439034742256}
eval_hierarchical_labels_metrics: {'level_1': {'cross_entropy_loss': 1.4253383874893188, 'precision': 0.5951899009837123, 'recall': 0.5891009238091965, 'f1': 0.

#### Analyse consecuitve tokens in a confusion matrix (Section 6.2)

In [8]:
from torch.utils.data import DataLoader, TensorDataset
import torch
import numpy as np

# Create TensorDataset for efficient loading
def create_dataset(dataset):
    # List to hold data
    input_ids_list = []
    attention_masks_list = []
    labels_list = []
    
    for item in tqdm(dataset, desc='Creating dataset'):
        input_ids_list.append(item['input_ids'])
        attention_masks_list.append(item['attention_mask'])
        labels_list.append(item['bio_labels'])
    
    # Convert lists to tensors
    input_ids = torch.tensor(input_ids_list)
    attention_masks = torch.tensor(attention_masks_list)
    labels = torch.tensor(labels_list)
    
    return TensorDataset(input_ids, attention_masks, labels)

# DataLoader for batch processing
def create_data_loader(dataset, batch_size=512):
    return DataLoader(create_dataset(dataset), batch_size=batch_size)

def get_consecutive_label_pairs(model, data_loader, label_map, device='cuda'):
    all_true_label_pairs = []
    all_pred_label_pairs = []

    model.to(device)
    model.eval()

    for batch in tqdm(data_loader, total=len(data_loader), desc="Making Predictions", unit="batch"):
        
        input_ids, attention_mask, true_labels = tuple(t.to(device) for t in batch) # move to device
        
        # Make predictions using loaded model
        with torch.no_grad():
            outputs = model(input_ids, attention_mask=attention_mask)
        logits = outputs[0]
        predictions = torch.argmax(logits, dim=2)

        for i in range(input_ids.size(0)):  # iterate over batch
            true_label_seq = true_labels[i].cpu().numpy()
            pred_label_seq = predictions[i].cpu().numpy()
            
            # Remove padding and special tokens
            mask = true_label_seq != -100
            true_label_seq = true_label_seq[mask]
            pred_label_seq = pred_label_seq[mask]
            
            # Get pairs by zipping the sequence with itself offset by one
            true_label_pairs = list(zip(true_label_seq[:-1], true_label_seq[1:]))
            pred_label_pairs = list(zip(pred_label_seq[:-1], pred_label_seq[1:]))
            
            # Convert to labels and store
            true_label_pairs = [(label_map[i], label_map[j]) for i, j in true_label_pairs]
            pred_label_pairs = [(label_map[i], label_map[j]) for i, j in pred_label_pairs]
            
            # Write to lsts
            all_true_label_pairs.extend(true_label_pairs)
            all_pred_label_pairs.extend(pred_label_pairs)
    
    return all_true_label_pairs, all_pred_label_pairs

# Get true and predicted label pairs
test_data_loader = create_data_loader(dataset_dict['test'])
true_label_pairs, pred_label_pairs = get_consecutive_label_pairs(token_model, test_data_loader, index_to_label, device='cuda')

# Write both lists to files
save_path = "Evaluation/token-level-model-CM-final-"
if save_path:
        with open(save_path + 'true_label_pairs.pkl', 'wb') as f:
            pickle.dump(true_label_pairs, f)
        with open(save_path + 'pred_label_pairs.pkl', 'wb') as f:
            pickle.dump(pred_label_pairs, f)

Creating dataset:   0%|          | 0/15618 [00:00<?, ?it/s]

Making Predictions:   0%|          | 0/31 [00:00<?, ?batch/s]

In [7]:
# Load lists with label pairs from files (see above)
load_path = "Evaluation/token-level-model-CM-final-"
if load_path:
    with open(load_path + 'true_label_pairs.pkl', 'rb') as f:
        true_label_pairs = pickle.load(f)
    with open(load_path + 'pred_label_pairs.pkl', 'rb') as f:
        pred_label_pairs = pickle.load(f)

In [34]:
from sklearn.metrics import confusion_matrix

# Flatten list of pairs
flat_true_label_pairs = [f"{i}-{j}" for i, j in true_label_pairs]
flat_pred_label_pairs = [f"{i}-{j}" for i, j in pred_label_pairs]

# Define labels and their order for cm
ordered_labels = [
    'B-B', 'B-I', 'B-O',
    'I-B', 'I-I', 'I-E', 'I-O',
    'E-B', 'E-I', 'E-E', 'E-O',
    'O-B', 'O-I', 'O-E', 'O-O'
]

# Create confusion matrix with reordered labels
cm = confusion_matrix(flat_true_label_pairs, flat_pred_label_pairs, labels=ordered_labels, normalize=None) #'true') (unnormalised for report)
print("Confusion Matrix:\n", cm)
print("Labels:\n", ordered_labels)

Confusion Matrix:
 [[      0       0       0       0       2       0       0       0       0
        0       0       0       0       0       0]
 [      6    2925     101       4    2589       2      31       0       8
        0      71      23     172       0    1327]
 [      0       0       0       0       0       0       0       0       0
        0       0       0       0       0       0]
 [      0       2       0      26     535      61       4       0       0
        0       2      59      11       2      60]
 [      1    1368      36     289 1805881     844    4918       0     189
        0     637    1057    5205       9  181068]
 [      0       2       0       1    2022    3677     142       0       1
        8      18       0      62      62     812]
 [      0       0       0       0       0       0       0       0       0
        0       0       0       0       0       0]
 [      0       1       0       0      28       0       2       0       1
        0      19       1       

In [35]:
def format_number(x):
    '''
    Helper function to format numbers for the confusion matrix to make output more readable
    '''
    if x >= 1_000_000:
        return f"{x/1_000_000:.1f}m"
    elif x >= 1_000:
        return f"{x/1_000:.1f}k"
    else:
        return f"{x:,}"

# Normalise cm by row
row_sums = cm.sum(axis=1, keepdims=True)
cm_normalised = np.divide(cm, row_sums, where=row_sums!=0)
cm_normalised = np.nan_to_num(cm_normalised)  # Replace NaN with 0

# Create DataFrame
df_cm = pd.DataFrame(cm, index=ordered_labels, columns=ordered_labels)
df_cm_normalized = pd.DataFrame(cm_normalised, index=ordered_labels, columns=ordered_labels)

# Format numbers in df
df_cm_formatted = df_cm.applymap(format_number)

# Heatmap coloring with shades of green
def color_cell(value, max_value):
    normalized_value = value / max_value
    red_intensity = 255 - int(normalized_value * 128)
    return f'\\cellcolor[RGB]{{{red_intensity},255,{red_intensity}}}'

max_value = df_cm_normalized.values.max()
for i in range(len(ordered_labels)):
    for j in range(len(ordered_labels)):
        color = color_cell(df_cm_normalized.iloc[i, j], max_value)
        df_cm_formatted.iloc[i, j] = color + df_cm_formatted.iloc[i, j]

# Make diagonal elements bold
for label in ordered_labels:
    df_cm_formatted.at[label, label] = '\\textbf{' + df_cm_formatted.at[label, label] + '}'

# Convert to LaTeX code
latex_code = df_cm_formatted.to_latex(escape=False)

latex_preamble = fr'''
\begin{{table}}[htbp]
\centering
\footnotesize
\caption{{Confusion matrix of consecutively predicted labels. Values are highlighted in a row-wise normalised manner to account for the significant class imbalance.}}
\label{{tab:confusion_matrix}}
\begin{{tabular}}{{@{{}}l{"c"*len(ordered_labels)}@{{}}}}
% \toprule
& \multicolumn{{{len(ordered_labels)}}}{{c}}{{\textbf{{Predictions}}}} \\
\cmidrule(lr){{2-{len(ordered_labels)+1}}}
\textbf{{Gold}}
'''.rstrip()

latex_postamble = r'''
\end{table}
'''

# Merge latex code together
latex_code = latex_preamble + latex_code.split('\n', 2)[-1] + latex_postamble
latex_code = latex_code.replace('\label{tab:confusion_matrix}', '\\label{tab:confusion_matrix} \n\\setlength{\\tabcolsep}{4.5pt} % reduce column spacing')

print(latex_code)


\begin{table}[htbp]
\centering
\footnotesize
\caption{Confusion matrix of consecutively predicted labels. Values are highlighted in a row-wise normalised manner to account for the significant class imbalance.}
\label{tab:confusion_matrix} 
\setlength{\tabcolsep}{4.5pt} % reduce column spacing
\begin{tabular}{@{}lccccccccccccccc@{}}
% \toprule
& \multicolumn{15}{c}{\textbf{Predictions}} \\
\cmidrule(lr){2-16}
\textbf{Gold} & B-B & B-I & B-O & I-B & I-I & I-E & I-O & E-B & E-I & E-E & E-O & O-B & O-I & O-E & O-O \\
\midrule
B-B & \textbf{\cellcolor[RGB]{255,255,255}0} & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{127,255,127}2 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcolor[RGB]{255,255,255}0 & \cellcol

  df_cm_formatted = df_cm.applymap(format_number)


### Function to predict tags for a page from the data_dict
(data_dict not available here due to confidentiality)

In [7]:
def predict_page(block_text):
    inputs = tokenizer(block_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs.pop('token_type_ids', None) 

    # Move tensors to same device as the model
    inputs = {k: v.to(token_model.device) for k, v in inputs.items()}

    # Get predictions                        
    with torch.no_grad():
        outputs = token_model(**inputs)

    # Process outputs for token classification
    logits = outputs[0]
    predicted_labels = torch.argmax(logits, dim=2)

    # Convert predicted labels to tags (e.g., 'B', 'I', 'O')
    predicted_tags = [index_to_label[label.item()] for label in predicted_labels[0]]

    return predicted_tags

doc_id = "27605974"
page_index = "18"
page_text = data_dict[doc_id][page_index]['full_text']
print(page_text)

predicted_tags = predict_page(page_text)
print(f"Predicted tags: {predicted_tags}")


19 
CONSULTATION PAPER ON DRAFT RTS UNDER ARTICLE 45(6) OF DIRECTIVE (EU) 2015/849 
 
 
 
 C. Baseline scenario 
 
In the baseline scenario, Directive (EU) 2015/849 would be transposed without accompanying 
draft RTS under Article 45(6). This means that Member States and credit and financial institutions 
may adopt divergent views about the way credit and financial institutions should address the risk 
associated with business in  third  countries where the implementation of local law does not 
permit the application of group-wide policies and procedures. 
 
 D. Options considered 
 
Option 1: The draft RTS could require credit and financial institutions to close down all 
relationships and withdraw entirely from business in the third country. 
 
Option 2: The draft RTS could set out minimum  actions and additional measures credit and 
financial institutions have to apply in all cases, irrespective of the risk and the type of legal 
impediment. 
 
Option 3: The draft RTS could distingu

### Code for better visualisation of model predictions for a page

In [8]:
from torch.nn.functional import softmax
from detailed_labels_handler import decode_detailed_labels

def display_token_predictions(model, tokenizer, text):
    # Tokenize input text and prepare inputs for model (for simplicity, we truncate the text here)
    inputs = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=512)
    inputs.pop('token_type_ids', None) 
    input_ids = inputs["input_ids"][0]

    # Move tensors to the same device as the model
    inputs = {k: v.to(model.device) for k, v in inputs.items()}

    # Get predictions
    with torch.no_grad():
        outputs = model(**inputs)

   # Process outputs for token classification
    logits = outputs[0]
    predictions = torch.argmax(logits, dim=2)[0]
    probabilities = softmax(logits, dim=2)[0]
    
    # Auxiliary objective: detailed labels
    if HIER_LABELS_LEVELS:
        detailed_logits = outputs[1]
        detailed_predictions = []
        if detailed_logits is not None:
            encoded_detailed_labels_per_token = [[] for _ in range(len(input_ids))] # initialise list of lists to store detailed labels per token
            for level_logits in detailed_logits:
                level_probabilities = softmax(level_logits, dim=-1)
                predicted_detailed_labels = torch.argmax(level_probabilities, dim=-1).squeeze()
                if predicted_detailed_labels.dim() == 0:
                    # in case of only one senetnce --> scalar
                    encoded_detailed_labels_per_token[0].append(predicted_detailed_labels.item())
                else:
                    # Otherwise, iterate as normal
                    for token_idx, label in enumerate(predicted_detailed_labels):
                        encoded_detailed_labels_per_token[token_idx].append(label.item())

            for encoded_labels in encoded_detailed_labels_per_token:
                decoded_label = decode_detailed_labels([encoded_labels], mapping_dicts, HIER_LABELS_LEVELS)
                detailed_predictions.append(decoded_label)
    else: # placeholder
        detailed_predictions = ["N/A"]*len(input_ids)

    # Convert input ids to tokens and predicted labels to tags
    tokens = tokenizer.convert_ids_to_tokens(input_ids)
    predicted_tags = [index_to_label[label.item()] for label in predictions]

    # Display tokens with their corresponding tags and probabilities
    for token, tag, prob, detailed_tag in zip(tokens, predicted_tags, probabilities, detailed_predictions):
        prob_value = prob[torch.argmax(prob)].item()  # Get highest probability value
        if token not in [cls_token, sep_token, pad_token]:  # Skip special tokens as defined in code cells above
            print(f"{token:15} : {tag} ({prob_value:.4f}) --> {detailed_tag}")


# Get prediction for page from data_dict below
doc_id = "27416407"
page_index = "3" 
page_text = data_dict[doc_id][page_index]['full_text']
display_token_predictions(token_model, tokenizer, page_text)

OP              : O (0.9969) --> ['definitions']
IN              : O (0.9972) --> ['definitions']
ION             : O (0.9975) --> ['definitions']
ĠON             : O (0.9974) --> ['definitions']
ĠAM             : O (0.9977) --> ['definitions']
L               : O (0.9972) --> ['definitions']
/               : O (0.9977) --> ['definitions']
C               : O (0.9970) --> ['definitions']
FT              : O (0.9967) --> ['definitions']
ĠAND            : O (0.9979) --> ['definitions']
ĠC              : O (0.9979) --> ['definitions']
UST             : O (0.9976) --> ['definitions']
OM              : O (0.9978) --> ['definitions']
ERS             : O (0.9979) --> ['definitions']
ĠWHO            : O (0.9981) --> ['definitions']
ĠARE            : O (0.9981) --> ['definitions']
ĠAS             : O (0.9981) --> ['definitions']
YL              : O (0.9978) --> ['definitions']
UM              : O (0.9978) --> ['definitions']
ĠSEE            : O (0.9975) --> ['definitions']
K               : O 

### Function to create snippets / post-processing algorithm (Section 5.4)

![](../Misc/post-processing.png)

In [9]:
def majority_vote(tags):
    '''
    Majority vote function to determine the detailed label for a snippet
    '''
    if not tags:
        return "N/A"  # Default label if no tags are present
    
    # Flatten into a single list
    flat_tags = [item for sublist in tags for item in sublist if item is not None]
    if not flat_tags:
        return "N/A"  # Return default if flattened list is empty

    # Return most frequent tag
    return max(set(flat_tags), key=flat_tags.count)

def create_snippets(tokens, tags, tokenizer, merged_detailed_tags=None, min_snippet_length=10, ignore_o_threshold = 4):
    '''
    Post-processing code to create snippets from NER (i.e. BIO(E)) tags
    '''
    snippets = []
    current_snippet = []
    current_detailed_tags = []
    
    # Sequentially process each token and its corresponding tag
    idx = 0
    while idx < len(tokens):
        
        # Get token and tag
        token = tokens[idx]
        tag = tags[idx]
        detailed_tag = merged_detailed_tags[idx] if merged_detailed_tags else None

        # Decide what to do based on the tag
        # "B": Begin a new snippet
        # "I": Continue the current snippet
        # "O": End the current snippet
        # "E": End the current snippet
        if tag == 'I' and token not in [cls_token, sep_token]:
            current_snippet.append(token)
            if detailed_tag:
                current_detailed_tags.append(detailed_tag)
        elif tag == 'O':
            o_count = 1
            while idx + o_count < len(tags) and tags[idx + o_count] == 'O':
                o_count += 1
            
            if current_snippet and o_count <= ignore_o_threshold:
                # Append all 'O' sequence at once if below threshold
                current_snippet.extend(tokens[idx:idx + o_count])
                if detailed_tag:
                    current_detailed_tags.extend(merged_detailed_tags[idx:idx + o_count])
                idx += o_count - 1  # Skip 'O' sequence to the end
            else:
                # create snippet
                if len(current_snippet) >= min_snippet_length:
                    snippet_text = tokenizer.convert_tokens_to_string(current_snippet)
                    snippets.append((snippet_text.strip(), majority_vote(current_detailed_tags)))
                current_snippet = []
                current_detailed_tags = []
        elif tag == 'B':
            # create snippet
            if current_snippet and len(current_snippet) >= min_snippet_length:
                snippet_text = tokenizer.convert_tokens_to_string(current_snippet)
                snippets.append((snippet_text.strip(), majority_vote(current_detailed_tags)))
            current_snippet = [token]
            current_detailed_tags = [detailed_tag] if detailed_tag else []
        elif tag == 'E':
            if current_snippet:
                current_snippet.append(token)
                if detailed_tag:
                    current_detailed_tags.append(detailed_tag)
                if len(current_snippet) >= min_snippet_length:
                    snippet_text = tokenizer.convert_tokens_to_string(current_snippet)
                    snippets.append((snippet_text.strip(), majority_vote(current_detailed_tags)))
                current_snippet = []
                current_detailed_tags = []
        idx += 1

    # Check if there is a current snippet at the end
    if current_snippet and len(current_snippet) >= min_snippet_length:
        snippet_text = tokenizer.convert_tokens_to_string(current_snippet)
        snippets.append((snippet_text.strip(), majority_vote(current_detailed_tags)))
    return snippets

def display_token_predictions_and_create_snippets(model, tokenizer, text, window_size=512, stride=256, min_snippet_length=10, ignore_o_threshold=4, verbose=True):
    '''
    Get token predictions and create snippets from the predictions using the custom post-processing algorithm.
    '''
    adjusted_window_size = window_size - 2  # Adjust for special tokens

    # Tokenize text with stride handling to get multiple windows (including overlapping windows)
    tokenized_text = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=adjusted_window_size, return_overflowing_tokens=True, stride=stride, add_special_tokens=True)

    # Setup inputs dict for model
    inputs = {"input_ids": [], "attention_mask": []}
    for window in tokenized_text.encodings:
        inputs["input_ids"].append(torch.tensor([window.ids]).to(model.device))
        inputs["attention_mask"].append(torch.tensor([window.attention_mask]).to(model.device))

    # Convert list of tensors to a single tensor for batch processing
    batch_input_ids = torch.cat(inputs["input_ids"], dim=0)
    batch_attention_mask = torch.cat(inputs["attention_mask"], dim=0)
    batch_inputs = {"input_ids": batch_input_ids, "attention_mask": batch_attention_mask}

    # Model prediction in one go (significantly more efficient due to batch handling)
    with torch.no_grad():
        outputs = model(**batch_inputs)
    logits = outputs[0]
    predicted_labels = torch.argmax(logits, dim=2)
    
    # Auxiliary Objective: detailed labels
    detailed_predictions = []
    if len(outputs) > 1 and outputs[1] is not None:
        detailed_logits = outputs[1]
        encoded_detailed_labels_per_token = [[] for _ in range(len(tokenized_text.input_ids[0]))]  # Adjust to token count

        for level_logits in detailed_logits:
            level_probabilities = torch.nn.functional.softmax(level_logits, dim=-1)
            predicted_detailed_labels = torch.argmax(level_probabilities, dim=-1)
            
            # Predicted_detailed_labels shape is [batch_size, num_tokens]
            # Batch Handling:
            for token_idx in range(predicted_detailed_labels.shape[1]):  # Iterate over tokens
                labels_for_token = predicted_detailed_labels[:, token_idx]  # All batch predictions for this token
                encoded_detailed_labels_per_token[token_idx].extend(labels_for_token.tolist())
        # Decode using heloer function
        for encoded_labels in encoded_detailed_labels_per_token:
            decoded_label = detailed_labels_handler.decode_detailed_labels([encoded_labels], mapping_dicts, HIER_LABELS_LEVELS)
            detailed_predictions.append(decoded_label)
    else:
        detailed_predictions = ["N/A"] * len(tokenized_text.input_ids[0])
        
    # Merge tokens and tags from all windows back together --> this is the same as the original text for the page  
    merged_tokens = []
    merged_tags = []
    merged_detailed_tags = []
    for window_idx, window in enumerate(tokenized_text.encodings):
        start_index = 1 if window_idx == 0 else stride + 1
        end_index = -1 if window_idx < len(tokenized_text.encodings) - 1 else None
        
        tokens = tokenizer.convert_ids_to_tokens(batch_input_ids[window_idx])
        window_tags = [index_to_label[label.item()] for label in predicted_labels[window_idx][start_index:end_index]]
        window_detailed_tags = detailed_predictions[start_index:end_index]
        window_tokens = [tok for tok in tokens[start_index:end_index] if tok not in [cls_token, sep_token, pad_token]]
        
        merged_tokens.extend(window_tokens)
        merged_tags.extend(window_tags)
        merged_detailed_tags.extend(window_detailed_tags)
    
    if verbose:
        for token, tag, detailed_tag in zip(merged_tokens, merged_tags, merged_detailed_tags):
            print(f"{token:15} : {tag} --> {detailed_tag}")

    # Create snippets for page
    snippets = create_snippets(merged_tokens, merged_tags, tokenizer, merged_detailed_tags=merged_detailed_tags,  min_snippet_length=min_snippet_length, ignore_o_threshold=ignore_o_threshold)
    
    if verbose:
        for idx, (snippet, detailed_tag) in enumerate(snippets, 1):
            print(f'---------------------- Snippet {idx} | Label: {detailed_tag} ----------------------')
            print(snippet + "\n")
    
    return snippets

# Use above function with exemplary page from data_dict
# Note: This is how the snippet identifier system would be used in practice! (We would only want to pass the text of a page and get the snippets)
doc_id = "20829680" 
page_index = "34" 
page_text = data_dict[doc_id][page_index]['full_text']
snippets = display_token_predictions_and_create_snippets(token_model, tokenizer, page_text, ignore_o_threshold=7, verbose=True)

### Example of how custom PDFs could be fed into the model to identify FRI

In [28]:
import fitz
from collections import defaultdict

def default_page():
    return {'blocks': [], 'full_text': ""}

def create_adjusted_dict(pdf_path: str, doc_id: str):
    '''
    Read custom pdf using PyMuPDF (alias fitz) and create a dictionary with the same structure as the data_dict
    '''
    pdf_data_dict = defaultdict(lambda: defaultdict(default_page))

    # Open PDF file
    doc = fitz.open(pdf_path)
    # Process each page
    for page_index, page in enumerate(doc):
        # Extract text blocks from the page
        blocks = page.get_text("blocks")
        # Process each block
        block_texts = []
        for block in blocks:
            block_text = block[4].strip()  # Extract the text from the block
            block_texts.append(block_text)
            pdf_data_dict[doc_id][str(page_index)]['blocks'].append({'text': block_text})
        full_page_text = ' '.join([block_text for block_text in block_texts if block_text is not None])
        pdf_data_dict[doc_id][str(page_index)]['full_text'] = full_page_text
    return pdf_data_dict


####################################################################

page_index = 21
file_index = 2 # easy indexing in example pdfs folder

pdf_files = [file for file in os.listdir("../Example_PDFs") if file.endswith('.pdf')]
print(pdf_files)
pdf_path = os.path.join("../Example_PDFs", pdf_files[file_index])

# create dict for local file
adjusted_dict = create_adjusted_dict(pdf_path, 'test_pdf')
page_text = adjusted_dict["test_pdf"][str(page_index)]['full_text']
snippets = display_token_predictions_and_create_snippets(token_model, tokenizer, page_text)

['EXAMPLE1_finregulation.pdf', 'tlh45_Proposal.pdf', 'EXAMPLE2_CELEX_32018L0843_EN_TXT.pdf', 'EXAMPLE3_uksi_20170692_en.pdf', 'EXAMPLE0_wipo_financial_regulations.pdf']
Ġ(              : I --> ['str']
18              : I --> ['str']
)               : I --> ['str']
Ġin             : I --> ['str']
ĠArticle        : I --> ['str']
Ġ32             : I --> ['str']
Ġthe            : I --> ['str']
Ġfollowing      : I --> ['str']
Ġparagraph      : I --> ['str']
Ġis             : I --> ['str']
Ġadded          : I --> ['str']
:               : I --> ['str']
ĠâĢ             : I --> ['str']
ĺ               : I --> ['str']
9               : I --> ['str']
.               : I --> ['str']
Ġ               : I --> ['str']
Ċ               : I --> ['str']
Without         : I --> ['str']
Ġprejudice      : I --> ['str']
Ġto             : I --> ['str']
ĠArticle        : I --> ['str']
Ġ34             : I --> ['str']
(               : I --> ['str']
2               : I --> ['str']
),              : I --> ['str']

In [None]:
def get_sample_doc_page_ids_from_test_set(test_dataset, num_samples=20):
    '''
    Helper function to exract IDs from dataset
    '''
    sample_doc_page_ids = []
    sample_doc_ids = set()

    for data_point in test_dataset:
        metadata = data_point['metadata']
        doc_id = metadata['doc_id']
        page_id = metadata['page_id']
        
        if doc_id not in sample_doc_ids:
            sample_doc_ids.add(doc_id)
            sample_doc_page_ids.append((doc_id, page_id))
        
        if len(sample_doc_ids) >= num_samples:
            break

    return sample_doc_page_ids, list(sample_doc_ids)

sample_doc_page_ids, sample_doc_ids = get_sample_doc_page_ids_from_test_set(test_dataset, 20)
print(sample_doc_ids)

['22396093', '22371945', '19180743', '22683853', '20518933', '19534497', '20575785', '19507185', '26369135', '19586509', '19073734', '17459693', '20545737', '20529447', '22243253', '19592316', '20202989', '22057835', '21450084', '20753115']


## Conduct Predictions on Test Pages:

Collect Documents and pages from test dataset for easy access

In [8]:
from collections import defaultdict

test_doc_pages_overview = defaultdict(list)
for sample in tqdm(dataset_dict['test']):
    doc_id = sample['metadata']['doc_id']
    page_index = sample['metadata']['page_id']
    if page_index not in test_doc_pages_overview[doc_id]:
        test_doc_pages_overview[doc_id].append(page_index)

  0%|          | 0/15618 [00:00<?, ?it/s]

#### Method to predict and process snippets in batches:

In [10]:
import torch.nn.functional as F
import torch
from functools import partial
from tqdm.contrib.concurrent import process_map

import importlib
importlib.reload(detailed_labels_handler)

def process_chunk(data, batch_input_ids, predicted_labels, detailed_predictions, tokenizer, stride, min_snippet_length, ignore_o_threshold):
    '''
    Worker function to process a chunk of data and create snippets
    '''
    index, encodings = data
    merged_tokens = []
    merged_tags = []
    merged_detailed_tags = []
    current_idx = index

    # Merge tokens and tags from all windows back together
    for window_idx, window in enumerate(encodings):
        start_index = 1 if window_idx == 0 else stride + 1
        end_index = -1 if window_idx < len(encodings) - 1 else None

        tokens = tokenizer.convert_ids_to_tokens(batch_input_ids[current_idx])
        window_tags = [index_to_label[label.item()] for label in predicted_labels[current_idx][start_index:end_index]]
        window_tokens = [tok for tok in tokens[start_index:end_index] if tok not in [cls_token, sep_token, pad_token]]
        window_detailed_tags = detailed_predictions[start_index:end_index] 

        merged_tokens.extend(window_tokens)
        merged_tags.extend(window_tags)
        merged_detailed_tags.extend(window_detailed_tags)
        current_idx += 1

    # Create snippets for page (i.e. merged windows)
    snippets = create_snippets(merged_tokens, merged_tags, tokenizer, merged_detailed_tags=merged_detailed_tags, min_snippet_length=min_snippet_length, ignore_o_threshold=ignore_o_threshold)
    return snippets

def create_snippets_multiprocessing(all_encodings, batch_input_ids, predicted_labels
                                    , detailed_predictions, tokenizer, stride
                                    , min_snippet_length, ignore_o_threshold, workers, chunksize):
    '''
    Create snippets from the predictions using the custom post-processing algorithm with multiprocessing
    '''
    
    # Partial function to pass arguments to process_chunk (easier handling for multiprocessing)
    partial_process_chunk = partial(
        process_chunk, 
        batch_input_ids=batch_input_ids, 
        predicted_labels=predicted_labels,
        detailed_predictions=detailed_predictions,
        tokenizer=tokenizer, 
        stride=stride, 
        min_snippet_length=min_snippet_length, 
        ignore_o_threshold=ignore_o_threshold
    )

    # Creating index and encodings pair for proper order (very important for correct snippet creation)
    data = []
    encding_idx = 0
    current_idx = 0
    while encding_idx < len(all_encodings):
        encodings = all_encodings[encding_idx]
        num_windows = len(encodings)
        data.append((current_idx, encodings))
        current_idx += num_windows
        encding_idx += 1
    
    results = process_map(partial_process_chunk, data, max_workers=workers, chunksize=chunksize, desc="Creating snippets")

    return results



def batch_predict(model, tokenizer, texts, metadata, window_size=512, stride=256, sub_batch_size=256):
    '''
    Process multiple texts in batches and create predictions
    Window size and stride/overap are used to handle long texts.
    '''
    
    adjusted_window_size = window_size - 2 # account for special tokens
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    model = model.to(device)

    # Tokenize all texts with stride handling to get multiple windows for each text, considering overlap
    # Prepare inputs for model
    all_inputs = {"input_ids": [], "attention_mask": [], "metadata": []}
    all_encodings = []
    max_length = 0 
    for text, meta in tqdm(zip(texts, metadata), desc="Tokenizing", total=len(texts)):
        tokenized_text = tokenizer(text, return_tensors="pt", padding=True, truncation=True, max_length=adjusted_window_size, return_overflowing_tokens=True, stride=stride, add_special_tokens=True)
        all_encodings.append(tokenized_text.encodings)
        for window in tokenized_text.encodings:
            tensor_ids = torch.tensor([window.ids]).to(model.device)
            all_inputs["input_ids"].append(tensor_ids)
            all_inputs["attention_mask"].append(torch.tensor([window.attention_mask]).to(model.device))
            all_inputs["metadata"].append(meta)
            max_length = max(max_length, tensor_ids.size(1))
    
    # Apply padding to ensure uniform length within batch
    padded_input_ids = [F.pad(tensor, (0, max_length - tensor.shape[1]), 'constant', 0) for tensor in all_inputs["input_ids"]] # pad difference between max_length and current length with 0
    padded_attention_mask = [F.pad(tensor, (0, max_length - tensor.shape[1]), 'constant', 0) for tensor in all_inputs["attention_mask"]] # pad difference between max_length and current length with 0
    
    # Convert list of tensors to a single tensor for batch processing
    batch_input_ids = torch.cat(padded_input_ids, dim=0)
    batch_attention_mask = torch.cat(padded_attention_mask, dim=0)
    batch_inputs = {"input_ids": batch_input_ids, "attention_mask": batch_attention_mask}

    print("Number of samples:", len(padded_attention_mask))
    
    # Model predictions in one go
    model.eval()
    
    # Calculate bacthes as number of samples divided by sub_batch_size + 1 if there is a remainder
    total_batches = len(batch_input_ids) // sub_batch_size + (len(batch_input_ids) % sub_batch_size != 0)
    
    predicted_labels = [] 
    detailed_predictions = [] 
    batch_metadata = []
    with tqdm(total=total_batches, desc="Predicting") as pbar:
        for i in range(0, len(batch_input_ids), sub_batch_size):
            
            # Prepare sub-batch
            sub_input_ids = batch_input_ids[i:i+sub_batch_size]
            sub_attention_mask = batch_attention_mask[i:i+sub_batch_size]
            sub_metadata = all_inputs["metadata"][i:i+sub_batch_size]
            batch_inputs = {"input_ids": sub_input_ids, "attention_mask": sub_attention_mask}

            # Make predictions with token-level model
            with torch.no_grad():
                outputs = model(**batch_inputs)
            logits = outputs[0]
            labels = torch.argmax(logits, dim=2)
            predicted_labels.extend(labels)
            batch_metadata.extend(sub_metadata) 
            
            # Get detailed labels (auxiliary objective)
            if len(outputs) > 1 and outputs[1] is not None:
                    detailed_logits = outputs[1]
                    level_predictions = [torch.argmax(level_logits, dim=2) for level_logits in detailed_logits]
                    merged_predictions = torch.stack(level_predictions, dim=2)
                    detailed_predictions.extend(merged_predictions)

            pbar.update()
    
    # Decode detailed labels using helper function
    decoded_detailed_labels = [
        detailed_labels_handler.decode_detailed_labels(sample_predictions.tolist(), mapping_dicts, HIER_LABELS_LEVELS) 
        for sample_predictions in detailed_predictions
    ]
    
    # Result with metadata (used for debugging)
    results_with_metadata = list(zip(batch_metadata, predicted_labels, decoded_detailed_labels))

    return all_encodings, batch_input_ids, predicted_labels, decoded_detailed_labels, tokenizer

In [None]:
import time
from tqdm.notebook import tqdm

doc_count = 0
page_count = 0

# Step 1: Collect all texts and metadata for batch processing
all_texts = []
metadata = []
for doc_id in tqdm(test_doc_pages_overview):
    for page_index in test_doc_pages_overview[doc_id]:
        page_text = data_dict[doc_id][page_index]['full_text']
        all_texts.append(page_text)
        metadata.append((doc_id, page_index))
        page_count += 1
    doc_count += 1
    #if doc_count >= 10:
        #break

# Step 2: Process all texts in batches
stride = OVERLAP
sub_batch_size = 512 if MODEL == "roberta-large" else 2048 # batch size needs to be smaller for larger model to not run out of memory

start_time = time.time()  # Start timing
all_encodings, batch_input_ids, predicted_labels, detailed_predictions, tokenizer = batch_predict(token_model
                                                                                            , tokenizer
                                                                                            , all_texts
                                                                                            , metadata
                                                                                            , window_size=512
                                                                                            , stride = stride
                                                                                            , sub_batch_size=sub_batch_size)
end_time = time.time()  # End timing

# Move to cpu for subsequent processing and snippet creation
batch_input_ids = batch_input_ids.to("cpu")
predicted_labels = [label.to("cpu") for label in predicted_labels]

# Create snippets using multiprocessing
snippets_collection = create_snippets_multiprocessing(all_encodings
                                            , batch_input_ids
                                            , predicted_labels
                                            , detailed_predictions
                                            , tokenizer
                                            , stride = stride
                                            , min_snippet_length = 10
                                            , ignore_o_threshold = 4
                                            # , multiprocessing = True
                                            , workers = 8
                                            , chunksize = 512)

# Step 3: Assign results back to the original data_dict
for ((doc_id, page_index), snippets) in tqdm(zip(metadata, snippets_collection), desc="Assign results", total=len(metadata)):
    snippet_texts = []
    detailed_tags = []
    for snippet_text, detailed_tag in snippets:
        if snippet_text.strip():
            snippet_texts.append(snippet_text)
            detailed_tags.append(detailed_tag)
    predicted_snippets = []

    if snippet_texts:
        tokenized_snippets = tokenizer(snippet_texts, padding=True, truncation=True, return_tensors="pt")
        
        for snippet_text, detailed_tag, snippet_tokenized_ids in zip(snippet_texts, detailed_tags, tokenized_snippets['input_ids']):
            snippet_tokenized_text = tokenizer.convert_ids_to_tokens(snippet_tokenized_ids)
            dict_to_add = {'text': snippet_text, 'tokenized_text': snippet_tokenized_text, 'detailed_tag': detailed_tag}
            predicted_snippets.append(dict_to_add)

    # Write predicted snippets into data_dict, following structure of other elements
    data_dict[doc_id][page_index]['predicted_snippets'] = predicted_snippets

total_time = end_time - start_time  # Calculate total time

print(f"Total time for processing {doc_count} documents (={page_count} pages): {total_time} seconds.")

## Calculate evaluation metrics (Chapter 6 in report)

In [8]:
import importlib
import evaluator
importlib.reload(evaluator)

metrics_config = {
    'iou': True,
    'bleu': True,
    'jaccard': True,
    'precision': True,
    'recall': True,
    'f1': True,
    'precision_region_lvl': True,
    'recall_region_lvl': True,
    'f1_region_lvl': True,
    'edit_distance': True,
    'rouge-1-f': True,
    'rouge-2-f': True,
    'rouge-l-f': True,
    'pk': False,
    'windowdiff': False,
    'cohen_kappa': False
}


In [33]:
aggregated_metrics_predicted_snippets = evaluator.evaluate_snippets_parallel(data_dict, "predicted_snippets", "refined_regions", metrics_config)
aggregated_metrics_predicted_snippets['inference_time'] = total_time
aggregated_metrics_predicted_snippets['pages'] = page_count
aggregated_metrics_predicted_snippets['batch_size'] = sub_batch_size
aggregated_metrics_predicted_snippets['GPU'] = torch.cuda.get_device_properties(0).name

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6530978259125895
Average Recall Score: 0.8428955629920146
Average F1 Score: 0.7359568080237774
Average Iou Score: 0.7685942283292141
Average Bleu Score: 0.5939380039220407
Average Jaccard Score: 0.7685942283292143
Average Edit_distance Score: 148.65251626155427
Average Precision_region_lvl Score: 0.6631284134971608
Average Recall_region_lvl Score: 0.8922598553166202
Average F1_region_lvl Score: 0.6980769874318282
Average Rouge-1-f Score: 0.8588084778041937
Average Rouge-2-f Score: 0.8334367448925841
Average Rouge-l-f Score: 0.8574261693576304


In [34]:
MODEL

'roberta-base'

In [35]:
PRETRAINED_MODEL

'or6jqv'

In [36]:
aggregated_metrics_predicted_snippets

{'precision': 0.6530978259125895,
 'recall': 0.8428955629920146,
 'f1': 0.7359568080237774,
 'iou': 0.7685942283292141,
 'bleu': 0.5939380039220407,
 'jaccard': 0.7685942283292143,
 'edit_distance': 148.65251626155427,
 'precision_region_lvl': 0.6631284134971608,
 'recall_region_lvl': 0.8922598553166202,
 'f1_region_lvl': 0.6980769874318282,
 'rouge-1-f': 0.8588084778041937,
 'rouge-2-f': 0.8334367448925841,
 'rouge-l-f': 0.8574261693576304,
 'inference_time': 120.92954063415527,
 'pages': 6918,
 'batch_size': 2048,
 'GPU': 'NVIDIA A100-SXM4-80GB'}

In [37]:
# Save result to json file
with open(f'Evaluation/{DOMAIN}/token-{str(MODEL).replace("/", "-")}_{PRETRAINED_MODEL if PRETRAINED_MODEL else unique_tag}.json', 'w') as f:
    json.dump(aggregated_metrics_predicted_snippets, f, indent=4)

### Validation loop for best snippet creation parameters (see Section 5.4 in report)

Clean data_dict predictions, if relevant

In [10]:
for doc_id, page_dict in tqdm(data_dict.items()):
    for page_index, page_content in page_dict.items():
        if page_index in ['title', 'doc_long_id']:
            continue
        data_dict[doc_id][page_index]['predicted_snippets'] = None

  0%|          | 0/1149 [00:00<?, ?it/s]

In [11]:
from collections import defaultdict

# collect validation data in dict for easy access (same as done above for test data)
validation_doc_pages_overview = defaultdict(list)
for sample in tqdm(dataset_dict['test']):
    doc_id = sample['metadata']['doc_id']
    page_index = sample['metadata']['page_id']
    if page_index not in validation_doc_pages_overview[doc_id]:
        validation_doc_pages_overview[doc_id].append(page_index)

  0%|          | 0/15618 [00:00<?, ?it/s]

In [13]:
doc_count = 0
page_count = 0

# Step 1: Collect all texts and metadata for batch processing of validation data
all_texts = []
metadata = []
for doc_id in tqdm(validation_doc_pages_overview):
    for page_index in validation_doc_pages_overview[doc_id]:
        page_text = data_dict[doc_id][page_index]['full_text']
        all_texts.append(page_text)
        metadata.append((doc_id, page_index))
        page_count += 1
    doc_count += 1

# Step 2: Process all texts in batches
stride = OVERLAP

all_encodings, batch_input_ids, predicted_labels, tokenizer = batch_predict(token_model
                                                                                            , tokenizer
                                                                                            , all_texts
                                                                                            , window_size=512
                                                                                            , stride = stride
                                                                                            , sub_batch_size=512 if MODEL == "roberta-large" else 2048)
batch_input_ids = batch_input_ids.to("cpu")
predicted_labels = [label.to("cpu") for label in predicted_labels]

  0%|          | 0/173 [00:00<?, ?it/s]

Tokenizing:   0%|          | 0/6918 [00:00<?, ?it/s]

Number of samples: 15763


Predicting:   0%|          | 0/8 [00:00<?, ?it/s]

In [17]:
import importlib
import evaluator
importlib.reload(evaluator)

metrics_config = {
    'iou': True,
    'bleu': True,
    'jaccard': True,
    'precision': True,
    'recall': True,
    'f1': True,
    'precision_region_lvl': True,
    'recall_region_lvl': True,
    'f1_region_lvl': True,
    'edit_distance': False,
    'rouge-1-f': True,
    'rouge-2-f': True,
    'rouge-l-f': True,
    'pk': False,
    'windowdiff': False,
    'cohen_kappa': False
}


# Test values in simple grid search
all_scores = []
best_run = {}
best_f1 = 0
run_id = 0
ignore_o_thresholds = [0, 2, 4, 8, 10]
min_snippet_lengths = [0, 4, 8, 10, 12]
for ignore_o_threshold in ignore_o_thresholds:
    for min_snippet_length in min_snippet_lengths:
        print(f"Run {run_id} out of {len(ignore_o_thresholds)*len(min_snippet_lengths)}. Combination tested: {ignore_o_threshold=} and {min_snippet_length=}. Currently best: {best_run}")

        # Create snippets through adjusted post-proessing algorithm using multiprocessing
        snippets_collection = create_snippets_multiprocessing(all_encodings
                                                    , batch_input_ids
                                                    , predicted_labels
                                                    , tokenizer
                                                    , stride = stride
                                                    , min_snippet_length = min_snippet_length
                                                    , ignore_o_threshold = ignore_o_threshold
                                                    # , multiprocessing = True
                                                    , workers = 8
                                                    , chunksize = 512)

        # Step 3: Assign results back to original data_dict
        for ((doc_id, page_index), snippets) in tqdm(zip(metadata, snippets_collection), desc="Assign results", total=len(metadata)):
            snippet_texts = []
            detailed_tags = []
            for snippet_text, detailed_tag in snippets:
                if snippet_text.strip():
                    snippet_texts.append(snippet_text)
                    detailed_tags.append(detailed_tag)
            predicted_snippets = []

            if snippet_texts:
                tokenized_snippets = tokenizer(snippet_texts, padding=True, truncation=True, return_tensors="pt")
                
                for snippet_text, detailed_tag, snippet_tokenized_ids in zip(snippet_texts, detailed_tags, tokenized_snippets['input_ids']):
                    snippet_tokenized_text = tokenizer.convert_ids_to_tokens(snippet_tokenized_ids)
                    dict_to_add = {'text': snippet_text, 'tokenized_text': snippet_tokenized_text, 'detailed_tag': detailed_tag}
                    predicted_snippets.append(dict_to_add)

            data_dict[doc_id][page_index]['predicted_snippets'] = predicted_snippets
        
        # Get scores for the current run
        aggregated_metrics_predicted_snippets = evaluator.evaluate_snippets_parallel(data_dict, "predicted_snippets", "refined_regions", metrics_config)
        all_scores.append(aggregated_metrics_predicted_snippets)
        f1_score = aggregated_metrics_predicted_snippets['f1']
        if f1_score > best_f1:
            best_run['run_id'] = run_id
            best_run['ignore_o_threshold'] = ignore_o_threshold
            best_run['min_snippet_length'] = min_snippet_length
            best_run['f1'] = f1_score
            best_f1 = f1_score
        run_id += 1

Run 0 out of 25. Combination tested: ignore_o_threshold=0 and min_snippet_length=0. Currently best: {}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6534951734094638
Average Recall Score: 0.7885198727110102
Average F1 Score: 0.7146859283339813
Average Iou Score: 0.7435251950611076
Average Bleu Score: 0.5622263863505115
Average Jaccard Score: 0.7435251950611076
Average Precision_region_lvl Score: 0.667132528976865
Average Recall_region_lvl Score: 0.8439876973903111
Average F1_region_lvl Score: 0.6757403110200386
Average Rouge-1-f Score: 0.841391720239641
Average Rouge-2-f Score: 0.8137685035997634
Average Rouge-l-f Score: 0.8399571872883745
Run 1 out of 25. Combination tested: ignore_o_threshold=0 and min_snippet_length=4. Currently best: {'run_id': 0, 'ignore_o_threshold': 0, 'min_snippet_length': 0, 'f1': 0.7146859283339813}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.653500110222604
Average Recall Score: 0.7891874668522366
Average F1 Score: 0.7149629687945669
Average Iou Score: 0.7443366934743155
Average Bleu Score: 0.56290025531491
Average Jaccard Score: 0.7443366934743156
Average Precision_region_lvl Score: 0.6675567026206406
Average Recall_region_lvl Score: 0.8450144874845837
Average F1_region_lvl Score: 0.6765345689867835
Average Rouge-1-f Score: 0.8420190933957421
Average Rouge-2-f Score: 0.814592498478804
Average Rouge-l-f Score: 0.8405830866092516
Run 2 out of 25. Combination tested: ignore_o_threshold=0 and min_snippet_length=8. Currently best: {'run_id': 1, 'ignore_o_threshold': 0, 'min_snippet_length': 4, 'f1': 0.7149629687945669}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6535276534357056
Average Recall Score: 0.7892388960119511
Average F1 Score: 0.7150005576554852
Average Iou Score: 0.7441568549620449
Average Bleu Score: 0.5629956247939677
Average Jaccard Score: 0.7441568549620449
Average Precision_region_lvl Score: 0.6675884155465397
Average Recall_region_lvl Score: 0.8453393543361369
Average F1_region_lvl Score: 0.6766699911229992
Average Rouge-1-f Score: 0.8417659431885891
Average Rouge-2-f Score: 0.8143258393867231
Average Rouge-l-f Score: 0.8403202528647408
Run 3 out of 25. Combination tested: ignore_o_threshold=0 and min_snippet_length=10. Currently best: {'run_id': 2, 'ignore_o_threshold': 0, 'min_snippet_length': 8, 'f1': 0.7150005576554852}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6535804337583346
Average Recall Score: 0.789405684754522
Average F1 Score: 0.7151005865321617
Average Iou Score: 0.7443551183427839
Average Bleu Score: 0.5632200564546811
Average Jaccard Score: 0.7443551183427839
Average Precision_region_lvl Score: 0.6676519339194741
Average Recall_region_lvl Score: 0.845996735257491
Average F1_region_lvl Score: 0.6769203592450044
Average Rouge-1-f Score: 0.8419477171959541
Average Rouge-2-f Score: 0.8145069442560818
Average Rouge-l-f Score: 0.8404991846567922
Run 4 out of 25. Combination tested: ignore_o_threshold=0 and min_snippet_length=12. Currently best: {'run_id': 3, 'ignore_o_threshold': 0, 'min_snippet_length': 10, 'f1': 0.7151005865321617}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6535877685593585
Average Recall Score: 0.7894565108666783
Average F1 Score: 0.7151258303968845
Average Iou Score: 0.7440688157984248
Average Bleu Score: 0.563315423756124
Average Jaccard Score: 0.7440688157984248
Average Precision_region_lvl Score: 0.6676564003371623
Average Recall_region_lvl Score: 0.8459997000611125
Average F1_region_lvl Score: 0.6770364818095016
Average Rouge-1-f Score: 0.8413773472792347
Average Rouge-2-f Score: 0.8138673879262516
Average Rouge-l-f Score: 0.839919672640271
Run 5 out of 25. Combination tested: ignore_o_threshold=2 and min_snippet_length=0. Currently best: {'run_id': 4, 'ignore_o_threshold': 0, 'min_snippet_length': 12, 'f1': 0.7151258303968845}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6441759143375186
Average Recall Score: 0.8093745391700305
Average F1 Score: 0.7173876662529682
Average Iou Score: 0.7453425172972481
Average Bleu Score: 0.5696131841360336
Average Jaccard Score: 0.7453425172972481
Average Precision_region_lvl Score: 0.6581789528875815
Average Recall_region_lvl Score: 0.866333792369093
Average F1_region_lvl Score: 0.6799449789981434
Average Rouge-1-f Score: 0.8408828615175408
Average Rouge-2-f Score: 0.8138799935610037
Average Rouge-l-f Score: 0.8395866073631703
Run 6 out of 25. Combination tested: ignore_o_threshold=2 and min_snippet_length=4. Currently best: {'run_id': 5, 'ignore_o_threshold': 2, 'min_snippet_length': 0, 'f1': 0.7173876662529682}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6441783504046438
Average Recall Score: 0.8096437515380496
Average F1 Score: 0.717494906129535
Average Iou Score: 0.7457159004683019
Average Bleu Score: 0.5699057382560535
Average Jaccard Score: 0.7457159004683019
Average Precision_region_lvl Score: 0.6583811438520701
Average Recall_region_lvl Score: 0.866824239941157
Average F1_region_lvl Score: 0.6802901856733803
Average Rouge-1-f Score: 0.8411890566545775
Average Rouge-2-f Score: 0.8141585288086293
Average Rouge-l-f Score: 0.8398923588814001
Run 7 out of 25. Combination tested: ignore_o_threshold=2 and min_snippet_length=8. Currently best: {'run_id': 6, 'ignore_o_threshold': 2, 'min_snippet_length': 4, 'f1': 0.717494906129535}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6442193361560231
Average Recall Score: 0.8099099018345771
Average F1 Score: 0.717624824086556
Average Iou Score: 0.7458745572991652
Average Bleu Score: 0.5701082717970113
Average Jaccard Score: 0.7458745572991652
Average Precision_region_lvl Score: 0.6585209774053409
Average Recall_region_lvl Score: 0.8672977079253992
Average F1_region_lvl Score: 0.6805659066917427
Average Rouge-1-f Score: 0.8412475821744897
Average Rouge-2-f Score: 0.814218566277425
Average Rouge-l-f Score: 0.8399484500804423
Run 8 out of 25. Combination tested: ignore_o_threshold=2 and min_snippet_length=10. Currently best: {'run_id': 7, 'ignore_o_threshold': 2, 'min_snippet_length': 8, 'f1': 0.717624824086556}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6442193361560231
Average Recall Score: 0.8099099018345771
Average F1 Score: 0.717624824086556
Average Iou Score: 0.7458745572991652
Average Bleu Score: 0.5701082717970113
Average Jaccard Score: 0.7458745572991652
Average Precision_region_lvl Score: 0.6585209774053409
Average Recall_region_lvl Score: 0.8672977079253992
Average F1_region_lvl Score: 0.6805659066917427
Average Rouge-1-f Score: 0.8412475821744897
Average Rouge-2-f Score: 0.814218566277425
Average Rouge-l-f Score: 0.8399484500804423
Run 9 out of 25. Combination tested: ignore_o_threshold=2 and min_snippet_length=12. Currently best: {'run_id': 7, 'ignore_o_threshold': 2, 'min_snippet_length': 8, 'f1': 0.717624824086556}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6442292082574644
Average Recall Score: 0.8099223129999372
Average F1 Score: 0.7176358210633038
Average Iou Score: 0.7456441128140486
Average Bleu Score: 0.5701117326955447
Average Jaccard Score: 0.7456441128140486
Average Precision_region_lvl Score: 0.6585475193798747
Average Recall_region_lvl Score: 0.8672918808945806
Average F1_region_lvl Score: 0.6806028995267532
Average Rouge-1-f Score: 0.8407959262865686
Average Rouge-2-f Score: 0.8137279512571438
Average Rouge-l-f Score: 0.8394891864478369
Run 10 out of 25. Combination tested: ignore_o_threshold=4 and min_snippet_length=0. Currently best: {'run_id': 9, 'ignore_o_threshold': 2, 'min_snippet_length': 12, 'f1': 0.7176358210633038}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6419695304064112
Average Recall Score: 0.8156838613382433
Average F1 Score: 0.7184755832751906
Average Iou Score: 0.7451742964108466
Average Bleu Score: 0.5710486810653063
Average Jaccard Score: 0.7451742964108466
Average Precision_region_lvl Score: 0.6561882605452536
Average Recall_region_lvl Score: 0.8720900651730561
Average F1_region_lvl Score: 0.6810806748588459
Average Rouge-1-f Score: 0.8404688167271754
Average Rouge-2-f Score: 0.8135940246601502
Average Rouge-l-f Score: 0.839209938694407
Run 11 out of 25. Combination tested: ignore_o_threshold=4 and min_snippet_length=4. Currently best: {'run_id': 10, 'ignore_o_threshold': 4, 'min_snippet_length': 0, 'f1': 0.7184755832751906}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6419719273849819
Average Recall Score: 0.8159248115636341
Average F1 Score: 0.7185705405425095
Average Iou Score: 0.7454198342437712
Average Bleu Score: 0.57124417559797
Average Jaccard Score: 0.7454198342437712
Average Precision_region_lvl Score: 0.6562767339609791
Average Recall_region_lvl Score: 0.8724341805476294
Average F1_region_lvl Score: 0.6813091208663313
Average Rouge-1-f Score: 0.8407740885347628
Average Rouge-2-f Score: 0.8138724144034043
Average Rouge-l-f Score: 0.8395147797481768
Run 12 out of 25. Combination tested: ignore_o_threshold=4 and min_snippet_length=8. Currently best: {'run_id': 11, 'ignore_o_threshold': 4, 'min_snippet_length': 4, 'f1': 0.7185705405425095}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6419951677328206
Average Recall Score: 0.8159820666009053
Average F1 Score: 0.7186073024711069
Average Iou Score: 0.7454365247464708
Average Bleu Score: 0.5713438639340951
Average Jaccard Score: 0.7454365247464708
Average Precision_region_lvl Score: 0.6563016149365688
Average Recall_region_lvl Score: 0.8726855317307785
Average F1_region_lvl Score: 0.6814434106302335
Average Rouge-1-f Score: 0.8406870094376678
Average Rouge-2-f Score: 0.8137941023408615
Average Rouge-l-f Score: 0.8394218748136362
Run 13 out of 25. Combination tested: ignore_o_threshold=4 and min_snippet_length=10. Currently best: {'run_id': 12, 'ignore_o_threshold': 4, 'min_snippet_length': 8, 'f1': 0.7186073024711069}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6419979466687059
Average Recall Score: 0.8161975361078724
Average F1 Score: 0.7186925874431567
Average Iou Score: 0.7456770504596139
Average Bleu Score: 0.5715394624023001
Average Jaccard Score: 0.7456770504596139
Average Precision_region_lvl Score: 0.656386247153565
Average Recall_region_lvl Score: 0.8729778976287863
Average F1_region_lvl Score: 0.6816645406185242
Average Rouge-1-f Score: 0.8409363403734572
Average Rouge-2-f Score: 0.8140532795110852
Average Rouge-l-f Score: 0.8396707726324711
Run 14 out of 25. Combination tested: ignore_o_threshold=4 and min_snippet_length=12. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6419941502294405
Average Recall Score: 0.8161927095435656
Average F1 Score: 0.7186883374719957
Average Iou Score: 0.7455170655170201
Average Bleu Score: 0.571533744106066
Average Jaccard Score: 0.7455170655170201
Average Precision_region_lvl Score: 0.6563821364198162
Average Recall_region_lvl Score: 0.8728474791484877
Average F1_region_lvl Score: 0.6816566127919944
Average Rouge-1-f Score: 0.8406217057909169
Average Rouge-2-f Score: 0.8137223426190354
Average Rouge-l-f Score: 0.8393561380499306
Run 15 out of 25. Combination tested: ignore_o_threshold=8 and min_snippet_length=0. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6385847970549449
Average Recall Score: 0.8195262357823188
Average F1 Score: 0.717828729325119
Average Iou Score: 0.7429200490382156
Average Bleu Score: 0.5707674093632182
Average Jaccard Score: 0.7429200490382156
Average Precision_region_lvl Score: 0.6535739077529336
Average Recall_region_lvl Score: 0.8760401800084575
Average F1_region_lvl Score: 0.6809781446154587
Average Rouge-1-f Score: 0.8381857040877957
Average Rouge-2-f Score: 0.8110093126443688
Average Rouge-l-f Score: 0.8369408433025517
Run 16 out of 25. Combination tested: ignore_o_threshold=8 and min_snippet_length=4. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6385860781393131
Average Recall Score: 0.8197669489334863
Average F1 Score: 0.717921862799457
Average Iou Score: 0.7431678352697195
Average Bleu Score: 0.5709627104442664
Average Jaccard Score: 0.7431678352697195
Average Precision_region_lvl Score: 0.6536606738462623
Average Recall_region_lvl Score: 0.8763367651766931
Average F1_region_lvl Score: 0.6812049566015791
Average Rouge-1-f Score: 0.8384639326755831
Average Rouge-2-f Score: 0.8112868179694823
Average Rouge-l-f Score: 0.8372186459328419
Run 17 out of 25. Combination tested: ignore_o_threshold=8 and min_snippet_length=8. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6385860781393131
Average Recall Score: 0.8197669489334863
Average F1 Score: 0.717921862799457
Average Iou Score: 0.7430865164259257
Average Bleu Score: 0.5709613715369265
Average Jaccard Score: 0.7430865164259257
Average Precision_region_lvl Score: 0.6536606738462623
Average Recall_region_lvl Score: 0.8763367651766931
Average F1_region_lvl Score: 0.6812049566015791
Average Rouge-1-f Score: 0.8383001785519164
Average Rouge-2-f Score: 0.8111362687285999
Average Rouge-l-f Score: 0.8370548918091752
Run 18 out of 25. Combination tested: ignore_o_threshold=8 and min_snippet_length=10. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6385855543116108
Average Recall Score: 0.819791948857828
Average F1 Score: 0.7179311185805491
Average Iou Score: 0.7431864472280022
Average Bleu Score: 0.5710567426030477
Average Jaccard Score: 0.7431864472280022
Average Precision_region_lvl Score: 0.6536480776701112
Average Recall_region_lvl Score: 0.8764613696628377
Average F1_region_lvl Score: 0.681279400665653
Average Rouge-1-f Score: 0.8383909738759759
Average Rouge-2-f Score: 0.8112322879399753
Average Rouge-l-f Score: 0.8371454740451545
Run 19 out of 25. Combination tested: ignore_o_threshold=8 and min_snippet_length=12. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6385844804025012
Average Recall Score: 0.8197905702139862
Average F1 Score: 0.7179299112360311
Average Iou Score: 0.7431265327762528
Average Bleu Score: 0.5710535136615903
Average Jaccard Score: 0.7431265327762528
Average Precision_region_lvl Score: 0.6536467460306233
Average Recall_region_lvl Score: 0.876432850383805
Average F1_region_lvl Score: 0.6812768561946612
Average Rouge-1-f Score: 0.838248377481117
Average Rouge-2-f Score: 0.8110801851188194
Average Rouge-l-f Score: 0.8370028776502956
Run 20 out of 25. Combination tested: ignore_o_threshold=10 and min_snippet_length=0. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6375605005881332
Average Recall Score: 0.8202098145930294
Average F1 Score: 0.7174427610898976
Average Iou Score: 0.742475255026636
Average Bleu Score: 0.5704749597757481
Average Jaccard Score: 0.742475255026636
Average Precision_region_lvl Score: 0.6526835007626605
Average Recall_region_lvl Score: 0.8768379247385243
Average F1_region_lvl Score: 0.6806971371341963
Average Rouge-1-f Score: 0.8378928806149109
Average Rouge-2-f Score: 0.8106548258227755
Average Rouge-l-f Score: 0.8366832791866671
Run 21 out of 25. Combination tested: ignore_o_threshold=10 and min_snippet_length=4. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6375617730650689
Average Recall Score: 0.8204507308263628
Average F1 Score: 0.7175357155882639
Average Iou Score: 0.7427228890617288
Average Bleu Score: 0.570670160788503
Average Jaccard Score: 0.7427228890617288
Average Precision_region_lvl Score: 0.6527699621829387
Average Recall_region_lvl Score: 0.8771347828733057
Average F1_region_lvl Score: 0.6809238529672007
Average Rouge-1-f Score: 0.8381710090064712
Average Rouge-2-f Score: 0.8109322098521418
Average Rouge-l-f Score: 0.8369609936855233
Run 22 out of 25. Combination tested: ignore_o_threshold=10 and min_snippet_length=8. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6375617730650689
Average Recall Score: 0.8204507308263628
Average F1 Score: 0.7175357155882639
Average Iou Score: 0.7426415702179349
Average Bleu Score: 0.570668821881163
Average Jaccard Score: 0.7426415702179349
Average Precision_region_lvl Score: 0.6527699621829387
Average Recall_region_lvl Score: 0.8771347828733057
Average F1_region_lvl Score: 0.6809238529672007
Average Rouge-1-f Score: 0.8380072548828046
Average Rouge-2-f Score: 0.8107816606112592
Average Rouge-l-f Score: 0.8367972395618565
Run 23 out of 25. Combination tested: ignore_o_threshold=10 and min_snippet_length=10. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6375612444781544
Average Recall Score: 0.8204757562033803
Average F1 Score: 0.717544951115323
Average Iou Score: 0.742741424882741
Average Bleu Score: 0.5707641428874527
Average Jaccard Score: 0.742741424882741
Average Precision_region_lvl Score: 0.6527572135920608
Average Recall_region_lvl Score: 0.8772595239127864
Average F1_region_lvl Score: 0.6809982489300367
Average Rouge-1-f Score: 0.8380980000830333
Average Rouge-2-f Score: 0.8108776191436278
Average Rouge-l-f Score: 0.8368877777094977
Run 24 out of 25. Combination tested: ignore_o_threshold=10 and min_snippet_length=12. Currently best: {'run_id': 13, 'ignore_o_threshold': 4, 'min_snippet_length': 10, 'f1': 0.7186925874431567}


Creating snippets:   0%|          | 0/6918 [00:00<?, ?it/s]

Assign results:   0%|          | 0/6918 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/4807 [00:00<?, ?it/s]

Average Precision Score: 0.6375965971451114
Average Recall Score: 0.820521251450161
Average F1 Score: 0.7175847388657597
Average Iou Score: 0.7426647029140002
Average Bleu Score: 0.5707664066475714
Average Jaccard Score: 0.7426647029140002
Average Precision_region_lvl Score: 0.652785547472765
Average Recall_region_lvl Score: 0.8773667119563194
Average F1_region_lvl Score: 0.6810438606967714
Average Rouge-1-f Score: 0.8379245295698652
Average Rouge-2-f Score: 0.8106749944062592
Average Rouge-l-f Score: 0.8367083031375858


In [18]:
# save run data to file
import json

with open('Evaluation/token_model_snippet_creation_params_all_scores.json', 'w') as outfile:
    json.dump(all_scores, outfile, indent=4)


# Evaluation of CYBER I & CYBER II
(samle process as done for aml test data above)

In [6]:
docReader = DocReader(MODEL, tokenizer)
create_original_snippets = True
add_full_page = True

cyber_file_path_i = config["general"]['DATA_DICT_FILE_PATH']['CYBER_I']
cyber_file_path_ii = config["general"]['DATA_DICT_FILE_PATH']['CYBER_II']

if True:
    print(f"Load data_dict from {DATA_PATH}/{cyber_file_path_i}.")
    with open(f'{DATA_PATH}/{cyber_file_path_i}', 'rb') as handle:
        data_dict_cyber_i = pickle.load(handle)
    print("Done.")

if True:
    print(f"Load data_dict from {DATA_PATH}/{cyber_file_path_ii}.")
    with open(f'{DATA_PATH}/{cyber_file_path_ii}', 'rb') as handle:
        data_dict_cyber_ii = pickle.load(handle)
    print("Done.")

Load data_dict from /home/tlh45/rds/hpc-work/preprocessing/data_dicts/28259v240404_CYBER/data_dict_roberta-base-06-04.pkl.
Done.


In [12]:
import math
import time
import importlib
import evaluator
importlib.reload(evaluator)

metrics_config = {
    'iou': True,
    'bleu': True,
    'jaccard': True,
    'precision': True,
    'recall': True,
    'f1': True,
    'precision_region_lvl': True,
    'recall_region_lvl': True,
    'f1_region_lvl': True,
    'edit_distance': False,
    'rouge-1-f': True,
    'rouge-2-f': True,
    'rouge-l-f': True,
    'pk': False,
    'windowdiff': False,
    'cohen_kappa': False
}

for theme, data_dict in zip(["CYBER_I", "CYBER_II"], [data_dict_cyber_i, data_dict_cyber_ii]):   
    doc_count = 0
    page_count = 0
    
    # Step 1: Collect all texts and metadata for batch processing
    all_texts = []
    metadata = []
    for doc_id, page_dict in tqdm(data_dict.items(), desc="Collect texts and metadata"):
        for page_index, page_content in page_dict.items():
            if page_index in ['title', 'doc_long_id']:
                continue
            page_text = data_dict[doc_id][page_index]['full_text']
            all_texts.append(page_text)
            metadata.append((doc_id, page_index))
            page_count += 1
        doc_count += 1
        
        
    # Split metadata and all_texts into chunks
    # This step is necessary as the evaluation of the cyber data is considerably larger than the evaluation of the AML test data only
    text_chunk_size = 10000
    all_texts_chunks = [all_texts[i:i + text_chunk_size] for i in range(0, len(all_texts), text_chunk_size)]
    total_time = 0
    full_snippets_collection = []
    for all_texts_chunk in all_texts_chunks:

        # Step 2: Process all texts in batches
        stride = OVERLAP
        sub_batch_size = 512 if MODEL == "roberta-large" else 2048
        if True:
            start_time = time.time()  # Start timing
            all_encodings, batch_input_ids, predicted_labels, detailed_predictions, tokenizer = batch_predict(token_model
                                                                                                            , tokenizer
                                                                                                            , all_texts_chunk
                                                                                                            , metadata
                                                                                                            , window_size=512
                                                                                                            , stride = stride
                                                                                                            , sub_batch_size=sub_batch_size)
            end_time = time.time()  # End timing
            total_time_chunk = end_time - start_time  # Calculate total time
            total_time += total_time_chunk
        # del token_model
        gc.collect()
        torch.cuda.empty_cache()
        batch_input_ids = batch_input_ids.to("cpu")
        predicted_labels = [label.to("cpu") for label in predicted_labels]

        snippets_collection = create_snippets_multiprocessing(all_encodings
                                                , batch_input_ids
                                                , predicted_labels
                                                , detailed_predictions
                                                , tokenizer
                                                , stride = stride
                                                , min_snippet_length = 10
                                                , ignore_o_threshold = 4
                                                # , multiprocessing = True
                                                , workers = 8
                                                , chunksize = 512)

        full_snippets_collection.extend(snippets_collection)
        del snippets_collection, batch_input_ids, predicted_labels, all_encodings
        gc.collect()
    
    with open(f"full_snippets_collection_{theme}_{PRETRAINED_MODEL}.pkl", "wb") as f:
        pickle.dump(full_snippets_collection, f)
        
    # Step 3: Assign results back to original data_dict
    for ((doc_id, page_index), snippets) in tqdm(zip(metadata, full_snippets_collection), desc="Assign results", total=len(metadata)):
        snippet_texts = []
        detailed_tags = []
        for snippet_text, detailed_tag in snippets:
            if snippet_text.strip():
                snippet_texts.append(snippet_text)
                detailed_tags.append(detailed_tag)
        predicted_snippets = []

        if snippet_texts:
            tokenized_snippets = tokenizer(snippet_texts, padding=True, truncation=True, return_tensors="pt")
            
            for snippet_text, detailed_tag, snippet_tokenized_ids in zip(snippet_texts, detailed_tags, tokenized_snippets['input_ids']):
                snippet_tokenized_text = tokenizer.convert_ids_to_tokens(snippet_tokenized_ids)
                dict_to_add = {'text': snippet_text, 'tokenized_text': snippet_tokenized_text, 'detailed_tag': detailed_tag}
                predicted_snippets.append(dict_to_add)

        data_dict[doc_id][page_index]['predicted_snippets'] = predicted_snippets

        
    aggregated_metrics_predicted_snippets = evaluator.evaluate_snippets_parallel(data_dict, "predicted_snippets", "refined_regions", metrics_config, batch_size=64)
    aggregated_metrics_predicted_snippets['inference_time'] = total_time
    aggregated_metrics_predicted_snippets['pages'] = page_count
    aggregated_metrics_predicted_snippets['batch_size'] = sub_batch_size
    aggregated_metrics_predicted_snippets['GPU'] = torch.cuda.get_device_properties(0).name
        
    # Save to json
    with open(f'Evaluation/{theme}/token-{str(MODEL).replace("/", "-")}_{PRETRAINED_MODEL if PRETRAINED_MODEL else unique_tag}.json', 'w') as f:
        json.dump(aggregated_metrics_predicted_snippets, f, indent=4)
        

Collect texts and metadata:   0%|          | 0/730 [00:00<?, ?it/s]

Tokenizing:   0%|          | 0/10000 [00:00<?, ?it/s]

Number of samples: 21450


Predicting:   0%|          | 0/11 [00:00<?, ?it/s]

Creating snippets:   0%|          | 0/10000 [00:00<?, ?it/s]

Tokenizing:   0%|          | 0/10000 [00:00<?, ?it/s]

Number of samples: 21311


Predicting:   0%|          | 0/11 [00:00<?, ?it/s]

Creating snippets:   0%|          | 0/10000 [00:00<?, ?it/s]

Tokenizing:   0%|          | 0/4513 [00:00<?, ?it/s]

Number of samples: 9977


Predicting:   0%|          | 0/5 [00:00<?, ?it/s]

Creating snippets:   0%|          | 0/4513 [00:00<?, ?it/s]

Assign results:   0%|          | 0/24513 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/383 [00:00<?, ?it/s]

Average Precision Score: 0.6701474538585565
Average Recall Score: 0.8092387605230164
Average F1 Score: 0.7331544523751762
Average Iou Score: 0.7443593037861259
Average Bleu Score: 0.6040033016575665
Average Jaccard Score: 0.7443593037861259
Average Precision_region_lvl Score: 0.6933163250659018
Average Recall_region_lvl Score: 0.8724206453249644
Average F1_region_lvl Score: 0.7104174815874686
Average Rouge-1-f Score: 0.8330772493050197
Average Rouge-2-f Score: 0.8059755768699488
Average Rouge-l-f Score: 0.8316925812118936


Collect texts and metadata:   0%|          | 0/965 [00:00<?, ?it/s]

Tokenizing:   0%|          | 0/10000 [00:00<?, ?it/s]

Number of samples: 21184


Predicting:   0%|          | 0/11 [00:00<?, ?it/s]

Creating snippets:   0%|          | 0/10000 [00:00<?, ?it/s]

Tokenizing:   0%|          | 0/10000 [00:00<?, ?it/s]

Number of samples: 20113


Predicting:   0%|          | 0/10 [00:00<?, ?it/s]

Creating snippets:   0%|          | 0/10000 [00:00<?, ?it/s]

Tokenizing:   0%|          | 0/9058 [00:00<?, ?it/s]

Number of samples: 18048


Predicting:   0%|          | 0/9 [00:00<?, ?it/s]

Creating snippets:   0%|          | 0/9058 [00:00<?, ?it/s]

Assign results:   0%|          | 0/29058 [00:00<?, ?it/s]

Evaluation Started.


  0%|          | 0/454 [00:00<?, ?it/s]

Average Precision Score: 0.7200198200254957
Average Recall Score: 0.803640139905733
Average F1 Score: 0.7595353873134623
Average Iou Score: 0.7534106807910367
Average Bleu Score: 0.6240424194086169
Average Jaccard Score: 0.7535112439446973
Average Precision_region_lvl Score: 0.7285348941660446
Average Recall_region_lvl Score: 0.8608933583949359
Average F1_region_lvl Score: 0.727774583826841
Average Rouge-1-f Score: 0.8391584908899389
Average Rouge-2-f Score: 0.8119460703854102
Average Rouge-l-f Score: 0.8377474080570861
