In [1]:
import os,sys,inspect
currentdir = os.path.dirname(os.path.abspath(inspect.getfile(inspect.currentframe())))
parentdir = os.path.dirname(currentdir)
sys.path.insert(0,parentdir)

In [2]:
from wiser.data.dataset_readers.cdr import CDRCombinedDatasetReader
from wiser.rules import TaggingRule, LinkingRule, DictionaryMatcher
from wiser.generative import get_label_to_ix, get_rules
from labelmodels import *
from wiser.generative import train_generative_model
from labelmodels import LearningConfig
from wiser.generative import evaluate_generative_model
from wiser.data import save_label_distribution
from wiser.eval import *
from collections import Counter

import torch

from nltk.tokenize import word_tokenize
from nltk.tokenize import sent_tokenize
from tokenizations import get_alignments, get_original_spans
from typing import List, Optional, Tuple
from tqdm.auto import tqdm

from xml.etree import ElementTree
from transformers import AutoTokenizer, AutoModel

## Load Data

In [3]:
cdr_reader = CDRCombinedDatasetReader()
train_data = cdr_reader.read('../data/BC5CDR/CDR_TrainingSet.BioC.xml')
dev_data = cdr_reader.read('../data/BC5CDR/CDR_DevelopmentSet.BioC.xml')
test_data = cdr_reader.read('../data/BC5CDR/CDR_TestSet.BioC.xml')

cdr_docs = train_data + dev_data + test_data

