# Create ExTRI Named Entity Recognition Dataset

This notebook takes the gold-standard and validation sets from ExTRI and process them to create a Named Entity Regonition (NER) dataset. The code consists of the following steps:
1. Import and preprocess the table of sentences from the _sentence coverage_ supplementary file from the ExTRI paper.
2. Import and preprocess the table of gold-standard sentences from ExTRI.
3. Import and preprocess the table of validation set sentences from ExTRI.
4. Merge the gold-standard and validation set tables.
5. Create a dictionary containing gene names and gene synonyms for each gene in the merged table.
6. Create a Hugging Face NER dataset by annotating each sentence with a fine-tuned BERT model for the identification of genetic entities, then fuzzy matching those entities to the reported transcription factor(s) and target gene(s) to determine their classification.

## 1. Importing packages and setting global variables

In [None]:
# Standard library imports
import os
import re
from itertools import islice

# Third-party imports
import pandas as pd
import pyarrow as pa
from datasets import ClassLabel, Dataset, Features, Sequence, Value
from pyhere import here
from spacy import displacy
from thefuzz import process
import torch
from torchtext.data.utils import get_tokenizer
from tqdm import tqdm
from transformers import AutoModelForTokenClassification, AutoTokenizer

In [None]:
# File relative pathnames
sentence_coverage_path = str(here("data/extri/interim/sentence_coverage.tsv"))
gold_standard_set_path = str(here("data/extri/interim/gold_standatrd_set.tsv"))
validation_set_path = str(here("data/extri/interim/validation_set.tsv"))

output_dir = str(here("data/extri/processed"))


In [None]:
if torch.backends.mps.is_available():
    device = torch.device("mps")
elif torch.cuda.is_available():
    device = torch.device("cuda")
else:
    device = torch.device("cpu")

## 2. Custom functions

`construct_sentence_uid` is used for creating sentences unique IDs. In ExTRI, sentences are identified by their PMID and a number ID assigned within each PMID.

In [None]:
def construct_sentence_uid(row):
    # Builds a sentence unique id (uid) from a string that contains
    # `PMID:Sentence ID:TF:TG`. It returns a string like `PMID:Sentence UID`

    values = row["PMID:Sentence ID:TF:TG"].split(":")
    return f'{values[0]}:{values[1]}'

In [None]:
def create_entities_for_displacy(dataset):
    """
    This function takes as input a Hugging Face Dataset and
    creates a list of dictionaries that contains each sentence's
    entities in the format accepted by displacy.

    To render a sentence with displacy any dictionary in the list
    can be passed as input.
    """

    dataset_entities = []

    for row in tqdm(dataset):
        
        # Initialize values for parsing an entity
        ents = []
        parsing_entity = False
        entity_start = None
        entity_end = None
        entity_label = None

        tokens_list = row["tokens"]
        labels_list = row["ner_tags"]
        
        # Get the text by joining the tokens
        # do not get the text from the original sentence as some
        # compound words will not have spaces between them
        # therefore disrupting the matching of labels and tokens
        text = " ".join(tokens_list)

        # Create a list of token character lengths
        tokens_length = [len(string) for string in tokens_list]

        for index, (token, label) in enumerate(zip(tokens_list, labels_list)):

            label = features_schema["ner_tags"].feature.int2str(label)
            # Remove "I-" or "B-" from the label
            label = re.sub(r"^(B-|I-)", "", label)

            # If label contains an entity
            if label != "O":

                # If currently parsing an entity
                if parsing_entity == True:

                    # Update entity end
                    entity_end = sum(tokens_length[:index + 1]) + index

                # Start tracking an entity
                else:
                    parsing_entity = True
                    entity_label = label
                    entity_start = sum(tokens_length[:index]) + index # accounts for lenths of prev tokens plus spaces
                    entity_end = entity_start + tokens_length[index]
            
            # If label is "O"
            elif label == "O" or index == len(labels_list):

                # If an entity was being parsed
                if parsing_entity == True:
                    entity_dict = {
                        "start": entity_start,
                        "end": entity_end,
                        "label": entity_label
                    }

                    # Append to `ents`
                    ents.append(entity_dict)

                    # Reset values
                    parsing_entity = False
                    entity_start = None
                    entity_end = None
                    entity_label = None

                # Not parsing an entity, continue to next token
                else:
                    continue

        # Create sentence object containing text and entities
        sentence_entities_dict = {
            "text": text,
            "ents": ents
        }

        # Append sentence entities to output
        dataset_entities.append(sentence_entities_dict)
    
    # Return sentences with entities  formatted for displacy
    return dataset_entities

