# Make predictions based on document level labels

In [211]:
import pandas as pd
import torch

from torch import nn
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer, AdamW, BertConfig

from HierarchyClassifer import HierarchyClassifier, HierarchyDataset, MentionPooling, ContextTokenizer
from ICD10Encoder import ICD10BiEncoder, EntityMentionTokenizer

# Prerare the Data

In [212]:
# load data 

data = pd.read_csv("Data\split_data.csv")
mapping = pd.read_csv("Data\code_block_unique_mapping.csv")
text_data = pd.read_csv("Data\Final_texts.csv")

data = pd.merge(data, mapping, how='left', left_on='code', right_on='code')

# split test set
test_df = data[data['set'] == 'test']


Index(['Unnamed: 0', 'patient_id', 'code', 'start', 'end', 'text', 'base_code',
       'chapter', 'description', 'set'],
      dtype='object')
(3886, 10)
Index(['code', 'block', 'index', 'block_index'], dtype='object')
Index(['Unnamed: 0', 'patient_id', 'code', 'start', 'end', 'text', 'base_code',
       'chapter', 'description', 'set', 'block', 'index', 'block_index'],
      dtype='object')
(3886, 13)


In [214]:
# extarct labels of each document

def extract_context(text, start_position, end_position, context_size=5):
      # Ensure valid positions
        if start_position < 0 or end_position > len(text) or start_position >= end_position:
            raise ValueError("Invalid mention positions")

        # Extract mention
        mention = text[start_position:end_position]

        # get the left context and tokenize
        left_context = text[:start_position]
        left_context = left_context.split()[-context_size:]

        # get the right context and tokenize
        right_context = text[end_position:]
        right_context = right_context.split()[:context_size]

        return mention, " ".join(left_context), " ".join(right_context)

doc_labels = []

for _, row_texts in text_data.iterrows():
    for _, row_annotations in data.iterrows():
        if row_texts.patient_id == row_annotations.patient_id:
            mention, left_context, right_context = extract_context(row_texts.text, row_annotations.start, row_annotations.end)
            # Add columns to the data DataFrame
            data.at[row_annotations.name, 'left_context'] = left_context
            data.at[row_annotations.name, 'right_context'] = right_context
        else:
            continue

In [None]:
# prepare dataset by adding the document labels

grouped_df = data.groupby('patient_id').agg(list).reset_index()
merged_df = pd.merge(data, grouped_df[['patient_id', 'code']], on='patient_id', how='left', suffixes=('', '_grouped'))

test_df = merged_df[merged_df['set'] == 'test']

In [None]:
# initialize tokenizer 

tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
ct = ContextTokenizer(tokenizer, special_tokens={'additional_special_tokens': ['[Ms]','[Me]']}, max_length=128)

In [219]:
def tokenize_data(df, tokenizer_instance):
    tokenized_data = df.apply(tokenizer_instance.tokenizeWcontext, axis=1)
    return list(tokenized_data)

tokenized_test = tokenize_data(test_df, ct)

In [220]:
def create_special_tokens_mask(input_ids, tokenizer, special_tokens=["[Ms]", "[Me]"]):
    # Convert input_ids to a PyTorch tensor
    input_ids = torch.Tensor(input_ids)

    # Get the token IDs for the special tokens
    special_token_ids = tokenizer.convert_tokens_to_ids(special_tokens)

    # Create a mask indicating the positions of special tokens
    special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for token_id in special_token_ids:
        special_tokens_mask |= (input_ids == token_id)

    return special_tokens_mask


# Create a list to store special tokens masks
test_special_tokens_masks = []

for tokenized_input in tokenized_test:
    input_ids = tokenized_input['input_ids']
    special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)
    test_special_tokens_masks.append(special_tokens_mask)

In [221]:
parent_class_to_index = {class_name: index for index, class_name in enumerate(data['block'].unique())}
child_class_to_index = {class_name: index for index, class_name in enumerate(data['code'].unique())}

test_parent_labels = [parent_class_to_index[label] for label in test_df['block']]
test_child_labels = [child_class_to_index[label] for label in test_df['code']]

test_labels = [[child, parent] for child, parent in zip(test_child_labels, test_parent_labels)]

In [222]:
code_to_description_mapping = dict(zip(data['code'], data['description']))
print(code_to_description_mapping)
print(child_class_to_index)

In [None]:
# exctract code groups

codes = test_df['code_grouped'].tolist()
codes = [set(item) for item in codes]

In [223]:
# create dataloader

test_dataset = HierarchyDataset(tokenized_test, test_labels, test_special_tokens_masks)

batch_size = 32

