In [6]:
import pickle
import pandas as pd
from tqdm import tqdm

import torch
from torch import nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
from transformers import AdamW, BertConfig, BertModel, BertTokenizer

from sklearn.metrics import accuracy_score, f1_score, precision_score, recall_score

from ICD10Encoder import EntityMentionTokenizer, ICD10BiEncoder
from HierarchyClassifer import ContextTokenizer, HierarchyClassifier, HierarchyDataset, MentionPooling

In [2]:
# load dataset
data = pd.read_csv("Data\split_data.csv")
mapping = pd.read_csv("Data\code_block_unique_mapping.csv")
text_data = pd.read_csv("Data\Final_texts.csv")

data = pd.merge(data, mapping, how='left', left_on='code', right_on='code')
mentions = data['text']
descriptions = data['description']

# load code-embeddings dict extartced from the trained bi encoder 
file_path = 'Data\embeddings_dict.pkl'
with open(file_path, 'rb') as file:
    entity_embeddings_dict = pickle.load(file)

# Initialize classifier and bi-encoder models

In [3]:
# initialize Hierarchy model for reranking

model = BertModel.from_pretrained('bert-base-multilingual-uncased')

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
ct = ContextTokenizer(tokenizer, special_tokens={'additional_special_tokens': ['[Ms]','[Me]']}, max_length=128)

# Resize token embeddings
model.resize_token_embeddings(len(ct.tokenizer))

optimizer = AdamW(model.parameters(), lr=1e-5)
config = BertConfig.from_pretrained('bert-base-multilingual-uncased')
parent_classifier = nn.Linear(config.hidden_size, data['block'].nunique())
child_classifier = nn.Linear(config.hidden_size, data['code'].nunique())
criterion = nn.CrossEntropyLoss()

# set optimizers for the linear layers
optimizer_parent = torch.optim.Adam(parent_classifier.parameters(), lr=1e-3)
optimizer_child = torch.optim.Adam(child_classifier.parameters(), lr=1e-3)

pooling_function = MentionPooling(pool_type='average')

average_pooling_model = HierarchyClassifier(model=model,
                                 optimizer=optimizer,
                                 parent_classifier=parent_classifier,
                                 child_classifier=child_classifier,
                                 optimizer_parent = optimizer_parent,
                                 optimizer_child = optimizer_child,
                                 criterion=criterion)


In [4]:
average_pooling_model.load_model('Hier_CONTEXT_6th\checkpoint_epoch_0.pt')

In [8]:
# initialize bi-ecoder model

max_length = 128

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
emt = EntityMentionTokenizer(tokenizer, max_length=max_length)

model_mention = BertModel.from_pretrained('bert-base-multilingual-uncased')
model_entity = BertModel.from_pretrained('bert-base-multilingual-uncased')

# Resize token embeddings
model_mention.resize_token_embeddings(len(emt.tokenizer))
model_entity.resize_token_embeddings(len(emt.tokenizer))

# Define separate optimizers for mention and entity transformers
optimizer_mention = AdamW(model_mention.parameters(), lr=1e-5)
optimizer_entity = AdamW(model_entity.parameters(), lr=1e-5)

bi_encoder = ICD10BiEncoder(mention_transformer=model_mention,
                            entity_trasformer=model_entity,
                            optimizer_mention=optimizer_mention,
                            optimizer_entity=optimizer_entity)

Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertModel: ['cls.predictions.transform.dense.bias', 'cls.seq_relationship.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.seq_relationship.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.LayerNorm.bias']
- This IS expected if you are initializing BertModel from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertModel from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of the model checkpoint at bert-base-multilingual-uncased were not used when initializing BertModel: ['cls.predictions.transform.

In [10]:
# Load the checkpoint
def load_model(model, checkpoint_path):
    checkpoint = torch.load(checkpoint_path)

    # Load the model state
    model.entity_transformer.load_state_dict(checkpoint['entity_model_state_dict'])
    model.mention_transformer.load_state_dict(checkpoint['mention_model_state_dict'])
    model.entity_optimizer.load_state_dict(checkpoint['entity_optimizer_state_dict'])
    model.mention_optimizer.load_state_dict(checkpoint['mention_optimizer_state_dict'])

    model.train_losses = checkpoint['train_losses']
    model.val_losses = checkpoint['val_losses']
    

load_model(bi_encoder, 'ICD10BiEncoder.pt')

# Prepare the test data

In [11]:
# prepare data

def extract_context(text, start_position, end_position, context_size=5):
      # Ensure valid positions
        if start_position < 0 or end_position > len(text) or start_position >= end_position:
            raise ValueError("Invalid mention positions")

        # Extract mention
        mention = text[start_position:end_position]

        # get the left context and tokenize
        left_context = text[:start_position]
        left_context = left_context.split()[-context_size:]

        # get the right context and tokenize
        right_context = text[end_position:]
        right_context = right_context.split()[:context_size]

        return mention, " ".join(left_context), " ".join(right_context)

for _, row_texts in text_data.iterrows():
    for _, row_annotations in data.iterrows():
        if row_texts.patient_id == row_annotations.patient_id:
            mention, left_context, right_context = extract_context(row_texts.text, row_annotations.start, row_annotations.end)
            # Add columns to the data DataFrame
            data.at[row_annotations.name, 'left_context'] = left_context
            data.at[row_annotations.name, 'right_context'] = right_context
        else:
             continue

In [12]:
test_df = data[data['set'] == 'test']

In [13]:
# tokenize data

def tokenize_data(df, tokenizer_instance):
    tokenized_data = df.apply(tokenizer_instance.tokenizeWcontext, axis=1)
    return list(tokenized_data)

tokenized_test = tokenize_data(test_df, ct)

In [14]:
def create_special_tokens_mask(input_ids, tokenizer, special_tokens=["[Ms]", "[Me]"]):
    # Convert input_ids to a PyTorch tensor
    input_ids = torch.Tensor(input_ids)

    # Get the token IDs for the special tokens
    special_token_ids = tokenizer.convert_tokens_to_ids(special_tokens)

    # Create a mask indicating the positions of special tokens
    special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for token_id in special_token_ids:
        special_tokens_mask |= (input_ids == token_id)

    return special_tokens_mask


# Create a list to store special tokens masks
test_special_tokens_masks = []

for tokenized_input in tokenized_test:
    input_ids = tokenized_input['input_ids']
    special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)
    test_special_tokens_masks.append(special_tokens_mask)

