In [96]:
!apt-get install openjdk-8-jdk-headless -qq > /dev/null
!wget -q http://archive.apache.org/dist/spark/spark-3.1.1/spark-3.1.1-bin-hadoop3.2.tgz
!tar xf spark-3.1.1-bin-hadoop3.2.tgz
!pip install -q findspark

import os
os.environ["JAVA_HOME"] = "/usr/lib/jvm/java-8-openjdk-amd64"
os.environ["SPARK_HOME"] = "/content/spark-3.1.1-bin-hadoop3.2"

^C


In [86]:
import findspark
findspark.init()

In [87]:
from pyspark.sql import SparkSession
import pyspark.sql.functions as F

In [88]:
spark = SparkSession.builder.getOrCreate()

In [89]:
df = spark.read.csv("/content/healthcare-dataset-stroke-data.csv", header=True, inferSchema=True)

In [90]:
df.show()

+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+
|   id|gender| age|hypertension|heart_disease|ever_married|    work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|
+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+
| 9046|  Male|67.0|           0|            1|         Yes|      Private|         Urban|           228.69|36.6|formerly smoked|     1|
|51676|Female|61.0|           0|            0|         Yes|Self-employed|         Rural|           202.21| N/A|   never smoked|     1|
|31112|  Male|80.0|           0|            1|         Yes|      Private|         Rural|           105.92|32.5|   never smoked|     1|
|60182|Female|49.0|           0|            0|         Yes|      Private|         Urban|           171.23|34.4|         smokes|     1|
| 1665|Female|79.0|           1|            0|         

In [91]:
df.select("stroke").show()

+------+
|stroke|
+------+
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
|     1|
+------+
only showing top 20 rows



In [92]:
df.count()

5110

In [93]:
len(df.columns)

12

