# PII W and B Inference
This notebook does inference and post processing for PII detection data using finetuned deBERTa models. The classifer does predictions on a validation set and uses predictions to update prediction thresholds for the competition/testing data. The classifer works for single models or ensembles using weighted averaging. 

Model training is done here: https://www.kaggle.com/code/jonathankasprisin/pii-train

### Reference
1. https://www.kaggle.com/code/thedrcat/pii-data-detection-infer-with-w-b

References for specfic supporting fucntions are given in the corresponding code blocks. 

# Imports

In [None]:
from pathlib import Path
import os
import json
import argparse
from itertools import chain
import torch
from transformers import AutoTokenizer, Trainer, TrainingArguments
from transformers import AutoModelForTokenClassification, DataCollatorForTokenClassification
from datasets import Dataset, features
import numpy as np
import pandas as pd
import gc

# Config

In [None]:
import yaml
#generally the same
DATA_PATH = '../input/pii-detection-removal-from-educational-data'
VAL_PATH = '/kaggle/input/pii-bagging-datasets/val2.json'
OUTPUT_DIR = "/kaggle/working/"


#ensemble
MODEL_PATHS = {
    '/kaggle/input/pii-ens1-train/pii_ens1_base_full': 1
}

INFERENCE_MAX_LENGTH = 3500
STRIDE = 0
THRESHOLD = .99
print(f'Infer max length: {INFERENCE_MAX_LENGTH}, stride: {STRIDE}, threshold: {THRESHOLD}')


# Training Metric calc

In [None]:
# https://www.kaggle.com/code/conjuring92/pii-metric-fine-grained-eval

from collections import defaultdict
from typing import Dict
from scipy.special import softmax
# from utils import parse_predictions #SCRIPT version

class PRFScore:
    """A precision / recall / F score."""

    def __init__(
        self,
        *,
        tp: int = 0,
        fp: int = 0,
        fn: int = 0,
    ) -> None:
        self.tp = tp
        self.fp = fp
        self.fn = fn

    def __len__(self) -> int:
        return self.tp + self.fp + self.fn

    def __iadd__(self, other):  # in-place add
        self.tp += other.tp
        self.fp += other.fp
        self.fn += other.fn
        return self

    def __add__(self, other):
        return PRFScore(
            tp=self.tp + other.tp, fp=self.fp + other.fp, fn=self.fn + other.fn
        )

    def score_set(self, cand: set, gold: set) -> None:
        self.tp += len(cand.intersection(gold))
        self.fp += len(cand - gold)
        self.fn += len(gold - cand)

    @property
    def precision(self) -> float:
        return self.tp / (self.tp + self.fp + 1e-100)

    @property
    def recall(self) -> float:
        return self.tp / (self.tp + self.fn + 1e-100)

    @property
    def f1(self) -> float:
        p = self.precision
        r = self.recall
        return 2 * ((p * r) / (p + r + 1e-100))

    @property
    def f5(self) -> float:
        beta = 5
        p = self.precision
        r = self.recall

        fbeta = (1+(beta**2))*p*r / ((beta**2)*p + r + 1e-100)
        return fbeta

    def to_dict(self) -> Dict[str, float]:
        return {"p": self.precision, "r": self.recall, "f5": self.f5}