In [15]:
# extract hierarchical test labels

parent_class_to_index = {class_name: index for index, class_name in enumerate(data['block'].unique())}
child_class_to_index = {class_name: index for index, class_name in enumerate(data['code'].unique())}

test_parent_labels = [parent_class_to_index[label] for label in test_df['block']]
test_child_labels = [child_class_to_index[label] for label in test_df['code']]

test_labels = [[child, parent] for child, parent in zip(test_child_labels, test_parent_labels)]

In [376]:
# create mapping to retrieve the labels in text

code_to_description_mapping = dict(zip(data['code'], data['description']))
print(code_to_description_mapping)
print(child_class_to_index)

for key, value in code_to_description_mapping.items():
    print(key,  ";", value)

{'R06.0': 'Δύσπνοια', 'I48.9': 'Κολπική µαρµαρυγή και πτερυγισµός, µη καθορισµένα', 'I50.9': 'Καρδιακή ανεπάρκεια, μη καθορισμένη', 'I48.2': 'Κολπική µαρµαρυγή, µόνιµη [χρόνια]', 'I10': 'Ιδιοπαθής (πρωτοπαθής) υπέρταση', 'E11': 'Μη ινσουλινοεξαρτώμενος σακχαρώδης διαβήτης', 'I51.7': 'Καρδιομεγαλία', 'I34.0': 'Ανεπάρκεια μιτροειδούς βαλβίδας', 'I35.1': 'Ανεπάρκεια αορτικής (βαλβίδας)', 'I07.1': 'Ανεπάρκεια της τριγλώχινας', 'R07': 'Πόνος στο λαιμό και το στήθος (θώρακα)', 'I21.4': 'Οξύ υπενδοκαρδιακό έμφραγμα του μυοκαρδίου', 'Z95.5': 'Παρουσία εμφυτεύματος και μοσχεύματος στεφανιαίας αγγειοπλαστικής', 'I79.0': 'Ανεύρυσμα της αορτής σε παθήσεις που ταξινομούνται αλλού', 'I40.9': 'Οξεία μυοκαρδίτιδα, μη καθορισμένη', 'R07.4': 'Θωρακικός πόνος, μη καθορισμένος', 'E03': 'Άλλες μορφές υποθυρεοειδισμού', 'I25': 'Χρόνια ισχαιμική καρδιοπάθεια', 'E11.8': 'Άλλες συγκεκριμένες μορφές σακχαρώδους διαβήτη με διάφορες επιπλοκές', 'I35.0': 'Στένωση αορτικής (βαλβίδας)', 'E87.1': 'Υποωσμωτικότητα και

In [17]:
# create dataloaders

test_dataset = HierarchyDataset(tokenized_test, test_labels, test_special_tokens_masks)

batch_size = 32

test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Candidate Generation Eval

In [305]:
def generate_candidates(tokenized_mention, bi_encoder, entity_embeddings, pooling_function, top_k=50, similarity_metric='dot_product'):

    # prepare mention input
    input_ids = tokenized_mention['input_ids']
    attention_mask = tokenized_mention['attention_mask']
    token_type_ids = tokenized_mention['token_type_ids']
    special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)

    # forward pass from bi encoder to get mention embedding
    with torch.no_grad():
        mention_outputs = bi_encoder.mention_transformer(input_ids, attention_mask, token_type_ids)

    mention_embeddings = pooling_function(mention_outputs.last_hidden_state, special_tokens_mask)

    entity_tensors = tuple(entity_embeddings.values())

    # compute similarity metrics
    if similarity_metric == 'dot_product':
        similarity_scores = torch.matmul(mention_embeddings, torch.squeeze(torch.stack(entity_tensors), dim=1).t())
        _, top_indices = torch.topk(similarity_scores, top_k, dim=1, largest=True, sorted=True)
        top_candidates = [list(entity_embeddings.keys())[i.item()] for i in top_indices[0]]

    elif similarity_metric =='euclidean':
        similarity_scores = -torch.norm(torch.stack(entity_tensors) - mention_embeddings.unsqueeze(0), dim=2, p=2)
        _, top_indices = torch.topk(similarity_scores, top_k, dim=0, largest=True, sorted=True)
        top_candidates = [list(entity_embeddings.keys())[i.item()] for i in top_indices[0]]

    elif similarity_metric == 'jaccard':
        similarity_scores = F.cosine_similarity(F.relu(mention_embeddings), F.relu(torch.stack(entity_tensors)), dim=1)
        
    elif similarity_metric == 'cosine':
        similarity_scores = F.cosine_similarity(torch.stack(entity_tensors).squeeze(1), mention_embeddings, dim=1)
        _, top_indices = torch.topk(similarity_scores, top_k, largest=True, sorted=True)
        top_candidates = [list(entity_embeddings.keys())[i] for i in top_indices]
    
    return top_indices, top_candidates


