In [1]:
import os
import json
import pandas as pd
from matplotlib import pyplot as plt
from sklearn.model_selection import train_test_split
from shutil import copyfile
import csv

In [4]:
from spacy.lang.en import English
nlp = English()
sentencizer = nlp.add_pipe("sentencizer")


In [11]:
inputpath = f"data/Chia_with_Scope"
outputpath = f"data/Chia_to_bio"
trainpath = f"data/Train"
testpath = f"data/Test"

In [12]:
inputfiles = set()
for f in os.listdir(inputpath):
    if f.endswith('.ann'):
        inputfiles.add(f.split('.')[0].split('_')[0])
len(inputfiles)

1000

In [13]:
# list of entity types to retain
select_types = ['Condition', 'Value', 'Drug', 'Procedure', 'Measurement', 'Temporal', \
    'Observation', 'Person', 'Mood', 'Device', 'Pregnancy_considerations']

In [15]:
# convert Brat format into BIO format
# function for getting entity annotations from the annotation file
def get_annotation_entities(ann_file, select_types=None):
    entities = []
    with open(ann_file, "r", encoding="utf-8") as f:
        for line in f:
            if line.startswith('T'):
                term = line.strip().split('\t')[1].split()
                if (select_types != None) and (term[0] not in select_types): continue
                if int(term[-1]) <= int(term[1]): continue
                entities.append((int(term[1]), int(term[-1]), term[0]))
    return sorted(entities, key=lambda x: (x[0], x[1]))

# function for handling overlap by keeping the entity with largest text span
def remove_overlap_entities(sorted_entities):
    keep_entities = []
    for idx, entity in enumerate(sorted_entities):
        if idx == 0:
            keep_entities.append(entity)
            last_keep = entity
            continue
        if entity[0] < last_keep[1]:
            if entity[1]-entity[0] > last_keep[1]-last_keep[0]:
                last_keep = entity
                keep_entities[-1] = last_keep
        elif entity[0] == last_keep[1]:
            last_keep = (last_keep[0], entity[1], last_keep[-1])
            keep_entities[-1] = last_keep
        else:
            last_keep = entity
            keep_entities.append(entity)
    return keep_entities

# inverse index of entity annotations
def entity_dictionary(keep_entities, txt_file):
    f_ann = {}
    with open(txt_file, "r", encoding="utf-8") as f: # Marked
        text = f.readlines()
        if file in ['NCT02348918_exc', 'NCT02348918_inc', 'NCT01735955_exc']:
            text = ' '.join([i.strip() for i in text])
        else:
            text = '  '.join([i.strip() for i in text])
    for entity in keep_entities:
        entity_text = text[entity[0]:entity[1]]
        doc = nlp(entity_text)
        token_starts = [(i, doc[i:].start_char) for i in range(len(doc))]
        term_type = entity[-1]
        term_offset = entity[0]
        for i, token in enumerate(doc):
            ann_offset = token_starts[i][1]+term_offset
            if ann_offset not in f_ann:
                f_ann[ann_offset] = [i, token.text, term_type]
    return f_ann

# Brat -> BIO format conversion
for infile in inputfiles:
    for t in ["exc", "inc"]:
        file = f"{infile}_{t}"
        ann_file = f"{inputpath}/{file}.ann"
        txt_file = f"{inputpath}/{file}.txt"
        out_file = f"{outputpath}/{file}.bio.txt"
        sorted_entities = get_annotation_entities(ann_file, select_types)
        keep_entities = remove_overlap_entities(sorted_entities)
        f_ann = entity_dictionary(keep_entities, txt_file)
        with open(out_file, "w", encoding="utf-8") as f_out:
            with open(txt_file, "r", encoding="utf-8") as f:
                sent_offset = 0
                for line in f:
                    # print(line.strip())
                    if '⁄' in line:
                        # print(txt_file)
                        line = line.replace('⁄', '/') # replace non unicode characters
                    doc = nlp(line.strip())
                    token_starts = [(i, doc[i:].start_char) for i in range(len(doc))]
                    for token in doc:
                        token_sent_offset = token_starts[token.i][1]
                        token_doc_offset = token_starts[token.i][1]+sent_offset
                        if token_doc_offset in f_ann:
                            if f_ann[token_doc_offset][0] == 0:
                                label = f"B-{f_ann[token_doc_offset][2]}"
                            else:
                                label = f"I-{f_ann[token_doc_offset][2]}"
                        else:
                            label = f"O"
                        # print(token.text, token_sent_offset, token_sent_offset+len(token.text), token_doc_offset, token_doc_offset+len(token.text), label)
                        f_out.write(f"{token.text} {token_sent_offset} {token_sent_offset+len(token.text)} {token_doc_offset} {token_doc_offset+len(token.text)} {label}\n")
                    # print('\n')
                    f_out.write('\n')
                    if file in ['NCT02348918_exc', 'NCT02348918_inc', 'NCT01735955_exc']: # 3 trials with inconsistent offsets
                        sent_offset += (len(line.strip())+1)
                    else:
                        sent_offset += (len(line.strip())+2)