## 3. Load the sentences coverage table

Load the `sentence coverage` table from ExTRI, which contains the following columns:
- `PMID:Sentence ID:TF:TG`
- `Transcription Factor (Associated Gene Name)`
- `Target Gene (Associated Gene Name)`
- `Sentence`

In [None]:
# Read sentence coverage table
sentence_cov = pd.read_csv(
    sentence_coverage_path,
    sep="\t",
    usecols=["PMID:Sentence ID:TF:TG", "Transcription Factor (Associated Gene Name)", "Target Gene (Associated Gene Name)", "Sentence"]
)

# Rename columns
sentence_cov.rename(columns={
    "Transcription Factor (Associated Gene Name)": "TF",
    "Target Gene (Associated Gene Name)": "TG"
}, inplace=True)

# Create a column of unique ID (uid) for each sentence by merging the PMID and the sentence ID
# This is because sentence IDs are created for each PMID
sentence_cov["sentence_uid"] = sentence_cov.apply(construct_sentence_uid, axis=1)

# Keep only the `sentence_uid` column and drop duplicates
sentence_cov = sentence_cov[["sentence_uid", "Sentence"]].drop_duplicates()
sentence_cov.head(5)

## 4. Load Gold-Standard set

Load the _gold-standard_ table from ExTRI and format it.

In [None]:
gs_set = pd.read_csv(
    gold_standard_set_path,
    sep="\t",
    encoding="latin",
    dtype={"PMID": "Int64", "SID": "Int64"}
)

# Create a column to indicate dataset source
gs_set["dataset_source"] = "gold-standard"

# Create the sentence_uid column
gs_set["sentence_uid"] = gs_set.apply(lambda row: "{}:{}".format(row["PMID"], row["SID"]), axis=1)

# Filter out sentences with comments or negated
gs_set = gs_set[pd.isna(gs_set["Negated"])]
gs_set = gs_set[pd.isna(gs_set["Comments"])]

# Drop columns
gs_set = gs_set.drop(["PMID", "SID", "Negated", "Comments"], axis=1)

# Rename columns
gs_set.rename(columns={"DbTF": "TF", "Sentence": "sentence"}, inplace=True)

# Reorder columns
gs_set = gs_set.loc[:, ["sentence_uid", "sentence", "dataset_source", "TF", "TG"]]

gs_set.head(2)

## 5. Load Validation set

We create a column of `PMID:Sentence ID` from this dataset to obtain the sentences.

In [None]:
validation_set = pd.read_csv(
    validation_set_path,
    sep="\t"
)

# Filter out invalid sentences
validation_set = validation_set[validation_set["Valid"] == "Valid"]

# Create `sentence_uid` column
validation_set["sentence_uid"] = validation_set.apply(construct_sentence_uid, axis=1)

# Add dataset source column
validation_set["dataset_source"] = "Validation"

# Create TF and TG columns
validation_set["TF"] = validation_set.apply(lambda row: row["PMID:Sentence ID:TF:TG"].split(":")[2], axis=1)
validation_set["TG"] = validation_set.apply(lambda row: row["PMID:Sentence ID:TF:TG"].split(":")[3], axis=1)