def compute_metrics(p, id2label, valid_ds, valid_df, threshold=0.9):
    """
    Compute the LB metric (lb) and other auxiliary metrics
    """
    predictions, labels = p
    
    pred_df = val_parse_predictions(predictions, id2label, valid_ds, threshold=threshold)
    
    references = zip(valid_df.document, valid_df.token, valid_df.label)
    predictions = zip(pred_df.document, pred_df.token, pred_df.label)
    
    score_per_type = defaultdict(PRFScore)
    references = set(references)

    for ex in predictions:
        pred_type = ex[-1] # (document, token, label)
        if pred_type != 'O':
            pred_type = pred_type[2:] # avoid B- and I- prefix
            
        if pred_type not in score_per_type:
            score_per_type[pred_type] = PRFScore()

        if ex in references:
            score_per_type[pred_type].tp += 1
            references.remove(ex)
        else:
            score_per_type[pred_type].fp += 1

    for doc, tok, ref_type in references:
        if ref_type != 'O':
            ref_type = ref_type[2:] # avoid B- and I- prefix
        
        if ref_type not in score_per_type:
            score_per_type[ref_type] = PRFScore()
        score_per_type[ref_type].fn += 1

    totals = PRFScore()
    
    for prf in score_per_type.values():
        totals += prf

    results = {
        "ents_p": totals.precision,
        "ents_r": totals.recall,
        "ents_f5": totals.f5,
        "ents_per_type": {k: v.to_dict() for k, v in score_per_type.items() if k!= 'O'},
    }
    
    # Unpack nested dictionaries
    final_results = {}
    for key, value in results.items():
        if isinstance(value, dict):
            for n, v in value.items():
                if isinstance(v, dict):
                    for n2, v2 in v.items():
                        final_results[f"{key}_{n}_{n2}"] = v2
                else:
                    final_results[f"{key}_{n}"] = v              
        else:
            final_results[key] = value
            
    return final_results

def val_parse_predictions(predictions, id2label, ds, threshold=0.9):

    ## only threshold 'o' label  #####
    # Scale last dimension to probabilities for interpretability
    pred_softmax = softmax(predictions, axis=2) #why 2 not -1
    print(f" val_parse- preidicitions.shape {predictions.shape}")
    preds = predictions.argmax(-1)
    preds_without_O = pred_softmax[:,:,:12].argmax(-1)
    O_preds = pred_softmax[:,:,12]
    #preds_final = predictions.argmax(-1) #Choose label with max probability
    preds_final = np.where(O_preds < threshold, preds_without_O , preds)

# ###### threshold each label #######
#     # Ensure thresholds is a numpy array for element-wise comparison
#     thresholds = np.array([threshold]) if np.isscalar(threshold) else np.array(threshold)
#     if thresholds.shape[0] != predictions.shape[-1]: 
#         thresholds = np.full(predictions.shape[-1], thresholds[0])

#     # Scale last dimension to probabilities for interpretability
#     pred_softmax = softmax(predictions, axis=2) #why 2 not -1
#     preds = predictions.argmax(-1)
#     preds_without_O = pred_softmax[:,:,:12].argmax(-1)
#     O_preds = pred_softmax[:,:,12]
#     # Find the maximum non-'O' prediction for each sample
#     max_non_O_preds = pred_softmax[:,:,:12].max(-1)
#     theshold_for_non_o_label = thresholds[pred_softmax[:,:,:12].argmax(-1)]
#     preds_final = np.where(max_non_O_preds < theshold_for_non_o_label , 12 , preds_without_O)

    triplets = set()
    row, document, token, label, token_str = [], [], [], [], []
    for i, (p, token_map, offsets, tokens, doc, indices) in enumerate(zip(preds_final, ds["token_map"], ds["offset_mapping"], ds["tokens"], ds["document"], ds["token_indices"])):

        for token_pred, (start_idx, end_idx) in zip(p, offsets):
            label_pred = id2label[str(token_pred)]

            if start_idx + end_idx == 0: continue

            if token_map[start_idx] == -1:
                start_idx += 1

            # ignore "\n\n"
            while start_idx < len(token_map) and tokens[token_map[start_idx]].isspace():
                start_idx += 1

            if start_idx >= len(token_map): break

            #CHECK
            token_id = token_map[start_idx] #token ID at the start of the index
#             original_token_id = token_map[start_idx]
#             token_id = indices[original_token_id]

            # ignore "O" predictions and whitespace preds
            if label_pred != "O" and token_id != -1:
                triplet = (label_pred, token_id, tokens[token_id])

                if triplet not in triplets:
                    row.append(i)
                    document.append(doc)
                    token.append(token_id)
                    label.append(label_pred)
                    token_str.append(tokens[token_id])
                    triplets.add(triplet)

    df = pd.DataFrame({
        "eval_row": row,
        "document": document,
        "token": token,
        "label": label,
        "token_str": token_str
    })

    df = df.drop_duplicates().reset_index(drop=True)

    df["row_id"] = list(range(len(df)))
    return df