HBox(children=(HTML(value='reading instances'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width…

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))







HBox(children=(HTML(value='reading instances'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width…

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))







HBox(children=(HTML(value='reading instances'), FloatProgress(value=1.0, bar_style='info', layout=Layout(width…

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))







## Tagging Functions

In [4]:
dict_core_chem = set()
dict_core_chem_exact = set()
dict_core_dis = set()
dict_core_dis_exact = set()

with open('../data/AutoNER_dicts/BC5CDR/dict_core.txt') as f:
    for line in f.readlines():
        line = line.strip().split(None, 1)
        entity_type = line[0]
        tokens = cdr_reader.get_tokenizer()(line[1])
        term = tuple([str(x) for x in tokens])

        if len(term) > 1 or len(term[0]) > 3:
            if entity_type == 'Chemical':
                dict_core_chem.add(term)
            elif entity_type == 'Disease':
                dict_core_dis.add(term)
            else:
                raise Exception()
        else:
            if entity_type == 'Chemical':
                dict_core_chem_exact.add(term)
            elif entity_type == 'Disease':
                dict_core_dis_exact.add(term)
            else:
                raise Exception()

lf = DictionaryMatcher(
    "DictCore-Chemical",
    dict_core_chem,
    i_label="I-Chemical",
    uncased=True)
lf.apply(cdr_docs)
lf = DictionaryMatcher(
    "DictCore-Chemical-Exact",
    dict_core_chem_exact,
    i_label="I-Chemical",
    uncased=False)
lf.apply(cdr_docs)
lf = DictionaryMatcher(
    "DictCore-Disease",
    dict_core_dis,
    i_label="I-Disease",
    uncased=True)
lf.apply(cdr_docs)
lf = DictionaryMatcher(
    "DictCore-Disease-Exact",
    dict_core_dis_exact,
    i_label="I-Disease",
    uncased=False)
lf.apply(cdr_docs)

terms = []
with open('../data/umls/umls_element_ion_or_isotope.txt', 'r') as f:
    for line in f.readlines():
        terms.append(line.strip().split(" "))
lf = DictionaryMatcher(
    "Element, Ion, or Isotope",
    terms,
    i_label='I-Chemical',
    uncased=True,
    match_lemmas=True)
lf.apply(cdr_docs)


terms = []
with open('../data/umls/umls_organic_chemical.txt', 'r') as f:
    for line in f.readlines():
        terms.append(line.strip().split(" "))
lf = DictionaryMatcher(
    "Organic Chemical",
    terms,
    i_label='I-Chemical',
    uncased=True,
    match_lemmas=True)
lf.apply(cdr_docs)


terms = []
with open('../data/umls/umls_antibiotic.txt', 'r') as f:
    for line in f.readlines():
        terms.append(line.strip().split(" "))
lf = DictionaryMatcher(
    "Antibiotic",
    terms,
    i_label='I-Chemical',
    uncased=True,
    match_lemmas=True)
lf.apply(cdr_docs)


terms = []
with open('../data/umls/umls_disease_or_syndrome.txt', 'r') as f:
    for line in f.readlines():
        terms.append(line.strip().split(" "))
lf = DictionaryMatcher(
    "Disease or Syndrome",
    terms,
    i_label='I-Disease',
    uncased=True,
    match_lemmas=True)
lf.apply(cdr_docs)

terms = []
with open('../data/umls/umls_body_part.txt', 'r') as f:
    for line in f.readlines():
        terms.append(line.strip().split(" "))
lf = DictionaryMatcher(
    "TEMP",
    terms,
    i_label='TEMP',
    uncased=True,
    match_lemmas=True)
lf.apply(cdr_docs)


class BodyTerms(TaggingRule):
    def apply_instance(self, instance):
        tokens = [token.text.lower() for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        terms = {"cancer", "cancers", "damage", "disease", "diseases", "pain", "injury", "injuries"}

        for i in range(0, len(tokens) - 1):
            if instance['WISER_LABELS']['TEMP'][i] == 'TEMP':
                if tokens[i + 1] in terms:
                    labels[i] = "I-Disease"
                    labels[i + 1] = "I-Disease"
        return labels


lf = BodyTerms()
lf.apply(cdr_docs)

for doc in cdr_docs:
    del doc['WISER_LABELS']['TEMP']


class Acronyms(TaggingRule):
    other_lfs = {
        'I-Chemical': ("Antibiotic", "Element, Ion, or Isotope", "Organic Chemical"),
        'I-Disease': ("BodyTerms", "Disease or Syndrome")
    }

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        active = False
        for tag, lf_names in self.other_lfs.items():
            acronyms = set()
            for lf_name in lf_names:
                for i in range(len(instance['tokens']) - 2):
                    if instance['WISER_LABELS'][lf_name][i] == tag:
                        active = True
                    elif active and instance['tokens'][i].text == '(' and instance['tokens'][i + 2].pos_ == "PUNCT" and instance['tokens'][i + 1].pos_ != "NUM":
                        acronyms.add(instance['tokens'][i + 1].text)
                        active = False
                    else:
                        active = False

            for i, token in enumerate(instance['tokens']):
                if token.text in acronyms:
                    labels[i] = tag

        return labels


lf = Acronyms()
lf.apply(cdr_docs)


class Damage(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i +
                                                                               1].lemma_ == 'damage':
                labels[i] = 'I-Disease'
                labels[i + 1] = 'I-Disease'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I-Disease'
                    else:
                        break

        return labels


lf = Damage()
lf.apply(cdr_docs)


class Disease(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i +
                                                                               1].lemma_ == 'disease':
                labels[i] = 'I-Disease'
                labels[i + 1] = 'I-Disease'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I-Disease'
                    else:
                        break

        return labels


lf = Disease()
lf.apply(cdr_docs)


class Disorder(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i +
                                                                               1].lemma_ == 'disorder':
                labels[i] = 'I-Disease'
                labels[i + 1] = 'I-Disease'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I-Disease'
                    else:
                        break

        return labels


lf = Disorder()
lf.apply(cdr_docs)


class Lesion(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i +
                                                                               1].lemma_ == 'lesion':
                labels[i] = 'I-Disease'
                labels[i + 1] = 'I-Disease'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I-Disease'
                    else:
                        break

        return labels


lf = Lesion()
lf.apply(cdr_docs)


class Syndrome(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            if instance['tokens'][i].dep_ == 'compound' and instance['tokens'][i +
                                                                               1].lemma_ == 'syndrome':
                labels[i] = 'I-Disease'
                labels[i + 1] = 'I-Disease'

                # Adds any other compound tokens before the phrase
                for j in range(i - 1, -1, -1):
                    if instance['tokens'][j].dep_ == 'compound':
                        labels[j] = 'I-Disease'
                    else:
                        break

        return labels


lf = Syndrome()
lf.apply(cdr_docs)


exceptions = {'determine', 'baseline', 'decline',
              'examine', 'pontine', 'vaccine',
              'routine', 'crystalline', 'migraine',
              'alkaline', 'midline', 'borderline',
              'cocaine', 'medicine', 'medline',
              'asystole', 'control', 'protocol',
              'alcohol', 'aerosol', 'peptide',
              'provide', 'outside', 'intestine',
              'combine', 'delirium', 'VIP'}

suffixes = ('ine', 'ole', 'ol', 'ide', 'ine', 'ium', 'epam')


class ChemicalSuffixes(TaggingRule):
    def apply_instance(self, instance):

        labels = ['ABS'] * len(instance['tokens'])

        acronyms = set()
        for i, t in enumerate(instance['tokens']):
            if len(t.lemma_) >= 7 and t.lemma_ not in exceptions and t.lemma_.endswith(
                    suffixes):
                labels[i] = 'I-Chemical'

                if i < len(instance['tokens']) - 3 and instance['tokens'][i + \
                           1].text == '(' and instance['tokens'][i + 3].text == ')':
                    acronyms.add(instance['tokens'][i + 2].text)

        for i, t in enumerate(instance['tokens']):
            if t.text in acronyms and t.text not in exceptions:
                labels[i] = 'I-Chemical'
        return labels


lf = ChemicalSuffixes()
lf.apply(cdr_docs)


class CancerLike(TaggingRule):
    def apply_instance(self, instance):
        tokens = [token.text.lower() for token in instance['tokens']]
        labels = ['ABS'] * len(tokens)

        suffixes = ("edema", "toma", "coma", "noma")

        for i, token in enumerate(tokens):
            for suffix in suffixes:
                if token.endswith(suffix) or token.endswith(suffix + "s"):
                    labels[i] = 'I-Disease'
        return labels


lf = CancerLike()
lf.apply(cdr_docs)


exceptions = {'diagnosis', 'apoptosis', 'prognosis', 'metabolism'}

suffixes = ("agia", "cardia", "trophy", "itis",
            "emia", "enia", "pathy", "plasia", "lism", "osis")


class DiseaseSuffixes(TaggingRule):
    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i, t in enumerate(instance['tokens']):
            if len(t.lemma_) >= 5 and t.lemma_.lower(
            ) not in exceptions and t.lemma_.endswith(suffixes):
                labels[i] = 'I-Disease'

        return labels


lf = DiseaseSuffixes()
lf.apply(cdr_docs)


exceptions = {'hypothesis', 'hypothesize', 'hypobaric', 'hyperbaric'}

prefixes = ('hyper', 'hypo')


class DiseasePrefixes(TaggingRule):
    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i, t in enumerate(instance['tokens']):
            if len(t.lemma_) >= 5 and t.lemma_.lower(
            ) not in exceptions and t.lemma_.startswith(prefixes):
                if instance['tokens'][i].pos_ == "NOUN":
                    labels[i] = 'I-Disease'

        return labels


lf = DiseasePrefixes()
lf.apply(cdr_docs)


exceptions = {
    "drug",
    "pre",
    "therapy",
    "anesthetia",
    "anesthetic",
    "neuroleptic",
    "saline",
    "stimulus"}


class Induced(TaggingRule):
    def apply_instance(self, instance):

        labels = ['ABS'] * len(instance['tokens'])

        for i in range(1, len(instance['tokens']) - 3):
            lemma = instance['tokens'][i].lemma_.lower()
            if instance['tokens'][i].text == '-' and instance['tokens'][i +
                                                                        1].lemma_ == 'induce':
                labels[i] = 'O'
                labels[i + 1] = 'O'
                if instance['tokens'][i -
                                      1].lemma_ in exceptions or instance['tokens'][i -
                                                                                    1].pos_ == "PUNCT":
                    labels[i - 1] = 'O'
                else:
                    labels[i - 1] = 'I-Chemical'
        return labels


lf = Induced()
lf.apply(cdr_docs)


class Vitamin(TaggingRule):
    def apply_instance(self, instance):

        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens']) - 1):
            text = instance['tokens'][i].text.lower()
            if instance['tokens'][i].text.lower() == 'vitamin':
                labels[i] = 'I-Chemical'
                if len(instance['tokens'][i +
                                          1].text) <= 2 and instance['tokens'][i +
                                                                               1].text.isupper():
                    labels[i + 1] = 'I-Chemical'

        return labels


lf = Vitamin()
lf.apply(cdr_docs)


class Acid(TaggingRule):
    def apply_instance(self, instance):

        labels = ['ABS'] * len(instance['tokens'])

        tokens = instance['tokens']

        for i, t in enumerate(tokens):
            if i > 0 and t.text.lower(
            ) == 'acid' and tokens[i - 1].text.endswith('ic'):
                labels[i] = 'I-Chemical'
                labels[i - 1] = 'I-Chemical'

        return labels


lf = Acid()
lf.apply(cdr_docs)


class OtherPOS(TaggingRule):
    other_pos = {"ADP", "ADV", "DET", "VERB"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(0, len(instance['tokens'])):
            # Some chemicals with long names get tagged as verbs
            if instance['tokens'][i].pos_ in self.other_pos and instance['WISER_LABELS'][
                    'Organic Chemical'][i] == 'ABS' and instance['WISER_LABELS']['DictCore-Chemical'][i] == 'ABS':
                labels[i] = "O"
        return labels


lf = OtherPOS()
lf.apply(cdr_docs)


stop_words = {"a", "an", "as", "be", "but", "do", "even",
              "for", "from",
              "had", "has", "have", "i", "in", "is", "its", "just",
              "may", "my", "no", "not", "on", "or",
              "than", "that", "the", "these", "this", "those", "to", "very",
              "what", "which", "who", "with"}


class StopWords(TaggingRule):

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens'])):
            if instance['tokens'][i].lemma_ in stop_words:
                labels[i] = 'O'
        return labels


lf = StopWords()
lf.apply(cdr_docs)


class CommonOther(TaggingRule):
    other_lemmas = {'patient', '-PRON-', 'induce', 'after', 'study',
                    'rat', 'mg', 'use', 'treatment', 'increase',
                    'day', 'group', 'dose', 'treat', 'case', 'result',
                    'kg', 'control', 'report', 'administration', 'follow',
                    'level', 'suggest', 'develop', 'week', 'compare',
                    'significantly', 'receive', 'mouse',
                    'protein', 'infusion', 'output', 'area', 'effect',
                    'rate', 'weight', 'size', 'time', 'year',
                    'clinical', 'conclusion', 'outcome', 'man', 'woman',
                    'model', 'concentration'}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])
        for i in range(len(instance['tokens'])):
            if instance['tokens'][i].lemma_ in self.other_lemmas:
                labels[i] = 'O'
        return labels


lf = CommonOther()
lf.apply(cdr_docs)


class Punctuation(TaggingRule):

    other_punc = {"?", "!", ";", ":", ".", ",",
                  "%", "<", ">", "=", "\\"}

    def apply_instance(self, instance):
        labels = ['ABS'] * len(instance['tokens'])

        for i in range(len(instance['tokens'])):
            if instance['tokens'][i].text in self.other_punc:
                labels[i] = 'O'
        return labels


lf = Punctuation()
lf.apply(cdr_docs)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




## Linking Rules

In [5]:

class PossessivePhrase(LinkingRule):
    def apply_instance(self, instance):
        links = [0] * len(instance['tokens'])
        for i in range(1, len(instance['tokens'])):
            if instance['tokens'][i -
                                  1].text == "'s" or instance['tokens'][i].text == "'s":
                links[i] = 1

        return links


lf = PossessivePhrase()
lf.apply(cdr_docs)


class HyphenatedPrefix(LinkingRule):
    chem_mods = set(["alpha", "beta", "gamma", "delta", "epsilon"])

    def apply_instance(self, instance):
        links = [0] * len(instance['tokens'])
        for i in range(1, len(instance['tokens'])):
            if (instance['tokens'][i - 1].text.lower() in self.chem_mods or
                    len(instance['tokens'][i - 1].text) < 2) \
                    and instance['tokens'][i].text == "-":
                links[i] = 1

        return links


lf = HyphenatedPrefix()
lf.apply(cdr_docs)


class PostHyphen(LinkingRule):
    def apply_instance(self, instance):
        links = [0] * len(instance['tokens'])
        for i in range(1, len(instance['tokens'])):
            if instance['tokens'][i - 1].text == "-":
                links[i] = 1

        return links


lf = PostHyphen()
lf.apply(cdr_docs)


dict_full = set()

with open('../data/AutoNER_dicts/BC5CDR/dict_full.txt') as f:
    for line in f.readlines():
        tokens = cdr_reader.get_tokenizer()(line.strip())
        term = tuple([str(x) for x in tokens])
        if len(term) > 1:
            dict_full.add(tuple(term))


class ExtractedPhrase(LinkingRule):
    def __init__(self, terms):
        self.term_dict = {}

        for term in terms:
            term = [token.lower() for token in term]
            if term[0] not in self.term_dict:
                self.term_dict[term[0]] = []
            self.term_dict[term[0]].append(term)

        # Sorts the terms in decreasing order so that we match the longest
        # first
        for first_token in self.term_dict.keys():
            to_sort = self.term_dict[first_token]
            self.term_dict[first_token] = sorted(
                to_sort, reverse=True, key=lambda x: len(x))

    def apply_instance(self, instance):
        tokens = [token.text.lower() for token in instance['tokens']]
        links = [0] * len(instance['tokens'])

        i = 0
        while i < len(tokens):
            if tokens[i] in self.term_dict:
                candidates = self.term_dict[tokens[i]]
                for c in candidates:
                    # Checks whether normalized AllenNLP tokens equal the list
                    # of string tokens defining the term in the dictionary
                    if i + len(c) <= len(tokens):
                        equal = True
                        for j in range(len(c)):
                            if tokens[i + j] != c[j]:
                                equal = False
                                break

                        # If tokens match, labels the instance tokens
                        if equal:
                            for j in range(i + 1, i + len(c)):
                                links[j] = 1
                            i = i + len(c) - 1
                            break
            i += 1

        return links


lf = ExtractedPhrase(dict_full)
lf.apply(cdr_docs)

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=1500.0), HTML(value='')))




## Define Functions

In [6]:
def txt_to_token_span(tokens: List[str],
                      text: str,
                      txt_spans: List[tuple]):
    """
    Transfer text-domain spans to token-domain spans
    :param tokens: tokens
    :param text: text
    :param txt_spans: text spans tuples: (start, end, ...)
    :return: a list of transferred span tuples.
    """
    token_indices = get_original_spans(tokens, text)
    tgt_spans = list()
    for txt_span in txt_spans:
        spacy_start = txt_span[0]
        spacy_end = txt_span[1]
        start = None
        end = None
        for i, (s, e) in enumerate(token_indices):
            if s <= spacy_start < e:
                start = i
            if s <= spacy_end <= e:
                end = i + 1
            if (start is not None) and (end is not None):
                break
        assert (start is not None) and (end is not None), ValueError("input spans out of scope")
        tgt_spans.append((start, end))
    return tgt_spans

In [7]:
def respan(src_tokens: List[str],
           tgt_tokens: List[str],
           src_span: List[tuple]):
    """
    transfer original spans to target spans
    :param src_tokens: source tokens
    :param tgt_tokens: target tokens
    :param src_span: a list of span tuples. The first element in the tuple
    should be the start index and the second should be the end index
    :return: a list of transferred span tuples.
    """
    s2t, _ = get_alignments(src_tokens, tgt_tokens)
    tgt_spans = list()
    for spans in src_span:
        start = s2t[spans[0]][0]
        if spans[1] < len(s2t):
            end = s2t[spans[1]-1][-1] + 1
        else:
            end = s2t[-1][-1]
        if end == start:
            end += 1
        tgt_spans.append((start, end))

    return tgt_spans

In [8]:
def label_to_span(labels: List[str],
                  scheme: Optional[str] = 'BIO') -> dict:
    """
    convert labels to spans
    :param labels: a list of labels
    :param scheme: labeling scheme, in ['BIO', 'BILOU'].
    :return: labeled spans, a list of tuples (start_idx, end_idx, label)
    """
    assert scheme in ['BIO', 'BILOU'], ValueError("unknown labeling scheme")

    labeled_spans = dict()
    i = 0
    while i < len(labels):
        if labels[i] == 'O' or labels[i] == 'ABS':
            i += 1
            continue
        else:
            if scheme == 'BIO':
                if labels[i][0] == 'B':
                    start = i
                    lb = labels[i][2:]
                    i += 1
                    try:
                        while labels[i][0] == 'I':
                            i += 1
                        end = i
                        labeled_spans[(start, end)] = lb
                    except IndexError:
                        end = i
                        labeled_spans[(start, end)] = lb
                        i += 1
                # this should not happen
                elif labels[i][0] == 'I':
                    i += 1
            elif scheme == 'BILOU':
                if labels[i][0] == 'U':
                    start = i
                    end = i + 1
                    lb = labels[i][2:]
                    labeled_spans[(start, end)] = lb
                    i += 1
                elif labels[i][0] == 'B':
                    start = i
                    lb = labels[i][2:]
                    i += 1
                    try:
                        while labels[i][0] != 'L':
                            i += 1
                        end = i
                        labeled_spans[(start, end)] = lb
                    except IndexError:
                        end = i
                        labeled_spans[(start, end)] = lb
                        break
                    i += 1
                else:
                    i += 1

    return labeled_spans

In [9]:
def build_bert_emb(sents: List[str],
                   tokenizer,
                   model,
                   device: str):
    bert_embs = list()
    for i, sent in enumerate(sents):

        joint_sent = ' '.join(sent)
        bert_tokens = tokenizer.tokenize(joint_sent)

        input_ids = torch.tensor([tokenizer.encode(joint_sent, add_special_tokens=True)], device=device)
        # calculate BERT last layer embeddings
        with torch.no_grad():
            last_hidden_states = model(input_ids)[0].squeeze(0).to('cpu')
            trunc_hidden_states = last_hidden_states[1:-1, :]

        ori2bert, bert2ori = get_alignments(sent, bert_tokens)

        emb_list = list()
        for idx in ori2bert:
            emb = trunc_hidden_states[idx, :]
            emb_list.append(emb.mean(dim=0))

        # TODO: using the embedding of [CLS] may not be the best idea
        # It does not matter since that embedding is not used in the training
        emb_list = [last_hidden_states[0, :]] + emb_list
        bert_emb = torch.stack(emb_list)
        bert_embs.append(bert_emb)
    return bert_embs

## Construct My dataset

In [10]:
LABEL = ['Chemical', 'Disease']
LINK = 'LINK'

In [11]:
train_path = '../data/BC5CDR/CDR_TrainingSet.BioC.xml'
dev_path = '../data/BC5CDR/CDR_DevelopmentSet.BioC.xml'
test_path = '../data/BC5CDR/CDR_TestSet.BioC.xml'

root = ElementTree.parse(train_path).getroot()
xml_docs = root.findall("./document")
train_xml_sents = list()
for xml_doc in tqdm(xml_docs):
    xml_title = xml_doc.find("passage[infon='title']")
    xml_abstract = xml_doc.find("passage[infon='abstract']")

    title = xml_title.find('text').text
    abstract = xml_abstract.find('text').text
    train_xml_sents.append(title + " " + abstract)

root = ElementTree.parse(dev_path).getroot()
xml_docs = root.findall("./document")
dev_xml_sents = list()
for xml_doc in tqdm(xml_docs):
    xml_title = xml_doc.find("passage[infon='title']")
    xml_abstract = xml_doc.find("passage[infon='abstract']")

    title = xml_title.find('text').text
    abstract = xml_abstract.find('text').text
    dev_xml_sents.append(title + " " + abstract)

root = ElementTree.parse(test_path).getroot()
xml_docs = root.findall("./document")
test_xml_sents = list()
for xml_doc in tqdm(xml_docs):
    xml_title = xml_doc.find("passage[infon='title']")
    xml_abstract = xml_doc.find("passage[infon='abstract']")

    title = xml_title.find('text').text
    abstract = xml_abstract.find('text').text
    test_xml_sents.append(title + " " + abstract)

xml_sents = train_xml_sents + dev_xml_sents + test_xml_sents

HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))