In [17]:
# dataset separation: 800 trials (80%) for training, 100 trials (10%) for validation and 100 trials (10%) for testing
train_ids, dev_ids = train_test_split(list(inputfiles), train_size=0.8, random_state=13, shuffle=True)
dev_ids, test_ids = train_test_split(dev_ids, train_size=0.5, random_state=13, shuffle=True)
print(len(train_ids), len(dev_ids), len(test_ids))
chia_datasets = {"train":train_ids, "dev":dev_ids, "test":test_ids}
json.dump(chia_datasets, open("data/chia_datasets.json", "w", encoding="utf-8"))

800 100 100


In [18]:
# Merge BIO format train, validation and test datasets
# chia_datasets = json.load(open("chia/chia_datasets.json", "r", encoding="utf-8"))
# merge the train dataset
with open("data/train.txt", "w", encoding="utf-8") as f:
    for fid in chia_datasets["train"]:
        copyfile(f"{outputpath}/{fid}_exc.bio.txt", f"{trainpath}/{fid}_exc.bio.txt")
        copyfile(f"{outputpath}/{fid}_inc.bio.txt", f"{trainpath}/{fid}_inc.bio.txt")
        with open(f"{outputpath}/{fid}_exc.bio.txt", "r", encoding="utf-8") as fr:
            txt = fr.read().strip()
            if txt != '':
                f.write(txt)
                f.write("\n\n")
        with open(f"{outputpath}/{fid}_inc.bio.txt", "r", encoding="utf-8") as fr:
            txt = fr.read().strip()
            if txt != '':
                f.write(txt)
                f.write("\n\n")

# merge the validation dataset
with open("data/dev.txt", "w", encoding="utf-8") as f:
    for fid in chia_datasets["dev"]:
        copyfile(f"{outputpath}/{fid}_exc.bio.txt", f"{trainpath}/{fid}_exc.bio.txt")
        copyfile(f"{outputpath}/{fid}_inc.bio.txt", f"{trainpath}/{fid}_inc.bio.txt")
        with open(f"{outputpath}/{fid}_exc.bio.txt", "r", encoding="utf-8") as fr:
            txt = fr.read().strip()
            if txt != '':
                f.write(txt)
                f.write("\n\n")
        with open(f"{outputpath}/{fid}_inc.bio.txt", "r", encoding="utf-8") as fr:
            txt = fr.read().strip()
            if txt != '':
                f.write(txt)
                f.write("\n\n")

# merge the test dataset
with open("data/test.txt", "w", encoding="utf-8") as f:
    for fid in chia_datasets["test"]:
        copyfile(f"{outputpath}/{fid}_exc.bio.txt", f"{testpath}/{fid}_exc.bio.txt")
        copyfile(f"{outputpath}/{fid}_inc.bio.txt", f"{testpath}/{fid}_inc.bio.txt")
        with open(f"{outputpath}/{fid}_exc.bio.txt", "r", encoding="utf-8") as fr:
            txt = fr.read().strip()
            if txt != '':
                f.write(txt)
                f.write("\n\n")
        with open(f"{outputpath}/{fid}_inc.bio.txt", "r", encoding="utf-8") as fr:
            txt = fr.read().strip()
            if txt != '':
                f.write(txt)
                f.write("\n\n")

