In [1]:
# imports

import pandas as pd
import torch

from torch import nn
from torch.utils.data import DataLoader
from transformers import BertModel, BertTokenizer, AdamW, BertConfig
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from tqdm import tqdm

from HierarchyClassifer import HierarchyClassifier, HierarchyDataset, MentionPooling, ContextTokenizer

  from .autonotebook import tqdm as notebook_tqdm


# Data Preparation

In [2]:
# load data

# load annotations
data = pd.read_csv("Data\split_data.csv")

# load code-block mapping
mapping = pd.read_csv("Data\code_block_unique_mapping.csv")

# load the full texts to extract the context
text_data = pd.read_csv("Data\Final_texts.csv")

# incorporate the code-block mapping into the dataset
data = pd.merge(data, mapping, how='left', left_on='code', right_on='code')

In [3]:
# add context columns

def extract_context(text, start_position, end_position, context_size=5):
      # Ensure valid positions
        if start_position < 0 or end_position > len(text) or start_position >= end_position:
            raise ValueError("Invalid mention positions")

        # extract mention
        mention = text[start_position:end_position]

        # get the left context and tokenize
        left_context = text[:start_position]
        left_context = left_context.split()[-context_size:]

        # get the right context and tokenize
        right_context = text[end_position:]
        right_context = right_context.split()[:context_size]

        return mention, " ".join(left_context), " ".join(right_context)


# itarate data and extract the context for each mention
for _, row_texts in text_data.iterrows():
    for _, row_annotations in data.iterrows():
        if row_texts.patient_id == row_annotations.patient_id:
            mention, left_context, right_context = extract_context(row_texts.text, row_annotations.start, row_annotations.end)
            # Add columns to the data DataFrame
            data.at[row_annotations.name, 'left_context'] = left_context
            data.at[row_annotations.name, 'right_context'] = right_context
        else:
             continue

In [4]:
# set splits 

train_df = data[data['set'] == 'train']
test_df = data[data['set'] == 'test']
val_df = data[data['set'] == 'validation']

In [5]:
# tokenize data

max_length = 128

# initialize custom tokenizer from Bert pre-trained
tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-uncased')
ct = ContextTokenizer(tokenizer, special_tokens={'additional_special_tokens': ['[Ms]','[Me]']}, max_length=max_length)

def tokenize_data(df, tokenizer_instance):
    tokenized_data = df.apply(tokenizer_instance.tokenizeWcontext, axis=1)
    return list(tokenized_data)

tokenized_train = tokenize_data(train_df, ct)
tokenized_test = tokenize_data(test_df, ct)
tokenized_val = tokenize_data(val_df, ct)

In [6]:
# special tokens mask created to locate the positions of the special tokens to used for the pooling function

def create_special_tokens_mask(input_ids, tokenizer, special_tokens=["[Ms]", "[Me]"]):
    # convert input_ids to a PyTorch tensor
    input_ids = torch.Tensor(input_ids)

    # get the token IDs for the special tokens
    special_token_ids = tokenizer.convert_tokens_to_ids(special_tokens)

    # create a mask indicating the positions of special tokens
    special_tokens_mask = torch.zeros_like(input_ids, dtype=torch.bool)

    for token_id in special_token_ids:
        special_tokens_mask |= (input_ids == token_id)

    return special_tokens_mask


# Create a list to store special tokens masks
train_special_tokens_masks = []
test_special_tokens_masks = []
val_special_tokens_masks = []

for tokenized_input in tokenized_train:
    input_ids = tokenized_input['input_ids']
    special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)
    train_special_tokens_masks.append(special_tokens_mask)

for tokenized_input in tokenized_test:
    input_ids = tokenized_input['input_ids']
    special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)
    test_special_tokens_masks.append(special_tokens_mask)

for tokenized_input in tokenized_val:
    input_ids = tokenized_input['input_ids']
    special_tokens_mask = create_special_tokens_mask(input_ids, ct.tokenizer)
    val_special_tokens_masks.append(special_tokens_mask)

