# Project: DrugGuardian Pro

## Setup

In [1]:
import os
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' 

## Tensorflow with CUDA
Uncomment to force tensorflow to use CPU for training and testing

In [2]:
#os.environ['CUDA_VISIBLE_DEVICES'] = '-1'

## Libraries

In [3]:
import joblib

import numpy as np
from sklearn.metrics import accuracy_score, confusion_matrix
from sklearn.metrics import roc_auc_score, roc_curve, auc, precision_recall_curve, f1_score, recall_score, precision_score
from sklearn.model_selection import train_test_split

from models import neural_binary_model, k_nearest_neighbors, linear_discriminant_analysis
from storage import save_data, load_data, save_label_encoder, load_label_encoder, load_model_data
from tools import prepare_dataset1, preprocess_labels

## Parameters

In [4]:
# Define the path for storing data
storage_path = "./storage/DS1/"

# Options for saving and loading split train and test sets to/from disk
save_train_test_to_file = True
load_train_test_from_files = True

# Options for saving and loading machine learning models to/from disk
save_model_to_file = False
load_model_from_file = True

# If train and test sets are loaded from disk, specify the number of samples to include (set None for the whole dataset)
n_sample = None

## Data Loading
Attempt to load data from files on disk (uses specified storage path)

In [5]:
if load_train_test_from_files:
    try:
        X_train, X_test, y_train, y_test = load_data(storage_path)
        encoder = load_label_encoder(storage_path)
    except FileNotFoundError:
        # disables loading from files (flag)
        load_train_test_from_files = False

## Data preparation
In case loading from files is disabled or unsuccessful

In [6]:
if not load_train_test_from_files:
    # Load your data and labels
    X, y = prepare_dataset1(sample_size=n_sample, keep_all_features=False, separate=True)

    # Convert labels to numpy array
    y, encoder = preprocess_labels(y)

    # Split the data into training and testing sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

    # Release memory by deleting the original data
    del X, y

    # Reshape the data
    X_train = X_train.reshape(X_train.shape[0], -1)
    X_test = X_test.reshape(X_test.shape[0], -1)

    if save_model_to_file:
        # Save the data and labels to files
        save_data(X_train, X_test, y_train, y_test, storage_path)
        save_label_encoder(encoder, storage_path)

        # Load the data and labels back from the files (for mmap_mode)
        X_train, X_test, y_train, y_test = load_data(storage_path)

## Model Fitting
Fit different machine learning models to the data

### K-Nearest Neighbors


In [7]:
# Create a knn classifier instance
knn = k_nearest_neighbors()
# Train the model
knn.fit(X_train, y_train)

### Linear Discriminant Analysis

In [8]:
# Create a lda classifier instance
if load_model_from_file:
    lda = joblib.load("./storage/DS1/lda.sav")
else:
    lda = linear_discriminant_analysis()
    # Train the model
    lda.fit(X_train, encoder.inverse_transform((y_train > 0.5).astype(int).argmax(axis=1)))

https://scikit-learn.org/stable/model_persistence.html#security-maintainability-limitations


### Neural Binary Model

In [9]:
# Create a neural network model
nbm = neural_binary_model(X_train.shape[1])
# Train the model
nbm.fit(X_train, y_train, epochs=20, batch_size=64, validation_split=0, shuffle=True)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<keras.src.callbacks.History at 0x7f7b881e9ea0>

## Evaluation
Evaluate the performance of the fitted models

### K-Nearest Neighbors

In [10]:
# Evaluate the model on the test set
y_pred = knn.predict(X_test)

# Convert one-hot encoded labels back to original labels
y_test_labels = encoder.inverse_transform(y_test.argmax(axis=1))
y_pred_labels = encoder.inverse_transform(y_pred.argmax(axis=1))

# Evaluate the performance
accuracy = accuracy_score(y_test_labels, y_pred_labels)
print(f"Accuracy: {accuracy}")

confusion_matrix(y_test_labels, y_pred_labels)

Accuracy: 0.7647724813106674


array([[39650,   949],
       [13179,  6283]])

### Linear Discriminant Analysis

In [11]:
# Evaluate the model on the test set
y_pred = lda.predict(X_test)

# Convert one-hot encoded labels back to original labels
y_test_labels = encoder.inverse_transform(y_test.argmax(axis=1))

# Evaluate the performance
accuracy = accuracy_score(y_test_labels, y_pred)
print(f"Accuracy: {accuracy}")

confusion_matrix(y_test_labels, y_pred)

Accuracy: 0.8766920297697341


array([[37459,  3140],
       [ 4266, 15196]])

### Neural Binary Model

In [12]:
# Evaluate the model on the test set
y_pred = nbm.predict(X_test)
y_pred_binary = (y_pred > 0.5).astype(int)

# Convert one-hot encoded labels back to original labels
y_test_labels = encoder.inverse_transform(y_test.argmax(axis=1))
y_pred_labels = encoder.inverse_transform(y_pred_binary.argmax(axis=1))

# Evaluate the performance
accuracy = accuracy_score(y_test_labels, y_pred_labels)
print(f"Accuracy: {accuracy}")