In [19]:
# convert Chia in Brat into format for Att-BiLSTM-CRF model
out_file = f"data/chia_ner.tsv"
with open(out_file, "w", encoding="utf-8") as f_out:
    for infile in inputfiles:
        for t in ["exc", "inc"]:
            file = f"{infile}_{t}"
            ann_file = f"{inputpath}/{file}.ann"
            txt_file = f"{inputpath}/{file}.txt"
            sorted_entities = get_annotation_entities(ann_file, select_types)
            keep_entities = remove_overlap_entities(sorted_entities)
            with open(txt_file, "r", encoding="utf-8") as f:
                sent_offset = 0
                for line in f:
                    # print(line.strip())
                    if '⁄' in line: line = line.replace('⁄', '/')
                    sent_end = sent_offset + len(line)
                    sent_ents = []
                    for ent in keep_entities:
                        if ent[0] < sent_offset or ent[1] < sent_offset: continue
                        if ent[0] >= sent_end or ent[1] > sent_offset+len(line.strip()): break
                        ent_start = ent[0]-sent_offset+1
                        ent_end = ent[1]-sent_offset+1
                        sent_ents.append(f"{ent_start}:{ent_end}:{ent[2].lower()}")
                    if sent_ents == []:
                        if file in ['NCT02348918_exc', 'NCT02348918_inc', 'NCT01735955_exc']:
                            sent_offset += (len(line.strip())+1)
                        else:
                            sent_offset += (len(line.strip())+2)
                        continue
                    # print(f"{file}\t{','.join(sent_ents)}\t{line.strip()}")
                    f_out.write(f"{file}\t{','.join(sent_ents)}\t{line.strip()}")
                    # print('\n')
                    f_out.write('\n')
                    if file in ['NCT02348918_exc', 'NCT02348918_inc', 'NCT01735955_exc']:
                        sent_offset += (len(line.strip())+1)
                    else:
                        sent_offset += (len(line.strip())+2)

In [23]:
# split Chia in format for Att-BiLSTM-CRF model into train, validation and test datasets
# chia_datasets = json.load(open("chia/chia_datasets.json", "r", encoding="utf-8"))
with open("data/chia_ner_train.tsv", "w", encoding="utf-8") as ftrain, open("data/chia_ner_dev.tsv", "w", encoding="utf-8") as fdev, open("data/chia_ner_test.tsv", "w", encoding="utf-8") as ftest:
    with open("data/chia_ner.tsv", "r", encoding="utf-8") as fread:
        for line in fread:
            if line.split('\t', 1)[0].split("_")[0] in chia_datasets["train"]:
                ftrain.write(line)
            elif line.split('\t', 1)[0].split("_")[0] in chia_datasets["dev"]:
                fdev.write(line)
            else:
                ftest.write(line)

In [24]:
def abc_strict_match(gs, pred, s_idx, e_idx, ent_type):
    if s_idx == 0:
        for idx in range(s_idx, e_idx):
            if gs[idx] != pred[idx]:
                return False
        if e_idx < len(gs):
            if gs[e_idx] == ent_type or pred[e_idx] == ent_type:
                return False
    else:
        if gs[s_idx-1] == ent_type or pred[s_idx-1] == ent_type:
            return False
        for idx in range(s_idx, e_idx):
            if gs[idx] != pred[idx]:
                return False
        if e_idx < len(gs):
            if gs[e_idx] == ent_type or pred[e_idx] == ent_type:
                return False
    return True

def abc_relax_match(gs, pred, s_idx, e_idx, ent_type):
    for idx in range(s_idx, e_idx):
        if gs[idx] == pred[idx] == ent_type:
            return True
    return False

