In [1]:
# Data processing libraries
import pandas as pd
import numpy as np

# NLP libraries
import nltk
from nltk.tag import StanfordPOSTagger

# Machine Learning Libraries
from IPython.display import display # For displaying DataFrames correctly in Jupyter
from sklearn import svm
import scipy.stats # for RandomizedSearchCV
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, KFold, train_test_split # Parameter selection
import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics

# Other libraries
import time # Execution time of some blocks
import statistics

# Import our own defined functions
from xlm_parsers_functions import *
from drug_interaction_functions import *
from drug_functions import *
from NER_functions import *

In [2]:
data_dir = 'data/Train/DrugBank/'

def readXMLData(data_dir):

    # Use xlm_element.tag to get the name of the xlm element
    # Use xlm_element.attrib to get all the attributes of the xlm element as a string

    # Parse the DrugBank Files
    drugs_dataset = []
    #parent_directory = '../LaboCase/small_train_DrugBank/'
    for filename in os.listdir(data_dir):
        if filename.endswith(".xml"):
            # Parse the file
            tree = ET.parse(data_dir + filename)
            # Create a list of lists with the interactions of the file
            drugs_dataset = drugs_dataset + listDDIFromXML(tree.getroot())

    return(drugs_dataset)

# Create a list of lists with the interactions of the file
XMLdata = readXMLData(data_dir)

In [3]:
with(open('data/DrugBank_names_DB.txt', 'r')) as f:
    drugbank_db = f.read().splitlines()
    
def sent2features(tupl, i, database):
    
    if len(tupl) != 5:
        raise ValueError('The introduced tuple does not have the correct length')
    sent = tupl[0]
    ent1 = tupl[1]
    ent2 = tupl[2]
    ent_list = tupl[3]
    
    features = {
        
    'ent1': ent1,
    'ent2': ent2,
    # Orthographic features
        
    # Entity 1
    'ent1_all_uppercase_letters' : allCaps(ent1), 
    'ent1_initial_capital_letter': initCap(ent1), 
    'ent1_contains_capital_letter' : hasCap(ent1),
    'ent1_single_capital_letter' : singleCap(ent1),
    'ent1_punctuation' : punctuation(ent1),
    'ent1_initial_digit' : initDigit(ent1),
    'ent1_single_digit' : singleDigit(ent1),
    'ent1_letter_and_num' : alphaNum(ent1),
    'ent1_many_numbers' : manyNum(ent1),
    'ent1_contains_real_numbers' : realNum(ent1),
    'ent1_intermediate_dash' : inDash(ent1),
    'ent1_has_digit' : hasDigit(ent1),
    'ent1_is_Dash' : isDash(ent1),
    'ent1_is_roman_letter' : roman(ent1),
    'ent1_is_end_punctuation' : endPunctuation(ent1),
    'ent1_caps_mix' : capsMix(ent1),

    # Entity 2
    'ent2_all_uppercase_letters' : allCaps(ent2), 
    'ent2_initial_capital_letter': initCap(ent2), 
    'ent2_contains_capital_letter' : hasCap(ent2),
    'ent2_single_capital_letter' : singleCap(ent2),
    'ent2_punctuation' : punctuation(ent2),
    'ent2_initial_digit' : initDigit(ent2),
    'ent2_single_digit' : singleDigit(ent2),
    'ent2_letter_and_num' : alphaNum(ent2),
    'ent2_many_numbers' : manyNum(ent2),
    'ent2_contains_real_numbers' : realNum(ent2),
    'ent2_intermediate_dash' : inDash(ent2),
    'ent2_has_digit' : hasDigit(ent2),
    'ent2_is_Dash' : isDash(ent2),
    'ent2_is_roman_letter' : roman(ent2),
    'ent2_is_end_punctuation' : endPunctuation(ent2),
    'ent2_caps_mix' : capsMix(ent2),
        
    # Morphological information: prefixes/suffixes of lengths from 2 to 5 and word shapes of tokens. 
    # Entity 1
    'ent1_word[-5:]': ent1[-5:],
    'ent1_word[-4:]': ent1[-4:],
    'ent1_word[-3:]': ent1[-3:],
    'ent1_word[-2:]': ent1[-2:],

    # Entity 2
    'ent2_word[-5:]': ent2[-5:],
    'ent2_word[-4:]': ent2[-4:],
    'ent2_word[-3:]': ent2[-3:],
    'ent2_word[-2:]': ent2[-2:],
    
    # Domain knowledge
    # Entity 1
    'ent1_contains_drug_sufix': containsSufix(ent1),
    'ent1_contains_drug_prefix': containsPrefix(ent1),

    # Entity 2
    'ent2_contains_drug_sufix': containsSufix(ent2),
    'ent2_contains_drug_prefix': containsPrefix(ent2),
        
    # Is in DrugBank dataset
    'ent1_isInDB':isTokenInDB(ent1,database),
    'ent2_isInDB':isTokenInDB(ent2,database)
    }

    return features


def text2features(text,database):
    for i in range(len(text)):
        return(sent2features(text, i, drugbank_db))

def text2labels(text):
    return text[4]

In [11]:
%%time
X = [[text2features(s, drugbank_db)] for s in XMLdata]
y = [[text2labels(s)] for s in XMLdata]

CPU times: user 7.25 s, sys: 110 ms, total: 7.36 s
Wall time: 7.41 s


In [5]:
seed = 16273
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=seed, shuffle = True)
X_train, X_val, y_train, y_val = train_test_split(X_train, y_train, test_size=0.2, random_state=seed, shuffle = True)
print('Number of training sentences: ', len(X_train))
print('Number of testing sentences: ', len(X_test))

Number of training sentences:  16643
Number of testing sentences:  5201


In [12]:
%%time
crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, y_train)

CPU times: user 3.57 s, sys: 53.5 ms, total: 3.63 s
Wall time: 3.66 s


In [7]:
labels = list(crf.classes_)
labels.remove('none')
labels
y_pred = crf.predict(X_test)

In [9]:
for i in range(len(y_pred)):
    if y_pred[i][0] is None:
        print(i)
        y_pred[i][0] = 'none'
    
    if y_test[i][0] is None:
        print(i)
        y_test[i][0] = 'none'
        
sklearn.metrics.recall_score(y_true = y_test, 
                             y_pred = y_pred, 
                             labels=labels, 
                             pos_label=1, 
                             average='weighted',
                             sample_weight=None)

517


0.3019108280254777

In [None]:
'''
def transformStrCategoriesIntoInts(vector):
    res = []
    for el in vector:
        if el == 'none' or el == 'None':
            res.append(0)
        elif el == 'mechanism':
            res.append(1)
        elif el == 'effect':
            res.append(2)
        elif el == 'int':
            res.append(3)
        elif el == 'advise':
            res.append(4)
        else:
            print(el)
            print(type(el))
            print(vector.index(el))
    return(res)
'''