# AHLT - MIRI
# Drugs Interaction Classifier

In [6]:
# Data processing libraries
import pandas as pd
import numpy as np

# NLP libraries
import nltk
from nltk.tag import StanfordPOSTagger

# Machine Learning Libraries
from IPython.display import display # For displaying DataFrames correctly in Jupyter
from sklearn import svm
import scipy.stats # for RandomizedSearchCV
from sklearn.model_selection import GridSearchCV, RandomizedSearchCV, KFold, train_test_split # Parameter selection
import sklearn_crfsuite
from sklearn_crfsuite import scorers
from sklearn_crfsuite import metrics


# Other libraries
import time # Execution time of some blocks
import statistics

# Import our own defined functions
from xlm_parsers_functions import *
from drug_interaction_functions import *
from drug_functions import *

In [5]:
data_dir = 'data/small_train_DrugBank/'
filename = 'Acamprosate_ddi.xml'
tree = ET.parse(data_dir + filename)
# Create a list of lists with the interactions of the file
train_text_entities_relations = listDDIFromXML(tree.getroot())


## Objectives of this part
In this second part of the project, we will focus on two different things: 
1. Detection of interactions between drugs
2. Classification of each drug-drug interaction according to one of the following types:
    - Advice: 'Interactions may be expected, and Uroxatral should not be used in combination with other alpha-blockers.'
    - Effect: 'In uninfected volunteers, 46% developed rash while receiving Sustiva and Clarithromycin.'
    - Mechanism: 'Grepafloxacin is a competitive inhibitor of the metabolism of theophylline'.
    - Int: The interaction of omeprazole and ketoconazole has been stablished.

## Parsing the XML Files

### DrugBank and MedLine files

In [None]:
# Define the data paths
train_data_dir = 'data/small_train_DrugBank/'
test_data_dir = 'data/small_test_DrugBank/'

# Read the data from the specified directories
DrugBank_df = readTrainingData(train_data_dir)

# Select the initial columns from which we will compute the features for each row
train_df = DrugBank_df[['sentence_text', 'e1_name', 'e2_name', 'list_entities', 'interaction', 'interaction_type']]

## Creation of features
Before training our model, we need to come up with features to help us determine whether there is a relationship between the two drugs or not.

Some ideas for features are the following:
- Does the sentence contain a modal verb (should, must,...) between the two entities?
- Word bigrams: This is a binary feature for all word bigrams that appeared more than once in the corpus, indicating the presence or absence of each such bigram in the sentence
- Number of words between a pair of drugs
- Number of drugs between a pair of drugs
- POS of words between a pair of drugs: This is a binary feature for word POS tags obtained from POS tagging, and indicates the presence or absence of each POS between the two main drugs.
- Path between a pair of drugs: Path between two main drugs in the parse tree is another feature in our system. Because syntactic paths are in general a sparse feature, we reduced the sparsity by collapsing identical adjacent non-terminal labels. E.g., NP-S-VP-VP-NP is converted to NP-S-VP-NP. This technique decreased the number of paths by 24.8%.

In [None]:
# nltk.help.upenn_tagset() List of all POS tags from NLTK

In [None]:
start = time.time()
train_df = createMorphologicFeatures(train_df)
end = time.time()
print('Time for creating morphological features: ', str(end - start))

In [None]:
start = time.time()
train_df = createOrtographicFeatures(train_df)
end = time.time()
print('Time for creating ortographic features: ', str(end - start))

In [None]:
start = time.time()
train_df = createContextFeatures(train_df)
end = time.time()
print('Time for creating context features: ', str(end - start))

In [None]:
pd.options.display.max_columns = 5
display(train_df.head())
#train_df.dtypes

### Categorical variables preprocessing
As we are working with the sklearn.SVM machine learning model, in this case we need our output variable ('interaction') to be a binary variable encoded with 0 and 1's. For this purpose, we will use the pd.replace function.

In [None]:
new_encoding = {'interaction': {'true':1, 'false':0}}
train_df.replace(new_encoding, inplace = True)

