# Libraries Installation

In [None]:
%pip install -r ../requirements.txt

# Import Libraries

In [1]:
# library
import os

import torch
from torch import nn
from transformers import LongformerTokenizer, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:

test_dataset_doc_parh = "./dataset/validation_dataset/Validation_Release/"

test_path = [test_dataset_doc_parh + file_path for file_path in os.listdir(test_dataset_doc_parh)]


In [3]:
def load_medical_records(paths):
    medical_record_dict = {}
    for data_path in paths:

        if os.path.isfile(data_path):
            file_id = data_path.split("/")[-1].split(".txt")[0]
            with open(data_path, "r", encoding="utf-8") as f:
                file_text = f.read()
                medical_record_dict[file_id] = file_text
    return medical_record_dict

test_record_dict = load_medical_records(test_path)

In [4]:
# double check
print(len(list(test_record_dict.keys())))

560


In [5]:
# fix label_type
labels_type_table={'OTHER': 0, 'PATIENT': 1, 'DOCTOR': 2, 'CITY': 3, 'ROOM': 4, 'STREET': 5, 'MEDICALRECORD': 6, 'DEPARTMENT': 7, 'LOCATION-OTHER': 8, 'COUNTRY': 9, 'IDNUM': 10, 'STATE': 11, 'AGE': 12, 'SET': 13, 'HOSPITAL': 14, 'DATE': 15, 'ZIP': 16, 'URL': 17, 'DURATION': 18, 'ORGANIZATION': 19, 'TIME': 20, 'PHONE': 21}
print(labels_type_table)

{'OTHER': 0, 'PATIENT': 1, 'DOCTOR': 2, 'CITY': 3, 'ROOM': 4, 'STREET': 5, 'MEDICALRECORD': 6, 'DEPARTMENT': 7, 'LOCATION-OTHER': 8, 'COUNTRY': 9, 'IDNUM': 10, 'STATE': 11, 'AGE': 12, 'SET': 13, 'HOSPITAL': 14, 'DATE': 15, 'ZIP': 16, 'URL': 17, 'DURATION': 18, 'ORGANIZATION': 19, 'TIME': 20, 'PHONE': 21}


In [6]:
def create_label_dict(label_path):
    label_dict = {}  # y
    with open(label_path, "r", encoding="utf-8-sig") as f:
        file_text = f.read().strip()  

    # (id, label, start, end, query) or (id, label, start, end, query, time_org, timefix)
    for line in file_text.split("\n"):
        sample = line.split("\t")  
        sample[2], sample[3] = int(sample[2]), int(sample[3])

        if sample[0] not in label_dict:
            label_dict[sample[0]] = [sample[1:]]
        else:
            label_dict[sample[0]].append(sample[1:])

    return label_dict

In [7]:
val_label_path = "./dataset/validation_dataset/answer.txt"
ground_truth = create_label_dict(val_label_path)

# Predict

In [9]:
from transformers import LongformerModel
from torchcrf import CRF

class MyLongformerModel(nn.Module):
    def __init__(self, num_labels):
        super(MyLongformerModel, self).__init__()

        self.longformer = LongformerModel.from_pretrained('allenai/longformer-base-4096')
        self.dropout = nn.Dropout(p=0.1)
        self.classifier = nn.Linear(768, num_labels)
        self.crf = CRF(num_labels, batch_first=True)

    def forward(self, input_ids, attention_mask, labels=None):
        outputs = self.longformer(input_ids=input_ids, attention_mask=attention_mask, return_dict=True)
        sequence_output = self.dropout(outputs.last_hidden_state)
        logits = self.classifier(sequence_output)

        if labels is not None:
            loss = -self.crf(logits, labels, mask=attention_mask.byte())
            return loss
        else:
            return self.crf.decode(logits, mask=attention_mask.byte())

model = MyLongformerModel(num_labels=22)


In [10]:

model_path = './model/longformer-crf_43_13_0.9842817215780284'
model.load_state_dict(torch.load(model_path, map_location=torch.device('cuda')))




<All keys matched successfully>

In [11]:
model_name = "allenai/longformer-base-4096"
tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=True)