confusion_matrix(y_test_labels, y_pred_labels)

Accuracy: 0.8976041024957959


array([[38614,  1985],
       [ 4165, 15297]])

## Prediction example
Prediction will be performed for two drug pairs: Meperidine and Dexmedetomidine.

In [33]:
# DB00454 (Meperidine) Feature vector
drug_1_sim = [   0.00004,   0.00012,   0.00011,   0.00012,   0.84896,   0.00008,   0.00005,   0.00009,   0.00009,   0.00011,   0.00010,   0.00011,   0.00009,   0.00007,   0.00008,   0.00008,   0.00012,   0.00007,   0.00010,   0.00009,   0.00010,   0.00006,   0.00011,   0.00007,   0.00006,   0.00007,   0.00006,   0.00012,   0.00009,   0.00006,   0.00006,   0.00280,   0.00006,   0.00008,   0.00007,   0.00008,   0.00010,   0.00005,   0.00010,   0.00010,   0.00008,   0.00009,   0.00007,   0.00004,   0.00010,   0.00007,   0.00007,   0.00007,   0.00006,   0.00010,   0.00007,   0.00012,   0.00006,   0.00005,   0.00008,   0.00012,   0.00013,   0.00005,   0.00008,   0.00011,   0.00012,   0.00010,   0.00006,   0.00009,   0.00008,   0.00011,   0.00005,   0.00009,   0.00009,   0.00009,   0.00006,   0.00005,   0.00007,   0.00007,   0.00005,   0.00011,   0.00011,   0.00012,   0.00008,   0.00008,   0.00012,   0.00008,   0.00009,   0.00007,   0.00010,   0.00010,   0.00006,   0.00009,   0.00010,   0.00005,   0.00008,   0.00012,   0.00006,   0.00005,   0.00011,   0.00039,   0.00006,   0.00008,   0.00009,   0.00007,   0.00012,   0.00007,   0.00012,   0.00010,   0.00012,   0.00009,   0.00008,   0.00006,   0.00006,   0.00010,   0.00011,   0.00007,   0.00010,   0.00013,   0.00006,   0.00007,   0.00007,   0.00007,   0.00011,   0.00006,   0.00008,   0.00005,   0.00006,   0.00009,   0.00005,   0.00012,   0.00011,   0.04033,   0.00011,   0.00011,   0.00010,   0.00004,   0.00006,   0.00011,   0.00006,   0.00007,   0.00010,   0.00008,   0.00009,   0.00007,   0.00005,   0.00734,   0.00099,   0.00007,   0.00008,   0.00007,   0.00007,   0.00010,   0.00006,   0.00008,   0.00007,   0.00006,   0.00010,   0.00005,   0.00010,   0.00010,   0.00011,   0.00008,   0.00004,   0.00039,   0.00006,   0.00009,   0.00007,   0.00012,   0.00009,   0.00005,   0.00012,   0.00012,   0.00012,   0.00007,   0.00010,   0.00008,   0.00008,   0.00011,   0.00005,   0.00008,   0.00011,   0.00010,   0.00099,   0.00009,   0.00009,   0.00011,   0.00011,   0.00009,   0.00009,   0.00009,   0.00012,   0.00012,   0.00009,   0.00009,   0.00007,   0.01886,   0.00006,   0.00006,   0.00007,   0.04033,   0.00011,   0.00010,   0.00008,   0.00008,   0.00010,   0.00007,   0.00009,   0.00012,   0.00007,   0.00012,   0.00006,   0.00010,   0.00005,   0.00005,   0.00008,   0.00008,   0.00013,   0.00011,   0.00014,   0.00009,   0.00010,   0.00008,   0.00010,   0.00007,   0.00009,   0.00009,   0.00008,   0.00008,   0.00008,   0.00011,   0.00007,   0.00010,   0.00008,   0.00007,   0.00010,   0.00007,   0.00009,   0.00010,   0.00008,   0.00010,   0.00009,   0.00005,   0.00009,   0.00008,   0.00010,   0.00010,   0.00009,   0.00010,   0.00011,   0.00006,   0.00006,   0.00007,   0.00008,   0.00011,   0.00007,   0.00006,   0.00006,   0.00008,   0.00012,   0.00007,   0.00008,   0.00008,   0.00005,   0.00005,   0.00007,   0.00012,   0.00008,   0.00006,   0.00005,   0.00005,   0.00008,   0.00007,   0.00008,   0.00012,   0.00007,   0.00011,   0.00006,   0.00014,   0.00008,   0.00008,   0.00005,   0.00012,   0.00006,   0.00007,   0.00011,   0.00007,   0.00009,   0.00007,   0.00010,   0.00007,   0.00011,   0.00012,   0.00008,   0.01886,   0.00010,   0.00010,   0.00007,   0.00006,   0.00008,   0.00012,   0.00007,   0.00010,   0.00006,   0.00009,   0.00006,   0.00010,   0.00012,   0.00007,   0.00007,   0.00006,   0.00013,   0.00009,   0.00008,   0.00007,   0.00007,   0.00005,   0.00005,   0.00006,   0.00007,   0.00010,   0.00008,   0.00009,   0.00009,   0.00006,   0.00008,   0.00005,   0.00008,   0.00007,   0.00007,   0.00011,   0.00007,   0.00010,   0.00013,   0.00012,   0.00012,   0.00008,   0.00008,   0.00011,   0.00011,   0.00010,   0.00010,   0.00005,   0.00013,   0.00005,   0.00008,   0.14047,   0.00010,   0.00010,   0.00013,   0.00008,   0.00008,   0.00008,   0.00009,   0.00011,   0.00009,   0.00007,   0.00010,   0.00011,   0.00006,   0.00008,   0.00008,   0.00011,   0.00009,   0.00007,   0.00004,   0.00009,   0.00020,   0.00007,   0.00006,   0.00010,   0.00007,   0.00006,   0.00006,   0.00008,   0.00005,   0.00010,   0.00010,   0.00008,   0.00006,   0.00011,   0.00006,   0.00012,   0.00009,   0.00009,   0.00008,   0.00010,   0.00009,   0.00007,   0.00011,   0.00009,   0.00006,   0.00011,   0.00012,   0.00009,   0.00008,   0.00006,   0.00007,   0.00012,   0.00005,   0.00009,   0.00008,   0.00011,   0.00007,   0.00008,   0.00007,   0.00009,   0.00008,   0.00006,   0.00009,   0.00010,   0.00012,   0.00010,   0.00010,   0.00008,   0.00007,   0.00008,   0.00020,   0.00008,   0.00009,   0.00009,   0.00011,   0.00009,   0.00012,   0.00007,   0.00012,   0.00005,   0.00006,   0.00009,   0.00008,   0.00006,   0.00005,   0.00013,   0.00009,   0.00011,   0.00006,   0.00008,   0.00008,   0.00011,   0.00005,   0.00009,   0.00007,   0.00012,   0.00011,   0.00010,   0.00007,   0.00007,   0.00008,   0.00009,   0.00007,   0.00005,   0.00008,   0.00004,   0.00010,   0.00005,   0.00013,   0.00012,   0.00008,   0.00007,   0.00012,   0.00010,   0.00006,   0.00009,   0.00010,   0.00007,   0.00006,   0.00011,   0.00012,   0.00007,   0.00010,   0.09169,   0.00005,   0.00008,   0.00008,   0.00007,   0.00008,   0.00009,   0.00010,   0.00005,   0.00009,   0.00012,   0.00011,   0.00009,   0.00011,   0.00012,   0.00009,   0.00008,   0.09169,   0.00005,   0.00006,   0.00010,   0.00012,   0.00006,   0.00011,   0.00009,   0.00004,   0.00005,   0.00006,   0.00012,   0.00012,   0.00009,   0.00006,   0.00006,   0.00006,   0.00007,   0.00006,   0.00006,   0.00011,   0.00007,   0.00011,   0.00009,   0.00007,   0.00010,   0.00009,   0.00011,   0.00005,   0.00011,   0.00006,   0.00008,   0.00009,   0.00006,   0.00009,   0.00007,   0.00007,   0.00006,   0.00007,   0.00008,   0.00009,   0.00006,   0.14047,   0.00011,   0.00007,   0.00735,   0.00008,   0.00011,   0.00007,   0.00011,   0.00005,   0.00007,   0.00008,   0.00007,   0.00009,   0.00007,   0.00009,   0.00009,   0.00006,   0.00007,   0.00009,   0.00008,   0.00005,   0.00280,   0.00009,   0.00012
]

