# <span style="color:firebrick;"> Drug Classification 💉 </span>

#### If you like my work, It will be really great of you to upvote this notebook!
#### If not then you leaving a comment on what do I need to work on and improve will be really helpful!

## <span style="color:firebrick;"> Importing Libraries </span>

In [None]:
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
import plotly.graph_objs as go
import plotly.express as px
import warnings
warnings.simplefilter("ignore")

## <span style="color:firebrick;"> Loading up the data </span>

In [None]:
data = pd.read_csv("../input/drug-classification/drug200.csv")
data.head()

* **Age**: Age of the patient
* **Sex**: Gender of the patients
* **BP**: Blood Pressure of the patient
* **Cholesterol**: Cholesterol of the patient
* **Na_to_K**: Sodium to Potassium ratio in patient's blood
* **Drug**: Drug type give to patients

In [None]:
# Looking for missing values in the dataset
data.isna().sum()

In [None]:
data.shape

In [None]:
data.info()

In [None]:
data.dtypes

In [None]:
data['Drug'].value_counts()   

In [None]:
plt.style.use("seaborn")
plt.figure(figsize=(15,8))
plt.title("Genders", fontsize=20, y=1.02)
sns.countplot(x = data.Sex, palette="hot")
plt.show()

In [None]:
plt.figure(figsize=(15,8))
plt.title("Drug Types", fontsize=20, y=1.02)
sns.countplot(x = data.Drug, palette="hot")
plt.show()

In [None]:
plt.figure(figsize=(15,8))
plt.title("Cholesterol", fontsize=20, y=1.02)
sns.countplot(x = data.Cholesterol, palette="hot")
plt.show()

In [None]:
plt.figure(figsize=(15,8))
plt.title("Blood Pressure", fontsize=20, y=1.02)
sns.countplot(x = data.BP, palette="hot")
plt.show()

In [None]:
plt.style.use("seaborn")
fig, ax = plt.subplots(figsize=(15,8))
sns.histplot(data["Age"], kde=True, bins=25, color="firebrick")
plt.title("Age of the patients", fontsize=20, y=1.02)
ax.set_xlabel("Age",fontsize=15);

In [None]:
data.head(1)

In [None]:
plt.style.use("seaborn")
fig, ax = plt.subplots(figsize=(15,8))
sns.histplot(data["Na_to_K"], color="red", kde=True, bins=25)
plt.title("Sodium to Potassium ratio in patient's blood", fontsize=20, y=1.02)
ax.set_xlabel("Na_to_K",fontsize=15);

In [None]:
plt.style.use("seaborn")
fig, ax =plt.subplots(1,2, figsize=(20,8)) 

sns.barplot(x = "Sex", y = "Count", hue = "Cholesterol", data = data.groupby(["Sex", "Cholesterol"]).size().reset_index(name = "Count"), palette="hot",ax=ax[0])
ax[0].set_xlabel("Sex",fontsize=14);

sns.barplot(x = "Sex", y = "Count", hue = "BP", data = data.groupby(["Sex", "BP"]).size().reset_index(name = "Count"), palette="hot", ax=ax[1]);
ax[1].set_xlabel("Sex",fontsize=14);

In [None]:
plt.style.use("seaborn")
fig, ax =plt.subplots(1,2, figsize=(20,8)) 

sns.barplot(x = "BP", y = "Count", hue = "Cholesterol", data = data.groupby(["BP", "Cholesterol"]).size().reset_index(name = "Count"), palette="hot", ax=ax[0])
ax[0].set_xlabel("BP",fontsize=14);

sns.barplot(x = "Drug", y = "Count", hue = "Cholesterol", data = data.groupby(["Drug", "Cholesterol"]).size().reset_index(name = "Count"), palette="hot", ax=ax[1]);
ax[1].set_xlabel("Drug",fontsize=14);

In [None]:
plt.style.use("seaborn")
fig, ax =plt.subplots(1,2, figsize=(20,8)) 

sns.barplot(x = "Sex", y = "Count", hue = "Drug", data = data.groupby(["Sex", "Drug"]).size().reset_index(name = "Count"), palette="hot", ax=ax[0])
ax[0].set_xlabel("Sex",fontsize=14);

sns.barplot(x = "Drug", y = "Count", hue = "BP", data = data.groupby(["Drug", "BP"]).size().reset_index(name = "Count"), palette="hot", ax=ax[1]);
ax[1].set_xlabel("Drug",fontsize=14);

In [None]:
plt.figure(figsize = (15,8))
sns.swarmplot(x = "Drug", y = "Age", data = data, palette="hot")
plt.legend(data.Drug.value_counts().index)
plt.title("Age - Drug", fontsize=20, y=1.02)
plt.show()

In [None]:
data.dtypes

In [None]:
# Converting the non-numeric values into numeric values