HBox(children=(HTML(value=''), FloatProgress(value=0.0, max=500.0), HTML(value='')))




In [12]:
src_token_list = list()
src_anno_list = list()
weak_anno_list = list()
link_anno_list = list()

allen_data = cdr_docs
mapping_dict = {0:'O', 1:'I'}

for src_txt, allen_annos in zip(xml_sents, allen_data):
    
    # handle the data read from the source text
    src_tokens = word_tokenize(src_txt)
    for i in range(len(src_tokens)):
        if src_tokens[i] == r'``' or src_tokens[i] == r"''":
            src_tokens[i] = r'"'
    
    
    allen_tokens = list(map(str, allen_annos['tokens']))
    
    src_labels = list(allen_annos['tags'])
    pre_anno = 'O'
    for i in range(len(src_labels)):
        current_anno = src_labels[i]
        if src_labels[i][0] == 'I':
            if pre_anno[0] != 'I' and pre_anno[0] != 'B':
                src_labels[i] = 'B-' + src_labels[i][2:]
        pre_anno = current_anno
    src_spans = label_to_span(src_labels)
    src_spans_ = respan(allen_tokens, src_tokens, src_spans)
    
    src_annos = dict()
    for span, lb in zip(src_spans_, src_spans.values()):
        src_annos[span] = lb

    src_token_list.append(src_tokens)
    src_anno_list.append(src_annos)
    
    # handle the data constructed using Allennlp
    weak_anno = dict()
    
    for k in allen_annos['WISER_LABELS']:
        std_lbs = allen_annos['WISER_LABELS'][k][:]
        
        pre_anno = 'O'
        for i in range(len(std_lbs)):
            current_anno = std_lbs[i]
            if std_lbs[i][0] == 'I':
                if pre_anno[0] != 'I' and pre_anno[0] != 'B':
                    std_lbs[i] = 'B-' + std_lbs[i][2:]
            pre_anno = current_anno
        weak_span = label_to_span(std_lbs)

        src_weak_span = respan(allen_tokens, src_tokens, weak_span)
        src_weak_anno = dict()
        for span, lb in zip(src_weak_span, weak_span.values()):
            src_weak_anno[span] = [(lb, 1.0)]
            
        weak_anno[k] = src_weak_anno
    weak_anno_list.append(weak_anno)

    
    linked_dict = dict()
    for src, entity_lbs in allen_annos['WISER_LINKS'].items():
        entity_lbs = [mapping_dict[lb] for lb in entity_lbs]

        pre_anno = 'O'
        for i in range(len(entity_lbs)):
            current_anno = entity_lbs[i]
            if entity_lbs[i] == 'I':
                if pre_anno != 'I':
                    entity_lbs[i] = 'B-' + LINK
                else:
                    entity_lbs[i] = 'I-' + LINK
            pre_anno = current_anno

        entity_spans = label_to_span(entity_lbs)
        complete_span = dict()
        for (start, end), lb in entity_spans.items():
            if start != 0:
                start = start - 1
            complete_span[(start, end)] = lb
        src_link_span = respan(allen_tokens, src_tokens, complete_span)
        linked_dict[src] = src_link_span
    link_anno_list.append(linked_dict)

