# Identifying diseases in medical data using BERT
In this notebook, we use IBM'S Watson NLP library to extract mentions to diseaes from the NCBI Disease dataset, hosted on [Hugging Face's repository](https://huggingface.co/datasets/ncbi_disease).

## 1. Imports

In [None]:
import os
import datasets
import pandas as pd
import json

import watson_nlp
import watson_nlp.data_model as dm
from watson_nlp.toolkit import entity_mentions_utils

## 2. Read and preprocess data
### 2.1. Load data from Huggingface

In [None]:
# Load dataset from Huggingface repository
ncbi = datasets.load_dataset('ncbi_disease')

# Get train/test/val slices into dataframes
df_train = pd.DataFrame(ncbi['train'])
print('[+] Training set shape: {}'.format(df_train.shape))
df_test = pd.DataFrame(ncbi['test'])
print('[+] Testing set shape: {}'.format(df_test.shape))
df_val = pd.DataFrame(ncbi['validation'])
print('[+] Validation set shape: {}'.format(df_val.shape))

list_df = [df_train, df_test, df_val]

# Inspect
df_train.head()


### 2.2. Reshape data so that it can be used by Watson NLP

The data does not come in a shape that's suitable for Watson NLP. We would like out data to be shaped in a json format such that each sentence is associated with a set of mentions (see cell below). We'll take several steps to reformat it.

In [None]:
help(watson_nlp.toolkit.entity_mentions_utils.prepare_train_from_json)

In [None]:
df_train.tail()

#### 2.2.1. Extract mentions from tokens column based on ner_tags column


In [None]:
# Define function
def get_entities_from_token_list(list1, list2):
    '''
    Returns a list of labeled entities from a list of tokens and a list of ner tags. Params:
    - list1: list of tokens
    - list2: list of ner tags, taking values 0, 1 or 2
    '''
    
    entities = list()
    entity = ''
    for token in range(len(list1)):
        if list2[token] != 0:
            if list2[token] == 1:
                entity = list1[token]
            elif list2[token] == 2:
                entity = entity + ' ' + list1[token]
            else:
                '[-] NER indexing error!'
        elif entity != '':
            # When we no longer have an entity, we can append to entities list 
            entities.append(entity)
            entity = ''

    return entities

In [None]:
# Apply function to all 3 dataframes
for _df in list_df:
    _df['entities'] = [get_entities_from_token_list(_df['tokens'][i], _df['ner_tags'][i]) for i in range(len(_df))]
df_train.head()

#### 2.2.2. Transform tokens list into just a string


In [None]:
# Define function
def transform_tokens_into_sentence(list1):
    '''
    Returns a single string composed of all tokens in a list
    '''
    return " ".join(list1)

In [None]:
# Apply function
for _df in list_df:
    _df['sentence'] = [transform_tokens_into_sentence(_df['tokens'][i]) for i in range(len(_df))]

df_val

#### 2.2.3. Identify position of entities in sentences (beginning and end), and create json-like structures


In [None]:
# Define function - Get the position in the sentence where each entity begins and ends
def get_entities_position(sentence:str, entities:list):
    '''
    For a given sentence containing a list of entities, it returns a list of dictionaries with the entity, its type (hardcoded to Disease) and its beginning and end. 
    - sentence: a sentence in string format
    - entities: a list of entities that appear in sentence
    '''

    mentions = []
    _sent = sentence
    check = 0

    for i in range(len(entities)):

        _dict = {
            'text':entities[i],
            'type':"Disease",
            'location':{}
        }

        b = _sent.find(entities[i]) + check
        e = b + len(entities[i])

        assert sentence[b:e] == entities[i]

        _sent = sentence[e:] # This resets the string so that the find method starts looking after the last entity. Otherwise, if the same entity is repeated, it will always yield the first occurrence
        check = len(sentence[:e])

        _dict['location']['begin'] = b
        _dict['location']['end'] = e

        mentions.append(_dict)

    return mentions


# Define function - Produce a json-like structure with all entities' positions
def build_json_structure(_df):
    '''
    Returns a json-like structure (a list of dictionaries) following the structure needed by a Watson NLU DataStream
    '''

    _list = []

    for i in _df.index:
        _dict = {
            'id':i,
            'text': _df['sentence'][i],
        }

        _dict['mentions'] = get_entities_position(_df['sentence'][i], _df['entities'][i])

        _list.append(_dict)

    return _list

In [None]:
# Get our json-like structures
train_list = build_json_structure(df_train)
test_list = build_json_structure(df_test)
val_list = build_json_structure(df_val)

In [None]:
# Save data as json
out_file = open("./Data/train_set.json", "w")     
json.dump(train_list, out_file, indent = 4)    
out_file.close() 

out_file = open("./Data/test_set.json", "w")     
json.dump(test_list, out_file, indent = 4)    
out_file.close() 

out_file = open("./Data/val_set.json", "w")     
json.dump(val_list, out_file, indent = 4)    
out_file.close() 

#### 2.2.4. Convert data to Watson NLP data streams

In [None]:
# Download En syntax model
syntax_model = watson_nlp.load(watson_nlp.download('syntax_izumo_en_stock'))

In [None]:
# Convert the entity labeled data in standard format to IOB streams
train_labeled_data_stream = dm.DataStream.from_iterable(train_list)
train_iob_stream = entity_mentions_utils.prepare_train_from_json(train_labeled_data_stream, syntax_model)

val_labeled_data_stream = dm.DataStream.from_iterable(val_list)
val_iob_stream = entity_mentions_utils.prepare_train_from_json(val_labeled_data_stream, syntax_model)

test_labeled_data_stream = dm.DataStream.from_iterable(test_list)
test_iob_stream = entity_mentions_utils.prepare_train_from_json(test_labeled_data_stream, syntax_model)

## 3. Train model

### 3.1. Load pretrained model - BERT in this case

In [None]:
pretrained_model_resource = watson_nlp.load(watson_nlp.download('pretrained-model_bert_multi_bert_multi_cased'))

### 3.2. Train 

See method arguments below

In [None]:
help(watson_nlp.blocks.entity_mentions.BERT.train)

In [None]:
labels = watson_nlp.toolkit.entity_mentions_utils.create_iob_labels(['Disease'])

In [None]:
# train the model, returns the instance of the block
entities_model = watson_nlp.blocks.entity_mentions.BERT.train(
    train_labeled_documents = train_iob_stream,
    dev_labeled_documents = val_iob_stream,
    label_list = labels,
    pretrained_model_resource = pretrained_model_resource,
    do_lower_case=True,
    num_train_epochs=4,
    train_batch_size=32,
    dev_batch_size=32,
    keep_model_artifacts=False)

In [None]:
# Save model
model_path = './Models/entities_bert'
entities_model.save(model_path)

### 3.3 Load model 
We need to read the model from local disk for evaluation

In [None]:
model_path = './Models/entities_bert'
entities_model = watson_nlp.blocks.entity_mentions.BERT.load(model_path)

## 4. Test model
### 4.1. Quick test

In [None]:
# Define one quick test
id = 1
test_sentence = df_test['sentence'][id]
test_entities = df_test['entities'][id]

print('Sentence: ', test_sentence)
print('Entities: ', test_entities)

# Run syntax model on text
syntax_analysis_en = syntax_model.run(test_sentence, parsers=('token',))
# type(syntax_analysis_en), syntax_analysis_en.get_token_texts_by_sentence()

In [None]:
# Run BERT mentions model on syntax result
ent_prediction = entities_model.run(syntax_analysis_en)
ent_prediction

### 4.2. Evaluate model quality

In [None]:
# Execute the model and generate the quality report
preprocess_func = lambda raw_doc: syntax_model.run(raw_doc)
quality_report = entities_model.evaluate_quality('./Data/test_set.json', preprocess_func)

# Print the quality report
print(json.dumps(quality_report, indent=4))