In [14]:
import nltk
import numpy as np
import pandas as pd
import glob
from tqdm import tqdm

In [15]:
def parse_entities_relationships(file_content):
    """
    Parses out entities and relationships from the entities file.

    Args:
    file_content (str): A string representing the content of the entities file.

    Returns:
    tuple: A tuple containing two dictionaries. The first dictionary contains information about entities 
           found in the file, while the second dictionary contains information about relationships among
           entities.

    """

    entities = {}
    relationships = {}
    for line in file_content.strip().split('\n'):
        if line.strip().startswith("#"):
            continue
        cols = line.split('\t')
        identifier = cols[0]
        if "Arg1:" in cols[1]:
            relationship_type, arg1, arg2 = cols[1].split()
            relationships[identifier] = {
                'type': relationship_type,
                'arg1': arg1.split(':')[1],
                'arg2': arg2.split(':')[1]
            }            
        else:
            split_cols = cols[1].split()
            entity_type = split_cols[0]
            value = cols[2]
            start = split_cols[1]
            end = split_cols[-1]
            entities[identifier] = {
                'type': entity_type,
                'start': int(start),
                'end': int(end),
                'value': value
            }
    return entities, relationships


def get_rel_spans(entities, relationships, txt_str, split_char="."):
    """
    Determines the span of each relationship in a given text.

    Args:
    entities (dict): A dictionary containing information about entities.
    relationships (dict): A dictionary containing information about relationships among entities.
    txt_str (str): A string representing the text to be analyzed.
    split_char (str, optional): The character to use for splitting text. Defaults to ".".

    Returns:
    dict: A dictionary containing information about the span of each relationship in the text.
    
    """

    spans = {}
    trim_count = 0
    for rel in relationships.keys():
        arg1 = entities[relationships[rel]['arg1']]
        arg2 = entities[relationships[rel]['arg2']]
        min_left = min(arg1['start'], arg2['start'])
        max_right = max(arg1['end'], arg2['end'])
        # find dots
        prev_char = txt_str[:min_left].rfind(split_char)+1
        next_char = txt_str[max_right:].find(split_char)+max_right
        # see if no exist
        assert(prev_char != -1)
        assert(next_char != -1)
        # trimming very long spans
        if min_left - prev_char > 100:
            prev_char = min_left - 100
            trim_count += 1
        if next_char - max_right:
            next_char = max_right + 100
            trim_count += 1
        relationships[rel]['span_start'] = prev_char
        relationships[rel]['span_end'] = next_char
    # if trim_count > 0:
    #     print(trim_count)
    return relationships


def collate_span_data_with_rels(rel_spans, ents, txt_str):
    """
    Collates information about the span of each relationship in a given text with information about entities.

    Args:
    rel_spans (dict): A dictionary containing information about the span of each relationship in the text.
    ents (dict): A dictionary containing information about entities.
    txt_str (str): A string representing the text to be analyzed.

    Returns:
    dict: A dictionary containing information about the span of each relationship in the text along with 
          information about all entities in the text.
    
    """

    all_data = {}
    for i, (rel, rdata_) in enumerate(rel_spans.items()):
        span_data = {}

        rdata = rdata_.copy()
        span = txt_str[rdata['span_start']:rdata['span_end']]
        span_data['text'] = span

        rdata['arg1_start'] = ents[rdata['arg1']]['start'] - rdata['span_start']
        rdata['arg1_end'] = ents[rdata['arg1']]['end'] - rdata['span_start']
        rdata['arg2_start'] = ents[rdata['arg2']]['start'] - rdata['span_start']
        rdata['arg2_end'] = ents[rdata['arg2']]['end'] - rdata['span_start']
        span_data[rel] = rdata
        
        all_ents = {}
        for ent, edata_ in ents.items():
            edata = edata_.copy()
            if rdata['span_start'] <= edata['start'] and rdata['span_end'] >= edata['end']:
                edata['start_pos'] = edata['start'] - rdata['span_start']
                edata['end_pos'] = edata['end'] - rdata['span_start']
                all_ents[ent] = edata        
        span_data['all_entities'] = all_ents    
        all_data[i] = span_data        
    return all_data