# Merge with `sentence_cov` to obtain sentences
validation_set = pd.merge(validation_set, sentence_cov, how="left", on="sentence_uid")

# Drop columns
validation_set = validation_set.drop(["PMID:Sentence ID:TF:TG", "Valid"], axis=1)

validation_set.rename(columns={"Sentence": "sentence"}, inplace=True)

# Reorder columns
validation_set = validation_set.loc[:, ["sentence_uid", "sentence", "dataset_source", "TF", "TG"]]

validation_set.head(2)

Merge `sentences` dataframe with `validation_set` to add the `Sentence` column.

## 6. Create DataFrame for combined sets (Gold-Standard and Validation)

Create `combined_sets_df` by concatenating the gold-standard and validation sets.

In [None]:
combined_sets_df = pd.concat([gs_set, validation_set], axis=0)
print(f"combined_sets_df shape is {combined_sets_df.shape}")
combined_sets_df.head(2)

## 7. Obtain gene synonyms for TFs and TGs

Save a list of all genes (TFs and TGs) into `all_genes_list`.

In [None]:
all_genes_list = list(pd.concat([combined_sets_df["TF"], combined_sets_df["TG"]]).unique())
print(f"Number of unique genes: {len(all_genes_list)}")

Aggregate rows in `combined_sets_df` with the same `sentence_uid` and concatenate TFs and TGs for the same sentence with a comma.

In [None]:
combined_sets_df = combined_sets_df.groupby(["sentence_uid", "dataset_source", "sentence"], as_index=False).agg(
    {
        "TF": lambda x: ",".join(set(x)),
        "TG": lambda x: ",".join(set(x))
    }
)

combined_sets_df.head(2)

We now create a dictionary of genes and their synonyms using NCBI datasets.

In [None]:
from io import StringIO
import subprocess

genes_synonyms_dict = {}
genes_without_entries = []

for gene in tqdm(all_genes_list):
    command = f"datasets summary gene symbol {gene} --ortholog human,10090,10116 --as-json-lines | dataformat tsv gene --fields gene-id,symbol,tax-name,common-name,synonyms,gene-type,ensembl-geneids"
    result = subprocess.run(command, stdout=subprocess.PIPE, shell=True)
    data = StringIO(result.stdout.decode("utf-8"))

    try:
        ncbi_genes_df = pd.read_csv(data, sep="\t", na_values="NaN")
    except:
        genes_without_entries.append(gene)
        continue
    
    genes_synonyms_dict[gene] = []

    # TODO: Simplify this loop.
    for column_string in ncbi_genes_df["Synonyms"]:

        if pd.notna(column_string):
            genes_synonyms_dict[gene].extend(column_string.split(","))

Genes that were not present in NCBI datasets were:

In [None]:
genes_without_entries

## 8. Create NER Hugging Face dataset

Next we create an empty `Dataset` (`ner_dataset`) and annotate the senteces from `combined_sets_df` with NER tags. First, we initialize the model and tokenizers. The `spacy_tokenizer` is used to tokenize the senteces, as they will be stored in the dataset (`ner_dataset`) in the form of a list of tokens. The `bert_tokenizer` is a Hugging Face tokenizer that will process each sentence for being input into the model.

In [None]:
# Define the list of NER tags
entities_list = ["O", "B-TRANSCRIPTION_FACTOR", "I-TRANSCRIPTION_FACTOR", "B-TARGET_GENE", "I-TARGET_GENE"]

# Initialize tokenizers and NER model
spacy_tokenizer = get_tokenizer("spacy", language = "en_core_web_trf")
bert_tokenizer = AutoTokenizer.from_pretrained("alvaroalon2/biobert_genetic_ner")


model = AutoModelForTokenClassification.from_pretrained("alvaroalon2/biobert_genetic_ner")
model.to(device)
model.eval()

