In [1]:
import nltk
# nltk.download('punkt')
import numpy as np
from nltk.tokenize import sent_tokenize
import json
import itertools
import re

In [2]:
# base_path = "../../data/test/"
# file_name = "164726"
base_path = "../../data/"
file_name = "100187"

In [3]:
ann_str = None
with open(base_path + file_name + ".ann", "r") as file:
    ann_str = file.read()

In [22]:
def parse_entities_relationships(file_content):
    """
    parses out entities and relationships from the entities file
    """
    entities = {}
    relationships = {}

    for line in file_content.strip().split('\n'):
        cols = line.split('\t')
        identifier = cols[0]

        if "Arg1:" in cols[1]:
            relationship_type, arg1, arg2 = cols[1].split()
            relationships[identifier] = {
                'type': relationship_type,
                'arg1': arg1.split(':')[1],
                'arg2': arg2.split(':')[1]
            }
            
        else:
            split_cols = cols[1].split()
            entity_type = split_cols[0]
            value = cols[2]
            start = split_cols[1]
            end = split_cols[-1]
            entities[identifier] = {
                'type': entity_type,
                'start': int(start),
                'end': int(end),
                'value': value
            }

    return entities, relationships


def sentence_start_end(sentences, txt_str):
    """
    get starting and ending positions of sentences
    """
    curr_start = 0
    sent_positions = []
    for i in range(len(sentences)):
        sent = sentences[i]
        match_idx = txt_str.find(sent, curr_start)
        start = match_idx
        end = match_idx + len(sent)
        curr_start = end
        sent_positions.append((start, end))
    return sent_positions


def update_entity_indices(entities, sent_positions):
    """
    for each entity update its position relative to this sentence
    """
    for entity_key in entities:
        entity = entities[entity_key]
        start = entity['start']
        end = entity['end']
        if entity_key == "T277":
            print(start, end)
        for i, (sent_start, sent_end) in enumerate(sent_positions):
            if start >= sent_start and end <= sent_end:
                entity['sentence'] = i
                entity['sent_start'] = start - sent_start
                entity['sent_end'] = end - sent_start
                break
    return entities

In [23]:
# [(m.start(0), m.end(0)) for m in re.finditer("7-1", "5-7-110010")]

In [24]:
txt_str = None
with open(base_path + file_name + ".txt", "r") as file:
    txt_str = file.read()

In [25]:
ent, rel = parse_entities_relationships(ann_str)

In [26]:
sentences = sent_tokenize(txt_str)
sent_pos = sentence_start_end(sentences, txt_str)

In [28]:
for i in range(len(sentences)):
    s = sentences[i]
    if "Pantoprazole" in s:
        print(i)
sent_pos[215]

157
179
215


(17279, 17328)

In [29]:
new_ent = update_entity_indices(ent, sent_pos)

17279 17291


In [30]:
new_ent['T277']

{'type': 'Drug',
 'start': 17279,
 'end': 17291,
 'value': 'Pantoprazole',
 'sentence': 215,
 'sent_start': 0,
 'sent_end': 12}

In [31]:
ents_covered = set(
    itertools.chain.from_iterable(([[r['arg1'], r['arg2']] for k, r in rel.items()]))
    )
uncovered = set(ent.keys()) - ents_covered

uncovered_dict = {k: v for k, v in ent.items() if k in uncovered}

In [None]:
ents = update_entity_indices(uncovered_dict, sent_pos)

In [None]:
ent_updated = update_entity_indices(ent, sent_pos)

In [None]:
spans = get_relationship_spans(ent_updated, rel, txt_str)

In [None]:
def get_relationship_spans(entities, relationships, txt_str):
    """
    get relationship spans
    """
    sentences = sent_tokenize(txt_str)
    entities = update_entity_indices(entities, sentence_start_end(sentences, txt_str))

    relationship_spans = {}

    for rel_id, relationship in relationships.items():
        arg1 = entities[relationship['arg1']]
        arg2 = entities[relationship['arg2']]
        if arg1['sentence'] <= arg2['sentence']:
            sent_start = arg1['sentence']
            sent_end = arg2['sentence']
            arg1_start = arg1['sent_start']
            arg1_end = arg1['sent_end']
            offset = len(' '.join(sentences[sent_start: sent_end]))
            arg2_start = offset + arg2['sent_start']
            arg2_end = offset + arg2['sent_end']
        else:
            sent_start = arg2['sentence']
            sent_end = arg1['sentence']
            arg2_start = arg2['sent_start']
            arg2_end = arg2['sent_end']
            offset = len(' '.join(sentences[sent_start: sent_end]))
            arg1_start = offset + arg1['sent_start']
            arg1_end = offset + arg1['sent_end']
        
        span = ' '.join(sentences[sent_start:sent_end+1])
        