def identify_incorrect_labels(reference_df, pred_df):
    """
    Identify incorrectly labeled tokens and classify them as False Negatives or False Positives.

    Parameters:
    - reference_df (DataFrame): DataFrame with the reference labels.
    - pred_df (DataFrame): DataFrame with the predicted labels.

    Returns:
    - incorrectly_labeled (DataFrame): DataFrame with the incorrectly labeled tokens and their error types.
    """
    # Drop unnecessary columns from pred_df
    pred_df = pred_df.drop(columns=['eval_row', 'row_id'])

    # Merge the DataFrames
    merged_df = pd.merge(reference_df, pred_df, on=['document', 'token'], how='outer', suffixes=('_actual', '_pred'))

    # Identify incorrectly labeled tokens
    incorrectly_labeled = merged_df[merged_df['label_actual'] != merged_df['label_pred']].copy()

    # Fill NaN values in 'label_actual' and 'label_pred' with 'O'
    incorrectly_labeled['label_actual'] = incorrectly_labeled['label_actual'].fillna('O')
    incorrectly_labeled['label_pred'] = incorrectly_labeled['label_pred'].fillna('O')

    # Define conditions for False Negatives and False Positives
    condition_fn = (
        (incorrectly_labeled['label_actual'] != 'O')  &
        ((incorrectly_labeled['label_pred'] == 'O') | (incorrectly_labeled['label_actual'] != incorrectly_labeled['label_pred']))
    )
    condition_fp = ((incorrectly_labeled['label_actual'] == 'O') & (incorrectly_labeled['label_pred'] != 'O'))

    # Use np.select to choose between 'FN', 'FP', and None based on the conditions
    choices = ['FN', 'FP']
    incorrectly_labeled['error'] = np.select([condition_fn, condition_fp], choices, default=None)

    return incorrectly_labeled

# From Training Helpers


In [None]:
import numpy as np
from datasets import Dataset

#prep data for NER training by tokenize the text and align labels to tokens
def val_tokenize(example, tokenizer, label2id, max_length, stride):
    """This function ensures that the text is correctly tokenized and the labels 
    are correctly aligned with the tokens for NER training.

    Args:
        example (dict): The example containing the text and labels.
        tokenizer (Tokenizer): The tokenizer used to tokenize the text.
        label2id (dict): A dictionary mapping labels to their corresponding ids.
        max_length (int): The maximum length of the tokenized text.

    Returns:
        dict: The tokenized example with aligned labels.

    Reference: credit to https://www.kaggle.com/code/valentinwerner/915-deberta3base-training/notebook
    """

    # rebuild text from tokens
    text = []
    labels = []
    token_map = [] 
    
    idx = 0

    #iterate through tokens, labels, and trailing whitespace using zip to create tuple from three lists
    for t, l, ws in zip(
        example["tokens"], example["provided_labels"], example["trailing_whitespace"]
    ):
        
        text.append(t)
        token_map.extend([idx]*len(t)) 
        #extend so we can add multiple elements to end of list if ws
        labels.extend([l] * len(t))
        
        if ws:
            text.append(" ")
            labels.append("O")
            token_map.append(-1) #CHECK
            
        idx += 1

    #Tokenize text and return offsets for start and end character position. Limit length of tokenized text.
    tokenized = tokenizer(
        "".join(text),
        return_offsets_mapping=True,
        max_length=max_length,
        truncation=True,
        stride = stride,
    ) 

    #convert to np array for indexing
    labels = np.array(labels)

    # join text list into a single string 
    text = "".join(text)
    token_labels = []

    #iterate through each tolken
    for start_idx, end_idx in tokenized.offset_mapping:
        #if special tolken (CLS token) then append O
        #CLS : classification token added to the start of each sequence
        if start_idx == 0 and end_idx == 0:
            token_labels.append(label2id["O"])
            continue

        # case when token starts with whitespace
        if text[start_idx].isspace():
            start_idx += 1

        #append orginal label to token_labels
        token_labels.append(label2id[labels[start_idx]])

    length = len(tokenized.input_ids)

    return {**tokenized, "labels": token_labels, "length": length,"token_map": token_map, } 