# Initialize Dataset
# ner_dataset = Dataset(
#     pa.table({
#         "sentence_uid": [],
#         "dataset_source": [],
#         "sentence": [],
#         "tokens": [],
#         "ner_tags": []
#     })
# )

# Features schema
features_schema = Features({
    "sentence_uid": Value(dtype="string"),
    "dataset_source": Value(dtype="string"),
    "sentence": Value(dtype="string"),
    "tokens": Sequence(feature=Value(dtype="string")),
    "ner_tags": Sequence(
        ClassLabel(
            num_classes= len(entities_list),
            names=entities_list
        )
    )
})

Now we iterate over each row in `combined_sets_df` and tag entities in each sentence as TF or TG using a Hugging Face NER model and fuzzy search.

In [None]:
# Iterate over each row in `combined_sets_df`
dataset_entries = []
for _, row in tqdm(combined_sets_df.iterrows(), total=combined_sets_df.shape[0]):

    # Remove repeated whitespaces. This would cause an error later since
    # word_ids from the BERT tokenizer removes whitespaces
    sentence = re.sub(" +", " ", row["sentence"])
    # Tokenize sentence with spacy
    tokenized_text = spacy_tokenizer(sentence)

    # Tokenize with HF tokenizer
    tokenized_input = bert_tokenizer(tokenized_text, padding=False, truncation=True, max_length=512, is_split_into_words=True, return_tensors="pt")

    # Move tensors to device
    tokenized_input = tokenized_input.to(device)

    results = model(**tokenized_input)
    logits = results["logits"].squeeze(0)
    _, label_preds = torch.max(logits, 1)

    """
    Iterate over each index of (BERT-)tokenized words to assign a label (model prediction)
    to each word in the original sentence.
    """
    sentence_ner_tags = []
    prev_word_idx = None

    word_ids = tokenized_input.word_ids(batch_index=0)
    # `word_idx` is the corresponding index to the word in the spacy tokenized text
    for enum_idx, word_idx in enumerate(word_ids):

        if (word_idx is None) or (word_idx == prev_word_idx):
            continue

        # New word (token)
        else:

            # If label is different than 'O' (it contains a genetic entity)
            if label_preds[enum_idx] != 2:
        
                # Check if previous NER tag was a genetic entity
                # If so, add a I-GENTIC tag
                if len(sentence_ner_tags) > 0 and (sentence_ner_tags[-1] == "B-GENETIC" or sentence_ner_tags[-1] == "I-GENETIC"):
                    sentence_ner_tags.append("I-GENETIC")
                else:
                    sentence_ner_tags.append("B-GENETIC")

            # If label is 'O'
            else:
                sentence_ner_tags.append("O")

            # Update `prev_word_idx`
            prev_word_idx = word_idx

    # Get TFs and TGs and their synonyms
    tfs = []
    tgs = []

    # TODO: This code can be improved. It is repeated for TF and TG.
    for tf in row["TF"].split(","):
        if tf in genes_synonyms_dict.keys():
            tfs.extend(list(set(genes_synonyms_dict.get(tf) + [tf])))
        else:
            tfs.append(tf)
    
    for tg in row["TG"].split(","):
        if tg in genes_synonyms_dict.keys():
            tgs.extend(list(set(genes_synonyms_dict.get(tg) + [tg])))
        else:
            tgs.append(tg)

    current_entity_indexes = []
    num_tf_entities = 0
    num_tg_entities = 0

    # Iterate over each word token (from spacy tokenization)
    for enum_idx, token_idx in enumerate(tokenized_text):

        # Get label for current token (either B-GENETIC, I-GENETIC, or O)
        try:
            current_label = sentence_ner_tags[enum_idx]
        except:
            print(f"Length of tokenized_text: {len(tokenized_text)}")
            print(f"Length of sentence_ner_tags: {len(sentence_ner_tags)}")
            raise Exception(f"Error for sentence {tokenized_text} |TAGS:| {sentence_ner_tags}")
        
        # If word is part of a genetic entity then add its index to `current_entity_indexes`
        if current_label == "B-GENETIC" or current_label == "I-GENETIC":
            current_entity_indexes.append(enum_idx)

        # If word is not part of a genetic entity or it is the last word in the sentence
        if (current_label == "O") or (enum_idx == len(tokenized_text) - 1):

            # If an entity is being parsed from previous token(s) or last token
            if len(current_entity_indexes) > 0:
                genetic_entity = " ".join([tokenized_text[index] for index in current_entity_indexes])

                # Get score for TF and TG based on fuzzy search
                _, tf_match_score = process.extractOne(genetic_entity, tfs)
                _, tg_match_score = process.extractOne(genetic_entity, tgs)

                # If match score is lower than 50 it likely doesn't belong to any TF or TG.
                # Then set label to "O"
                if tf_match_score < 50 and tg_match_score < 50:
                    entity_label = "O"

                    # Update corresponding indexes in sentence_ner_tags to "O"
                    for entity_index in current_entity_indexes:
                        sentence_ner_tags[entity_index] = entity_label
                else:

                    if tf_match_score > tg_match_score:
                        entity_label = "TRANSCRIPTION_FACTOR"
                        num_tf_entities += 1
                    else:
                        entity_label = "TARGET_GENE"
                        num_tg_entities += 1
                    
                    # Update corresponding indexes in sentence_ner_tags
                    for entity_index in current_entity_indexes:
                        sentence_ner_tags[entity_index] = sentence_ner_tags[entity_index].replace("GENETIC", entity_label)

                # Reset `current_entity_indexes`
                current_entity_indexes = []
    
    try:
        # sentence_ner_tags = [features_schema["ner_tags"].feature.str2int(tag) for tag in sentence_ner_tags]
        [features_schema["ner_tags"].feature.str2int(tag) for tag in sentence_ner_tags]
    except:
        raise Exception(f"{row['sentence']} |TAGS| {sentence_ner_tags}")

    # Add sentence to dataset if it contains at least one TF and TG entities 
    if num_tf_entities >= 1 and num_tg_entities >= 1:

        dataset_entry = {
            "sentence_uid": row["sentence_uid"],
            "dataset_source": row["dataset_source"],
            "sentence": row["sentence"],
            "tokens": tokenized_text,
            "ner_tags": sentence_ner_tags
        }

        # ner_dataset = ner_dataset.add_item(dataset_entry)
        dataset_entries.append(dataset_entry)