In [51]:
dataset = 'chia' # or 'frd', 'chia'
outfolder = f"data/attbilstmcrf"
outfile = f"{dataset}_attbilstmcrf_results.txt"
labels_dict = {'chia':['Mood', 'Condition', 'Procedure', 'Measurement', 'Value', 'Drug', 'Temporal', 'Observation', 'Pregnancy', 'Person', 'Device'], 'frd':['chronic_disease', 'treatment', 'upper_bound', 'pregnancy', 'clinical_variable', 'lower_bound', 'cancer', 'age', 'language_fluency', 'gender', 'contraception_consent', 'technology_access', 'allergy_name', 'bmi', 'ethnicity']}
labels = labels_dict[dataset]

In [73]:
eval_metrics = {"category":{},"overall":{}, "prediction":{}}
with open(f"{outfolder}/{outfile}", "r", encoding="utf-8") as f:
    next(f)
    for line in f:
        gs = eval(line.split('\t')[1])
        pred = eval(line.split('\t')[0])
        print(gs, pred)
        for i in zip(gs, pred):
            if i[0] == i[1]: eval_metrics["overall"]["acc_true"] = eval_metrics["overall"].get("acc_true", 0) + 1
            else: eval_metrics["overall"]["acc_false"] = eval_metrics["overall"].get("acc_false", 0) + 1
        llen = len(gs)
        cur_idx = 0
        while cur_idx < llen:
            if gs[cur_idx] == 0:
                cur_idx += 1
            else:
                start_idx = cur_idx
                end_idx = start_idx + 1
                cate = gs[start_idx]
                while end_idx < llen and gs[end_idx] == cate:
                    end_idx += 1
                eval_metrics["overall"]['gs'] = eval_metrics["overall"].get('gs', {})
                eval_metrics["overall"]['gs']['count'] = eval_metrics["overall"]['gs'].get('count', 0) + 1
                eval_metrics["overall"]['gs'][labels[cate-1]] = eval_metrics["overall"]['gs'].get(labels[cate-1], 0) + 1
                if abc_strict_match(gs, pred, start_idx, end_idx, cate):
                    eval_metrics["overall"]["strict_predicted"] = eval_metrics["overall"].get("strict_predicted", 0) + 1
                    eval_metrics["category"][labels[cate-1]] = eval_metrics["category"].get(labels[cate-1], {"strict":0, "relax":0, "miss":0})
                    eval_metrics["category"][labels[cate-1]]["strict"] += 1
                elif abc_relax_match(gs, pred, start_idx, end_idx, cate):
                    eval_metrics["overall"]["relax_predicted"] = eval_metrics["overall"].get("relax_predicted", 0) + 1
                    eval_metrics["category"][labels[cate-1]] = eval_metrics["category"].get(labels[cate-1], {"strict":0, "relax":0, "miss":0})
                    eval_metrics["category"][labels[cate-1]]["relax"] += 1
                else:
                    eval_metrics["overall"]["miss_predicted"] = eval_metrics["overall"].get("miss_predicted", 0) + 1
                    eval_metrics["category"][labels[cate-1]] = eval_metrics["category"].get(labels[cate-1], {"strict":0, "relax":0, "miss":0})
                    eval_metrics["category"][labels[cate-1]]["miss"] += 1
                cur_idx = end_idx
        cur_idx = 0
        while cur_idx < llen:
            if pred[cur_idx] == 0:
                cur_idx += 1
            else:
                start_idx = cur_idx
                end_idx = start_idx + 1
                cate = pred[start_idx]
                while end_idx < llen and pred[end_idx] == cate:
                    end_idx += 1
                if abc_strict_match(gs, pred, start_idx, end_idx, cate):
                    eval_metrics["overall"]["strict_predict"] = eval_metrics["overall"].get("strict_predict", 0) + 1
                    eval_metrics["prediction"][labels[cate-1]] = eval_metrics["prediction"].get(labels[cate-1], {"strict":0, "relax":0, "miss":0})
                    eval_metrics["prediction"][labels[cate-1]]["strict"] += 1
                elif abc_relax_match(gs, pred, start_idx, end_idx, cate):
                    eval_metrics["overall"]["relax_predict"] = eval_metrics["overall"].get("relax_predict", 0) + 1
                    eval_metrics["prediction"][labels[cate-1]] = eval_metrics["prediction"].get(labels[cate-1], {"strict":0, "relax":0, "miss":0})
                    eval_metrics["prediction"][labels[cate-1]]["relax"] += 1
                else:
                    eval_metrics["overall"]["miss_predict"] = eval_metrics["overall"].get("miss_predict", 0) + 1
                    eval_metrics["prediction"][labels[cate-1]] = eval_metrics["prediction"].get(labels[cate-1], {"strict":0, "relax":0, "miss":0})
                    eval_metrics["prediction"][labels[cate-1]]["miss"] += 1
                cur_idx = end_idx