# DB00633 (Dexmedetomidine) Feature vector
drug_2_sim = [   0.00002,   0.00006,   0.00006,   0.00006,   0.00007,   0.00016,   0.00003,   0.00011,   0.00005,   0.00006,   0.00005,   0.00008,   0.00011,   0.00044,   0.00004,   0.00004,   0.00008,   0.00004,   0.00010,   0.00018,   0.00010,   0.00003,   0.00009,   0.04014,   0.00003,   0.00004,   0.00003,   0.00008,   0.00012,   0.00003,   0.00003,   0.00007,   0.00003,   0.00016,   0.85059,   0.00013,   0.00010,   0.00029,   0.00010,   0.00006,   0.00004,   0.00011,   0.00004,   0.00002,   0.00005,   0.00016,   0.00003,   0.00004,   0.00003,   0.00005,   0.14089,   0.00008,   0.00003,   0.00002,   0.00005,   0.00008,   0.00007,   0.00003,   0.00004,   0.00009,   0.00007,   0.00010,   0.00003,   0.00005,   0.00013,   0.00006,   0.00003,   0.00011,   0.00012,   0.00005,   0.00003,   0.00003,   0.00017,   0.04062,   0.00003,   0.00006,   0.00009,   0.00006,   0.00014,   0.00014,   0.00008,   0.00016,   0.00012,   0.00004,   0.00005,   0.00009,   0.00003,   0.00005,   0.00010,   0.00003,   0.00014,   0.00008,   0.00003,   0.00002,   0.00006,   0.00007,   0.00003,   0.00013,   0.00012,   0.01859,   0.00006,   0.00017,   0.00006,   0.00006,   0.00008,   0.00012,   0.00015,   0.00003,   0.00003,   0.00005,   0.00006,   0.00004,   0.00009,   0.00007,   0.00003,   0.00004,   0.00017,   0.00016,   0.00009,   0.00003,   0.00013,   0.00003,   0.00003,   0.00017,   0.00003,   0.00008,   0.00009,   0.00007,   0.00009,   0.00006,   0.00005,   0.00002,   0.00003,   0.00009,   0.00003,   0.00025,   0.00009,   0.00004,   0.00005,   0.00004,   0.00003,   0.00007,   0.00007,   0.00004,   0.00013,   0.00004,   0.00004,   0.00006,   0.00003,   0.00013,   0.00017,   0.00003,   0.00005,   0.00003,   0.00005,   0.00010,   0.00006,   0.00004,   0.00002,   0.00007,   0.00003,   0.00012,   0.00004,   0.00008,   0.00005,   0.00003,   0.00006,   0.00008,   0.00007,   0.00003,   0.00010,   0.00004,   0.00013,   0.00009,   0.00003,   0.00004,   0.00006,   0.00010,   0.00007,   0.00005,   0.00011,   0.00006,   0.00008,   0.00011,   0.00012,   0.00012,   0.00008,   0.00008,   0.00005,   0.00005,   0.00004,   0.00007,   0.00003,   0.00003,   0.00004,   0.00007,   0.00006,   0.00009,   0.00013,   0.00017,   0.00005,   0.00004,   0.00011,   0.00008,   0.00017,   0.00007,   0.00003,   0.00010,   0.00003,   0.00003,   0.00014,   0.00014,   0.00007,   0.00006,   0.00007,   0.00005,   0.00005,   0.00004,   0.00010,   0.00004,   0.00012,   0.00018,   0.00004,   0.00014,   0.00013,   0.00006,   0.00105,   0.00009,   0.00005,   0.00004,   0.00005,   0.00004,   0.00005,   0.00011,   0.00004,   0.00011,   0.00012,   0.00003,   0.00012,   0.00013,   0.00010,   0.00005,   0.00011,   0.00006,   0.00009,   0.00003,   0.00003,   0.00004,   0.00014,   0.00006,   0.00004,   0.00003,   0.00003,   0.00004,   0.00006,   0.00018,   0.00013,   0.00004,   0.00029,   0.00003,   0.00004,   0.00008,   0.00004,   0.00003,   0.00003,   0.00003,   0.00004,   0.00003,   0.00014,   0.00008,   0.00004,   0.00008,   0.00003,   0.00007,   0.00004,   0.00014,   0.00029,   0.00007,   0.00003,   0.00004,   0.00006,   0.00004,   0.00012,   0.00004,   0.00005,   0.00004,   0.00009,   0.00007,   0.00014,   0.00007,   0.00005,   0.00010,   0.00017,   0.00003,   0.00004,   0.00008,   0.00016,   0.00010,   0.00003,   0.00005,   0.00003,   0.00005,   0.00008,   0.00004,   0.00004,   0.00003,   0.00007,   0.00011,   0.00004,   0.00016,   0.00745,   0.00003,   0.00003,   0.00003,   0.00004,   0.00005,   0.00004,   0.00012,   0.00005,   0.00003,   0.00015,   0.00003,   0.00014,   0.00003,   0.01904,   0.00009,   0.00004,   0.00005,   0.00007,   0.00007,   0.00008,   0.00014,   0.00004,   0.00009,   0.00009,   0.00010,   0.00010,   0.00003,   0.00007,   0.00003,   0.00004,   0.00007,   0.00005,   0.00005,   0.00007,   0.00004,   0.00004,   0.00013,   0.00011,   0.00009,   0.00011,   0.00003,   0.00006,   0.00009,   0.00003,   0.00013,   0.00014,   0.00009,   0.00011,   0.00016,   0.00002,   0.00011,   0.00007,   0.00004,   0.00003,   0.00010,   0.00004,   0.00050,   0.00003,   0.00004,   0.00003,   0.00005,   0.00009,   0.00004,   0.00003,   0.00006,   0.00003,   0.00008,   0.00005,   0.00005,   0.00016,   0.00010,   0.00005,   0.00017,   0.00009,   0.00005,   0.00273,   0.00009,   0.00008,   0.00005,   0.00015,   0.00003,   0.00019,   0.00006,   0.00003,   0.00005,   0.00004,   0.00006,   0.00004,   0.00004,   0.00004,   0.00005,   0.00004,   0.00003,   0.00013,   0.00010,   0.00006,   0.00005,   0.00010,   0.00004,   0.00017,   0.00014,   0.00007,   0.00015,   0.00012,   0.00005,   0.00006,   0.00005,   0.00006,   0.14115,   0.00006,   0.00003,   0.00003,   0.00011,   0.00015,   0.00003,   0.00003,   0.00007,   0.00005,   0.00006,   0.00003,   0.00015,   0.00004,   0.00009,   0.00003,   0.00005,   0.00004,   0.00007,   0.00006,   0.00006,   0.00016,   0.09180,   0.00004,   0.00012,   0.00004,   0.00003,   0.00004,   0.00002,   0.00009,   0.00003,   0.00007,   0.00008,   0.00013,   0.00017,   0.00006,   0.00005,   0.00103,   0.00005,   0.00005,   0.00004,   0.00003,   0.00008,   0.00008,   0.00004,   0.00010,   0.00007,   0.00003,   0.00013,   0.00004,   0.00716,   0.00004,   0.00012,   0.00010,   0.00003,   0.00011,   0.00006,   0.00006,   0.00005,   0.00006,   0.00007,   0.00005,   0.00013,   0.00007,   0.00003,   0.00003,   0.00010,   0.00008,   0.00003,   0.00008,   0.00012,   0.00002,   0.00003,   0.00003,   0.00008,   0.00006,   0.00005,   0.00003,   0.00003,   0.00003,   0.00004,   0.00003,   0.00003,   0.00009,   0.00004,   0.00009,   0.00011,   0.00016,   0.00010,   0.00012,   0.00008,   0.00029,   0.00008,   0.00003,   0.00004,   0.00005,   0.00003,   0.00005,   0.00004,   0.00288,   0.00030,   0.00017,   0.00005,   0.00005,   0.00034,   0.00007,   0.00006,   0.00004,   0.00007,   0.00004,   0.00006,   0.00004,   0.00006,   0.00003,   0.00004,   0.00014,   0.00017,   0.00005,   0.09228,   0.00012,   0.00011,   0.00003,   0.00018,   0.00011,   0.00016,   0.00003,   0.00007,   0.00011,   0.00008
]