#create dataset if using wandb
def val_create_dataset(data, tokenizer, max_length, label2id, stride):
    '''
    data(pandas.DataFrame): for wandb artifact
    '''
    
    # Convert data to Hugging Face Dataset object
    ds = Dataset.from_dict({
        "full_text": data.full_text.tolist(),
        "document": data.document.tolist(),
        "tokens": data.tokens.tolist(),
        "trailing_whitespace": data.trailing_whitespace.tolist(),
        "provided_labels": data.labels.tolist(),
        "token_indices": data.token_indices.tolist(),
    })

    # Map the tokenize function to the Dataset
    ds = ds.map(
        val_tokenize,
        fn_kwargs={      # pass keyword args
            "tokenizer": tokenizer,
            "label2id": label2id,
            "max_length": max_length,
            "stride": stride,
        }, 
        num_proc=2
    )

    return ds

def get_reference_df_val(raw_df): 
    
    ref_df = raw_df[['document', 'tokens', 'labels']].copy()
    ref_df = ref_df.explode(['tokens', 'labels']).reset_index(drop=True).rename(columns={'tokens': 'token', 'labels': 'label'})
    ref_df['token_str'] = ref_df['token']
    ref_df['token'] = ref_df.groupby('document').cumcount()
        
    reference_df = ref_df[ref_df['label'] != 'O'].copy()
    reference_df = reference_df.reset_index().rename(columns={'index': 'row_id'})
    reference_df = reference_df[['row_id', 'document', 'token', 'label', 'token_str']].copy()
    
    return reference_df

# Val Score

In [None]:
import json
import pandas as pd
from transformers import Trainer

#VAL load and create dataset
print("validiation dataset")
data = json.load(open(VAL_PATH))
df = pd.DataFrame(data)
# df = df.head(50) #TEMP
reference_df = get_reference_df_val(df)
del data
gc.collect()

In [None]:
# Load id2label configuration from first model. all models should have same id2label
first_model_path = list(MODEL_PATHS.keys())[0]
config = json.load(open(first_model_path + "/config.json"))
id2label = config["id2label"]
label2id = config["label2id"]

del config
gc.collect()

In [None]:
# Initialize a tokenizer and model from the pretrained model path
tokenizer = AutoTokenizer.from_pretrained(first_model_path)

#create dataset to do inferenece on. Val to check errors
ds = val_create_dataset(df, tokenizer, INFERENCE_MAX_LENGTH, label2id, STRIDE)
print(f"val ds exampes: {ds.num_rows}, features: {ds.num_columns}") #DEBUG
print(f"The features in the dataset are: {ds.column_names}") #DEBUG

# intialize list to save predictions to for each model
all_preds = []

# Calculate the total weight for ensemble
total_weight = sum(MODEL_PATHS.values())

#iterate over all the models
for runs, (model_path, weight) in enumerate(MODEL_PATHS.items()):
    print(f"inference model {runs}...")
    
    #load model and collator from each model
    tokenizer = AutoTokenizer.from_pretrained(model_path)
    model = AutoModelForTokenClassification.from_pretrained(model_path)
    collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)

    #change eval size to 1 for memory
    args = TrainingArguments(
      ".",
      per_device_eval_batch_size=1,
      report_to="none",
    )

    trainer = Trainer(model=model, args=args, data_collator=collator,tokenizer=tokenizer,)

    predictions = trainer.predict(ds).predictions
    print(f"prediction np.shape expecting: num_examples, seq_length, num_labels: {predictions.shape}") #DEBUG

    #weigh the model's predictions and add to list
    weighted_predictions = predictions * weight
    all_preds.append(weighted_predictions)
    del tokenizer, model, trainer,weighted_predictions   #memory saving
    torch.cuda.empty_cache()
    gc.collect()
    
print(f"all_preds len expecting {len(MODEL_PATHS)}: {len(all_preds)}")

In [None]:
# #TEMP
# import pickle

# # Replace with your file path
# file_path = '/kaggle/input/temp-infer-debug/variables.pkl'

# # Open the file in binary read mode
# with open(file_path, 'rb') as file:
#     # Load all objects
#     data = pickle.load(file)

# # Extract the objects from the dictionary
# all_preds = data['all_preds']
# total_weight = data['total_weight']
# id2label = data['id2label']
# ds = data['ds']
# reference_df = data['reference_df']
# threshold_tests = data['threshold_tests']