def get_uncovered_entity_spans(entities, relations, txt_str, split_char="."):
    """
    Determines the span of each entity in a given text that is not involved in a relationship.

    Args:
    entities (dict): A dictionary containing information about entities.
    relationships (dict): A dictionary containing information about relationships among entities.
    txt_str (str): A string representing the text to be analyzed.
    split_char (str, optional): The character to use for splitting text. Defaults to ".".

    Returns:
    dict: A dictionary containing information about the span of each entity in the text that is not involved in
          a relationship.
    
    """

    ents_covered = set()
    for v in relations.values():
        ents_covered = ents_covered.union({v['arg1']}).union({v['arg2']})
        
    uncovered_dict = {k:v for k, v in entities.items() if k not in ents_covered}
    spans = {}
    trim_count = 0
    for ent, edata in uncovered_dict.items():
        min_left = edata['start']
        max_right = edata['end']
        # find dots
        prev_char = txt_str[:min_left].rfind(split_char)+1
        next_char = txt_str[max_right:].find(split_char)+max_right
        # see if no exist
        assert(prev_char != -1)
        assert(next_char != -1)
        
        # trimming very long spans
        if min_left - prev_char > 100:
            prev_char = min_left - 100
            trim_count += 1
        if next_char - max_right:
            next_char = max_right + 100
            trim_count += 1

        span = txt_str[prev_char: next_char]
        if span not in spans:
            spans[span] = {}
            
        ent_data = {
            'start_pos': edata['start'] - prev_char,
            'end_pos': edata['end'] - prev_char,
            'value': edata['value'],
            'type': edata['type']
        }
        spans[span][ent] = ent_data
    # if trim_count > 0:
    #     print(trim_count)
    return spans


def tokenize_preserving_ent(text, entity_mapping):
    """
    Tokenizes text while preserving entity spans.

    Args:
    text (str): A string representing the text to be tokenized.
    entity_mapping (dict): A dictionary containing information about the span of each entity in the text.

    Returns:
    list: A list of tokens extracted from the input text, while preserving the span of each entity.
    
    """

    span_locs = [0]
    for edata in entity_mapping.values():
        span_locs += [edata['start_pos'], edata['end_pos']]

    span_locs.append(len(text))
    span_locs = sorted(list(set(span_locs)))

    start_index = 0
    splits = []
    for index in span_locs:
        split = text[start_index:index].strip()
        if split:
            splits.append(split)
        start_index = index

    tokens = []
    for s in splits:
        tokens += nltk.word_tokenize(s)
    return tokens


def tag_index(text, ent_type, ent_start, ent_end, all_ents):
    """
    Determines the index tags of entities in a given text.

    Args:
    text (str): A string representing the text to be analyzed.
    ent_type (str): A string representing the type of the entity.
    ent_start (int): An integer representing the starting position of the entity in the text.
    ent_end (int): An integer representing the ending position of the entity in the text.
    all_ents (dict): A dictionary containing information about all entities in the text.

    Returns:
    tuple: A tuple containing three elements. The first element is a dictionary containing index tags of entities
           in the input text. The second and third elements are integers representing the starting and ending 
           positions of the entity in the text, respectively.
    
    """

    index_tag_map = {}
    tokens = np.array(tokenize_preserving_ent(text, all_ents))
    ent_text = text[ent_start: ent_end]
    ent_tokens = nltk.word_tokenize(ent_text)
    possible_loc = np.where(tokens==ent_tokens[0])[0]
    ent_start_loc = possible_loc[np.argmin(
        [abs(ent_start - len(' '.join(tokens[0:x]))) for x in possible_loc]
    )]
    ent_end_loc = ent_start_loc + len(ent_tokens)
    if len(ent_tokens) == 1:
        index_tag_map[ent_start_loc] = f"S-{ent_type}"
    else:
        for i in range(ent_start_loc, ent_end_loc):
            if i == ent_start_loc:
                index_tag_map[i] = f"B-{ent_type}"
            elif i == ent_end_loc - 1:
                index_tag_map[i] = f"E-{ent_type}"
            else:
                index_tag_map[i] = f"I-{ent_type}"
    return index_tag_map, ent_start_loc, ent_end_loc


