In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pyspark
from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.master("local[*]").getOrCreate()

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting pyspark
  Downloading pyspark-3.3.1.tar.gz (281.4 MB)
[K     |████████████████████████████████| 281.4 MB 47 kB/s 
[?25hCollecting py4j==0.10.9.5
  Downloading py4j-0.10.9.5-py2.py3-none-any.whl (199 kB)
[K     |████████████████████████████████| 199 kB 69.5 MB/s 
[?25hBuilding wheels for collected packages: pyspark
  Building wheel for pyspark (setup.py) ... [?25l[?25hdone
  Created wheel for pyspark: filename=pyspark-3.3.1-py2.py3-none-any.whl size=281845512 sha256=1927d413320f28dd681ad82e0a533e0c971e5b590b9e86b62176da4937eb43b0
  Stored in directory: /root/.cache/pip/wheels/43/dc/11/ec201cd671da62fa9c5cc77078235e40722170ceba231d7598
Successfully built pyspark
Installing collected packages: py4j, pyspark
Successfully installed py4j-0.10.9.5 pyspark-3.3.1


In [None]:
df = spark.read.csv("/content/drive/MyDrive/stroke_data.csv", inferSchema=True, header=True)
df.describe().show()

+-------+------------------+------+------------------+-------------------+-------------------+------------+---------+--------------+------------------+------------------+---------------+-------------------+
|summary|                 0|gender|               age|       hypertension|      heart_disease|ever_married|work_type|Residence_type| avg_glucose_level|               bmi| smoking_status|             stroke|
+-------+------------------+------+------------------+-------------------+-------------------+------------+---------+--------------+------------------+------------------+---------------+-------------------+
|  count|             67135| 67135|             67135|              67135|              67135|       67135|    67135|         67135|             67135|             67135|          67135|              67135|
|   mean|           33568.0|  null| 51.95950845311693|0.16410218217025396|0.10142250688910405|        null|     null|          null|113.41439606762462| 29.16154047813857|  

In [None]:
print(f'DF has { df.count() } records')

DF has 67135 records


In [None]:
# Counts the number of each class
df.groupBy("stroke").count().sort("count", ascending=False).show()

+------+-----+
|stroke|count|
+------+-----+
|     1|40287|
|     0|26848|
+------+-----+



In [None]:
df.filter(col("stroke") == 1).createOrReplaceTempView("work_view")

privateWorkDF = spark.sql("SELECT * FROM work_view WHERE work_type='Private'")
print(f'DF has { privateWorkDF.count() } strokes on private work type')

selfWorkDF = spark.sql("SELECT * FROM work_view WHERE work_type='Self-employed'")
print(f'DF has { selfWorkDF.count() } strokes on self employment work type')

govWorkDF = spark.sql("SELECT * FROM work_view WHERE work_type='Govt_job'")
print(f'DF has { govWorkDF.count() } strokes on government work type')

childrenDF = spark.sql("SELECT * FROM work_view WHERE age<=14")
print(f'DF has { childrenDF.count() } strokes on children')

DF has 23711 strokes on private work type
DF has 10807 strokes on self employment work type
DF has 5164 strokes on government work type
DF has 574 strokes on children


In [None]:
# Creates view to use spark.sql
df.createOrReplaceTempView("gender_view")

# Counts gender
genderDF = spark.sql("""
    SELECT gender, count(gender) AS count
    FROM gender_view
    GROUP BY gender""")
genderDF.show()

+------+-----+
|gender|count|
+------+-----+
|Female|39530|
| Other|   11|
|  Male|27594|
+------+-----+



In [None]:
df.createOrReplaceTempView("hipertension_view")

# Counts hypertension for strokes
hipertensionWithStrokeDF = spark.sql("""
    SELECT hypertension, count(hypertension) AS count
    FROM hipertension_view
    WHERE stroke=1
    GROUP BY hypertension""")
hipertensionWithStrokeDF.show()

# Counts hypertension for non-strokes
hipertensionWithoutStrokeDF = spark.sql("""
    SELECT hypertension, count(hypertension) AS count
    FROM hipertension_view
    WHERE stroke=0
    GROUP BY hypertension""")
hipertensionWithoutStrokeDF.show()

+------------+-----+
|hypertension|count|
+------------+-----+
|           1| 8817|
|           0|31470|
+------------+-----+

+------------+-----+
|hypertension|count|
+------------+-----+
|           1| 2200|
|           0|24648|
+------------+-----+



In [None]:
# Counts strokes by age
df.filter(col("stroke") == 1).createOrReplaceTempView("age_view")
ageDF = spark.sql("""
    SELECT age, count(*) AS count
    FROM age_view
    GROUP BY age
    SORT BY count DESC""")
ageDF.show()

+----+-----+
| age|count|
+----+-----+
|79.0| 2916|
|78.0| 2279|
|80.0| 1858|
|81.0| 1738|
|82.0| 1427|
|77.0|  994|
|74.0|  987|
|63.0|  942|
|76.0|  892|
|70.0|  881|
|66.0|  848|
|75.0|  809|
|67.0|  801|
|57.0|  775|
|73.0|  759|
|65.0|  716|
|72.0|  709|
|68.0|  688|
|69.0|  677|
|71.0|  667|
+----+-----+
only showing top 20 rows



In [None]:
elderlyDF = df.filter((col("age") > 50) & (col("stroke") == 1))
print(f"{ elderlyDF.count() } people had a stroke after 50 years old")

28938 people had a stroke after 50 years old


In [None]:
# Gets avg glucose for each class
df.createOrReplaceTempView("glucose_view")
glucoseDF = spark.sql("""
    SELECT stroke, AVG(avg_glucose_level) AS avg_glucose_level
    FROM glucose_view
    GROUP BY stroke""")
glucoseDF.show()

+------+------------------+
|stroke| avg_glucose_level|
+------+------------------+
|     1|119.95307046938272|
|     0|103.60273130214506|
+------+------------------+



In [None]:
# Gets avg bmi for each class
df.createOrReplaceTempView("bmi_view")
bmiDF = spark.sql("""
    SELECT stroke, AVG(bmi) AS avg_bmi
    FROM bmi_view
    GROUP BY stroke""")
bmiDF.show()

+------+------------------+
|stroke|           avg_bmi|
+------+------------------+
|     1|29.942490629729495|
|     0|27.989678933253657|
+------+------------------+



In [None]:
from pyspark.ml import Pipeline
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
from pyspark.ml.feature import VectorAssembler, StringIndexer, OneHotEncoder

# Defines columns to predict a stroke
cols = ["age", "bmi", "hypertension", "heart_disease", "avg_glucose_level"]
vecAssembler = VectorAssembler(inputCols=cols, outputCol="features")
decisionTree = DecisionTreeClassifier(labelCol='stroke',featuresCol='features')

# Constructs pipeline
pipeline = Pipeline(stages=[vecAssembler, decisionTree])

# Splits dataset
train_data, test_data = df.randomSplit([0.7,0.3])

# Trains the model
pipelineModel = pipeline.fit(train_data)
predictionsDF = pipelineModel.transform(test_data)

In [None]:
# Evaluates accuracy
evaluator = MulticlassClassificationEvaluator(metricName="accuracy", labelCol='stroke')
print(f"Accuracy: { evaluator.evaluate(predictionsDF) }")

Accuracy: 0.6864701211038317


In [None]:
# Turns categorical into numeric columns
categoricalCols = ["gender", "smoking_status"]
stringIndexer = StringIndexer(inputCols=categoricalCols, outputCols=[x + "Index" for x in categoricalCols]) 
oneHotEncoder = OneHotEncoder(inputCols=stringIndexer.getOutputCols(), outputCols=[x + "OHE" for x in categoricalCols]) 

# Concatenates new numerical columns to older
cols = cols + [c + "OHE" for c in categoricalCols]
vecAssembler = VectorAssembler(inputCols=cols, outputCol="features")
decisionTree = DecisionTreeClassifier(labelCol='stroke',featuresCol='features')

# Redefines the pipeline
pipeline = Pipeline(stages=[stringIndexer, oneHotEncoder, vecAssembler, decisionTree])

train_data, test_data = df.randomSplit([0.7,0.3])
pipelineModel = pipeline.fit(train_data)
predictionsDF = pipelineModel.transform(test_data)

# Gets accuracy including the categorical columns
evaluator = MulticlassClassificationEvaluator(metricName="accuracy", labelCol='stroke')
print(f"Accuracy: { evaluator.evaluate(predictionsDF) }")

Accuracy: 0.8351829634073186


In [None]:
# Prints importance of each feature to the decision three
va = pipelineModel.stages[-2]
tree = pipelineModel.stages[-1]
importances = [*zip(va.getInputCols(), tree.featureImportances)]
print(importances)

[('age', 0.1662830647450132), ('bmi', 0.0008064603705455093), ('hypertension', 0.0), ('heart_disease', 0.0), ('avg_glucose_level', 0.007570020040638776), ('genderOHE', 0.000542931280392311), ('smoking_statusOHE', 0.0)]