In [None]:
# Calculate the weighted average of predictions
weighted_average_predictions = np.sum(all_preds, axis=0) / total_weight



In [None]:
#### for single threshold #####
print("doing threshold tests:")
threshold_tests = [.6,.7,.8,.9,.99] #TEMP
scores =[]

for threshold in threshold_tests:
    metrics = compute_metrics((weighted_average_predictions, None), id2label, ds, reference_df, threshold=threshold)
    f5_score = metrics['ents_f5']
    scores.append(f5_score)
    print(f'threshold:f5 {threshold}: {f5_score}')

best_threshold = 0.0
best_f5 = 0.0
for thresh, score in zip(threshold_tests, scores):
    if score > best_f5:
        best_threshold = thresh
        best_f5 = score
print(f'Best f5 {best_f5}, best ensemble uniform threshold: {best_threshold}')

#get incorrect labels and save to csv
preds_df = val_parse_predictions(weighted_average_predictions, id2label, ds, threshold=best_threshold)

THRESHOLD = best_threshold

In [None]:
###### Alternate threshold test ######
# print("doing threshold tests:")

# #initialize starting thresholds for each class at .99
# thresholds = np.array([0.99 for _ in range(len(id2label))])
# metrics = compute_metrics((weighted_average_predictions, None), id2label, ds, reference_df, threshold=thresholds)
# print("f5 with all .99", metrics['ents_f5'])

# best_threshold = thresholds.copy()
# best_score = np.zeros(len(id2label))

# #loop through each label threshold
# #decrease all class thresholds and choose the best based on the per_type_{}_f5 score for each label
# #while threshold for label is > .1 
# while thresholds[0] > .2:
#     metrics = compute_metrics((weighted_average_predictions, None), id2label, ds, reference_df, threshold=thresholds)
#     #get f5 score for each label at for the threshold level. give f5 of -1 if the label isnt present resulting in no metric generated
#     f5_score = np.array([metrics[f'ents_per_type_{id2label[str(label_idx)][2:]}_f5'] if f'ents_per_type_{id2label[str(label_idx)][2:]}_f5' in metrics else -1 for label_idx in range(len(thresholds)-1)])
#     #for each label see if the decrease in threshold improved the score
#     for label_idx in range(len(thresholds)-1):
#         if f5_score[label_idx] > best_score[label_idx]: 
#             best_score[label_idx] = f5_score[label_idx]
#             best_threshold[label_idx] = thresholds[label_idx]

#     #decrease the value of that index by .1
#     thresholds= thresholds - 0.1

# #assign updated threshold for submission
# THRESHOLD = best_threshold

# metrics = compute_metrics((weighted_average_predictions, None), id2label, ds, reference_df, threshold=best_threshold)
# print("f5 best_threshold", metrics['ents_f5'])

In [None]:
#for testing in seperate cpu enviroment
import pickle

with open('inference_outputs.pkl', 'wb') as file:
    pickle.dump((all_preds, weighted_average_predictions,id2label, ds, reference_df), file)

In [None]:
#get incorrect labels and save to csv
preds_df = val_parse_predictions(weighted_average_predictions, id2label, ds, threshold=best_threshold)

#get incorrect labels and save to csv

incorrectly_labeled_df = identify_incorrect_labels(reference_df, preds_df)
          
incorrectly_labeled_df.to_csv("val_incorrectly_labeled.csv", index=False)

In [None]:
del all_preds, weighted_average_predictions,id2label, ds, reference_df, incorrectly_labeled_df   #memory saving
torch.cuda.empty_cache()
gc.collect()

# Submission


In [None]:
#functions
import numpy as np
import pandas as pd
from scipy.special import softmax

def add_token_indices(doc_tokens):
    token_indices = list(range(len(doc_tokens)))
    return token_indices