In [94]:
df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: string (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [95]:
df.groupBy("stroke").count().show()

+------+-----+
|stroke|count|
+------+-----+
|     1|  249|
|     0| 4861|
+------+-----+



In [97]:
df.groupBy("age").count().show()

+----+-----+
| age|count|
+----+-----+
|67.0|   49|
|70.0|   45|
| 8.0|   58|
|69.0|   54|
| 7.0|   32|
|1.16|    4|
|0.16|    3|
|1.08|    8|
|1.72|    6|
|49.0|   79|
| 1.4|    3|
|0.72|    5|
|29.0|   51|
|64.0|   53|
|75.0|   53|
|0.24|    5|
|47.0|   75|
|42.0|   71|
|44.0|   75|
|35.0|   54|
+----+-----+
only showing top 20 rows



In [98]:
df.groupBy("gender").count().show()

+------+-----+
|gender|count|
+------+-----+
|Female| 2994|
| Other|    1|
|  Male| 2115|
+------+-----+



In [102]:
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import VectorAssembler

In [100]:
df.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: string (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [101]:
indexer = StringIndexer(inputCol = "gender", outputCol = "gender_up")
indexed = indexer.fit(df).transform(df)

In [104]:
indexer = StringIndexer(inputCol = "ever_married", outputCol = "ever_married_up")
indexed = indexer.fit(indexed).transform(indexed)

In [105]:
indexer = StringIndexer(inputCol = "work_type", outputCol = "work_type_up")
indexed = indexer.fit(indexed).transform(indexed)

In [106]:
indexer = StringIndexer(inputCol = "Residence_type", outputCol = "Residence_type_up")
indexed = indexer.fit(indexed).transform(indexed)

In [107]:
indexer = StringIndexer(inputCol = "bmi", outputCol = "bmi_up")
indexed = indexer.fit(indexed).transform(indexed)

In [108]:
indexer = StringIndexer(inputCol = "smoking_status", outputCol = "smoking_status_up")
indexed = indexer.fit(indexed).transform(indexed)

In [109]:
indexed.show()

+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+---------+---------------+------------+-----------------+------+-----------------+
|   id|gender| age|hypertension|heart_disease|ever_married|    work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|gender_up|ever_married_up|work_type_up|Residence_type_up|bmi_up|smoking_status_up|
+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+---------+---------------+------------+-----------------+------+-----------------+
| 9046|  Male|67.0|           0|            1|         Yes|      Private|         Urban|           228.69|36.6|formerly smoked|     1|      1.0|            0.0|         0.0|              0.0| 155.0|              2.0|
|51676|Female|61.0|           0|            0|         Yes|Self-employed|         Rural|           202.21| N/A|   never smoked|     

In [110]:

assembler = VectorAssembler(inputCols=["gender_up","age", "hypertension", "heart_disease", "ever_married_up",   "work_type_up", "Residence_type_up" , "avg_glucose_level", "bmi_up", "smoking_status_up"],
                            outputCol = "features")

In [111]:
assembler

VectorAssembler_57a7fadadf4b

In [112]:
output = assembler.transform(indexed)

In [113]:
output.show()

+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+---------+---------------+------------+-----------------+------+-----------------+--------------------+
|   id|gender| age|hypertension|heart_disease|ever_married|    work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|gender_up|ever_married_up|work_type_up|Residence_type_up|bmi_up|smoking_status_up|            features|
+-----+------+----+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+---------+---------------+------------+-----------------+------+-----------------+--------------------+
| 9046|  Male|67.0|           0|            1|         Yes|      Private|         Urban|           228.69|36.6|formerly smoked|     1|      1.0|            0.0|         0.0|              0.0| 155.0|              2.0|[1.0,67.0,0.0,1.0...|
|51676|Female|61.0|           0|            0|  

In [114]:
output.select("features", "stroke").show(truncate= False)

+-----------------------------------------------+------+
|features                                       |stroke|
+-----------------------------------------------+------+
|[1.0,67.0,0.0,1.0,0.0,0.0,0.0,228.69,155.0,2.0]|1     |
|(10,[1,5,6,7],[61.0,1.0,1.0,202.21])           |1     |
|[1.0,80.0,0.0,1.0,0.0,0.0,1.0,105.92,99.0,0.0] |1     |
|(10,[1,7,8,9],[49.0,171.23,119.0,3.0])         |1     |
|[0.0,79.0,1.0,0.0,0.0,1.0,1.0,174.12,28.0,0.0] |1     |
|(10,[0,1,7,8,9],[1.0,81.0,186.21,50.0,2.0])    |1     |
|[1.0,74.0,1.0,1.0,0.0,0.0,1.0,70.09,86.0,0.0]  |1     |
|(10,[1,4,7,8],[69.0,1.0,94.39,59.0])           |1     |
|(10,[1,6,7,9],[59.0,1.0,76.15,1.0])            |1     |
|(10,[1,7,8,9],[78.0,58.57,23.0,1.0])           |1     |
|(10,[1,2,6,7,8],[81.0,1.0,1.0,80.43,42.0])     |1     |
|[0.0,61.0,0.0,1.0,0.0,3.0,1.0,120.46,223.0,3.0]|1     |
|(10,[1,7,8,9],[54.0,104.51,8.0,3.0])           |1     |
|(10,[0,1,3,7,9],[1.0,78.0,1.0,219.84,1.0])     |1     |
|(10,[1,3,7,8],[79.0,1.0,214.09

In [115]:
model_df = output.select("features","stroke")

In [116]:
model_df.show()

+--------------------+------+
|            features|stroke|
+--------------------+------+
|[1.0,67.0,0.0,1.0...|     1|
|(10,[1,5,6,7],[61...|     1|
|[1.0,80.0,0.0,1.0...|     1|
|(10,[1,7,8,9],[49...|     1|
|[0.0,79.0,1.0,0.0...|     1|
|(10,[0,1,7,8,9],[...|     1|
|[1.0,74.0,1.0,1.0...|     1|
|(10,[1,4,7,8],[69...|     1|
|(10,[1,6,7,9],[59...|     1|
|(10,[1,7,8,9],[78...|     1|
|(10,[1,2,6,7,8],[...|     1|
|[0.0,61.0,0.0,1.0...|     1|
|(10,[1,7,8,9],[54...|     1|
|(10,[0,1,3,7,9],[...|     1|
|(10,[1,3,7,8],[79...|     1|
|[0.0,50.0,1.0,0.0...|     1|
|[1.0,64.0,0.0,1.0...|     1|
|[1.0,75.0,1.0,0.0...|     1|
|(10,[1,4,7,8],[60...|     1|
|[1.0,57.0,0.0,1.0...|     1|
+--------------------+------+
only showing top 20 rows



In [120]:
training_df, test_df = model_df.randomSplit([0.8, 0.2])

In [121]:
print(training_df.count())

4119


In [122]:
print(test_df.count())

991


In [123]:
from pyspark.ml.classification import LogisticRegression

In [124]:
log_reg

LogisticRegressionModel: uid=LogisticRegression_c7db12a9bd5b, numClasses=2, numFeatures=10

In [125]:
train_results = log_reg.evaluate(training_df).predictions

In [126]:
train_results.show()

+--------------------+------+--------------------+--------------------+----------+
|            features|stroke|       rawPrediction|         probability|prediction|
+--------------------+------+--------------------+--------------------+----------+
|(10,[0,1,2,5,7],[...|     0|[4.75333801103636...|[0.99145085422500...|       0.0|
|(10,[0,1,2,6,7],[...|     1|[1.13941893410074...|[0.75757293863208...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[4.96633210974834...|[0.99307956466982...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[3.83336353340501...|[0.97882151560014...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[4.09490384162854...|[0.98361557314964...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[3.72579402916739...|[0.97647289953218...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[3.59903744165459...|[0.97337807481998...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[2.88175342876386...|[0.94693703771258...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[3.02015678140512...|[0.95347648065887...|       0.0|
|(10

In [127]:
results = log_reg.evaluate(test_df).predictions

In [128]:
results.show()

+--------------------+------+--------------------+--------------------+----------+
|            features|stroke|       rawPrediction|         probability|prediction|
+--------------------+------+--------------------+--------------------+----------+
|(10,[0,1,2,7,8],[...|     0|[3.94168027995710...|[0.98095422080095...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[3.80467194426990...|[0.97821849722395...|       0.0|
|(10,[0,1,2,7,8],[...|     0|[2.98945001580435...|[0.95209523166136...|       0.0|
|(10,[0,1,3,5,7],[...|     0|[2.54649326690379...|[0.92733757566770...|       0.0|
|(10,[0,1,3,7,8],[...|     0|[3.13502356520881...|[0.95831453593406...|       0.0|
|(10,[0,1,3,7,8],[...|     0|[1.91619375416141...|[0.87171338629596...|       0.0|
|(10,[0,1,3,7,9],[...|     1|[1.92708065945725...|[0.87292594080106...|       0.0|
|(10,[0,1,3,7,9],[...|     0|[1.10326944568710...|[0.75087219987168...|       0.0|
|(10,[0,1,4,7,8],[...|     0|[6.45014496010363...|[0.99842220031323...|       0.0|
|(10

In [129]:
results.select(["stroke", "prediction"]).show(10)

+------+----------+
|stroke|prediction|
+------+----------+
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
|     1|       0.0|
|     0|       0.0|
|     0|       0.0|
|     0|       0.0|
+------+----------+
only showing top 10 rows



In [130]:
tp = results[(results.stroke==1) & (results.prediction==1)].count()
tp

1

In [131]:
tn = results[(results.stroke==0) & (results.prediction==0)].count()
tn

940

In [132]:
fp = results[(results.stroke==0) & (results.prediction==1)].count()
fp

0

In [133]:
fn = results[(results.stroke==1) & (results.prediction==0)].count()
fn

50

In [134]:
accuracy = float((tp+tn)/ results.count())
print(accuracy)

0.9495459132189707


In [135]:
recall = float(tn)/ (tp+tn)
print(recall)

0.9989373007438895