test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)

# Classifier Evaluation

In [218]:
# initialize classifier model

model = BertModel.from_pretrained('bert-base-multilingual-uncased')

# Resize token embeddings
model.resize_token_embeddings(len(ct.tokenizer))

optimizer = AdamW(model.parameters(), lr=1e-5)
config = BertConfig.from_pretrained('bert-base-multilingual-uncased')
parent_classifier = nn.Linear(config.hidden_size, data['block'].nunique())
child_classifier = nn.Linear(config.hidden_size, data['code'].nunique())
criterion = nn.CrossEntropyLoss()

# set optimizers for the linear layers
optimizer_parent = torch.optim.Adam(parent_classifier.parameters(), lr=1e-3)
optimizer_child = torch.optim.Adam(child_classifier.parameters(), lr=1e-3)

pooling_function = MentionPooling(pool_type='average')

average_pooling_model = HierarchyClassifier(model=model,
                                 optimizer=optimizer,
                                 parent_classifier=parent_classifier,
                                 child_classifier=child_classifier,
                                 optimizer_parent = optimizer_parent,
                                 optimizer_child = optimizer_child,
                                 criterion=criterion)

average_pooling_model.load_model('Hier_CONTEXT_6th\checkpoint_epoch_0.pt')

In [None]:
# evaluate test set
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm
import numpy as np

with torch.no_grad():
    all_true_child_labels = []
    all_pred_child_labels = []
    i = 0

    for mention in tqdm(tokenized_test, desc='Evaluation', leave=True):

        # retrieve candidate concepts
        cc = codes[i]
        candidate_indexes = [child_class_to_index[code] for code in cc]
    
        input_ids = mention['input_ids']
        attention_mask = mention['attention_mask']
        token_type_ids = mention['token_type_ids']
        child_labels = test_labels[i][0]
        special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)

        # Forward pass
        outputs = average_pooling_model.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        pooled_mentions = pooling_function(outputs.last_hidden_state, special_tokens_mask)
        logits_child = average_pooling_model.child_classifier(pooled_mentions)
        
        # Apply softmax to get probabilities
        probabilities = torch.nn.functional.softmax(logits_child, dim=-1)
      
        selected_probs = []
        for x in candidate_indexes:
            selected_probs.append(probabilities[0][x].item())

        zipped_data = list(zip(candidate_indexes, selected_probs))
    
        # Sort based on probabilities
        sorted_data = sorted(zipped_data, key=lambda x: x[1], reverse=True)
    
        # Extract the sorted candidate indexes
        sorted_candidate_indexes = [item[0] for item in sorted_data]
       
        predicted_label = sorted_candidate_indexes[0]

        # Sort candidate concepts based on predicted probabilities
        
        # Update all_true_child_labels and all_pred_child_labels
        all_true_child_labels.append(child_labels)

        all_pred_child_labels.append(predicted_label)

        i += 1

    # Calculate precision, recall, and F1 score
    accuracy_child = accuracy_score(all_true_child_labels, all_pred_child_labels)
    precision_child = precision_score(all_true_child_labels, all_pred_child_labels, average='macro')
    recall_child = recall_score(all_true_child_labels, all_pred_child_labels, average='macro')
    f1_child = f1_score(all_true_child_labels, all_pred_child_labels, average='macro')

    # Print or store evaluation metrics
    print()
    print("Evaluation Metrics")
    print(f'Eval Accuracy: {accuracy_child}')
    print(f'Eval Child Precision: {precision_child}, Eval Child Recall: {recall_child}, Eval Child F1: {f1_child}')

# Bi-encoder Evaluation

In [227]:
# load embeddings dict
import pickle

file_path = 'Data\embeddings_dict.pkl'
with open(file_path, 'rb') as file:
    entity_embeddings_dict = pickle.load(file)

In [228]:
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
emt = EntityMentionTokenizer(tokenizer, max_length=128)

model_mention = BertModel.from_pretrained('bert-base-multilingual-uncased')
model_entity = BertModel.from_pretrained('bert-base-multilingual-uncased')

# Resize token embeddings
model_mention.resize_token_embeddings(len(emt.tokenizer))
model_entity.resize_token_embeddings(len(emt.tokenizer))

# Define separate optimizers for mention and entity transformers
optimizer_mention = AdamW(model_mention.parameters(), lr=1e-5)
optimizer_entity = AdamW(model_entity.parameters(), lr=1e-5)

bi_encoder = ICD10BiEncoder(mention_transformer=model_mention,
                            entity_trasformer=model_entity,
                            optimizer_mention=optimizer_mention,
                            optimizer_entity=optimizer_entity)