# generate_candidates("Υπέρταση", bi_encoder, entity_loader, pooling_function)

In [354]:
# evaluation loop

with torch.no_grad():
    all_true_child_labels = []
    all_pred_child_labels = []
    label_pairs = []
    misclassified_instances = []
    i = 0

    for mention in tqdm(tokenized_test, desc='Evaluation', leave=True):

        # Sretrieve candidate concepts
        candidate_indexes, labels = generate_candidates(mention, bi_encoder, entity_embeddings_dict, pooling_function, similarity_metric='dot_product', top_k=1)

        # prepare mention inputs
        input_ids = mention['input_ids']
        attention_mask = mention['attention_mask']
        token_type_ids = mention['token_type_ids']
        child_labels = test_labels[i][0]
        special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)

        # forward pass
        outputs = average_pooling_model.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)
        pooled_mentions = pooling_function(outputs.last_hidden_state, special_tokens_mask)
        logits_child = average_pooling_model.child_classifier(pooled_mentions)

        # get probabilities
        probabilities = torch.nn.functional.softmax(logits_child, dim=-1)

        # get candidate set probs based on the indices
        selected_probs = []
        for x in candidate_indexes[0]:
            x = x.item()
            selected_probs.append(probabilities[0][x].item())

        # zip the candidate indexes with their corresponding probabilities
        candidate_indexes = [x.item() for x in candidate_indexes[0]]
        zipped_data = list(zip(candidate_indexes, selected_probs))

        # sort based on probabilities
        sorted_data = sorted(zipped_data, key=lambda x: x[1], reverse=True)

        # extract the sorted candidate indexes
        sorted_candidate_indexes = [item[0] for item in sorted_data]

        # predicted label is the top candidate
        predicted_label = sorted_candidate_indexes[0]
    
        all_true_child_labels.append(child_labels)
        all_pred_child_labels.append(predicted_label)

        def find_key_by_value(dictionary, target_value):
            for key, value in dictionary.items():
                if value == target_value:
                    return key
            return None
        
        true_label = find_key_by_value(child_class_to_index, child_labels)
        true_pred = find_key_by_value(child_class_to_index, predicted_label)

        label_pairs.append([true_label, true_pred])

        # check if predicted label is correct
        if predicted_label != child_labels:
            # extract the mention text
            mention_text = ct.tokenizer.decode(mention['input_ids'][0], skip_special_tokens=False)

            # save misclassified instance details
            misclassified_instances.append({
                'mention_text': mention_text,
                'true_label': true_label,
                'predicted_label': true_pred
            })


        i += 1

    # calculate metrics
    accuracy_child = accuracy_score(all_true_child_labels, all_pred_child_labels)
    precision_child = precision_score(all_true_child_labels, all_pred_child_labels, average='macro')
    recall_child = recall_score(all_true_child_labels, all_pred_child_labels, average='macro')
    f1_child = f1_score(all_true_child_labels, all_pred_child_labels, average='macro')

    print()
    print("Evaluation Metrics")
    print(f'Eval Accuracy: {accuracy_child}')
    print(f'Eval Child Precision: {precision_child}, Eval Child Recall: {recall_child}, Eval Child F1: {f1_child}')


Evaluation: 100%|██████████| 777/777 [08:25<00:00,  1.54it/s]


Evaluation Metrics
Eval Accuracy: 0.7078507078507078
Eval Child Precision: 0.46655272833474687, Eval Child Recall: 0.4814782171973313, Eval Child F1: 0.44431054176565854



  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