#         print(len(span[arg1_start:arg1_end]), len(arg1['value']))
#         print(span[arg1_start:arg1_end], arg1['value'])
#         print(len(span[arg2_start:arg2_end]), len(arg2['value']))
#         print(span[arg2_start:arg2_end], arg2['value'])
#         assert (
#             (span[arg1_start:arg1_end] == arg1['value']) and 
#             (span[arg2_start:arg2_end] == arg2['value'])
#         )

        if span not in relationship_spans:
            relationship_spans[span] = {
                'relationships': {},
                'all_entities': {}
            }

        relationship_spans[span]['relationships'][rel_id] = {
            'type': relationship['type'],
            'arg1': relationship['arg1'],
            'arg1_start': arg1_start,
            'arg1_end': arg1_end,
            'arg2': relationship['arg2'],
            'arg2_start': arg2_start,
            'arg2_end': arg2_end
        }

        if relationship['arg1'] not in relationship_spans[span]['all_entities']:
            ent_data = {
                'start_pos': arg1_start,
                'end_pos': arg1_end,
                'text': arg1['value'],
                'type': arg1['type']
            }
            relationship_spans[span]['all_entities'][relationship['arg1']] = ent_data
        
        if relationship['arg2'] not in relationship_spans[span]['all_entities']:
            ent_data = {
                'start_pos': arg2_start,
                'end_pos': arg2_end,
                'text': arg2['value'],
                'type': arg2['type']
            }
            relationship_spans[span]['all_entities'][relationship['arg2']] = ent_data
    return relationship_spans


def get_ent_tags(text, ent_type, ent_text, ent_loc):
    index_tag_map = {}
    tokens = np.array(nltk.word_tokenize(text))
    ent_tokens = nltk.word_tokenize(ent_text)
    possible_loc = np.where(tokens==ent_tokens[0])[0]
    ent_start_loc = possible_loc[np.argmin(
        [abs(ent_loc - len(' '.join(tokens[0:x]))) for x in possible_loc]
    )]
    ent_end_loc = ent_start_loc + len(ent_tokens) - 1
    if len(ent_tokens) == 1:
        index_tag_map[ent_start_loc] = f"S-{ent_type}"
    else:
        for i in range(ent_start_loc, ent_end_loc):
            if i == ent_start_loc:
                index_tag_map[i] = f"B-{ent_type}"
            elif i == ent_end_loc - 1:
                index_tag_map[i] = f"E-{ent_type}"
            else:
                index_tag_map[i] = f"I-{ent_type}"
    return index_tag_map, ent_start_loc, ent_end_loc


def get_token_tags(text, all_ents, all_rels=None):
    tokens = nltk.word_tokenize(text)
    tags = ['O'] * len(tokens)
    tag_dict = {}
    ent_dict = {}
    for eid, val in all_ents.items():
        tmp, s_idx, e_idx = get_ent_tags(
            text, val['type'], val['text'], val['start_pos']
        )
        tag_dict = tag_dict | tmp
        ent_dict[eid] = {
            'start_idx': s_idx,
            'end_idx': e_idx
        }
    for i, t in tag_dict.items():
        tags[i] = t
        
    if all_rels:
        rels = all_rels.copy()
        for _, rdata in rels.items():
            rdata['arg1_start'] = ent_dict[rdata['arg1']]['start_idx']
            rdata['arg1_end'] = ent_dict[rdata['arg1']]['end_idx']
            rdata['arg2_start'] = ent_dict[rdata['arg2']]['start_idx']
            rdata['arg2_end'] = ent_dict[rdata['arg2']]['end_idx']
        return list(zip(tokens, tags)), rels
    else:
        return list(zip(tokens, tags))

In [None]:
for k, v in spans.items():
    print(f"{k}\n")
    print(v)
    break

In [None]:
v

In [None]:
spans

In [None]:
# t = 'Aggressive diuresis\nwith IV lasix and diuril were unsuccessful so CVVH was\ninitiated.'
# get_token_tags(t, spans[t]['all_entities'], spans[t]['relationships'])

In [None]:
def process_sentences_without_relationships(sentences, sent_positions, entities, relationships, relationship_spans):
    processed_sentences = {entities[rel['arg1']]['sentence'] for rel in relationships.values()}
    processed_sentences.update({entities[rel['arg2']]['sentence'] for rel in relationships.values()})
    for i, (sent_start, sent_end) in enumerate(sent_positions):
        if i not in processed_sentences:
            entities_in_sent = [e for e in entities.values() if e['sentence'] == i]
            if len(entities_in_sent) > 1:
                span = sentences[i]

                if span not in relationship_spans:
                    relationship_spans[span] = {
                        'relationships': [],
                        'entities': {}
                    }

                for j, (e1, e2) in enumerate(zip(entities_in_sent[:-1], entities_in_sent[1:])):
                    print(e1['id'], e2['id'])
                    relationship_spans[span]['relationships'].append({
                        'id': f'NOREL{i}_{j}',
                        'type': 'NOREL',
                        'arg1': e1['id'],
                        'arg2': e2['id']
                    })

                for e in entities_in_sent:
                    relationship_spans[span]['entities'][e['id']] = e
    return relationship_spans

In [None]:
norels = process_sentences_without_relationships(sentences, sent_pos, ent, rel, spans)

In [None]:
norels