In [229]:
# Load the checkpoint
def load_model(model, checkpoint_path):
    # checkpoint_path = "path/to/your/checkpoint.pt"  # Replace with the actual path to your checkpoint
    checkpoint = torch.load(checkpoint_path)

    # Load the model state
    model.entity_transformer.load_state_dict(checkpoint['entity_model_state_dict'])
    model.mention_transformer.load_state_dict(checkpoint['mention_model_state_dict'])

    # If you have optimizers, load their state_dicts as well
    model.entity_optimizer.load_state_dict(checkpoint['entity_optimizer_state_dict'])
    model.mention_optimizer.load_state_dict(checkpoint['mention_optimizer_state_dict'])

    # Other information
    current_epoch = checkpoint['epoch']
    model.train_losses = checkpoint['train_losses']
    model.val_losses = checkpoint['val_losses']
    

load_model(bi_encoder, 'ICD10BiEncoder.pt')

In [278]:
# chnaged it to filter candidate concepts based on the document labels

def generate_candidates(tokenized_mention, bi_encoder, entity_embeddings, pooling_function, document_labels = 0, top_k=50, similarity_metric='dot_product'):
   
    input_ids = tokenized_mention['input_ids']
    attention_mask = tokenized_mention['attention_mask']
    token_type_ids = tokenized_mention['token_type_ids']
    special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)

    with torch.no_grad():
        mention_outputs = bi_encoder.mention_transformer(input_ids, attention_mask, token_type_ids)
        
    mention_embeddings = pooling_function(mention_outputs.last_hidden_state, special_tokens_mask)

    entity_tensors = tuple(entity_embeddings.values())
  
    similarity_scores = torch.matmul(mention_embeddings, torch.squeeze(torch.stack(entity_tensors), dim=1).t())
    
    selected_probabilities = similarity_scores[0, document_labels]

    zipped_data = list(zip(document_labels, selected_probabilities))

    sorted_data = sorted(zipped_data, key=lambda x: x[1], reverse=True)

    highest_k = sorted_data[:top_k]

    return highest_k


In [292]:
# def generate_candidates(tokenized_mention, bi_encoder, entity_embeddings, pooling_function, top_k=50, similarity_metric='dot_product'):
   
#     input_ids = tokenized_mention['input_ids']
#     attention_mask = tokenized_mention['attention_mask']
#     token_type_ids = tokenized_mention['token_type_ids']
#     special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)

#     # print(input_ids)
    
#     # Step 2: Forward Pass
#     with torch.no_grad():
#         mention_outputs = bi_encoder.mention_transformer(input_ids, attention_mask, token_type_ids)


#     mention_embeddings = pooling_function(mention_outputs.last_hidden_state, special_tokens_mask)
#     # mention_embeddings = mention_outputs.last_hidden_state[:, 0, :]  # Assuming 'CLS' pooling

#     # print(mention_embeddings.shape)
#     entity_tensors = tuple(entity_embeddings.values())

#     # get similarity score
#     if similarity_metric == 'dot_product':
#         similarity_scores = torch.matmul(mention_embeddings, torch.squeeze(torch.stack(entity_tensors), dim=1).t())
#         _, top_indices = torch.topk(similarity_scores, top_k, dim=1, largest=True, sorted=True)
#         top_candidates = [list(entity_embeddings.keys())[i.item()] for i in top_indices[0]]
#     elif similarity_metric =='euclidean':
#         similarity_scores = -torch.norm(torch.stack(entity_tensors) - mention_embeddings.unsqueeze(0), dim=2, p=2)
#         _, top_indices = torch.topk(similarity_scores, top_k, dim=0, largest=True, sorted=True)
#         top_candidates = [list(entity_embeddings.keys())[i.item()] for i in top_indices[0]]
#     elif similarity_metric == 'jaccard':
#         similarity_scores = F.cosine_similarity(F.relu(mention_embeddings), F.relu(torch.stack(entity_tensors)), dim=1)
#     elif similarity_metric == 'cosine':
#         similarity_scores = F.cosine_similarity(torch.stack(entity_tensors).squeeze(1), mention_embeddings, dim=1)
#         # print(similarity_scores)
#         _, top_indices = torch.topk(similarity_scores, top_k, largest=True, sorted=True)
#         # print(top_indices.tolist())
#         top_candidates = [list(entity_embeddings.keys())[i] for i in top_indices]
#     # print(top_indices)
    
#     return top_indices #, top_candidates

In [297]:
# evaluation loop