def infer_tokenize(example, tokenizer, max_length, stride):
    """
    Tokenize an example for NER using the given tokenizer.

    Args:
        example (dict): A dictionary containing "tokens" and "trailing_whitespace" lists.
            - "tokens": A list of token strings.
            - "trailing_whitespace": A list of boolean values indicating whether each token has trailing whitespace.
        tokenizer: The tokenizer to use for tokenization.
        label2id (dict): A dictionary mapping labels to their corresponding ids.
        max_length (int): The maximum length of the tokenized text.

    Returns:
        dict: A dictionary containing tokenized output, including offsets mapping and token map.
            - "input_ids": List of token IDs.
            - "attention_mask": List of attention mask values.
            - "offset_mapping": List of character offsets for each token.
            - "token_map": List mapping each input token to its original position in the example.
            
    Reference: https://www.kaggle.com/code/valentinwerner/893-deberta3base-Inference
    """
    #empty list to store text and tokens in respective map
    text = []
    token_map = []
    
    #keep track of tokens
    idx = 0
    
    #for the example go through tokens and whitespace
    for t, ws in zip(example["tokens"], example["trailing_whitespace"]):
        
        #add token to text
        text.append(t)
        #extend token length number of idx
        token_map.extend([idx]*len(t))
        #for whitespace add a space to text and label -1 in token map
        if ws:
            text.append(" ")
            token_map.append(-1)
            
        idx += 1
        
    #Tokenize the text and return offset mapping with the token map    
    tokenized = tokenizer(
        "".join(text),
        return_offsets_mapping=True,
        truncation=True,
        max_length= max_length,
        stride = stride,
    )
    length = len(tokenized.input_ids)
        
    return {
        **tokenized,
        "length": length,
        "token_map": token_map,
    }

def create_dataset(data, tokenizer, max_length, stride):
    ds = Dataset.from_dict({
        "full_text": data.full_text.tolist(),
        "document": data.document.tolist(),
        "tokens": data.tokens.tolist(),
        "trailing_whitespace": data.trailing_whitespace.tolist(),
        "token_indices": data.token_indices.tolist(),
    })
    ds = ds.map( 
        infer_tokenize,
        fn_kwargs={"tokenizer": tokenizer,
                   "max_length": max_length,
                   "stride": stride,
                  }, 
        num_proc=3
    )
    return ds

def parse_predictions(predictions, id2label, ds, threshold=0.9):

## only threshold 'o' label  #####
    # Scale last dimension to probabilities for interpretability
    pred_softmax = softmax(predictions, axis=2) #why 2 not -1
    print(f" val_parse- preidicitions.shape {predictions.shape}")
    preds = predictions.argmax(-1)
    preds_without_O = pred_softmax[:,:,:12].argmax(-1)
    O_preds = pred_softmax[:,:,12]
    #preds_final = predictions.argmax(-1) #Choose label with max probability
    preds_final = np.where(O_preds < threshold, preds_without_O , preds)
    
# ###### threshold each label #######
#     # Ensure thresholds is a numpy array for element-wise comparison
#     thresholds = np.array([threshold]) if np.isscalar(threshold) else np.array(threshold)
#     if thresholds.shape[0] != predictions.shape[-1]: 
#         thresholds = np.full(predictions.shape[-1], thresholds[0])

#     # Scale last dimension to probabilities for interpretability
#     pred_softmax = softmax(predictions, axis=2) #why 2 not -1
#     preds = predictions.argmax(-1)
#     preds_without_O = pred_softmax[:,:,:12].argmax(-1)
#     O_preds = pred_softmax[:,:,12]
#     # Find the maximum non-'O' prediction for each sample
#     max_non_O_preds = pred_softmax[:,:,:12].max(-1)
#     theshold_for_non_o_label = thresholds[pred_softmax[:,:,:12].argmax(-1)]
#     preds_final = np.where(max_non_O_preds < theshold_for_non_o_label , 12 , preds_without_O)

    triplets = set()
    row, document, token, label, token_str = [], [], [], [], []
    for i, (p, token_map, offsets, tokens, doc, indices) in enumerate(zip(preds_final, ds["token_map"], ds["offset_mapping"], ds["tokens"], ds["document"], ds["token_indices"])):

        for token_pred, (start_idx, end_idx) in zip(p, offsets):
            label_pred = id2label[str(token_pred)]

            if start_idx + end_idx == 0: continue

            if token_map[start_idx] == -1:
                start_idx += 1

            # ignore "\n\n"
            while start_idx < len(token_map) and tokens[token_map[start_idx]].isspace():
                start_idx += 1

            if start_idx >= len(token_map): break

            token_id = token_map[start_idx] #token ID at the start of the index

            # ignore "O" predictions and whitespace preds
            if label_pred != "O" and token_id != -1:
                triplet = (label_pred, token_id, tokens[token_id])

                if triplet not in triplets:
                    row.append(i)
                    document.append(doc)
                    token.append(token_id)
                    label.append(label_pred)
                    token_str.append(tokens[token_id])
                    triplets.add(triplet)

    df = pd.DataFrame({
        "eval_row": row,
        "document": document,
        "token": token,
        "label": label,
        "token_str": token_str
    })

    df = df.drop_duplicates().reset_index(drop=True)

    df["row_id"] = list(range(len(df)))
    return df

