In [278]:
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import IntegerType

In [279]:
spark = SparkSession.builder \
        .master("local[4]") \
        .appName("stroke analysis") \
        .getOrCreate()

In [280]:
healthdata = spark.read.csv("./healthcarestrokedata/train_2v.csv", header= True)
healthdata.show()

+-----+------+---+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+
|   id|gender|age|hypertension|heart_disease|ever_married|    work_type|Residence_type|avg_glucose_level| bmi| smoking_status|stroke|
+-----+------+---+------------+-------------+------------+-------------+--------------+-----------------+----+---------------+------+
|30669|  Male|  3|           0|            0|          No|     children|         Rural|            95.12|  18|           null|     0|
|30468|  Male| 58|           1|            0|         Yes|      Private|         Urban|            87.96|39.2|   never smoked|     0|
|16523|Female|  8|           0|            0|          No|      Private|         Urban|           110.89|17.6|           null|     0|
|56543|Female| 70|           0|            0|         Yes|      Private|         Rural|            69.04|35.9|formerly smoked|     0|
|46136|  Male| 14|           0|            0|          No| Nev

In [281]:
healthdata.count()

43400

In [282]:
healthdata.columns

['id',
 'gender',
 'age',
 'hypertension',
 'heart_disease',
 'ever_married',
 'work_type',
 'Residence_type',
 'avg_glucose_level',
 'bmi',
 'smoking_status',
 'stroke']

In [283]:
healthdata.printSchema()