def get_token_tags(text, all_ents, all_rels=None):
    """
    Determines the tags for each token in a given text.

    Args:
    text (str): A string representing the text to be analyzed.
    all_ents (dict): A dictionary containing information about all entities in the text.
    all_rels (dict, optional): A dictionary containing information about relationships among entities. Defaults to None.

    Returns:
    pandas.core.frame.DataFrame: A DataFrame containing information about tags for each token in the text.
    
    """

    tokens = tokenize_preserving_ent(text, all_ents)
    tags = ['O'] * len(tokens)
    tag_dict = {}
    ent_dict = {}
    for eid, val in all_ents.items():
        tmp, s_idx, e_idx = tag_index(
            text, val['type'], val['start_pos'], val['end_pos'], all_ents
        )
        tag_dict = tag_dict | tmp
        ent_dict[eid] = {
            'start_idx': s_idx,
            'end_idx': e_idx
        }
    for i, t in tag_dict.items():
        tags[i] = t
        
    if all_rels:
        rels = all_rels.copy()
        rel_lst = []
        for _, rdata in rels.items():
            tokens_str = "|".join(tokens)
            arg1_idx = (f"{ent_dict[rdata['arg1']]['start_idx']}"
            f":{ent_dict[rdata['arg1']]['end_idx']}")
            arg2_idx = (f"{ent_dict[rdata['arg2']]['start_idx']}"
            f":{ent_dict[rdata['arg2']]['end_idx']}")
            rlabel = rdata['type']
            rel_lst.append([tokens_str, arg1_idx, arg2_idx, rlabel])
        
        ner_df = pd.DataFrame(
            list(zip(tokens, tags)), columns=['token', 'tag']
        ) 
        rels_df = pd.DataFrame(rel_lst, columns=['text', 'arg1', 'arg2', 'label'])
        return ner_df, rels_df
    else:
        ner_df = pd.DataFrame(
            list(zip(tokens, tags)), columns=['token', 'tag']
        ) 
        return ner_df

In [16]:
def parse_data(path):
    """
    The parse_data function takes a file path as an argument and returns two pandas DataFrames. The function reads in text and annotation files from the specified path and processes the data to extract entity and relationship information.

    The function returns two DataFrames df_ner and df_rel. df_ner contains tokenized text along with associated entity labels, and df_rel contains tokenized text along with associated relationship labels.

    Parameters:

    path: string representing the path to the directory containing the text and annotation files
    Returns:

    df_ner: a pandas DataFrame containing the tokenized text along with the associated entity labels
    df_rel: a pandas DataFrame containing the tokenized text along with the associated relationship labels
    """
    anns = sorted(glob.glob(f"{path}./*.ann"))
    files = sorted(glob.glob(f"{path}./*.txt"))
    df1_lst = []
    df2_lst = []
    for a, f in tqdm(list(zip(anns, files))):
        idx = a.split('\\')[-1].split('.')[0]
        assert idx == f.split('\\')[-1].split('.')[0]
        with open(a, 'r') as a_f:
            ann_str = a_f.read()
        with open(f, 'r') as f_f:
            txt_str = f_f.read()

        ents, rels = parse_entities_relationships(ann_str)
        rel_spans = get_rel_spans(ents, rels, txt_str)
        collated_data = collate_span_data_with_rels(rel_spans, ents, txt_str)
        
        all_text_spans = set()
        df_ner_lst = []
        df_rel_lst = []
        for i, data in collated_data.items():
            try:
                s = data['text']
                rel_data = {
                    k:v for k,v in data.items() if k not in ['text', 'all_entities']
                }
                df1, df2 = get_token_tags(s, data['all_entities'], rel_data)
                
                token_str = "|".join(df1['token'].to_list())
                if token_str not in all_text_spans:
                    df1['sid'] = i
                    df1['contains_rel'] = 1
                    df_ner_lst.append(df1)
                    all_text_spans = all_text_spans.union({token_str})
                    
                df2['sid'] = i
                df_rel_lst.append(df2)
            except ValueError:
                print(s)
                print(data['all_entities'])
        
        uncovered_ent_spans = get_uncovered_entity_spans(ents, rels, txt_str)
        for i, (s, v) in enumerate(uncovered_ent_spans.items()):
            try:
                df1 = get_token_tags(s, v)
                df1['sid'] = i
                df1['contains_rel'] = 0
                df_ner_lst.append(df1)
            except ValueError:
                print(s)
                print(v)

        if len(df_ner_lst) > 0:
            df_ner_per_rec = pd.concat(df_ner_lst)
            df_ner_per_rec['uid'] = idx
            df1_lst.append(df_ner_per_rec)
            
        if len(df_rel_lst) > 0:
            df_rel_per_rec = pd.concat(df_rel_lst)
            df_rel_per_rec['uid'] = idx
            df2_lst.append(df_rel_per_rec)

    df_ner = pd.concat(df1_lst)
    df_rel = pd.concat(df2_lst)
    df_ner[['token', 'tag']] = df_ner[['token', 'tag']].astype('string')
    df_ner[
        ['sid', 'contains_rel', 'uid']
    ] = df_ner[['sid', 'contains_rel', 'uid']].astype('int')

    df_rel[['sid', 'uid']] = df_rel[['sid', 'uid']].astype('int')
    df_rel[
        ['text', 'arg1', 'arg2', 'label']
    ] = df_rel[['text', 'arg1', 'arg2', 'label']].astype('string')

    return df_ner, df_rel 

