# Objective

To create a simple model that can predict which medication from a given list
would be appropriate for a patient based on their symptoms/goals.

In [None]:
!pip install pandas
!pip install datasets

#press ESC to enter command mode

In [1]:
import os
import re
import glob
import shutil
import string
import pathlib


data_dir = os.path.abspath(os.path.join(os.getcwd(),'..','data'))


os.environ['MPLCONFIGDIR'] = os.path.join(data_dir,'plt_configs')
import matplotlib.pyplot as plt


os.environ['HF_HOME'] = os.path.join(data_dir,'hf_cache')
import datasets

import pandas as pd
import numpy as np

import tensorflow as tf
from tensorflow.keras import layers
from tensorflow.keras import losses
from tensorflow.keras.layers import TextVectorization

# Data collection

Drug review dataset obtained from Hugging Face

In [2]:
dataset = datasets.load_dataset("flxclxc/encoded_drug_reviews")

Using custom data configuration flxclxc--encoded_drug_reviews-ee0cdba36988e67d
Found cached dataset json (/tf/data/hf_cache/datasets/flxclxc___json/flxclxc--encoded_drug_reviews-ee0cdba36988e67d/0.0.0/e6070c77f18f01a5ad4551a8b7edfba20b8438b7cad4d94e6ad9378022ce4aab)


  0%|          | 0/1 [00:00<?, ?it/s]

In [3]:
df = pd.DataFrame(dataset)

In [4]:
print(df)

                                                   train