drugs_comb = np.concatenate((drug_1_sim, drug_2_sim))
drugs_comb = drugs_comb.reshape(1, -1)

# KNN
pred_label = encoder.inverse_transform(knn.predict(drugs_comb).argmax(axis=1))
print("KNN: Drug A & Drug B has {}".format("an interaction." if pred_label[0] else "no interactions."))

# LDA
y_pred = lda.predict(drugs_comb)
print("LDA: Drug A & Drug B has {}".format("an interaction." if pred_label[0] else "no interactions."))

# NBM
pred_binary = (nbm.predict(drugs_comb, verbose=0) > 0.5).astype(int)
# Convert one-hot encoded labels back to original labels
pred_label = encoder.inverse_transform(pred_binary.argmax(axis=1))
print("NBM: Drug A & Drug B has {}".format("an interaction." if pred_label[0] else "no interactions."))

KNN: Drug A & Drug B has no interactions.
LDA: Drug A & Drug B has no interactions.
NBM: Drug A & Drug B has no interactions.


The results obtained above indicate a prediction of "False," suggesting that there are no interactions between the specified drug pairs.

In [35]:
# DB01118 (Amiodarone) Feature vector
drug_1_sim = [   0.00009,   0.00007,   0.00008,   0.00007,   0.00006,   0.00004,   0.00012,   0.00004,   0.00009,   0.00008,   0.00008,   0.00005,   0.00004,   0.00003,   0.00011,   0.00011,   0.00006,   0.00013,   0.00005,   0.00004,   0.00005,   0.00013,   0.00005,   0.00003,   0.04044,   0.00014,   0.00282,   0.00006,   0.00004,   0.00013,   0.00283,   0.00007,   0.00014,   0.00004,   0.00003,   0.00004,   0.00005,   0.00003,   0.00005,   0.00008,   0.00011,   0.00005,   0.00013,   0.00010,   0.00009,   0.00004,   0.00016,   0.00013,   0.00014,   0.00009,   0.00003,   0.00006,   0.00014,   0.00011,   0.00010,   0.00006,   0.00007,   0.00012,   0.00011,   0.00005,   0.00007,   0.00005,   0.00014,   0.00010,   0.00004,   0.00008,   0.00012,   0.00004,   0.00004,   0.00010,   0.00014,   0.00013,   0.00003,   0.00003,   0.00013,   0.00008,   0.00005,   0.00007,   0.00004,   0.00004,   0.00006,   0.00004,   0.00004,   0.00014,   0.00008,   0.00005,   0.01892,   0.00009,   0.00005,   0.00012,   0.00004,   0.00006,   0.00101,   0.00011,   0.00008,   0.00006,   0.00014,   0.00004,   0.00004,   0.00003,   0.00007,   0.00003,   0.00007,   0.00008,   0.00006,   0.00004,   0.00004,   0.09193,   0.04044,   0.00009,   0.00008,   0.00013,   0.00005,   0.00007,   0.00016,   0.00014,   0.00004,   0.00004,   0.00005,   0.00014,   0.00004,   0.00012,   0.00014,   0.00004,   0.00013,   0.00006,   0.00005,   0.00006,   0.00005,   0.00007,   0.00009,   0.00009,   0.00021,   0.00005,   0.00014,   0.00003,   0.00005,   0.00011,   0.00009,   0.00013,   0.00012,   0.00006,   0.00006,   0.00013,   0.00004,   0.00012,   0.00012,   0.00008,   0.00014,   0.00004,   0.00003,   0.00014,   0.00008,   0.00012,   0.00009,   0.00005,   0.00008,   0.00012,   0.00009,   0.00007,   0.00014,   0.00004,   0.00012,   0.00006,   0.00009,   0.00012,   0.00007,   0.00006,   0.00007,   0.00015,   0.00005,   0.00011,   0.00004,   0.00005,   0.00012,   0.00010,   0.00008,   0.00005,   0.00007,   0.00010,   0.00004,   0.00007,   0.00005,   0.00004,   0.00004,   0.00004,   0.00006,   0.00006,   0.00010,   0.00009,   0.00013,   0.00006,   0.00013,   0.00015,   0.00014,   0.00006,   0.00007,   0.00005,   0.00004,   0.00004,   0.00008,   0.00012,   0.00004,   0.00006,   0.00004,   0.00007,   0.00014,   0.00005,   0.00012,   0.00012,   0.00004,   0.00004,   0.00007,   0.00007,   0.00006,   0.00010,   0.00009,   0.00010,   0.00005,   0.00013,   0.00004,   0.00004,   0.00012,   0.00004,   0.00004,   0.00007,   0.00003,   0.00005,   0.00010,   0.00014,   0.00008,   0.00012,   0.00009,   0.00005,   0.00011,   0.00005,   0.00004,   0.00012,   0.00004,   0.00004,   0.00005,   0.00009,   0.00005,   0.00008,   0.00005,   0.00014,   0.14082,   0.00013,   0.00004,   0.00008,   0.00013,   0.00014,   0.00014,   0.00011,   0.00007,   0.00003,   0.00004,   0.00010,   0.00003,   0.00012,   0.00013,   0.00006,   0.00011,   0.00014,   0.00013,   0.00012,   0.00012,   0.00014,   0.00004,   0.00006,   0.00012,   0.00006,   0.09192,   0.00007,   0.00011,   0.00004,   0.00003,   0.00007,   0.84981,   0.00014,   0.00008,   0.00012,   0.00004,   0.00013,   0.00008,   0.00013,   0.00005,   0.00007,   0.00004,   0.00006,   0.00008,   0.00005,   0.00004,   0.00013,   0.00011,   0.00006,   0.00004,   0.00005,   0.00014,   0.00009,   0.00014,   0.00009,   0.00006,   0.00012,   0.00014,   0.00013,   0.00007,   0.00004,   0.00012,   0.00004,   0.00003,   0.00012,   0.00012,   0.00014,   0.00012,   0.00009,   0.00010,   0.00004,   0.00009,   0.01892,   0.00004,   0.00012,   0.00004,   0.00021,   0.00003,   0.00005,   0.00013,   0.00008,   0.00006,   0.00007,   0.00006,   0.00004,   0.00011,   0.00005,   0.00005,   0.00005,   0.00005,   0.00013,   0.00006,   0.00011,   0.00011,   0.00006,   0.00008,   0.00008,   0.00006,   0.00011,   0.00010,   0.00004,   0.00005,   0.00005,   0.00005,   0.00040,   0.00008,   0.00005,   0.00015,   0.00004,   0.00004,   0.00005,   0.00004,   0.00004,   0.00009,   0.00004,   0.00006,   0.00014,   0.00014,   0.00005,   0.00013,   0.00003,   0.00013,   0.00011,   0.00013,   0.00009,   0.00005,   0.00012,   0.00014,   0.00008,   0.00013,   0.00006,   0.00010,   0.00009,   0.00004,   0.00005,   0.00010,   0.00004,   0.00005,   0.00009,   0.00003,   0.00005,   0.00006,   0.00010,   0.00004,   0.00015,   0.00003,   0.00007,   0.00012,   0.00009,   0.00011,   0.00007,   0.00012,   0.00011,   0.00013,   0.00009,   0.00011,   0.00014,   0.00004,   0.00005,   0.00007,   0.00009,   0.00005,   0.00010,   0.00003,   0.00004,   0.00007,   0.00004,   0.00004,   0.00009,   0.00007,   0.00010,   0.00007,   0.00003,   0.00007,   0.00013,   0.00014,   0.00004,   0.00004,   0.00013,   0.00013,   0.00006,   0.00009,   0.00008,   0.00738,   0.00004,   0.00011,   0.00005,   0.00012,   0.00010,   0.00012,   0.00007,   0.00008,   0.00008,   0.00004,   0.00003,   0.00011,   0.00004,   0.00014,   0.00013,   0.00012,   0.00010,   0.00005,   0.00013,   0.00006,   0.00006,   0.00004,   0.00003,   0.00007,   0.00008,   0.00003,   0.00010,   0.00008,   0.00012,   0.00041,   0.00005,   0.00006,   0.00012,   0.00005,   0.00006,   0.00013,   0.00004,   0.00010,   0.00003,   0.00011,   0.00004,   0.00005,   0.00013,   0.00004,   0.00007,   0.00007,   0.00009,   0.00008,   0.00007,   0.00010,   0.00004,   0.00006,   0.00013,   0.00101,   0.00005,   0.00006,   0.00013,   0.00006,   0.00004,   0.00010,   0.00012,   0.00014,   0.00006,   0.00007,   0.00010,   0.00738,   0.00013,   0.00015,   0.00012,   0.00014,   0.14082,   0.00005,   0.00016,   0.00005,   0.00004,   0.00004,   0.00005,   0.00004,   0.00006,   0.00003,   0.00005,   0.00014,   0.00011,   0.00009,   0.00014,   0.00009,   0.00014,   0.00003,   0.00003,   0.00004,   0.00010,   0.00010,   0.00003,   0.00006,   0.00008,   0.00014,   0.00006,   0.00010,   0.00007,   0.00013,   0.00008,   0.00012,   0.00013,   0.00004,   0.00003,   0.00009,   0.00003,   0.00004,   0.00005,   0.00014,   0.00003,   0.00004,   0.00004,   0.00012,   0.00006,   0.00004,   0.00006
]

