# Stroke Prediction

In [22]:
import pandas as pd
import numpy as np
import plotly.express as px
import plotly.figure_factory as ff

In [23]:
df = pd.read_csv("data/stroke.csv")

In [24]:
df.head()

Unnamed: 0,id,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,9046,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,51676,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,,never smoked,1
2,31112,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,60182,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,1665,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1


In [25]:
df.info()

<class 'pandas.core.frame.DataFrame'>
RangeIndex: 5110 entries, 0 to 5109
Data columns (total 12 columns):
 #   Column             Non-Null Count  Dtype  
---  ------             --------------  -----  
 0   id                 5110 non-null   int64  
 1   gender             5110 non-null   object 
 2   age                5110 non-null   float64
 3   hypertension       5110 non-null   int64  
 4   heart_disease      5110 non-null   int64  
 5   ever_married       5110 non-null   object 
 6   work_type          5110 non-null   object 
 7   Residence_type     5110 non-null   object 
 8   avg_glucose_level  5110 non-null   float64
 9   bmi                4909 non-null   float64
 10  smoking_status     5110 non-null   object 
 11  stroke             5110 non-null   int64  
dtypes: float64(3), int64(4), object(5)
memory usage: 479.2+ KB


In [26]:
df.describe()

Unnamed: 0,id,age,hypertension,heart_disease,avg_glucose_level,bmi,stroke
count,5110.0,5110.0,5110.0,5110.0,5110.0,4909.0,5110.0
mean,36517.829354,43.226614,0.097456,0.054012,106.147677,28.893237,0.048728
std,21161.721625,22.612647,0.296607,0.226063,45.28356,7.854067,0.21532
min,67.0,0.08,0.0,0.0,55.12,10.3,0.0
25%,17741.25,25.0,0.0,0.0,77.245,23.5,0.0
50%,36932.0,45.0,0.0,0.0,91.885,28.1,0.0
75%,54682.0,61.0,0.0,0.0,114.09,33.1,0.0
max,72940.0,82.0,1.0,1.0,271.74,97.6,1.0


In [27]:
df.isnull().sum()

id                     0
gender                 0
age                    0
hypertension           0
heart_disease          0
ever_married           0
work_type              0
Residence_type         0
avg_glucose_level      0
bmi                  201
smoking_status         0
stroke                 0
dtype: int64

In [28]:
bmi_mean = df["bmi"].mean()
bmi_mean

28.893236911794666

In [29]:
df["bmi"] = df["bmi"].fillna(bmi_mean)

In [30]:
df.drop("id", axis = 1, inplace=True)

In [31]:

age_stroke = df[df["stroke"] == 1]["age"]
age_nostroke = df[df["stroke"] == 0]["age"]

data = [ age_stroke, age_nostroke]
data_labels = [ "age_stroke", "age_nostroke"]

fig = ff.create_distplot(data, data_labels, bin_size=10)
fig.show()

In [32]:
df["gender"].value_counts()

gender
Female    2994
Male      2115
Other        1
Name: count, dtype: int64

In [33]:
df.drop(df.loc[df["gender"] == "Other"].index, inplace = True)

In [34]:
fig = px.histogram(df, x = "gender", color = "stroke", barmode="group")
fig.show()

In [35]:
fig = px.histogram(df, x = "work_type", color = "stroke", barmode='group')
fig.show()

In [36]:
fig = px.histogram(df, x = "Residence_type", color = "stroke", barmode= "group")
fig.show()

In [37]:
fig = px.histogram(df, x = "smoking_status", color = "stroke", barmode= "group")
fig.show()

In [38]:
#check for imbalance

fig = px.histogram(df, x = "stroke", color = "stroke")
fig.show()

In [39]:
#data preprocessing

X = df.iloc[:, 0:-1].values
y = df.iloc[:, -1].values

(X.shape, y.shape)

((5109, 10), (5109,))

In [40]:
df.head()

Unnamed: 0,gender,age,hypertension,heart_disease,ever_married,work_type,Residence_type,avg_glucose_level,bmi,smoking_status,stroke
0,Male,67.0,0,1,Yes,Private,Urban,228.69,36.6,formerly smoked,1
1,Female,61.0,0,0,Yes,Self-employed,Rural,202.21,28.893237,never smoked,1
2,Male,80.0,0,1,Yes,Private,Rural,105.92,32.5,never smoked,1
3,Female,49.0,0,0,Yes,Private,Urban,171.23,34.4,smokes,1
4,Female,79.0,1,0,Yes,Self-employed,Rural,174.12,24.0,never smoked,1


In [41]:
from sklearn.compose import ColumnTransformer
from sklearn.preprocessing import LabelEncoder, OneHotEncoder

In [42]:
le_transformer = LabelEncoder()

X[:, 0] = le_transformer.fit_transform(X[:, 0])
X[:, 4] = le_transformer.fit_transform(X[:, 4])
X[:, 6] = le_transformer.fit_transform(X[:, 6])

In [43]:
col_transformer = ColumnTransformer(transformers=[("onehot", OneHotEncoder(),[5,9])], remainder="passthrough")
X = np.array(col_transformer.fit_transform(X))

In [44]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, random_state = 1)

In [45]:
from sklearn.preprocessing import StandardScaler

scaler = StandardScaler()
X_train = scaler.fit_transform(X_train)
X_test = scaler.fit_transform(X_test)

(X_train.shape, X_test.shape, y_train.shape, y_test.shape)

((4087, 17), (1022, 17), (4087,), (1022,))

In [46]:
#upsampling using imblearn

from imblearn.over_sampling import SMOTE

smote = SMOTE(random_state=1)

X_train, y_train = smote.fit_resample(X_train, y_train.ravel())

In [47]:
print(X_train.shape)
print(X_test.shape)
print(y_train.shape)
print(y_test.shape)


(7796, 17)
(1022, 17)
(7796,)
(1022,)


In [48]:
from sklearn.linear_model import LogisticRegression
from sklearn.metrics import accuracy_score, roc_auc_score, confusion_matrix, ConfusionMatrixDisplay, precision_score, recall_score, f1_score,classification_report, roc_curve, auc, precision_recall_curve, average_precision_score
from sklearn.model_selection import cross_val_score

In [49]:
lr_model = LogisticRegression()

In [50]:
lr_model.fit(X_train, y_train)

In [51]:
y_pred = lr_model.predict(X_test)

In [52]:
cv_score = cross_val_score(lr_model, X_train, y_train, cv= 6)

precision = precision_score(y_test, y_pred)
accuracy = accuracy_score(y_test, y_pred)
recall = recall_score(y_test, y_pred)
f1 = f1_score(y_test, y_pred)
roc_score = roc_auc_score(y_test, y_pred)
conf_matrix = confusion_matrix(y_test, y_pred)


print("cv score: ", cv_score.mean())
print("accuracy: ", accuracy)
print("precision: ", precision)
print("recall: ", recall)
print("f1 score: ", f1)
print("roc_auc_score: ", roc_score)
print("confusion_matric: ", conf_matrix)



cv score:  0.7847597757643947
accuracy:  0.7495107632093934
precision:  0.1643835616438356
recall:  0.8
f1 score:  0.2727272727272727
roc_auc_score:  0.7731808731808731
confusion_matric:  [[718 244]
 [ 12  48]]