## Building the classifier - SVM

### Creation of the training, validation and testing datasets

In [None]:
# Name of the target variable
target_name = 'interaction'
sentence_name = 'sentence_text'
list_entities_name = 'list_entities'
ent_1_name = 'e1_name'
ent_2_name = 'e2_name'
var_not_incl = ['sentence_text', 'e1_name', 'e2_name']

# Create the appropiate data structure to pass it to the SVM.
# X columns should be all but target_name and token_name
# In this first step we will create a whole dataset with 100% of the data, which we will
# split in the following steps into training, validation and testing data
X = train_df.loc[:, [all(x) for x in list(zip( 
                train_df.columns != target_name,
                train_df.columns != list_entities_name))]]
Y = train_df[target_name]

Once we have our dataset with 100% of the data created, we will create the training, validation and testing datasets. For this part of the project we have decided to split the dataset with the following proportions (60, 20, 20).

In [None]:
seed = 16273
X_train, X_test, Y_train, Y_test = train_test_split(X, Y, test_size=0.2, random_state=seed, shuffle = True)
X_train, X_val, Y_train, Y_val = train_test_split(X_train, Y_train, test_size=0.2, random_state=seed, shuffle = True)

In [None]:
print(X_train.shape, X_val.shape, X_test.shape)
display(X.head())
display(Y.head())

## Model selection

## SVM

In [None]:
# Create a SVM object with the corresponding tunned parameters
svc = svm.SVC()

# Look for the best parameters of the SVM model with GridSearchCV
start = time.time()
clf = RandomizedSearchCV(svc,{'C': scipy.stats.expon(scale=100), 'gamma': scipy.stats.expon(scale=.1),
                              'kernel': ['rbf'], 'class_weight':['balanced', None]},n_iter=40,n_jobs=-1)
clf.fit(X_val.drop(var_not_incl, axis = 1), Y_val)
end = time.time()
print('Validating time of the SVM: ', str(end - start),'\n')

print('Best estimator: ', clf.best_estimator_)

In [None]:
# Train the SVM model with the parameters selected before
start = time.time()
model = clf.best_estimator_
model.fit(X_train.drop(var_not_incl, axis=1), Y_train)
end = time.time()
print('Training time of the SVM: ', str(end - start))

In [None]:
pred = model.predict(X_test.drop(var_not_incl, axis = 1))
true = np.array(Y_test)
print(pred, true)

In [None]:
print('Precision: ', round(computePrecision(true=true, pred=pred)*100, 1))
print('Recall: ', round(computeRecall(true=true, pred=pred)*100, 1))
print('F1: ', round(computeF1(true = true, pred = pred)*100, 1))

## CRF

In [None]:
Y_train

In [None]:
%%time

crf = sklearn_crfsuite.CRF(
    algorithm='lbfgs',
    c1=0.1,
    c2=0.1,
    max_iterations=100,
    all_possible_transitions=True
)
crf.fit(X_train, Y_train)

In [None]:
X_test['real_interaction'] = true
X_test['pred_interaction'] = pred

In [None]:
X_test[X_test['pred_interaction'] == 1]

In [None]:
'''
['ent1_contains_numbers',
 'ent1_has_uppercase',
 'ent1_all_uppercase',
 'ent1_initial_capital',
 'ent1_contains_slash',
 'ent1_contains_dash',
 'ent1_n_tokens',
 'ent1_contains_punctuation',
 'ent1_init_digit',
 'ent1_single_digit',
 'ent1_contains_roman',
 'ent1_end_punctuation',
 'ent1_caps_mix',
 'ent2_contains_numbers',
 'ent2_has_uppercase',
 'ent2_all_uppercase',
 'ent2_initial_capital',
 'ent2_contains_slash',
 'ent2_contains_dash',
 'ent2_n_tokens',
 'ent2_contains_punctuation',
 'ent2_init_digit',
 'ent2_single_digit',
 'ent2_contains_roman',
 'ent2_end_punctuation',
 'ent2_caps_mix',
 'n_modal_verbs_bw_entities',
 'n_tokens_bw_entities',
 'n_entities_bw_entities']
 
Precision:  62.8
Recall:  19.3
F1:  29.6
'''