In [17]:
df_ner_train, df_rel_train = parse_data("../../data/training")

  0%|          | 0/303 [00:00<?, ?it/s]

Diastolic CHF and possible restrictive cardiomyopathy
(5) AEA/VEA/CAD
(6) IDDM with albuminuria
(7) Gout on prednisone
(8) COPD


Social History:
- Tobacco: Quit 45 years ago
- Alcohol: Denies
- Illicits: Denies


Fami
{'T20': {'type': 'Drug', 'start': 3607, 'end': 3617, 'value': 'prednisone', 'start_pos': 108, 'end_pos': 118}, 'T141': {'type': 'Reason', 'start': 3599, 'end': 3603, 'value': 'Gout', 'start_pos': 100, 'end_pos': 104}, 'T295': {'type': 'Reason', 'start': 3573, 'end': 3577, 'value': 'IDDM', 'start_pos': 74, 'end_pos': 78}, 'T32': {'type': 'Drug', 'start': 3573, 'end': 3574, 'value': 'I', 'start_pos': 74, 'end_pos': 75}}
1**]
(3) Hypertension
(4) Diastolic CHF and possible restrictive cardiomyopathy
(5) AEA/VEA/CAD
(6) IDDM with albuminuria
(7) Gout on prednisone
(8) COPD


Social History:
- Tobacco: Quit 45 years ago
- A
{'T20': {'type': 'Drug', 'start': 3607, 'end': 3617, 'value': 'prednisone', 'start_pos': 134, 'end_pos': 144}, 'T141': {'type': 'Reason', 'start': 3599, '

In [19]:
df_rel_train.shape, df_ner_train.shape

((36346, 6), (1435233, 5))

In [20]:
df_ner_train[(df_ner_train.uid==115267) & (df_ner_train.sid==79) & (df_ner_train.contains_rel==1)]

Unnamed: 0,token,tag,sid,contains_rel,uid
0,multivitamin,S-Drug,79,1,115267
1,Tablet,S-Form,79,1,115267
2,Sig,O,79,1,115267
3,:,O,79,1,115267
4,One,B-Dosage,79,1,115267
5,(,I-Dosage,79,1,115267
6,1,I-Dosage,79,1,115267
7,),E-Dosage,79,1,115267
8,Tablet,S-Form,79,1,115267
9,PO,S-Route,79,1,115267


In [12]:
df_rel_train.shape, df_ner_train.shape

((36346, 6), (1435395, 5))

In [23]:
df_rel_train.to_parquet("../../data/parsed_data/rel_train.parquet")
df_ner_train.to_parquet("../../data/parsed_data/ner_train.parquet")

In [24]:
df_ner_test, df_rel_test = parse_data("../../data/test")

  0%|          | 0/202 [00:00<?, ?it/s]

In [25]:
df_rel_test.shape, df_ner_test.shape

((23462, 6), (931604, 5))

In [26]:
df_rel_test.to_parquet("../../data/parsed_data/rel_test.parquet")
df_ner_test.to_parquet("../../data/parsed_data/ner_test.parquet")