In [12]:
def decode_model_result(model_predict_list, offsets_mapping, labels_type_table):
    """
    Decode the model predictions into a list of labeled segments.

    Parameters:
    - model_predict_list (list): List of predicted labels from the model.
    - offsets_mapping (list): List of offset mappings for the predicted labels.
    - labels_type_table (dict): Dictionary mapping label IDs to label names.

    Returns:
    - list: List of labeled segments, where each segment is represented as (label, start, end).
    """

    id_to_label = {id: label for label, id in labels_type_table.items()}
    predict_y = []
    pre_label_id = 0

    for position_id, label_id in enumerate(model_predict_list):
        if label_id != 0:
            if pre_label_id != label_id:
                start = int(offsets_mapping[position_id][0])
            end = int(offsets_mapping[position_id][1])

        if pre_label_id != label_id and pre_label_id != 0:
            predict_y.append([id_to_label[pre_label_id], start, end])
        pre_label_id = label_id

    if pre_label_id != 0:
        predict_y.append([id_to_label[pre_label_id], start, end])

    return predict_y


def merge_overlapping_predictions(predictions):
    """
    Merge overlapping labeled segments in a list.

    Parameters:
    - predictions (list): List of labeled segments, where each segment is represented as (label, start, end).

    Returns:
    - list: List of merged labeled segments after resolving overlaps.
    """
    if not predictions:
        return []

    sorted_predictions = sorted(predictions, key=lambda x: x[1])

    merged_predictions = [sorted_predictions[0]]
    for current in sorted_predictions[1:]:
        last = merged_predictions[-1]
        if current[0] == last[0] and current[1] <= last[2]:
            merged_predictions[-1] = (last[0], last[1], max(last[2], current[2]))
        else:
            merged_predictions.append(current)

    return merged_predictions


def predict_text_segments(model, tokenizer, text, max_length, overlap, device):
    """
    Predict labeled segments in a given text using the model.

    Parameters:
    - model: The trained model for prediction.
    - tokenizer: Tokenizer for processing the input text.
    - text (str): The input text to be processed.
    - max_length (int): Maximum length of text segments for prediction.
    - overlap (int): Overlapping length between consecutive text segments.
    - device: Device to run the model on.

    Returns:
    - list: List of predicted labeled segments, where each segment is represented as (label, start, end).
    """
    all_predictions = []
    offset = 0

    for i in range(0, len(text), max_length - overlap):
        segment = text[i:i+max_length]
        encodings = tokenizer(segment, padding=True, truncation=True, return_tensors="pt", return_offsets_mapping=True)
        encodings["input_ids"] = encodings["input_ids"].to(device)
        encodings["attention_mask"] = encodings["attention_mask"].to(device)

        with torch.no_grad():
            outputs = model(encodings["input_ids"], encodings["attention_mask"])
            model_predict_list = outputs[0]  
            predictions = decode_model_result(model_predict_list, encodings["offset_mapping"][0], labels_type_table)

        adjusted_predictions = [(label, start+offset, end+offset) for label, start, end in predictions]
        all_predictions.extend(adjusted_predictions)
        offset = i + max_length - overlap

    return all_predictions


In [13]:
def post_processing(label_name, start, end, text_segment):
    """
    Perform post-processing on labeled segments to refine label information.

    Parameters:
    - label_name (str): The predicted label for the segment.
    - start (int): Start position of the labeled segment.
    - end (int): End position of the labeled segment.
    - text_segment (str): The actual text content of the segment.

    Returns:
    - tuple: Processed label, refined start position, refined end position, and updated text content.
    """
    processed_label = label_name.strip()

    if processed_label.endswith('-') or processed_label.endswith('"') or processed_label.endswith("'"):
        processed_label = processed_label[:-1]
        end -= 1

    if processed_label == 'DATE' and text_segment.isdigit() and len(text_segment) > 8:
        end = start + 8
        text_segment = text_segment[:8]

    if processed_label == 'STATE':
        if text_segment.endswith('TAS'):
            text_segment = 'TAS'
            start = end - 3
        elif (len(text_segment) >= 3):
            if text_segment[0].isupper() and text_segment[1].isupper() and text_segment[2].islower():
                if len(text_segment) == 3:
                    text_segment = text_segment[:2]
                    end -= 1
                else:
                    text_segment = text_segment[1:]
                    start += 1

    if processed_label == 'CITY':
        if any(text_segment.endswith(suffix) for suffix in ['ONT', 'LET', 'NET', 'LAT']):
            end -= 1
        elif any(text_segment.endswith(suffix) for suffix in ['RAS', 'CHS', 'LES']):
            end -= 1

    return processed_label, start, end, text_segment