# load, predict and submit

In [None]:
import json
import pandas as pd
from transformers import Trainer

data = json.load(open(DATA_PATH + "/test.json")) #submission)
df_full = pd.DataFrame(data)
df_full['token_indices'] = df_full['tokens'].apply(add_token_indices)
del data

#split data into 2000 document batches- to address memory out of bounds
split_size = 3000
num_split = len(df_full) // split_size
df_doc_batches = []
for i in range(num_split):
    df_batch= df_full.iloc[:split_size]
    df_doc_batches.append(df_batch)
    df_full = df_full.iloc[split_size:]
    print(df_full.shape[0])
    
if len(df_full) > 0:
    df_doc_batches.append(df_full)
        
del df_full


# Load id2label configuration from first model. all models should have same id2label
first_model_path = list(MODEL_PATHS.keys())[0]
config = json.load(open(first_model_path + "/config.json"))
id2label = config["id2label"]
label2id = config["label2id"]
del config

gc.collect()


# Initialize a tokenizer and model from the pretrained model path
tokenizer = AutoTokenizer.from_pretrained(first_model_path)

#initialize empty dataframe to append each batch to
sub_df= pd.DataFrame()

# run model in batches of documunets
for df in df_doc_batches:
    #create dataset to do inferenece on. Val to check errors
    ds = create_dataset(df, tokenizer, INFERENCE_MAX_LENGTH, STRIDE)
    print(f"val ds exampes: {ds.num_rows}, features: {ds.num_columns}") #DEBUG

    # intialize list to save predictions to for each model
    all_preds = []

    # Calculate the total weight for ensemble
    total_weight = sum(MODEL_PATHS.values())

    runs = 0
    #iterate over all the models
    for model_path, weight in MODEL_PATHS.items():
        print(f"inference model {runs}...")
        runs += 1

        #load model and collator from each model
        tokenizer = AutoTokenizer.from_pretrained(model_path)
        model = AutoModelForTokenClassification.from_pretrained(model_path)
        collator = DataCollatorForTokenClassification(tokenizer, pad_to_multiple_of=8)

        #change eval size to 1 for memory
        args = TrainingArguments(".", per_device_eval_batch_size=1,report_to="none",)

        trainer = Trainer(model=model, args=args, data_collator=collator,tokenizer=tokenizer,)

        predictions = trainer.predict(ds).predictions
        print(f"prediction np.shape expecting: num_examples, seq_length, num_labels: {predictions.shape}") #DEBUG

        #weigh the model's predictions and add to list
        weighted_predictions = predictions * weight
        all_preds.append(weighted_predictions)
        del model, trainer  #memory saving
        torch.cuda.empty_cache()
        gc.collect()

    weighted_average_predictions = np.sum(all_preds, axis=0) / total_weight
    preds_df = parse_predictions(weighted_average_predictions, id2label, ds, threshold=THRESHOLD)
    print("pred_df type ", type(preds_df))
    print("pred_df type ",type(sub_df))
    sub_df = pd.concat([sub_df,preds_df],ignore_index=True)
    del weighted_average_predictions, preds_df

In [None]:
sub_df["row_id"] = list(range(len(sub_df)))
#look at to see
display(sub_df.head(10))

In [None]:
#submit
sub_df[["row_id", "document", "token", "label"]].to_csv("submission.csv", index=False)