In [None]:
ner_dataset = Dataset.from_list(dataset_entries, features=features_schema)
ner_dataset

In [None]:
ner_dataset.features

To visualize the entities in the dataset using `displacy`, we run our custom function that formats the Hugging Face dataset for input to `displacy`.

In [None]:
ner_dataset_displacy_entities = create_entities_for_displacy(ner_dataset)

Here we visualize one example.

In [None]:
example_sentence = ner_dataset_displacy_entities[0]

options = {'colors': {'TRANSCRIPTION_FACTOR': '#D6D77F', 'TARGET_GENE': '#01DFED'}}
displacy.render(example_sentence, manual=True, style="ent", options=options)

Create the splits for the dataset (train and validation).

In [None]:
ner_dataset_dict = ner_dataset.train_test_split(test_size=0.2)

# Rename test to validation
validation_set = ner_dataset_dict.pop("test")
ner_dataset_dict["validation"] = validation_set
ner_dataset_dict

## Save the dataset

In [None]:
ner_dataset_dict.save_to_disk(output_dir)

In [None]:
# To load the dataset run:

import os
from datasets import load_dataset

data_files = {
    "train": str(here(os.path.join(output_dir, "train/data-00000-of-00001.arrow"))),
    "validation": str(here(os.path.join(output_dir, "validation/data-00000-of-00001.arrow")))
}

raw_dataset = load_dataset("arrow", data_files=data_files)
raw_dataset