data['Sex'] = data['Sex'].map({'M': 1, 'F': 2})
data['BP'] = data['BP'].map({'HIGH': 1, "NORMAL" : 2, "LOW" : 3})
data['Cholesterol'] = data['Cholesterol'].map({'HIGH': 1, "NORMAL" : 2})
data["Drug"] = data["Drug"].map({"DrugY":1, "drugC":2, "drugX":3, "drugA":4, "drugB":5})

In [None]:
data.head()

## <span style="color:firebrick;"> Splitting the data into training and test datasets </span>
Here, we are trying to predict the Drug type that is to be prescribed to the patient using the given data. Hence, the "Drug Type" will be the y label and rest of the data will be the X or the input data.

In [None]:
# X data
X = data.drop("Drug", axis=1)
X.head()

In [None]:
# y data
y = data["Drug"]
y.head()

In [None]:
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state=42)

In [None]:
len(X_train), len(X_test)

In [None]:
# Scaling the data
from sklearn.preprocessing import StandardScaler
scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.transform(X_test)

## <span style="color:firebrick;"> Logistic Regression </span>

In [None]:
from sklearn.linear_model import LogisticRegression
lr = LogisticRegression()
lr.fit(X_train, y_train)

In [None]:
LogisticRegressionScore = lr.score(X_test, y_test)
print("Accuracy obtained by Logistic Regression model:",LogisticRegressionScore*100)

In [None]:
# Having a look at the confusion matrix for Logistic Regression

from sklearn.metrics import confusion_matrix, classification_report

y_pred_lr = lr.predict(X_test)
cf_matrix = confusion_matrix(y_test, y_pred_lr)
sns.heatmap(cf_matrix, annot=True, cmap="vlag_r")
plt.title("Confusion Matrix for Logistic Regression", fontsize=14, fontname="Helvetica", y=1.03);

In [None]:
# Having a look at the classification report of Logistic Regression

from sklearn import metrics
print(metrics.classification_report(y_test, y_pred_lr))

## <span style="color:firebrick;"> Random Forest Classifier </span>

In [None]:
from sklearn.ensemble import RandomForestClassifier
rfc = RandomForestClassifier()
rfc.fit(X_train, y_train)

In [None]:
RandomForestClassifierScore = rfc.score(X_test,y_test)
print("Accacy obtained by Random Forest Classifier :", RandomForestClassifierScore*100)

In [None]:
# Confusion Matrix of Random Forest Classifier

y_pred_rfc = rfc.predict(X_test)
cf_matrix = confusion_matrix(y_test, y_pred_rfc)
sns.heatmap(cf_matrix, annot=True, cmap="vlag_r")
plt.title("Confusion Matrix for Random Forest Classifier", fontsize=14, fontname="Helvetica", y=1.03);

In [None]:
print(metrics.classification_report(y_test, y_pred_rfc))

## <span style="color:firebrick;"> K Neighbors Classifier </span>

In [None]:
from sklearn.neighbors import KNeighborsClassifier
knn = KNeighborsClassifier()
knn.fit(X_train, y_train)

In [None]:
KNeighborsClassifierScore = knn.score(X_test, y_test)
print("Accuracy obtained by K Neighbors Classifier :", KNeighborsClassifierScore*100)

In [None]:
# Confustion Matrix 

y_pred_knn = knn.predict(X_test)
cf_matrix = confusion_matrix(y_test, y_pred_knn)
sns.heatmap(cf_matrix, annot=True, cmap="vlag_r")
plt.title("Confusion Matrix for K Neighbors Classifier", fontsize=14, fontname="Helvetica", y=1.03);

In [None]:
print(metrics.classification_report(y_test,y_pred_knn))

## <span style="color:firebrick;"> Decision Tree Classifier </span>

In [None]:
from sklearn.tree import DecisionTreeClassifier
tree = DecisionTreeClassifier()
tree.fit(X_train, y_train)

In [None]:
DecisionTreeClassifierScore = tree.score(X_test,y_test)
print("Accuracy obtained by Decision Tree Classifier :", DecisionTreeClassifierScore*100)

In [None]:
y_pred_tree = tree.predict(X_test)
cf_matrix = confusion_matrix(y_test, y_pred_tree)
sns.heatmap(cf_matrix, annot=True, cmap="vlag_r")
plt.title("Confusion Metrix for Decision Tree Classifier", fontsize=14, fontname="Helvetica", y=1.03);

In [None]:
print(metrics.classification_report(y_test, y_pred_tree));

In [None]:
plt.style.use("seaborn")

x = ["LogisticRegression", 
     "RandomForestClassifier", 
     "KNeighborsClassifier",
     "Decision Tree Classifier"]

y = [LogisticRegressionScore, 
     RandomForestClassifierScore, 
     KNeighborsClassifierScore,
     DecisionTreeClassifierScore]

fig, ax = plt.subplots(figsize=(15,8))
sns.barplot(x=x,y=y, palette="hot");
plt.ylabel("Model Accuracy")
plt.xticks(rotation=40)
plt.title("Model Comparison - Model Accuracy", fontsize=15, fontname="Helvetica", y=1.03);