In [13]:
updated_link_anno_list = list()
for tag_anno, link_anno in zip(weak_anno_list, link_anno_list):
    tag_spans = list()
    for src, spans in tag_anno.items():
        for k, v in spans.items():
            tag_spans.append((set(range(k[0], k[1])), v[0][0]))
    
    link_entities = dict()
    for src, spans in link_anno.items():
        valid_spans = dict()
        for span in spans:
#             if span[1] - span[0] == 1:
#                 continue
            span_set = set(range(span[0], span[1]))
            for tag_span, lb in tag_spans:
                if span_set.intersection(tag_span):
                    if span in valid_spans.keys():
                        if lb not in valid_spans[span]:
                            valid_spans[span].append(lb)
                    else:
                        valid_spans[span] = [lb]
        valid_anno = dict()
        for sp, lbs in valid_spans.items():
            prob = 1/len(lbs)
            valid_anno[sp] = [(lb, prob) for lb in lbs]
        link_entities[src] = valid_anno
    updated_link_anno_list.append(link_entities)

In [14]:
combined_anno_list = list()
for tag_anno, link_anno in zip(weak_anno_list, updated_link_anno_list):
    comb_anno = dict()
    for k, v in tag_anno.items():
        comb_anno[k] = v
    for k, v in link_anno.items():
        comb_anno[k] = v
    combined_anno_list.append(comb_anno)