In [14]:
def merge_continuous_time_labels(predictions):
    """
    Merge continuous time labels that are adjacent in predictions.

    Parameters:
    - predictions (list of tuples): List of predictions with each tuple containing label name, start position,
      end position, and predicted text content.

    Returns:
    - list of tuples: Merged predictions where continuous time labels are combined into a single prediction.
    """
    merged_predictions = []
    prev_label = None

    for label_name, start, end, predict_str in predictions:
        if label_name == 'TIME' and prev_label and prev_label['label_name'] == 'TIME':
            if prev_label['end'] + 1 == start:
                prev_label['predict_str'] += ' ' + predict_str
                prev_label['end'] = end
                continue

        if prev_label:
            merged_predictions.append((prev_label['label_name'], prev_label['start'], prev_label['end'], prev_label['predict_str']))

        prev_label = {'label_name': label_name, 'start': start, 'end': end, 'predict_str': predict_str}

    if prev_label:
        merged_predictions.append((prev_label['label_name'], prev_label['start'], prev_label['end'], prev_label['predict_str']))

    return merged_predictions


In [15]:
def predict_for_single_sample(model, tokenizer, sample_id, val_medical_record_dict, device, max_length=4096, overlap=512):
    """
    Predict labels for a single medical record sample.

    Parameters:
    - model (torch.nn.Module): The trained model for making predictions.
    - tokenizer: The tokenizer used for encoding the input text.
    - sample_id (str): Identifier for the medical record sample.
    - val_medical_record_dict (dict): Dictionary containing medical record samples with sample_id as keys and text as values.
    - device: Device (e.g., 'cuda' or 'cpu') on which the model should run.
    - max_length (int): Maximum length for each text segment during prediction.
    - overlap (int): Overlap size between consecutive text segments during prediction.

    Returns:
    - str: String containing the predicted labels in the required format for the given medical record sample.
    """
    output_string = ""
    sample_text = val_medical_record_dict[sample_id]
    predictions = predict_text_segments(model, tokenizer, sample_text, max_length, overlap, device)
    final_predictions = merge_overlapping_predictions(predictions)

    extended_predictions = [(label_name, start, end, sample_text[start:end]) for label_name, start, end in final_predictions]

    merged_predictions = merge_continuous_time_labels(extended_predictions)

    for label_name, start, end, predict_str in merged_predictions:
        label_name, start, end, predict_str = post_processing(label_name, start, end, predict_str)
        sample_result_str = f"{sample_id}\t{label_name}\t{start}\t{end}\t{predict_str}\n"
        output_string += sample_result_str

    return output_string


In [16]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)

In [17]:
# test for one sample
sample_id = "file5124"  
print(predict_for_single_sample(model, tokenizer, sample_id, test_record_dict, device))

  score = torch.where(mask[i].unsqueeze(1), next_score, score)


file5124	DATE	12	20	20150830
file5124	HOSPITAL	36	58	COBRAM DISTRICT HEALTH
file5124	PATIENT	71	76	Corle
file5124	IDNUM	85	95	22N444639B
file5124	MEDICALRECORD	102	109	2254446
file5124	AGE	247	249	79
file5124	DOCTOR	892	902	F Koudelka
file5124	DOCTOR	1362	1374	F Wiltberger
file5124	DOCTOR	1988	1996	F Musich
file5124	DOCTOR	4180	4187	F Comee
file5124	DOCTOR	4570	4582	F Blachowski
file5124	DATE	4638	4644	2/6/72
file5124	DOCTOR	4676	4683	F Itani
file5124	DATE	7618	7626	8/6/2071
file5124	TIME	7888	7907	2846-12-08 00:00:00
file5124	PATIENT	7919	7926	Endsley



