# AHLT Term Project - DDI Classifier
## Alex Paranov, Anthony Nixon
### MIRI Masters - Term 2 2018


## Part 1: Defining Xml Classes

The first component of our project is to create data structures in which we can store and manipulate the xml data in an effecient manner.

We create four classes which correspond to the tagged elements in the xml annotation:

#### Document:
Represents and stores a full text sample consisting of sentence objects. The document class also contains a function "set_features()" which passes a call to a set_features() method at the sentence level and assigns featured words to the document featured_words list.

#### Sentence:
A sentence is a discrete segment of text which can be broken down into entities and pairs. The composing entities and pairs are stored in lists in the object of the same name.

An important part of the Sentence class is it's "set_features()" function which iterates over the entities and splits the words in the text and tags them. Then for each tagged_word, the helper function "get_featured_tuple()" returns a list of features based on orthographic features, prefix and suffix, word shapes, etc. (We will cover these features and their rationale in more detail later in the report).

#### Entity:
Stores a relevent mention of a drug name / substance / etc. in a sentence as well as the location offset.

#### Pair:
A pair is a drug drug interation relating entities in a sentence.

In [3]:
from nltk import word_tokenize, pos_tag
from nltk.stem import SnowballStemmer
from numpy import isfinite
import re
import string

class Document:
    def __init__(self, id):
        self.id = id
        self.sentences = []

    def add_sentence(self, sentence):
        self.sentences.append(sentence)

    def __str__(self):
        st = "DOCUMENT. Id: "+self.id + '\n'
        for sentence in self.sentences:
            st = st + sentence.__str__() + '\n'
        return st

    # Sets features for each sentence
    def set_features(self):
        featured_words = []
        featured_words_dict = [] #we need dictionary for preprocessing for ML algorithm
        for sentence in self.sentences:
            sent_features = sentence.set_features()
            m_dict = {}
            for s_feature in sent_features:
                for i in range(len(s_feature)):
                    m_dict[str(i)] = s_feature[i]

                featured_words_dict.append(m_dict)
            featured_words.extend(sent_features)

        self.featured_words = featured_words
        self.featured_words_dict = featured_words_dict

class Sentence:
    def __init__(self, id, text):
        self.id = id
        self.text = text
        self.entities = []
        self.pairs = []

    def add_entity(self, entity):
        self.entities.append(entity)

    def add_pair(self, pair):
        self.pairs.append(pair)

    def __str__(self):
        st = "\t---SENTENCE. Id: "+self.id+", Text: "+self.text + '\n'
        for entity in self.entities:
            st = st + entity.__str__() +'\n'
        return st

    def set_features(self):
        B_tags = [] #list with words that are of type B tag
        I_tags = [] #list of words that are of type I tag
        for entity in self.entities:
            words = entity.text.split(" ") #split words in text to tag
            for index, word in enumerate(words):
                if index == 0:
                    B_tags.append(word)
                else:
                    I_tags.append(word)

        tagged_words = pos_tag(word_tokenize(self.text))
        all_features = []

        for index, tagged_word in enumerate(tagged_words):
            # We don't want to save punctuations
            if len(tagged_word[0]) < 2:
                continue
            if tagged_word[0] in B_tags:
                all_features.append(self.get_featured_tuple(index, tagged_words, 'B'))
            elif tagged_word[0] in I_tags:
                all_features.append(self.get_featured_tuple(index, tagged_words, 'I'))
            else:
                all_features.append(self.get_featured_tuple(index, tagged_words, 'O'))

        return all_features

    # Following some guidelines from this table https://www.hindawi.com/journals/cmmm/2015/913489/tab1/
    def get_featured_tuple(self, index, tagged_words, bio_tag):
        features = [bio_tag]
        word = tagged_words[index][0]

        # get array of [word,pos_tag] for +-2 word window
        if len(tagged_words) > 2:
            windows = get_words_window(index, tagged_words, 2)
        elif len(tagged_words) > 1:
            windows = get_words_window(index, tagged_words, 1)
        else:
            windows = get_words_window(index, tagged_words, 0)

        features.extend(windows)

        # add length of a word
        features.append(len(word))

        orthographical_feature = get_orthographical_feature(word)
        features.append(orthographical_feature)

        # Prefix and suffixe is of lengths 3,4,5 respectively
        prefix_suffix_features = get_prefix_suffix_feature(word)
        features.extend(prefix_suffix_features)

        # General word shape and brief word shape
        word_shapes = get_word_shapes(word)
        features.extend(word_shapes)

        # May be add Y,N if drug is in drugbank or FDA approved list of drugs?
        return tuple(features)