## Build Bert Embeddings

In [15]:
tokenizer = AutoTokenizer.from_pretrained('allenai/scibert_scivocab_uncased')

model = AutoModel.from_pretrained('allenai/scibert_scivocab_uncased')

In [16]:
standarized_sents = list()
o2n_map = list()
n=0
for i, sents in enumerate(src_token_list):
    joint_sent = ' '.join(sents)
    len_bert_tokens = len(tokenizer.tokenize(joint_sent))
    if len_bert_tokens >= 510:        
        sts = sent_tokenize(joint_sent)
        
        sent_lens = list()
        for st in sts:
            sent_lens.append(len(word_tokenize(st)))
        ends = [np.sum(sent_lens[:i]) for i in range(1, len(sent_lens)+1)]
        
        nearest_end_idx1 = np.argmin((np.array(ends) - len_bert_tokens / 3) ** 2)
        nearest_end_idx2 = np.argmin((np.array(ends) - len_bert_tokens / 3 * 2) ** 2)
        split_1 = sents[:ends[nearest_end_idx1]]
        split_2 = sents[ends[nearest_end_idx1]:ends[nearest_end_idx2]]
        split_3 = sents[ends[nearest_end_idx2]:]
        standarized_sents.append(split_1)
        standarized_sents.append(split_2)
        standarized_sents.append(split_3)
        o2n_map.append([n, n+1, n+2])
        n += 3

    else:
        standarized_sents.append(sents)
        o2n_map.append([n])
        n += 1

