In [88]:
from pyspark.sql import SparkSession
import pyspark.sql as sparksql
spark = SparkSession.builder.appName('stroke').getOrCreate()
train = spark.read.csv('train_2v.csv', inferSchema=True,header=True)

In [89]:
type(train)

pyspark.sql.dataframe.DataFrame

In [65]:
train.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: double (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [94]:
cols = ["age", "heart_disease", "bmi","ever_married"]
train.select(cols).describe().show()

+-------+------------------+-------------------+------------------+------------+
|summary|               age|      heart_disease|               bmi|ever_married|
+-------+------------------+-------------------+------------------+------------+
|  count|             43400|              43400|             41938|       43400|
|   mean| 42.21789400921646|0.04751152073732719|28.605038390004545|        null|
| stddev|22.519648680503554|0.21273274050209726| 7.770020497238766|        null|
|    min|              0.08|                  0|              10.1|          No|
|    max|              82.0|                  1|              97.6|         Yes|
+-------+------------------+-------------------+------------------+------------+



In [99]:
spark.sql("Select ever_married, count(ever_married) from table where stroke = 1 group by ever_married").show()

+------------+-------------------+
|ever_married|count(ever_married)|
+------------+-------------------+
|          No|                 80|
|         Yes|                703|
+------------+-------------------+



In [66]:
train.groupBy('stroke').count().show()

+------+-----+
|stroke|count|
+------+-----+
|     1|  783|
|     0|42617|
+------+-----+



In [67]:
train.createOrReplaceTempView('table')

In [68]:
spark.sql("SELECT work_type, count(work_type) as work_type_count FROM table WHERE stroke == 1 GROUP BY work_type ORDER BY work_type_count DESC").show()

+-------------+---------------+
|    work_type|work_type_count|
+-------------+---------------+
|      Private|            441|
|Self-employed|            251|
|     Govt_job|             89|
|     children|              2|
+-------------+---------------+



In [69]:
spark.sql("SELECT gender, count(gender) as count_gender, count(gender)*100/sum(count(gender)) over() as percent FROM table GROUP BY gender").show()

+------+------------+-------------------+
|gender|count_gender|            percent|
+------+------------+-------------------+
|Female|       25665|  59.13594470046083|
| Other|          11|0.02534562211981567|
|  Male|       17724|  40.83870967741935|
+------+------------+-------------------+



In [70]:
spark.sql("SELECT gender, count(gender), (COUNT(gender) * 100.0) /(SELECT count(gender) FROM table WHERE gender == 'Male') as percentage FROM table WHERE stroke = '1' and gender = 'Male' GROUP BY gender").show()

+------+-------------+----------------+
|gender|count(gender)|      percentage|
+------+-------------+----------------+
|  Male|          352|1.98600767321146|
+------+-------------+----------------+



In [71]:
spark.sql("SELECT gender, count(gender), (COUNT(gender) * 100.0) /(SELECT count(gender) FROM table WHERE gender == 'Female') as percentage FROM table WHERE stroke = '1' and gender = 'Female' GROUP BY gender").show()

+------+-------------+----------------+
|gender|count(gender)|      percentage|
+------+-------------+----------------+
|Female|          431|1.67932982661212|
+------+-------------+----------------+



In [72]:
spark.sql("SELECT age, count(age) as age_count FROM table WHERE stroke == 1 GROUP BY age ORDER BY age_count DESC").show()

+----+---------+
| age|age_count|
+----+---------+
|79.0|       70|
|78.0|       57|
|80.0|       49|
|81.0|       43|
|82.0|       36|
|70.0|       25|
|77.0|       24|
|74.0|       24|
|76.0|       24|
|75.0|       23|
|67.0|       23|
|72.0|       21|
|68.0|       20|
|59.0|       20|
|69.0|       20|
|57.0|       19|
|71.0|       19|
|63.0|       18|
|65.0|       18|
|66.0|       17|
+----+---------+
only showing top 20 rows



In [73]:
#train.filter((train['stroke'] == 1) & (train['age'] > '50')).count()
spark.sql("Select count(*) from table where stroke == 1 and age > 50").show()

+--------+
|count(1)|
+--------+
|     708|
+--------+



In [87]:
train.describe().show()

+-------+-----------------+------+------------------+-------------------+-------------------+------------+---------+--------------+------------------+------------------+---------------+-------------------+
|summary|               id|gender|               age|       hypertension|      heart_disease|ever_married|work_type|Residence_type| avg_glucose_level|               bmi| smoking_status|             stroke|
+-------+-----------------+------+------------------+-------------------+-------------------+------------+---------+--------------+------------------+------------------+---------------+-------------------+
|  count|            43400| 43400|             43400|              43400|              43400|       43400|    43400|         43400|             43400|             41938|          30108|              43400|
|   mean|36326.14235023042|  null| 42.21789400921646|0.09357142857142857|0.04751152073732719|        null|     null|          null|104.48274999999916|28.605038390004545|       

In [40]:
# fill in missing values
train_f = train.na.fill('No Info', subset=['smoking_status'])
# fill in miss values with mean
from pyspark.sql.functions import mean
mean = train_f.select(mean(train_f['bmi'])).collect()
mean_bmi = mean[0][0]
train_f = train_f.na.fill(mean_bmi,['bmi'])

In [41]:
from pyspark.ml.feature import (VectorAssembler,OneHotEncoder,
                                StringIndexer)

In [54]:
gender_indexer = StringIndexer(inputCol = 'gender', outputCol = 'genderIndex')
gender_encoder = OneHotEncoder(inputCol = 'genderIndex', outputCol = 'genderVec')

ever_married_indexer = StringIndexer(inputCol = 'ever_married', outputCol = 'ever_marriedIndex')
ever_married_encoder = OneHotEncoder(inputCol = 'ever_marriedIndex', outputCol = 'ever_marriedVec')

work_type_indexer = StringIndexer(inputCol = 'work_type', outputCol = 'work_typeIndex')
work_type_encoder = OneHotEncoder(inputCol = 'work_typeIndex', outputCol = 'work_typeVec')

Residence_type_indexer = StringIndexer(inputCol = 'Residence_type', outputCol = 'Residence_typeIndex')
Residence_type_encoder = OneHotEncoder(inputCol = 'Residence_typeIndex', outputCol = 'Residence_typeVec')

smoking_status_indexer = StringIndexer(inputCol = 'smoking_status', outputCol = 'smoking_statusIndex')
smoking_status_encoder = OneHotEncoder(inputCol = 'smoking_statusIndex', outputCol = 'smoking_statusVec')

In [55]:
train.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: double (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [56]:
assembler = VectorAssembler(inputCols=['genderVec',
 'age',
 'hypertension',
 'heart_disease',
 'ever_marriedVec',
 'work_typeVec',
 'Residence_typeVec',
 'avg_glucose_level',
 'bmi',
 'smoking_statusVec'],outputCol='features')

In [57]:
from pyspark.ml.classification import DecisionTreeClassifier
dtc = DecisionTreeClassifier(labelCol='stroke',featuresCol='features')

In [58]:
from pyspark.ml import Pipeline
pipeline = Pipeline(stages=[gender_indexer, ever_married_indexer, work_type_indexer, Residence_type_indexer,
                           smoking_status_indexer, gender_encoder, ever_married_encoder, work_type_encoder,
                           Residence_type_encoder, smoking_status_encoder, assembler, dtc])

In [59]:
train_data,test_data = train_f.randomSplit([0.7,0.3])

In [60]:
model = pipeline.fit(train_data)

Exception ignored in: <function JavaWrapper.__del__ at 0x123ae0040>
Traceback (most recent call last):
  File "/Users/vinithangadi/.local/share/virtualenvs/HealthcareProject-u2jW5S_3/lib/python3.8/site-packages/pyspark/ml/wrapper.py", line 42, in __del__
    if SparkContext._active_spark_context and self._java_obj is not None:
AttributeError: 'OneHotEncoder' object has no attribute '_java_obj'


In [61]:
dtc_predictions = model.transform(test_data)

In [62]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
# Select (prediction, true label) and compute test error
acc_evaluator = MulticlassClassificationEvaluator(labelCol="stroke", predictionCol="prediction", metricName="accuracy")
dtc_acc = acc_evaluator.evaluate(dtc_predictions)
print('A Decision Tree algorithm had an accuracy of: {0:2.2f}%'.format(dtc_acc*100))

A Decision Tree algorithm had an accuracy of: 98.27%