StopIteration: 

In [65]:
attbc_metrics = {"category":{},"overall":{}}
attbc_metrics["overall"]["acc"] = eval_metrics["overall"]["acc_true"]/(eval_metrics["overall"]["acc_true"]+eval_metrics["overall"]["acc_false"])
pred_all = eval_metrics["overall"]['strict_predict'] + eval_metrics["overall"]['relax_predict'] + eval_metrics["overall"]['miss_predict']
pre_relax_all = (eval_metrics["overall"]['strict_predict'] + eval_metrics["overall"]['relax_predict'])/ pred_all
rec_relax_all = (eval_metrics["overall"]['strict_predicted'] + eval_metrics["overall"]['relax_predicted'])/ eval_metrics["overall"]['gs']['count']
f1_relax_all = (2*pre_relax_all*rec_relax_all)/(pre_relax_all+rec_relax_all)
print(f"Overall Relax Level: Precision: {pre_relax_all}, Recall: {rec_relax_all}, F1: {f1_relax_all}")
attbc_metrics["overall"]["relax"] = {"f_score": f1_relax_all, "precision":pre_relax_all, "recall":rec_relax_all}

pre_strict_all = eval_metrics["overall"]['strict_predict'] / pred_all
rec_strict_all = eval_metrics["overall"]['strict_predicted'] / eval_metrics["overall"]['gs']['count']
f1_strict_all = (2*pre_strict_all*rec_strict_all)/(pre_strict_all+rec_strict_all)
print(f"Overall Strict Level: Precision: {pre_strict_all}, Recall: {rec_strict_all}, F1: {f1_strict_all}")
print('\n')
attbc_metrics["overall"]["strict"] = {"f_score": f1_strict_all, "precision":pre_strict_all, "recall":rec_strict_all}

for i in eval_metrics["category"].keys():
    tt = eval_metrics["overall"]['gs'][i]
    tp = eval_metrics["prediction"][i]['strict'] + eval_metrics["prediction"][i]['relax'] + eval_metrics["prediction"][i]['miss']
    
    pre_relax = (eval_metrics["prediction"][i]['strict']+eval_metrics["prediction"][i]['relax'])/tp
    rec_relax = (eval_metrics["category"][i]['strict']+eval_metrics["category"][i]['relax'])/tt
    f1_relax = (2*pre_relax*rec_relax)/(pre_relax+rec_relax)
    print(f"Relax Level for {i}: Precision: {pre_relax}, Recall: {rec_relax}, F1: {f1_relax}")
    attbc_metrics["category"]["relax"] = attbc_metrics["category"].get("relax", {})
    attbc_metrics["category"]["relax"][i] = {"f_score": f1_relax, "precision":pre_relax, "recall":rec_relax}

    pre_strict = eval_metrics["prediction"][i]['strict']/tp
    rec_strict = eval_metrics["category"][i]['strict']/tt
    f1_strict = (2*pre_strict*rec_strict)/(pre_strict+rec_strict) if (pre_strict+rec_strict) != 0 else 0.0
    print(f"Strict Level for {i}: Precision: {pre_strict}, Recall: {rec_strict}, F1: {f1_strict}")
    print('\n')
    attbc_metrics["category"]["strict"] = attbc_metrics["category"].get("strict", {})
    attbc_metrics["category"]["strict"][i] = {"f_score": f1_strict, "precision":pre_strict, "recall":rec_strict}
json.dump(attbc_metrics, open(f"{outfolder}/{outfile}.json", "w", encoding="utf-8"), indent=4)

KeyError: 'acc_true'