In [18]:
def predict_for_entire_dataset(model, tokenizer, val_medical_record_dict, device, max_length=4096, overlap=512):
    """
    Predict labels for an entire dataset of medical record samples.

    Parameters:
    - model (torch.nn.Module): The trained model for making predictions.
    - tokenizer: The tokenizer used for encoding the input text.
    - val_medical_record_dict (dict): Dictionary containing medical record samples with sample_id as keys and text as values.
    - device: Device (e.g., 'cuda' or 'cpu') on which the model should run.
    - max_length (int): Maximum length for each text segment during prediction.
    - overlap (int): Overlap size between consecutive text segments during prediction.

    Returns:
    - str: String containing the predicted labels for the entire dataset in the required format.
    """
    output_string = ""
    for sample_id, sample_text in val_medical_record_dict.items():
        predictions = predict_text_segments(model, tokenizer, sample_text, max_length, overlap, device)
        final_predictions = merge_overlapping_predictions(predictions)

        extended_predictions = [(label_name, start, end, sample_text[start:end]) for label_name, start, end in final_predictions]

        merged_predictions = merge_continuous_time_labels(extended_predictions)

        for label_name, start, end, predict_str in merged_predictions:
            label_name, start, end, predict_str = post_processing(label_name, start, end, predict_str)
            sample_result_str = f"{sample_id}\t{label_name}\t{start}\t{end}\t{predict_str}\n"
            output_string += sample_result_str

    return output_string


In [None]:
output_string = predict_for_entire_dataset(model, tokenizer, test_record_dict, device)

In [None]:
file_text = output_string

In [None]:
predict = {}
for line in file_text.split("\n"):
    if (line):
        sample = line.split("\t") 
        sample[2], sample[3] = int(sample[2]), int(sample[3])

        if sample[0] not in predict:
             predict[sample[0]] = [sample[1:]]
        else:
            predict[sample[0]].append(sample[1:])

In [None]:
def compare_ner(ground_truth, predictions, category):
    """
    Compare predicted Named Entity Recognition (NER) results with ground truth for a specific category.

    Parameters:
    - ground_truth (dict): Ground truth labeled data with document IDs as keys and associated labels.
    - predictions (dict): Predicted labeled data with document IDs as keys and associated labels.
    - category (str): Specific NER category to evaluate.

    Prints:
    - Outputs differences between ground truth and predictions for the specified category.
    - Calculates and prints Precision, Recall, and F1-Score for the specified category.
    """
    def extract_entities(label_dict, category):
        """
        Extract entities of a specific category from a label dictionary.

        Parameters:
        - label_dict (dict): Dictionary containing labels with document IDs as keys and associated labels.
        - category (str): Specific NER category to extract.

        Returns:
        - dict: Dictionary with entities and their corresponding labels.
        """
        entities = {}
        for doc_id, labels in label_dict.items():
            for label in labels:
                if label[0] == category:
                    entities[(doc_id, tuple(label[1:3]))] = label[3]
        return entities

    gt_entities = extract_entities(ground_truth, category)
    pred_entities = extract_entities(predictions, category)

    # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN)
    TP = len([e for e in pred_entities if e in gt_entities and pred_entities[e] == gt_entities[e]])
    FP = len([e for e in pred_entities if e not in gt_entities or pred_entities[e] != gt_entities[e]])
    FN = len([e for e in gt_entities if e not in pred_entities])

    # Print differences
    print(f"Differences in '{category}':")
    for e in pred_entities:
        if e not in gt_entities or pred_entities[e] != gt_entities[e]:
            print(f"Predicted but incorrect or not in ground truth: {e}, Prediction: '{pred_entities[e]}'")

    for e in gt_entities:
        if e not in pred_entities:
            print(f"Missing in predictions: {e}, Ground Truth: '{gt_entities[e]}'")

    # Calculate Precision, Recall, F1
    precision = TP / (TP + FP) if TP + FP > 0 else 0
    recall = TP / (TP + FN) if TP + FN > 0 else 0
    f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

    print(f"\nPrecision: {precision:.4f}, Recall: {recall:.4f}, F1-Score: {f1_score:.4f}")


In [64]:
compare_ner(ground_truth, predict, 'DATE')