0      {'patient_id': 184648, 'drugName': 'Efudex', '...
1      {'patient_id': 25268, 'drugName': 'Flector Pat...
2      {'patient_id': 172019, 'drugName': 'Amitiza', ...
3      {'patient_id': 196063, 'drugName': 'Stendra', ...
4      {'patient_id': 225264, 'drugName': 'Bupropion'...
...                                                  ...
53466  {'patient_id': 199190, 'drugName': 'Depo-Prove...
53467  {'patient_id': 188476, 'drugName': 'ParaGard',...
53468  {'patient_id': 105752, 'drugName': 'Methylpred...
53469  {'patient_id': 56713, 'drugName': 'Meclizine',...
53470  {'patient_id': 215006, 'drugName': 'Fluoxetine...

[53471 rows x 1 columns]


In [5]:
complete_dataset = dataset['train'].to_pandas()
complete_dataset.head()

Unnamed: 0,patient_id,drugName,condition,review,rating,date,usefulCount,review_length,encoded
0,184648,Efudex,basal cell carcinoma,"""I have BCC on my upper arm and SCC on upper l...",1.0,"August 30, 2016",16,36,"[-0.0633561835, 0.0115883639, -0.0027463636, 0..."
1,25268,Flector Patch,pain,"""I tore my shoulder labrum and the pain can be...",8.0,"May 29, 2014",40,45,"[-0.083280459, 0.0182377025, 0.0619471855, 0.0..."
2,172019,Amitiza,irritable bowel syndrome,"""Amitiza is the best if you have ibs!""",10.0,"July 13, 2016",9,8,"[-0.0300639421, -0.0081300493, 0.0343461707, 0..."
3,196063,Stendra,erectile dysfunction,"""Viagra works in a strong, crude way with side...",10.0,"November 10, 2014",82,141,"[-0.0037669495, -0.0845683292, 0.0196341239, 0..."
4,225264,Bupropion,depression,"""I really wanted Wellbutrin to work. I was giv...",3.0,"October 4, 2015",15,62,"[-0.0633124188, 0.0167291258, 0.0707527027, 0...."


In [6]:
print(df.info()) # 53,471 total drug reviews in dataset

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 53471 entries, 0 to 53470
Data columns (total 1 columns):
 #   Column  Non-Null Count  Dtype 
---  ------  --------------  ----- 
 0   train   53471 non-null  object
dtypes: object(1)
memory usage: 417.9+ KB
None


In [7]:
complete_dataset.info() 

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 53471 entries, 0 to 53470
Data columns (total 9 columns):
 #   Column         Non-Null Count  Dtype  
---  ------         --------------  -----  
 0   patient_id     53471 non-null  int64  
 1   drugName       53471 non-null  object 
 2   condition      53471 non-null  object 
 3   review         53471 non-null  object 
 4   rating         53471 non-null  float64
 5   date           53471 non-null  object 
 6   usefulCount    53471 non-null  int64  
 7   review_length  53471 non-null  int64  
 8   encoded        53471 non-null  object 
dtypes: float64(1), int64(3), object(5)
memory usage: 3.7+ MB


# Data Cleaning

-Isolated drugs from the dataset with review counts > 200

In [8]:
# unique values in column "drugName"
drugs = complete_dataset['drugName'].unique()
for drug in drugs:
    print(drug)

Efudex
Flector Patch
Amitiza
Stendra
Bupropion
Amiodarone
Sprintec
Acetaminophen / tramadol
Sertraline
Lisinopril
Levonorgestrel
Liraglutide
Fluoxetine
Methylprednisolone
Doxylamine / pyridoxine
Zoladex
Clarithromycin
Tranylcypromine
Phentermine
Prednisolone
Tarceva
Ezetimibe / simvastatin
Belsomra
Enbrel
Flexeril
Ethinyl estradiol / levonorgestrel
Prozac
Clomipramine
Klonopin
Ethinyl estradiol / norgestimate
Zovirax
Adipex-P
Vilazodone
Etonogestrel
Buspirone
Phentermine / topiramate
Micardis HCT
Dapsone
TriNessa
Estarylla
Venlafaxine
Aviane
Geodon
Tiotropium
Remicade
Synthroid
Minocycline
Ativan
Eletriptan
Levoxyl
Duloxetine
Dexbrompheniramine / pseudoephedrine
Loestrin 24 Fe
Haloperidol
Nexplanon
Robaxin
Cefdinir
Ortho Tri-Cyclen Lo
Desogen
Hiprex
Lorazepam
Loratadine / pseudoephedrine
Lyrica
Norco
Ethinyl estradiol / norelgestromin
Ambrisentan
Trulicity
Duac
Zyrtec
Propranolol
Alli
Exemestane
Mirena
NuvaRing
Ethinyl estradiol / norethindrone
Desogestrel / ethinyl estradiol
Magnesium

In [15]:
pd.set_option("display.max_rows", None)
frequence = complete_dataset['drugName'].value_counts()
print(frequence)

Levonorgestrel                                                                                      1265
Etonogestrel                                                                                        1081
Ethinyl estradiol / norethindrone                                                                    869
Nexplanon                                                                                            736
Ethinyl estradiol / norgestimate                                                                     649
Ethinyl estradiol / levonorgestrel                                                                   591
Phentermine                                                                                          539
Sertraline                                                                                           506
Escitalopram                                                                                         452
Mirena                                                 

-Removed brand name drugs from dataset so that model did not classify brand/generic as two separate medications (ex. Lexapro/Escitalopram or Chantix/Varenicline)

In [10]:
base_dir = os.path.join(data_dir,'drugs')
if not os.path.exists(base_dir):
    os.makedirs(base_dir)

drug_names = ["Levonorgestrel",
              "Etonogestrel",
              "Ethinyl estradiol / norethindrone",
              "Nexplanon",
              "Ethinyl estradiol / norgestimate",
              "Ethinyl estradiol / levonorgestrel",
              "Phentermine",
              "Sertraline",
              "Escitalopram",
              "Mirena",
              "Implanon",
              "Gabapentin",
              "Miconazole",
              "Bupropion",
              "Venlafaxine",
              "Duloxetine",
              "Tramadol",
              "Clonazepam",
              "Citalopram",
              "Medroxyprogesterone",
              "Bupropion / naltrexone",
              "Varenicline",
              "Metronidazole",
              "Drospirenone / ethinyl estradiol",
              "Tioconazole",
              "Depo-Provera",
              "Liraglutide",
              "Skyla",
              "Fluoxetine",
              "Quetiapine",
              "Lo Loestrin Fe",
              "Alprazolam",
              "Chantix",
              "Amitriptyline",
              "Doxycycline",
              "Desvenlafaxine",
              "Trazodone",
              "Suprep Bowel Prep Kit",
              "Paroxetine",
              "NuvaRing",
              "Bisacodyl",
              "Lorcaserin"]
drug_directories = []

for drug_name in drug_names:
    current_drug_dir = os.path.join(data_dir,'drugs',drug_name)
    print(current_drug_dir)
    drug_directories.append(current_drug_dir)
    if not os.path.exists(current_drug_dir):
        os.makedirs(current_drug_dir)


/tf/data/drugs/Levonorgestrel
/tf/data/drugs/Etonogestrel
/tf/data/drugs/Ethinyl estradiol / norethindrone
/tf/data/drugs/Nexplanon
/tf/data/drugs/Ethinyl estradiol / norgestimate
/tf/data/drugs/Ethinyl estradiol / levonorgestrel
/tf/data/drugs/Phentermine
/tf/data/drugs/Sertraline
/tf/data/drugs/Escitalopram
/tf/data/drugs/Mirena
/tf/data/drugs/Implanon
/tf/data/drugs/Gabapentin
/tf/data/drugs/Miconazole
/tf/data/drugs/Bupropion
/tf/data/drugs/Venlafaxine
/tf/data/drugs/Duloxetine
/tf/data/drugs/Tramadol
/tf/data/drugs/Clonazepam
/tf/data/drugs/Citalopram
/tf/data/drugs/Medroxyprogesterone
/tf/data/drugs/Bupropion / naltrexone
/tf/data/drugs/Varenicline
/tf/data/drugs/Metronidazole
/tf/data/drugs/Drospirenone / ethinyl estradiol
/tf/data/drugs/Tioconazole
/tf/data/drugs/Depo-Provera
/tf/data/drugs/Liraglutide
/tf/data/drugs/Skyla
/tf/data/drugs/Fluoxetine
/tf/data/drugs/Quetiapine
/tf/data/drugs/Lo Loestrin Fe
/tf/data/drugs/Alprazolam
/tf/data/drugs/Chantix
/tf/data/drugs/Amitriptyli

In [11]:
data = np.array(["Levonorgestrel",
                 "Etonogestrel",
                 "Ethinyl estradiol / norethindrone",
                 "Nexplanon",
                 "Ethinyl estradiol / norgestimate",
                 "Ethinyl estradiol / levonorgestrel",
                 "Phentermine",
                 "Sertraline",
                 "Escitalopram",
                 "Mirena",
                 "Implanon",
                 "Gabapentin",
                 "Miconazole",
                 "Bupropion",
                 "Venlafaxine",
                 "Duloxetine",
                 "Tramadol",
                 "Clonazepam",
                 "Citalopram",
                 "Medroxyprogesterone",
                 "Bupropion / naltrexone",
                 "Varenicline",
                 "Metronidazole",
                 "Drospirenone / ethinyl estradiol",
                 "Tioconazole",
                 "Depo-Provera",
                 "Liraglutide",
                 "Skyla",
                 "Fluoxetine",
                 "Quetiapine",
                 "Lo Loestrin Fe",
                 "Alprazolam", 
                 "Chantix", 
                 "Amitriptyline",
                 "Doxycycline",
                 "Desvenlafaxine",
                 "Trazodone",
                 "Suprep Bowel Prep Kit",
                 "Paroxetine",
                 "NuvaRing",
                 "Bisacodyl",
                 "Lorcaserin"])
s = pd.Series(data)

In [12]:
print(s[:3])

0                       Levonorgestrel
1                         Etonogestrel
2    Ethinyl estradiol / norethindrone
dtype: object


In [13]:
print(s[-3:])

39      NuvaRing
40     Bisacodyl
41    Lorcaserin
dtype: object


In [14]:
complete_dataset.head()

Unnamed: 0,patient_id,drugName,condition,review,rating,date,usefulCount,review_length,encoded
0,184648,Efudex,basal cell carcinoma,"""I have BCC on my upper arm and SCC on upper l...",1.0,"August 30, 2016",16,36,"[-0.0633561835, 0.0115883639, -0.0027463636, 0..."
1,25268,Flector Patch,pain,"""I tore my shoulder labrum and the pain can be...",8.0,"May 29, 2014",40,45,"[-0.083280459, 0.0182377025, 0.0619471855, 0.0..."
2,172019,Amitiza,irritable bowel syndrome,"""Amitiza is the best if you have ibs!""",10.0,"July 13, 2016",9,8,"[-0.0300639421, -0.0081300493, 0.0343461707, 0..."
3,196063,Stendra,erectile dysfunction,"""Viagra works in a strong, crude way with side...",10.0,"November 10, 2014",82,141,"[-0.0037669495, -0.0845683292, 0.0196341239, 0..."
4,225264,Bupropion,depression,"""I really wanted Wellbutrin to work. I was giv...",3.0,"October 4, 2015",15,62,"[-0.0633124188, 0.0167291258, 0.0707527027, 0...."


In [16]:
drug_datasets = []
for drug_name in drug_names:

    filtered_df = complete_dataset[complete_dataset['drugName'] == drug_name]
    drug_datasets.append(filtered_df)

In [None]:
drug_datasets[0].head()

In [None]:
drug_datasets[41].head()

In [None]:
print(drug_names[41])
print(drug_directories[41])

In [None]:
drug_names
drug_directories
drug_datasets

for i in range(len(drug_names)):
    review_counter = 0
    for text in drug_datasets[i]['review']:
        with open(os.path.join(drug_directories[i],str(review_counter)+'.txt'), 'w') as f:
            f.write(text)
        review_counter+=1

# Loading data set for training

https://www.tensorflow.org/tutorials/load_data/text

In [None]:
batch_size = 32
seed = 42

raw_train_ds = tf.keras.utils.text_dataset_from_directory(
    pathlib.Path(base_dir),
    batch_size=batch_size,
    validation_split=0.2,
    subset='training',
    seed=seed)

In [None]:
for text_batch, label_batch in raw_train_ds.take(1):
    for i in range(10):
        print("Patient Review: ", text_batch.numpy()[i])
        print("Label:", label_batch.numpy()[i])

In [None]:
for i, label in enumerate(raw_train_ds.class_names):
    print("Label", i, "corresponds to", label)

In [None]:
# Create a validation set.

raw_val_ds = tf.keras.utils.text_dataset_from_directory(
    pathlib.Path(base_dir),
    batch_size=batch_size,
    validation_split=0.2,
    subset='validation',
    seed=seed)


In [None]:
# Prepare dataset for training

def custom_standardization(input_data):
  lowercase = tf.strings.lower(input_data)
  stripped_html = tf.strings.regex_replace(lowercase, '<br />', ' ')
  return tf.strings.regex_replace(stripped_html,
                                  '[%s]' % re.escape(string.punctuation),
                                  '')

In [None]:
max_features = 10000
sequence_length = 250

vectorize_layer = layers.TextVectorization(
    standardize=custom_standardization,
    max_tokens=max_features,
    output_mode='int',
    output_sequence_length=sequence_length)

In [None]:
train_text = raw_train_ds.map(lambda x, y: x)
vectorize_layer.adapt(train_text)

In [None]:
def vectorize_text(text, label):
  text = tf.expand_dims(text, -1)
  return vectorize_layer(text), label

In [None]:
text_batch, label_batch = next(iter(raw_train_ds))
first_review, first_label = text_batch[0], label_batch[0]
print("Review", first_review)
print("Label", raw_train_ds.class_names[first_label])
print("Vectorized review", vectorize_text(first_review, first_label))

In [None]:
print("1287 ---> ",vectorize_layer.get_vocabulary()[1287])
print(" 313 ---> ",vectorize_layer.get_vocabulary()[313])
print('Vocabulary size: {}'.format(len(vectorize_layer.get_vocabulary())))

In [None]:
#Dataset does not contain test so I'm omitting test_ds = raw_test_ds.map(vectorize_text)

train_ds = raw_train_ds.map(vectorize_text)
val_ds = raw_val_ds.map(vectorize_text)

In [None]:
AUTOTUNE = tf.data.AUTOTUNE

train_ds = train_ds.cache().prefetch(buffer_size=AUTOTUNE)
val_ds = val_ds.cache().prefetch(buffer_size=AUTOTUNE)

# Create the model

In [None]:
embedding_dim = 16

In [None]:
model = tf.keras.Sequential([
  layers.Embedding(max_features + 1, embedding_dim),
  layers.Dropout(0.2),
  layers.GlobalAveragePooling1D(),
  layers.Dropout(0.2),
  layers.Dense(len(drug_names), activation='softmax')])


In [None]:
model.summary()

# Train the model

In [None]:
model.compile(loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
              optimizer='nadam',
              metrics=[tf.keras.metrics.SparseCategoricalAccuracy()])

In [None]:
epochs = 400
history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=epochs)

In [None]:
loss, accuracy = model.evaluate(train_ds)

print("Loss: ", loss)
print("Accuracy: ", accuracy)

In [None]:
loss, accuracy = model.evaluate(val_ds)

print("Loss: ", loss)
print("Accuracy: ", accuracy)

In [None]:
hist = pd.DataFrame(history.history)
hist['epoch'] = history.epoch
hist.tail()


# Export the model

In [None]:
export_model = tf.keras.Sequential([
  vectorize_layer,
  model,
  layers.Activation('softmax')
])

export_model.compile(
    loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=False),
    optimizer='nadam',
    metrics=[tf.keras.metrics.SparseCategoricalAccuracy()]
)


In [None]:
loss, accuracy = export_model.evaluate(raw_train_ds)
print(accuracy)


# Inference on new data

In [None]:
examples = [
  "This medication relieved constipation.",
  "This medication helped me quit smoking.",
  "This medication helped me lose weight.",
    "Birth control medication."
]

predictions=export_model.predict(examples)


In [None]:
for j in range(len(examples)):
    print(examples[j])
    prediction=predictions[j]
    for i, label in enumerate(raw_train_ds.class_names):
        print(label+':'+str(prediction[i]))

# Constipation relief 

Model predicts Bisacodyl:0.028075192 

Bisacodyl is a laxative. Good prediction!

# Smoking cessation 

Model predicts Bupropion:0.028822651, Chantix:0.027415203, and Varenicline:0.027441071

Bupropion, Chantix, and Varencline are all used to help people quit smoking.

# Weight loss

Model predicts Phentermine:0.025583226. 

Phentermine is used for weight loss, so this is a good choice. However, I would have also expected the model to predict Bupropion / Naltrexone (0.02402098) with a higher degree of confidence as it is the other weight loss medication in this list of 40 medications.

# Birth control

Model predicts Ethinyl estradiol :0.02980319

Ethinyl Estradiol is the most commonly prescribed form of birth control medication so it makes sense that the model weighted this drug the highest. However, the model did not seem prefer the other birth control medications in this list (NuvaRing, Lo Loestrin FE

# Where can improvements be made?