root
 |-- id: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: string (nullable = true)
 |-- hypertension: string (nullable = true)
 |-- heart_disease: string (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: string (nullable = true)
 |-- bmi: string (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: string (nullable = true)



In [284]:
healthdata = healthdata.withColumn("age",col("age").cast(IntegerType())) \
.withColumn("hypertension", col("hypertension").cast(IntegerType())) \
.withColumn("heart_disease", col("heart_disease").cast(IntegerType())) \
.withColumn("avg_glucose_level", col("avg_glucose_level").cast(IntegerType())) \
.withColumn("bmi", col("bmi").cast(IntegerType())) \
.withColumn("stroke", col("stroke").cast(IntegerType()))

In [285]:
healthdata.dtypes

[('id', 'string'),
 ('gender', 'string'),
 ('age', 'int'),
 ('hypertension', 'int'),
 ('heart_disease', 'int'),
 ('ever_married', 'string'),
 ('work_type', 'string'),
 ('Residence_type', 'string'),
 ('avg_glucose_level', 'int'),
 ('bmi', 'int'),
 ('smoking_status', 'string'),
 ('stroke', 'int')]

In [286]:
healthdata.describe()

DataFrame[summary: string, id: string, gender: string, age: string, hypertension: string, heart_disease: string, ever_married: string, work_type: string, Residence_type: string, avg_glucose_level: string, bmi: string, smoking_status: string, stroke: string]

In [287]:
healthdata.groupby("stroke").count().show()

+------+-----+
|stroke|count|
+------+-----+
|     1|  783|
|     0|42617|
+------+-----+



healthdata.groupby("work_type").count().sort(col("count").desc()).show()

In [289]:
healthdata.createOrReplaceTempView("tmptable")

In [290]:
spark.sql("SELECT work_type, count(work_type) as work_type_count FROM tmptable WHERE stroke == 1 \
          GROUP BY work_type ORDER BY work_type_count DESC").show()

+-------------+---------------+
|    work_type|work_type_count|
+-------------+---------------+
|      Private|            441|
|Self-employed|            251|
|     Govt_job|             89|
|     children|              2|
+-------------+---------------+



In [291]:
spark.sql("SELECT gender, count(gender) as count_gender, count(gender)*100/sum(count(gender)) \
          over() as percent FROM tmptable GROUP BY gender").show()

+------+------------+-------------------+
|gender|count_gender|            percent|
+------+------------+-------------------+
|Female|       25665|  59.13594470046083|
| Other|          11|0.02534562211981567|
|  Male|       17724|  40.83870967741935|
+------+------------+-------------------+



In [292]:
spark.sql("SELECT gender, count(gender), (COUNT(gender) * 100.0) /(SELECT count(gender) FROM tmptable WHERE gender == 'Male') as percentage FROM tmptable \
WHERE stroke = '1' and gender = 'Male' GROUP BY gender").show()

+------+-------------+----------------+
|gender|count(gender)|      percentage|
+------+-------------+----------------+
|  Male|          352|1.98600767321146|
+------+-------------+----------------+



In [293]:
spark.sql("SELECT gender, count(gender), (COUNT(gender) * 100.0) /(SELECT count(gender) FROM tmptable WHERE gender == 'Female') as percentage FROM tmptable \
WHERE stroke = '1' and gender = 'Female' GROUP BY gender").show()

+------+-------------+----------------+
|gender|count(gender)|      percentage|
+------+-------------+----------------+
|Female|          431|1.67932982661212|
+------+-------------+----------------+



In [294]:
spark.sql("SELECT age, count(age) as age_count, (COUNT(age) * 100) / (SELECT count(age) FROM tmptable) as percent from tmptable \
GROUP BY age").show()

+---+---------+------------------+
|age|age_count|           percent|
+---+---------+------------------+
| 31|      592|1.3640552995391706|
| 65|      529|1.2188940092165899|
| 53|      701|1.6152073732718895|
| 78|      698| 1.608294930875576|
| 34|      540|1.2442396313364055|
| 81|      454|1.0460829493087558|
| 28|      540|1.2442396313364055|
| 76|      336|0.7741935483870968|
| 27|      558|1.2857142857142858|
| 26|      503|1.1589861751152073|
| 44|      671|1.5460829493087558|
| 12|      398|0.9170506912442397|
| 22|      503|1.1589861751152073|
| 47|      684| 1.576036866359447|
|  1|      629|1.4493087557603688|
| 52|      721|1.6612903225806452|
| 13|      419|0.9654377880184332|
| 16|      426|0.9815668202764977|
|  6|      246|0.5668202764976958|
|  3|      402|0.9262672811059908|
+---+---------+------------------+
only showing top 20 rows



In [295]:
healthdata.filter((col("age") > 50) & (col("stroke") == 1)).count()

708

## Cleaning the data

In [296]:
healthdata.filter(healthdata["smoking_status"].isNull()).count()

13292

In [297]:
healthdata.where(healthdata["smoking_status"].isNull()).count()

13292

In [298]:
healthdata.columns

['id',
 'gender',
 'age',
 'hypertension',
 'heart_disease',
 'ever_married',
 'work_type',
 'Residence_type',
 'avg_glucose_level',
 'bmi',
 'smoking_status',
 'stroke']

In [299]:
for column in healthdata.columns:
    count_null = healthdata.filter(healthdata[column].isNull()).count()
    if count_null > 0:
        print("Total null values in column {} is {}".format(column, count_null))
    

Total null values in column bmi is 1462
Total null values in column smoking_status is 13292


In [300]:
healthdata1 = healthdata.na.fill("No Info", subset = ['smoking_status'])


In [301]:
healthdata1.select(mean(healthdata['bmi'])).collect()

[Row(avg(bmi)=28.155562973913874)]

In [302]:
from pyspark.sql.functions import mean
mean = healthdata1.select(mean(healthdata['bmi'])).collect()
mean_bmi = mean[0][0]
healthdata1 = healthdata1.na.fill(mean_bmi,['bmi'])

In [303]:
from pyspark.ml.feature import (VectorAssembler,OneHotEncoder,
                                StringIndexer)

In [304]:
gender_indexer = StringIndexer(inputCol="gender", outputCol="genderIndexed")
gender_encoder = OneHotEncoder(inputCol="genderIndexed", outputCol="genderVector")
ever_married_indexer = StringIndexer(inputCol="ever_married", outputCol="ever_marriedIndexed")
ever_married_encoder = OneHotEncoder(inputCol="ever_marriedIndexed", outputCol="ever_marriedVector")
work_type_indexer = StringIndexer(inputCol="work_type", outputCol="work_typeIndexed")
work_type_encoder = OneHotEncoder(inputCol="work_typeIndexed", outputCol="work_typeVector")
Residence_type_indexer = StringIndexer(inputCol="Residence_type", outputCol="Residence_typeIndexed")
Residence_type_encoder = OneHotEncoder(inputCol="Residence_typeIndexed", outputCol="Residence_typeVector")
smoking_status_indexer = StringIndexer(inputCol="smoking_status", outputCol="smoking_statusIndexed")
smoking_status_encoder = OneHotEncoder(inputCol="smoking_statusIndexed", outputCol="smoking_statusVector")

In [305]:
assembler = VectorAssembler(inputCols=['genderVector',
 'age',
 'hypertension',
 'heart_disease',
 'ever_marriedVector',
 'work_typeVector',
 'Residence_typeVector',
 'avg_glucose_level',
 'bmi',
 'smoking_statusVector'],outputCol='features')

In [306]:
from pyspark.ml.classification import DecisionTreeClassifier

In [307]:
dtc = DecisionTreeClassifier(labelCol='stroke',featuresCol='features')

In [308]:
from pyspark.ml import Pipeline

pipeline = Pipeline(stages=[gender_indexer, ever_married_indexer, work_type_indexer, Residence_type_indexer,
                           smoking_status_indexer, gender_encoder, ever_married_encoder, work_type_encoder,
                           Residence_type_encoder, smoking_status_encoder, assembler, dtc])

In [309]:
train_data,test_data = healthdata1.randomSplit([0.7,0.3])

In [310]:
# ,,,,,,,,,,,,,,,,,,,,dt_model = pipeline.fit(train_data)

In [311]:
dtc_pred = dt_model.transform(test_data)

In [313]:
dtc_pred.show()

+-----+------+---+------------+-------------+------------+---------+--------------+-----------------+---+---------------+------+-------------+-------------------+----------------+---------------------+---------------------+-------------+------------------+---------------+--------------------+--------------------+--------------------+--------------+--------------------+----------+
|   id|gender|age|hypertension|heart_disease|ever_married|work_type|Residence_type|avg_glucose_level|bmi| smoking_status|stroke|genderIndexed|ever_marriedIndexed|work_typeIndexed|Residence_typeIndexed|smoking_statusIndexed| genderVector|ever_marriedVector|work_typeVector|Residence_typeVector|smoking_statusVector|            features| rawPrediction|         probability|prediction|
+-----+------+---+------------+-------------+------------+---------+--------------+-----------------+---+---------------+------+-------------+-------------------+----------------+---------------------+---------------------+-----------

In [314]:
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
acc_evaluator = MulticlassClassificationEvaluator(labelCol="stroke", predictionCol="prediction", metricName="accuracy")

In [315]:
dtc_acc = acc_evaluator.evaluate(dtc_pred)

In [317]:
dtc_acc

0.9833206397562834

In [318]:
print('A Decision Tree algorithm had an accuracy of: {0:2.2f}%'.format(dtc_acc*100))

A Decision Tree algorithm had an accuracy of: 98.33%