Differences in 'DATE':
Predicted but incorrect or not in ground truth: ('file292', (12, 19)), Prediction: '1994073'
Predicted but incorrect or not in ground truth: ('993', (6002, 6007)), Prediction: '333.4'
Predicted but incorrect or not in ground truth: ('735', (5435, 5443)), Prediction: 'original'
Predicted but incorrect or not in ground truth: ('file20783', (729, 738)), Prediction: '4/03/2021'
Predicted but incorrect or not in ground truth: ('650', (6103, 6106)), Prediction: 'Now'
Predicted but incorrect or not in ground truth: ('file18432', (772, 781)), Prediction: '5/06/2021'
Predicted but incorrect or not in ground truth: ('775', (339, 343)), Prediction: '2008'
Predicted but incorrect or not in ground truth: ('893', (555, 565)), Prediction: '13/2/62 LV'
Predicted but incorrect or not in ground truth: ('file9042', (18428, 18432)), Prediction: '2512'
Predicted but incorrect or not in ground truth: ('file5124', (12, 19)), Prediction: '2015083'
Predicted but incorrect or not in groun

In [None]:
def calculate_macro_scores(ground_truth, predictions):
    """
    Calculate Macro Precision, Recall, and F1-Score across different categories for Named Entity Recognition (NER).

    Parameters:
    - ground_truth (dict): Ground truth labeled data with document IDs as keys and associated labels.
    - predictions (dict): Predicted labeled data with document IDs as keys and associated labels.

    Prints:
    - Macro Precision, Recall, and F1-Score.

    Note:
    - Assumes the labeled data format with (doc_id, (start, end), label, text) for each entity.

    Example:
    ```python
    calculate_macro_scores(ground_truth, predictions)
    ```

    The function calculates the Macro Precision, Recall, and F1-Score across different categories for NER.

    """
    def extract_entities(label_dict):
        """
        Extract entities from a label dictionary.

        Parameters:
        - label_dict (dict): Dictionary containing labels with document IDs as keys and associated labels.

        Returns:
        - dict: Dictionary with entities and their corresponding labels.
        """
        entities = {}
        for doc_id, labels in label_dict.items():
            for label in labels:
                key = (doc_id, tuple(label[1:3]), label[0])
                entities[key] = label[3]
        return entities

    gt_entities = extract_entities(ground_truth)
    pred_entities = extract_entities(predictions)

    # Organize entities by category
    categories = set([key[2] for key in gt_entities.keys()] + [key[2] for key in pred_entities.keys()])

    total_precision, total_recall, total_f1 = 0, 0, 0
    for category in categories:
        # Filter entities by category
        gt_cat = {k: v for k, v in gt_entities.items() if k[2] == category}
        pred_cat = {k: v for k, v in pred_entities.items() if k[2] == category}

        # Calculate True Positives (TP), False Positives (FP), and False Negatives (FN)
        TP = len([e for e in pred_cat if e in gt_cat and pred_cat[e] == gt_cat[e]])
        FP = len([e for e in pred_cat if e not in gt_cat or pred_cat[e] != gt_cat[e]])
        FN = len([e for e in gt_cat if e not in pred_cat])

        # Calculate Precision, Recall, F1 for this category
        precision = TP / (TP + FP) if TP + FP > 0 else 0
        recall = TP / (TP + FN) if TP + FN > 0 else 0
        f1_score = 2 * (precision * recall) / (precision + recall) if precision + recall > 0 else 0

        total_precision += precision
        total_recall += recall
        total_f1 += f1_score

    # Calculate Macro Precision, Recall, F1
    num_categories = len(categories)
    macro_precision = total_precision / num_categories
    macro_recall = total_recall / num_categories
    macro_f1 = total_f1 / num_categories

    print(f"Macro Precision: {macro_precision:.4f}, Macro Recall: {macro_recall:.4f}, Macro F1-Score: {macro_f1:.4f}")


In [56]:
calculate_macro_scores(ground_truth, predict)

Macro Precision: 0.8860, Macro Recall: 0.9050, Macro F1-Score: 0.8939


In [26]:
calculate_macro_scores(ground_truth, predict)

Macro Precision: 0.9792, Macro Recall: 0.9821, Macro F1-Score: 0.9805


In [None]:
calculate_macro_scores(ground_truth, predict)