In [17]:
for i, sents in enumerate(standarized_sents):
    joint_sent = ' '.join(sents)
    if len(tokenizer.tokenize(joint_sent)) >= 510:
        print(i, len(sents), len(tokenizer.tokenize(joint_sent)))

In [18]:
device = torch.device('cuda') if torch.cuda.is_available() else torch.device('cpu')
model = model.to(device)
embs = build_bert_emb(standarized_sents, tokenizer, model, device)

In [32]:
combined_embs = list()
for o2n in o2n_map:
    if len(o2n) == 1:
        combined_embs.append(embs[o2n[0]])
    else:
        cat_emb = torch.cat([embs[o2n[0]], embs[o2n[1]][1:], embs[o2n[2]][1:]], dim=0)
        combined_embs.append(cat_emb)

In [38]:
for emb, sent in zip(combined_embs, src_token_list):
    assert len(emb) == len(sent) + 1

In [39]:
test_token_list = src_token_list[1000:]
test_anno_list = combined_anno_list[1000:]
test_lb_list = src_anno_list[1000:]
test_emb = combined_embs[1000:]

train_token_list = src_token_list[:500]
train_anno_list = combined_anno_list[:500]
train_lb_list = src_anno_list[:500]
train_emb = combined_embs[:500]

dev_token_list = src_token_list[500:1000]
dev_anno_list = combined_anno_list[500:1000]
dev_lb_list = src_anno_list[500:1000]
dev_emb = combined_embs[500:1000]

In [40]:
train_data = {
    "sentences": train_token_list,
    "annotations": train_anno_list,
    "labels": train_lb_list
}

torch.save(train_data, f"BC5CDR-linked-train.pt")
torch.save(train_emb, f"BC5CDR-emb-train.pt")

dev_data = {
    "sentences": dev_token_list,
    "annotations": dev_anno_list,
    "labels": dev_lb_list
}

torch.save(dev_data, f"BC5CDR-linked-dev.pt")
torch.save(dev_emb, f"BC5CDR-emb-dev.pt")

test_data = {
    "sentences": test_token_list,
    "annotations": test_anno_list,
    "labels": test_lb_list
}

torch.save(test_data, f"BC5CDR-linked-test.pt")
torch.save(test_emb, f"BC5CDR-emb-test.pt")