In [7]:
# map the labels to integer ranges and create labelsets

parent_class_to_index = {class_name: index for index, class_name in enumerate(data['block'].unique())}
child_class_to_index = {class_name: index for index, class_name in enumerate(data['code'].unique())}

# Create lists of indices for parent and child labels
train_parent_labels = [parent_class_to_index[label] for label in train_df['block']]
train_child_labels = [child_class_to_index[label] for label in train_df['code']]

test_parent_labels = [parent_class_to_index[label] for label in test_df['block']]
test_child_labels = [child_class_to_index[label] for label in test_df['code']]

val_parent_labels = [parent_class_to_index[label] for label in val_df['block']]
val_child_labels = [child_class_to_index[label] for label in val_df['code']]

train_labels = [[child, parent] for child, parent in zip(train_child_labels, train_parent_labels)]
val_labels = [[child, parent] for child, parent in zip(val_child_labels, val_parent_labels)]
test_labels = [[child, parent] for child, parent in zip(test_child_labels, test_parent_labels)]

In [8]:
train_dataset = HierarchyDataset(tokenized_train, train_labels, train_special_tokens_masks)
test_dataset = HierarchyDataset(tokenized_test, test_labels, test_special_tokens_masks)
val_dataset = HierarchyDataset(tokenized_val, val_labels, val_special_tokens_masks)

# Model training

In [9]:
# Create DataLoaders

batch_size = 32

train_loader = DataLoader(dataset=train_dataset, batch_size=batch_size, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=batch_size, shuffle=False)
val_loader = DataLoader(dataset=val_dataset, batch_size=batch_size, shuffle=False)

In [10]:
model = BertModel.from_pretrained('bert-base-multilingual-uncased')

# Resize token embeddings
model.resize_token_embeddings(len(ct.tokenizer))

optimizer = AdamW(model.parameters(), lr=1e-5, weight_decay=0.0)
config = BertConfig.from_pretrained('bert-base-multilingual-uncased')
parent_classifier = nn.Linear(config.hidden_size, data['block'].nunique())
child_classifier = nn.Linear(config.hidden_size, data['code'].nunique())
criterion = nn.CrossEntropyLoss()

# set optimizers for the linear layers
optimizer_parent = torch.optim.Adam(parent_classifier.parameters(), lr=1e-3)
optimizer_child = torch.optim.Adam(child_classifier.parameters(), lr=1e-3)

average_pooling_model = HierarchyClassifier(model=model,
                                 optimizer=optimizer,
                                 parent_classifier=parent_classifier,
                                 child_classifier=child_classifier,
                                 optimizer_parent = optimizer_parent,
                                 optimizer_child = optimizer_child,
                                 criterion=criterion)

In [14]:
pooling_function = MentionPooling(pool_type='average')

average_pooling_model.train(train_dataloader=train_loader,
                      val_dataloader=val_loader,
                      num_epochs=6,
                      foldername='prev_model',
                      paren_weight= 0.9,
                      train_parent=True,
                      pooling_function=pooling_function)

In [12]:
# load model

# average_pooling_model.load_model('prev_model\checkpoint_epoch_5.pt')

## Evaluation

In [15]:
# evaluation
device = 'cpu'

average_pooling_model.model.eval()
total_loss = 0.0
total_child = 0.0
total_child_correct = 0.0
total_child_samples = 0.0
all_true_child_labels = []
all_pred_child_labels = []

all_true_parent_labels = []
all_pred_parent_labels = []