# Getting words and pos tags of window +/- n
# return will be [word-n,pos_tag-n,.....word+n,pos_tag+n]
def get_words_window(index, tagged_words, n):
    windows = []
    if n >= len(tagged_words):
        raise ValueError("n must be less than length of tagged_words")

    for i in range(-(n+1),n+1):
        # we can reach the first and last element, so we are safe to get them
        if index + i >= 0 and index + i < len(tagged_words):
            word = tagged_words[index + i][0]
            pos_tag = tagged_words[index + i][1]
        else:
            word = ''
            pos_tag = ''

        windows.append(word)
        windows.append(pos_tag)

    return windows

def get_orthographical_feature(word):
    orthographical_feature = "alphanumeric"
    f_uppercase = lambda w: 1 if ord(w) >= 65 and ord(w) <= 90 else 0
    upper_case = list(map(f_uppercase, word))

    if sum(upper_case) == len(word):
        orthographical_feature = "all-capitalized"
    elif f_uppercase(word[0]) == 1:
        orthographical_feature = "is-capitalized"

    # Lambda function which uses ascii code of a character
    f_numerics = lambda w: 1 if w.isnumeric() else 0
    numerics = list(map(f_numerics, word))

    if sum(numerics) == len(word):
        orthographical_feature = "all-digits"

    if "-" in word:
        orthographical_feature += "Y"
    else:
        orthographical_feature += "N"

    return orthographical_feature

def get_prefix_suffix_feature(word):
    snowball_stemmer = SnowballStemmer("english")
    stemmed_word = snowball_stemmer.stem(word)
    ind = word.find(stemmed_word)

    prefix_len = len(word[:ind])
    suffix_len = len(word) - prefix_len - len(stemmed_word)

    pl3 = int(prefix_len == 3); sufl3 = int(suffix_len == 3)
    pl4 = int(prefix_len == 4); sufl4 = int(suffix_len == 4)
    pl5 = int(prefix_len == 5); sufl5 = int(suffix_len == 5)

    return (pl3, pl4, pl5, sufl3, sufl4, sufl5)

def get_word_shapes(word):
    # Generalized Word Shape Feature. Map upper case, lower case, digit and
    # other characters to X,x,0 and O respectively
    # Aspirin1+ will be mapped to Xxxxxxx0O, for example
    word_shape = ""
    for w in word:
        if w.isupper():
            word_shape += "X"
        elif w.islower():
            word_shape += "x"
        elif w.isnumeric():
            word_shape += "0"
        else:
            word_shape += "O"

    # Brief word shape. maps consecutive uppercase letters, lowercase letters,
    # digits, and other characters to “X,” “x,” “0,” and “O,” respectively.
    # Aspirin1+ will be mapped to Xx0O

    # Lambda function to determine if character belongs to category other based on its ascii value
    # We assume ascii unicode, which is true since our XML has UTF-8 encoding (English text)
    f_other = lambda w: True if (ord(w) < 48 or (ord(w) >= 58 and ord(w) <= 64) or
    (ord(w) >= 91 and ord(w) <= 96) or ord(w) > 122) else False

    word_shape_brief = ""
    i = 0
    while i < len(word):
        if word[i].isupper():
            word_shape_brief += "X"
            while i < len(word) and word[i].isupper():
                i += 1
            if i == len(word):
                break
        if word[i].islower():
            word_shape_brief += "x"
            while i < len(word) and word[i].islower():
                i += 1
            if i == len(word):
                break
        if word[i].isnumeric():
            word_shape_brief += "0"
            while i < len(word) and word[i].isnumeric():
                i += 1
            if i == len(word):
                break
        if f_other(word[i]):
            word_shape_brief += "O"
            while i < len(word) and f_other(word[i]):
                i += 1
                if i == len(word):
                    break
        i += 1

    return (word_shape, word_shape_brief)

class Entity:
    def __init__(self, id, charOffset, type, text):
        self.id = id
        self.charOffset = charOffset
        self.type = type
        self.text = text

    def __str__(self):
        st = "\t\t---ENTITY. Id: "+self.id+", CharOffSet: "+self.charOffset+", Type: "+self.type+", Text: "+self.text
        return st

class Pair:
    def __init__(self, id, e1, e2, ddi):
        self.id = id
        self.e1 = e1
        self.e2 = e2
        self.ddi = ddi
        self.type = ""

    def set_type(self, type):
        self.type = type

    def __str__(self):
        st = "\t\t---PAIR. Id: "+self.id+", E1: "+self.e1+", E2: "+self.e2+", DDI: "+str(self.ddi)
        if self.ddi:
            st += ", Type: "+self.type
        return st