# DB00390 (Digoxin) Feature vector
drug_2_sim = [   0.00002,   0.00007,   0.00006,   0.00007,   0.00007,   0.00017,   0.00003,   0.00012,   0.00005,   0.00006,   0.00006,   0.00009,   0.00012,   0.00016,   0.00004,   0.00004,   0.00008,   0.00004,   0.00010,   0.00049,   0.00010,   0.00003,   0.00009,   0.00016,   0.00004,   0.00004,   0.00004,   0.00008,   0.00013,   0.00003,   0.00004,   0.00007,   0.00003,   0.04065,   0.00016,   0.00014,   0.00010,   0.00013,   0.00011,   0.00006,   0.00004,   0.00011,   0.00004,   0.00002,   0.00006,   0.09223,   0.00004,   0.00004,   0.00003,   0.00005,   0.00016,   0.00008,   0.00003,   0.00003,   0.00005,   0.00009,   0.00007,   0.00003,   0.00005,   0.00009,   0.00007,   0.00011,   0.00003,   0.00005,   0.00013,   0.00006,   0.00003,   0.00012,   0.00013,   0.00005,   0.00003,   0.00003,   0.00017,   0.00016,   0.00003,   0.00006,   0.00009,   0.00007,   0.00014,   0.00015,   0.00008,   0.00752,   0.00013,   0.00004,   0.00006,   0.00010,   0.00004,   0.00005,   0.00010,   0.00003,   0.00014,   0.00008,   0.00004,   0.00003,   0.00006,   0.00008,   0.00003,   0.00014,   0.00012,   0.00016,   0.00007,   0.00017,   0.00007,   0.00006,   0.00008,   0.00013,   0.00016,   0.00004,   0.00004,   0.00005,   0.00006,   0.00004,   0.00010,   0.00007,   0.00003,   0.00004,   0.00744,   0.14110,   0.00009,   0.00003,   0.00014,   0.00003,   0.00003,   0.00018,   0.00003,   0.00008,   0.00009,   0.00007,   0.00010,   0.00006,   0.00005,   0.00002,   0.00003,   0.00010,   0.00003,   0.00016,   0.00010,   0.00004,   0.00005,   0.00004,   0.00003,   0.00007,   0.00008,   0.00004,   0.00014,   0.00004,   0.00004,   0.00006,   0.00003,   0.00014,   0.00017,   0.00003,   0.00006,   0.00003,   0.00005,   0.00011,   0.00006,   0.00004,   0.00002,   0.00007,   0.00003,   0.00013,   0.00004,   0.00008,   0.00005,   0.00003,   0.00007,   0.00008,   0.00007,   0.00004,   0.00011,   0.00005,   0.00014,   0.00009,   0.00003,   0.00005,   0.00006,   0.00011,   0.00007,   0.00005,   0.00012,   0.00006,   0.00009,   0.00011,   0.00013,   0.00013,   0.00008,   0.00009,   0.00005,   0.00005,   0.00004,   0.00007,   0.00003,   0.00004,   0.00004,   0.00008,   0.00006,   0.00010,   0.00014,   0.00111,   0.00006,   0.00004,   0.00012,   0.00009,   0.00043,   0.00007,   0.00003,   0.00011,   0.00003,   0.00003,   0.00015,   0.00015,   0.00007,   0.00006,   0.00008,   0.00005,   0.00005,   0.00005,   0.00010,   0.00004,   0.00013,   0.00021,   0.00004,   0.00015,   0.00014,   0.00006,   0.00016,   0.00010,   0.00005,   0.00004,   0.00006,   0.00004,   0.00005,   0.00011,   0.00005,   0.00011,   0.00012,   0.00003,   0.00012,   0.00014,   0.00010,   0.00005,   0.00011,   0.00006,   0.00009,   0.00003,   0.00004,   0.00004,   0.00014,   0.00006,   0.00004,   0.00003,   0.00003,   0.00005,   0.00007,   0.00017,   0.00014,   0.00005,   0.00012,   0.00003,   0.00004,   0.00008,   0.00004,   0.00003,   0.00003,   0.00003,   0.00004,   0.00004,   0.00015,   0.00008,   0.00004,   0.00009,   0.00004,   0.00007,   0.00005,   0.00015,   0.00012,   0.00007,   0.00004,   0.00004,   0.00006,   0.00004,   0.00013,   0.00004,   0.00006,   0.00004,   0.00010,   0.00007,   0.00015,   0.00008,   0.00006,   0.00010,   0.00286,   0.00003,   0.00005,   0.00009,   0.01902,   0.00011,   0.00003,   0.00005,   0.00003,   0.00006,   0.00008,   0.00004,   0.00004,   0.00003,   0.00007,   0.00012,   0.00004,   0.04059,   0.00016,   0.00003,   0.00003,   0.00003,   0.00004,   0.00006,   0.00005,   0.00013,   0.00005,   0.00004,   0.00016,   0.00003,   0.00015,   0.00004,   0.00016,   0.00010,   0.00004,   0.00006,   0.00008,   0.00007,   0.00009,   0.00015,   0.00005,   0.00010,   0.00009,   0.00011,   0.00011,   0.00003,   0.00008,   0.00003,   0.00004,   0.00007,   0.00006,   0.00006,   0.00008,   0.00004,   0.00005,   0.00013,   0.00011,   0.00009,   0.00011,   0.00004,   0.00006,   0.00009,   0.00003,   0.00013,   0.00015,   0.00009,   0.00012,   0.14109,   0.00002,   0.00012,   0.00008,   0.00004,   0.00003,   0.00011,   0.00004,   0.00014,   0.00003,   0.00004,   0.00003,   0.00005,   0.00010,   0.00004,   0.00003,   0.00006,   0.00003,   0.00008,   0.00005,   0.00005,   0.01911,   0.00011,   0.00005,   0.00024,   0.00009,   0.00005,   0.00015,   0.00009,   0.00008,   0.00005,   0.00015,   0.00003,   0.00016,   0.00007,   0.00003,   0.00005,   0.00005,   0.00006,   0.00004,   0.00005,   0.00004,   0.00005,   0.00004,   0.00003,   0.00013,   0.00011,   0.00007,   0.00006,   0.00010,   0.00005,   0.00017,   0.00015,   0.00007,   0.00016,   0.00012,   0.00005,   0.00006,   0.00005,   0.00007,   0.00016,   0.00007,   0.00003,   0.00003,   0.00012,   0.00015,   0.00003,   0.00003,   0.00008,   0.00005,   0.00006,   0.00004,   0.00016,   0.00004,   0.00010,   0.00003,   0.00005,   0.00004,   0.00007,   0.00006,   0.00006,   0.85064,   0.00016,   0.00004,   0.00012,   0.00004,   0.00003,   0.00004,   0.00002,   0.00010,   0.00003,   0.00008,   0.00009,   0.00014,   0.00017,   0.00007,   0.00006,   0.00014,   0.00005,   0.00006,   0.00004,   0.00003,   0.00009,   0.00008,   0.00004,   0.00011,   0.00007,   0.00003,   0.00013,   0.00005,   0.00015,   0.00004,   0.00013,   0.00010,   0.00003,   0.00012,   0.00007,   0.00006,   0.00005,   0.00006,   0.00007,   0.00005,   0.00014,   0.00008,   0.00003,   0.00004,   0.00010,   0.00008,   0.00003,   0.00009,   0.00013,   0.00002,   0.00003,   0.00003,   0.00009,   0.00007,   0.00005,   0.00004,   0.00003,   0.00004,   0.00004,   0.00003,   0.00004,   0.00009,   0.00004,   0.00010,   0.00012,   0.09227,   0.00010,   0.00012,   0.00009,   0.00012,   0.00009,   0.00003,   0.00004,   0.00005,   0.00003,   0.00005,   0.00004,   0.00016,   0.00013,   0.00104,   0.00005,   0.00005,   0.00013,   0.00007,   0.00006,   0.00004,   0.00008,   0.00005,   0.00006,   0.00004,   0.00006,   0.00003,   0.00004,   0.00015,   0.00018,   0.00005,   0.00016,   0.00013,   0.00011,   0.00003,   0.00017,   0.00012,   0.00294,   0.00003,   0.00008,   0.00012,   0.00008
]

drugs_comb = np.concatenate((drug_1_sim, drug_2_sim))
drugs_comb = drugs_comb.reshape(1, -1)

# KNN
pred_label = encoder.inverse_transform(knn.predict(drugs_comb).argmax(axis=1))
print("KNN: Drug A & Drug B has {}".format("an interaction." if pred_label[0] else "no interactions."))

# LDA
y_pred = lda.predict(drugs_comb)
print("LDA: Drug A & Drug B has {}".format("an interaction." if pred_label[0] else "no interactions."))

# NBM
pred_binary = (nbm.predict(drugs_comb, verbose=0) > 0.5).astype(int)
# Convert one-hot encoded labels back to original labels
pred_label = encoder.inverse_transform(pred_binary.argmax(axis=1))
print("NBM: Drug A & Drug B has {}".format("an interaction." if pred_label[0] else "no interactions."))


KNN: Drug A & Drug B has an interaction.
LDA: Drug A & Drug B has an interaction.
NBM: Drug A & Drug B has an interaction.