In [None]:
'''
['ent1_contains_prefix_alk',
 'ent1_contains_prefix_meth',
 'ent1_contains_prefix_eth',
 'ent1_contains_prefix_prop',
 'ent1_contains_prefix_but',
 'ent1_contains_prefix_pent',
 'ent1_contains_prefix_hex',
 'ent1_contains_prefix_hept',
 'ent1_contains_prefix_oct',
 'ent1_contains_prefix_non',
 'ent1_contains_prefix_dec',
 'ent1_contains_prefix_undec',
 'ent1_contains_prefix_dodec',
 'ent1_contains_prefix_eifcos',
 'ent1_contains_prefix_di',
 'ent1_contains_prefix_tri',
 'ent1_contains_prefix_tetra',
 'ent1_contains_prefix_penta',
 'ent1_contains_prefix_hexa',
 'ent1_contains_prefix_hepta',
 'ent1_contains_suffix_ane',
 'ent1_contains_suffix_ene',
 'ent1_contains_suffix_yne',
 'ent1_contains_suffix_yl',
 'ent1_contains_suffix_ol',
 'ent1_contains_suffix_al',
 'ent1_contains_suffix_oic',
 'ent1_contains_suffix_one',
 'ent1_contains_suffix_ate',
 'ent1_contains_suffix_amine',
 'ent1_contains_suffix_amide',
 'ent2_contains_prefix_alk',
 'ent2_contains_prefix_meth',
 'ent2_contains_prefix_eth',
 'ent2_contains_prefix_prop',
 'ent2_contains_prefix_but',
 'ent2_contains_prefix_pent',
 'ent2_contains_prefix_hex',
 'ent2_contains_prefix_hept',
 'ent2_contains_prefix_oct',
 'ent2_contains_prefix_non',
 'ent2_contains_prefix_dec',
 'ent2_contains_prefix_undec',
 'ent2_contains_prefix_dodec',
 'ent2_contains_prefix_eifcos',
 'ent2_contains_prefix_di',
 'ent2_contains_prefix_tri',
 'ent2_contains_prefix_tetra',
 'ent2_contains_prefix_penta',
 'ent2_contains_prefix_hexa',
 'ent2_contains_prefix_hepta',
 'ent2_contains_suffix_ane',
 'ent2_contains_suffix_ene',
 'ent2_contains_suffix_yne',
 'ent2_contains_suffix_yl',
 'ent2_contains_suffix_ol',
 'ent2_contains_suffix_al',
 'ent2_contains_suffix_oic',
 'ent2_contains_suffix_one',
 'ent2_contains_suffix_ate',
 'ent2_contains_suffix_amine',
 'ent2_contains_suffix_amide',
 'ent1_contains_numbers',
 'ent1_has_uppercase',
 'ent1_all_uppercase',
 'ent1_initial_capital',
 'ent1_contains_slash',
 'ent1_contains_dash',
 'ent1_n_tokens',
 'ent1_contains_punctuation',
 'ent1_init_digit',
 'ent1_single_digit',
 'ent1_contains_roman',
 'ent1_end_punctuation',
 'ent1_caps_mix',
 'ent2_contains_numbers',
 'ent2_has_uppercase',
 'ent2_all_uppercase',
 'ent2_initial_capital',
 'ent2_contains_slash',
 'ent2_contains_dash',
 'ent2_n_tokens',
 'ent2_contains_punctuation',
 'ent2_init_digit',
 'ent2_single_digit',
 'ent2_contains_roman',
 'ent2_end_punctuation',
 'ent2_caps_mix',
 'n_modal_verbs_bw_entities',
 'n_tokens_bw_entities',
 'n_entities_bw_entities']

Precision:  62.2
Recall:  21.4
F1:  31.8
'''

In [None]:
list(X_train.drop(var_not_incl, axis=1))