## Part 2: Parsing

The following code is our parser. The primary execution of the block is initiated by the parse_all_files() method. The parser first looks to see if the files have been parsed and stored locally, if not, then it will begin parsing.

The parser stores the data in our Document, Sentence, Entity, and Pair objects.

In [4]:
#!/usr/bin/python3
from xml_classes import *
import xml.etree.ElementTree as ET
from os.path import abspath, join, isdir, exists
from os import listdir, makedirs
import sys
import pickle

# Each dictionary contains name of dictionary and data, which is paths of all files in specified directory
train_path = abspath("data/train/DrugBank")
drug_bank_train = {'name': 'drug_bank_train', 'data': [join(train_path, f) for f in listdir(train_path)]}

train_path = abspath("data/train/MedLine")
medline_train =   {'name':'medline_train', 'data': [join(train_path, f) for f in listdir(train_path)]}

# Test for DDI extraction task

test_path = abspath("data/test/Test_DDI_Extraction_task/DrugBank")
drug_bank_ddi_test = {'name': 'drug_bank_ddi_test', 'data': [join(test_path, f) for f in listdir(test_path)]}
test_path = abspath("data/test/Test_DDI_Extraction_task/MedLine")
medline_ddi_test =   {'name': 'medline_ddi_test', 'data': [join(test_path, f) for f in listdir(test_path)]}

# Test for DrugNER task
test_path = abspath("data/test/Test_DrugNER_task/DrugBank")
drug_bank_ner_test = {'name': 'drug_bank_ner_test', 'data': [join(test_path, f) for f in listdir(test_path)]}
test_path = abspath("data/test/Test_DrugNER_task/MedLine")
medline_ner_test =   {'name': 'medline_ner_test', 'data': [join(test_path, f) for f in listdir(test_path)]}

class Parser:
    def set_path(self, xml_path):
        self.path = xml_path

    def parse_xml(self):
        tree = ET.parse(self.path)
        root = tree.getroot()
        document = Document(root.attrib['id'])
        for child in root:
            if child.tag == "sentence":
                sentence = Sentence(child.attrib['id'], child.attrib['text'])
                if len(sentence.text) < 2:
                    continue
                for second_child in child:
                    attr = second_child.attrib
                    if second_child.tag == "entity":
                        entity = Entity(attr['id'], attr['charOffset'], attr['type'], attr['text'])
                        sentence.add_entity(entity)
                    elif second_child.tag == "pair":
                        ddi = False
                        if attr['ddi'] == "true":
                            ddi = True

                        pair = Pair(attr['id'],attr['e1'],attr['e2'], ddi)
                        if pair.ddi and 'type' in attr:
                            pair.set_type(attr['type'])

                        sentence.add_pair(pair)

                document.add_sentence(sentence)
        return document

    def parse_save_xml_dict(self, xml_dict):
        parsed_docs = []
        for doc in xml_dict['data']:
            print("Parsing: "+doc)
            self.set_path(doc)
            d = self.parse_xml()
            parsed_docs.append(d)

        dir_path = abspath("data/pickle")
        if not isdir(dir_path):
            makedirs(dir_path)

        pickle_name = xml_dict['name']+".pkl"
        with open(join(dir_path, pickle_name),"wb") as f:
            pickle.dump(parsed_docs, f)
            print("Saved parsed documents from " + pickle_name + " into pickle!\n")

def parse_all_files():
    parser = Parser()
    if not exists("data/pickle/"+drug_bank_train['name']+".pkl"):
        parser.parse_save_xml_dict(drug_bank_train)
    if not exists("data/pickle/"+medline_train['name']+".pkl"):
        parser.parse_save_xml_dict(medline_train)
    if not exists("data/pickle/"+drug_bank_ddi_test['name']+".pkl"):
        parser.parse_save_xml_dict(drug_bank_ddi_test)
    if not exists("data/pickle/"+medline_ddi_test['name']+".pkl"):
        parser.parse_save_xml_dict(medline_ddi_test)
    if not exists("data/pickle/"+drug_bank_ner_test['name']+".pkl"):
        parser.parse_save_xml_dict(drug_bank_ner_test)
    if not exists("data/pickle/"+medline_ner_test['name']+".pkl"):
        parser.parse_save_xml_dict(medline_ner_test)

def main():
    parse_all_files()

if __name__ == "__main__":
    main()