with torch.no_grad():
    for batch in tqdm(test_loader, desc='Evaluation', leave=True):
        # Extract tensors from the batch dictionary
        input_ids = batch['input_ids'].to(device).squeeze(1)
        attention_mask = batch['attention_mask'].to(device).squeeze(1)
        token_type_ids = batch['token_type_ids'].to(device).squeeze(1)
        special_token_masks = batch['special_token_mask']
        parent_labels = batch['parent_label'].to(device)
        child_labels = batch['child_label'].to(device)

        # Forward pass
        outputs = average_pooling_model.model(input_ids=input_ids, attention_mask=attention_mask, token_type_ids=token_type_ids)

        logits_parent = average_pooling_model.parent_classifier(outputs.pooler_output)
        pooled_mentions = pooling_function(outputs.last_hidden_state, special_token_masks)
        logits_child = average_pooling_model.child_classifier(pooled_mentions)
        
        # Calculate losses
        loss_child = average_pooling_model.criterion(logits_child, child_labels)

        loss = loss_child
        total_loss += loss.item()
        total_child += loss_child

        _, predicted_parent = torch.max(logits_parent, 1)
        _, predicted_child = torch.max(logits_child, 1)

        total_child_correct += (predicted_child == child_labels).sum().item()

        total_child_samples += child_labels.size(0)

        # Store true and predicted labels for precision, recall, and F1 score
        all_true_child_labels.extend(child_labels.cpu().numpy())
        all_pred_child_labels.extend(predicted_child.cpu().numpy())

        all_true_parent_labels.extend(parent_labels.cpu().numpy())
        all_pred_parent_labels.extend(predicted_parent.cpu().numpy())

    # Calculate average loss and accuracy
    average_loss = total_loss / len(test_df)
    accuracy_child = accuracy_score(all_true_child_labels, all_pred_child_labels)
    accuracy_parent = accuracy_score(all_true_parent_labels, all_pred_parent_labels)

    print()
    print("Evaluation metrics")
    print(f'Eval Avg Loss: {average_loss}, Child Accuracy: {accuracy_child}, Parent Accuracy {accuracy_parent}')
    print()

    def calc_metrics(average):

        # Calculate precision, recall, and F1 score
        precision_child = precision_score(all_true_child_labels, all_pred_child_labels, average=average)
        recall_child = recall_score(all_true_child_labels, all_pred_child_labels, average=average)
        f1_child = f1_score(all_true_child_labels, all_pred_child_labels, average=average)
        
        precision_parent = precision_score(all_true_parent_labels, all_pred_parent_labels, average=average)
        recall_parent = recall_score(all_true_parent_labels, all_pred_parent_labels, average=average)
        f1_parent = f1_score(all_true_parent_labels, all_pred_parent_labels, average=average)

        print(f'{average.capitalize()} Average')
        print(f'Eval Child Precision: {precision_child}, Eval Child Recall: {recall_child}, Eval Child F1: {f1_child}')
        print(f'Eval Parent Precision: {precision_parent}, Eval Parent Recall: {recall_parent}, Eval Parent F1: {f1_parent}')
        print()

    calc_metrics('macro')
    calc_metrics('micro')
    calc_metrics('weighted')
   

Evaluation: 100%|██████████| 25/25 [03:10<00:00,  7.61s/it]

Evaluation metrics
Eval Avg Loss: 0.025794467602303896, Child Accuracy: 0.8108108108108109, Parent Accuracy 0.9317889317889317
Macro, Average
Eval Child Precision: 0.5383369892361333, Eval Child Recall: 0.5448285247444841, Eval Child F1: 0.534431898223665
Eval Parent Precision: 0.7211277173913043, Eval Parent Recall: 0.6863920673877083, Eval Parent F1: 0.6953271586474395

Micro, Average
Eval Child Precision: 0.8108108108108109, Eval Child Recall: 0.8108108108108109, Eval Child F1: 0.8108108108108109
Eval Parent Precision: 0.9317889317889317, Eval Parent Recall: 0.9317889317889317, Eval Parent F1: 0.9317889317889317

Weighted, Average
Eval Child Precision: 0.7896622711968275, Eval Child Recall: 0.8108108108108109, Eval Child F1: 0.7948467502235301
Eval Parent Precision: 0.9242241187505161, Eval Parent Recall: 0.9317889317889317, Eval Parent F1: 0.9260974172882394




  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
  _warn_prf(average, modifier, msg_start, len(result))