with torch.no_grad():
    all_true_child_labels = []
    all_pred_child_labels = []
    label_pairs = []
    misclassified_instances = []
    correctly_classified_instances = []
    k_indexes = []
    i = 0

    for mention in tqdm(tokenized_test, desc='Evaluation', leave=True):


        document_labels = [child_class_to_index[code] for code in codes[i]]
        child_labels = test_labels[i][0]

        # get dot product distance from the bi encoder embeddings
        predicted_k = generate_candidates(mention, bi_encoder, entity_embeddings_dict, pooling_function, similarity_metric='dot_product', top_k=100)

        predicted_k = predicted_k.tolist()[0]
        
        if child_labels in predicted_k:
            predicted_label = child_labels
            index = predicted_k.index(predicted_label)
            k_indexes.append(index)
    
        all_true_child_labels.append(child_labels)
        all_pred_child_labels.append(predicted_label)

        def find_key_by_value(dictionary, target_value):
            for key, value in dictionary.items():
                if value == target_value:
                    return key
            return None
        
        true_label = find_key_by_value(child_class_to_index, child_labels)
        true_pred = find_key_by_value(child_class_to_index, predicted_label)


        label_pairs.append([true_label, true_pred])

        # check if predicted label is correct
        if predicted_label != child_labels:
            mention_text = ct.tokenizer.decode(mention['input_ids'][0], skip_special_tokens=False)

            # Save misclassified instance details
            misclassified_instances.append({
                'mention_text': mention_text,
                'true_label': true_label,
                'predicted_label': true_pred
            })
        
        elif predicted_label == child_labels:

            mention_text = ct.tokenizer.decode(mention['input_ids'][0], skip_special_tokens=False)

            # Save misclassified instance details
            correctly_classified_instances.append({
                'mention_text': mention_text,
                'true_label': true_label,
                'predicted_label': true_pred
            })


        i += 1

    # Calculate metrics
    accuracy_child = accuracy_score(all_true_child_labels, all_pred_child_labels)
    precision_child = precision_score(all_true_child_labels, all_pred_child_labels, average='macro')
    recall_child = recall_score(all_true_child_labels, all_pred_child_labels, average='macro')
    f1_child = f1_score(all_true_child_labels, all_pred_child_labels, average='macro')

    print()
    print("Evaluation Metrics")
    print(f'Eval Accuracy: {accuracy_child}')
    print(f'Eval Child Precision: {precision_child}, Eval Child Recall: {recall_child}, Eval Child F1: {f1_child}')


Evaluation:  78%|███████▊  | 604/777 [03:16<00:55,  3.14it/s]

# Misclassification Analysis

In [269]:
misclassified_instances

[{'mention_text': '[CLS] ενζυμων. ο υπερηχοκαρδιογραφικος ελεγχος ανεδειξε [Ms] μη διατηρημενο κλασμα εξωθησεως [Me] ( 35 - 40 % ). διενεργηθηκε στεφανιογραφικος ελεγχος ο [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD]',
  'true_label': 'I50.9',
  'predicted_label': 'Z95'},
 {'mention_text': '[CLS] υπερταση, σακχαρωδης διαβητης, [Ms] αγχωδης συνδρομη [Me], εκτακτοσυστολικη αρρυθμια πορεια νοσου [SEP] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [PAD] [

In [240]:
misclassified_true_labels = [entry['true_label'] for entry in misclassified_instances]

train_df = data[data['set'] == 'train']
training_labels = train_df['code'].tolist()

seen_misclassified = [label for label in misclassified_true_labels if label not in training_labels]

num_seen_misclassified = len(seen_misclassified)

print("Number of misclassified true labels not present in the training set: ", num_seen_misclassified)
print("Misclassified true labels not seen in training set: ", seen_misclassified)


Number of misclassified true labels not present in the training set: 7
Misclassified true labels not seen in training set: ['I31.8', 'Z96.6', 'I42', 'I65', 'I35', 'I70', 'Z86.7']


In [273]:
# see if the model correctly classified any unseen classes

test_labels_list = test_df['code'].to_list()
correctly_classified_true_labels = [entry['true_label'] for entry in correctly_classified_instances]

unseen_labels = set(test_labels_list) ^ set(training_labels)

correct_unseen = [label for label in correctly_classified_true_labels if label in unseen_labels]

print("Number of correctly classified true labels not present in the training set: ", len(correct_unseen))
print("Correctly classifie labels not seen in training set: ", correct_unseen)


15
['I35.8', 'I31.8', 'Z96.6', 'I42', 'I65', 'D86.8', 'M25.5', 'N17', 'I51.4', 'I70.1', 'I35', 'D68.5', 'I44.6', 'I70', 'Z86.7']
