In [1]:
from pyspark.sql import SparkSession
import pyspark.sql as sparksql
spark = SparkSession.builder.appName('stroke').getOrCreate()
train = spark.read.csv('Stroke_Prediction_Train.csv', inferSchema=True,header=True)

In [2]:
train.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: double (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



## Summary of numerical variables

In [3]:
from pyspark.sql.functions import format_number
cols = ["age", "hypertension", "heart_disease", "avg_glucose_level", "bmi", "stroke"]
summary_train = train.select(cols).describe()
summary_train.select(summary_train['summary'], 
                    format_number(summary_train['age'].cast('float'),3).alias('Age'),
                    format_number(summary_train['hypertension'].cast('float'),3).alias('Hypertension'),
                    format_number(summary_train['heart_disease'].cast('float'),3).alias('Heart_disease'),
                    format_number(summary_train['avg_glucose_level'].cast('float'),3).alias('Avg_glucose_level'),
                    format_number(summary_train['bmi'].cast('float'),3).alias('BMI'),
                    format_number(summary_train['stroke'].cast('float'),3).alias('Stroke'),
                    ).show()

+-------+----------+------------+-------------+-----------------+----------+----------+
|summary|       Age|Hypertension|Heart_disease|Avg_glucose_level|       BMI|    Stroke|
+-------+----------+------------+-------------+-----------------+----------+----------+
|  count|43,400.000|  43,400.000|   43,400.000|       43,400.000|41,938.000|43,400.000|
|   mean|    42.218|       0.094|        0.048|          104.483|    28.605|     0.018|
| stddev|    22.520|       0.291|        0.213|           43.112|     7.770|     0.133|
|    min|     0.080|       0.000|        0.000|           55.000|    10.100|     0.000|
|    max|    82.000|       1.000|        1.000|          291.050|    97.600|     1.000|
+-------+----------+------------+-------------+-----------------+----------+----------+



## Correlations between numeric variables

In [4]:
from pyspark.ml.stat import Correlation
from pyspark.ml.feature import VectorAssembler

# convert to vector column first

cols = ["age", "hypertension", "heart_disease", "avg_glucose_level", "stroke"]
vector_col = "corr"
assembler = VectorAssembler(inputCols=cols, outputCol=vector_col)
train_vector = assembler.transform(train).select(vector_col)

# get correlation matrix
matrix = Correlation.corr(train_vector, vector_col)
matrix.collect()[0]["pearson({})".format(vector_col)].values

array([1.        , 0.27216879, 0.25018839, 0.23762684, 0.15604896,
       0.27216879, 1.        , 0.11977703, 0.16021129, 0.07533225,
       0.25018839, 0.11977703, 1.        , 0.14693807, 0.11376294,
       0.23762684, 0.16021129, 0.14693807, 1.        , 0.0789171 ,
       0.15604896, 0.07533225, 0.11376294, 0.0789171 , 1.        ])

## Exploring categorical variables

In [5]:
train.groupBy('gender').count().show()

+------+-----+
|gender|count|
+------+-----+
|Female|25665|
| Other|   11|
|  Male|17724|
+------+-----+



In [6]:
train.crosstab('gender', 'stroke').show()

+-------------+-----+---+
|gender_stroke|    0|  1|
+-------------+-----+---+
|        Other|   11|  0|
|         Male|17372|352|
|       Female|25234|431|
+-------------+-----+---+



In [7]:
train.groupBy('ever_married').count().show()

+------------+-----+
|ever_married|count|
+------------+-----+
|          No|15462|
|         Yes|27938|
+------------+-----+



In [8]:
train.crosstab('ever_married', 'stroke').show()

+-------------------+-----+---+
|ever_married_stroke|    0|  1|
+-------------------+-----+---+
|                 No|15382| 80|
|                Yes|27235|703|
+-------------------+-----+---+



In [9]:
train.groupBy('smoking_status').count().orderBy('count').show()

+---------------+-----+
| smoking_status|count|
+---------------+-----+
|         smokes| 6562|
|formerly smoked| 7493|
|           null|13292|
|   never smoked|16053|
+---------------+-----+



In [10]:
train.crosstab('smoking_status', 'stroke').show()

+---------------------+-----+---+
|smoking_status_stroke|    0|  1|
+---------------------+-----+---+
|      formerly smoked| 7272|221|
|               smokes| 6429|133|
|         never smoked|15769|284|
|                 null|13147|145|
+---------------------+-----+---+



In [11]:
train.groupBy('work_type').count().orderBy('count', ascending = False).show()

+-------------+-----+
|    work_type|count|
+-------------+-----+
|      Private|24834|
|Self-employed| 6793|
|     children| 6156|
|     Govt_job| 5440|
| Never_worked|  177|
+-------------+-----+



In [12]:
train.crosstab('work_type', 'stroke').show()

+----------------+-----+---+
|work_type_stroke|    0|  1|
+----------------+-----+---+
|        children| 6154|  2|
|    Never_worked|  177|  0|
|   Self-employed| 6542|251|
|         Private|24393|441|
|        Govt_job| 5351| 89|
+----------------+-----+---+



In [13]:
train.groupBy('Residence_type').count().show()

+--------------+-----+
|Residence_type|count|
+--------------+-----+
|         Urban|21756|
|         Rural|21644|
+--------------+-----+



In [14]:
train.crosstab('work_type', 'stroke').show()

+----------------+-----+---+
|work_type_stroke|    0|  1|
+----------------+-----+---+
|        children| 6154|  2|
|    Never_worked|  177|  0|
|   Self-employed| 6542|251|
|         Private|24393|441|
|        Govt_job| 5351| 89|
+----------------+-----+---+



In [15]:
train.groupBy('stroke').count().show()

+------+-----+
|stroke|count|
+------+-----+
|     1|  783|
|     0|42617|
+------+-----+



In [16]:
train.createOrReplaceTempView('table')

## Exploring using SQL queries

In [17]:
spark.sql("SELECT gender, count(gender) as count_gender, count(gender)*100/sum(count(gender)) over() as percent FROM table GROUP BY gender").show()

+------+------------+-------------------+
|gender|count_gender|            percent|
+------+------------+-------------------+
|Female|       25665|  59.13594470046083|
| Other|          11|0.02534562211981567|
|  Male|       17724|  40.83870967741935|
+------+------------+-------------------+



In [18]:
spark.sql("SELECT gender, count(gender), (COUNT(gender) * 100.0) /(SELECT count(gender) FROM table WHERE gender == 'Male') as percentage FROM table WHERE stroke = '1' and gender = 'Male' GROUP BY gender").show()

+------+-------------+----------------+
|gender|count(gender)|      percentage|
+------+-------------+----------------+
|  Male|          352|1.98600767321146|
+------+-------------+----------------+



In [19]:
spark.sql("SELECT gender, count(gender), (COUNT(gender) * 100.0) /(SELECT count(gender) FROM table WHERE gender == 'Female') as percentage FROM table WHERE stroke = '1' and gender = 'Female' GROUP BY gender").show()

+------+-------------+----------------+
|gender|count(gender)|      percentage|
+------+-------------+----------------+
|Female|          431|1.67932982661212|
+------+-------------+----------------+



In [20]:
spark.sql("SELECT age, count(age) as age_count FROM table WHERE stroke == 1 GROUP BY age ORDER BY age_count DESC").show()

+----+---------+
| age|age_count|
+----+---------+
|79.0|       70|
|78.0|       57|
|80.0|       49|
|81.0|       43|
|82.0|       36|
|70.0|       25|
|76.0|       24|
|74.0|       24|
|77.0|       24|
|67.0|       23|
|75.0|       23|
|72.0|       21|
|68.0|       20|
|69.0|       20|
|59.0|       20|
|71.0|       19|
|57.0|       19|
|63.0|       18|
|65.0|       18|
|66.0|       17|
+----+---------+
only showing top 20 rows



In [21]:
spark.sql("Select count(*) from table where stroke == 1 and age > 50").show()

+--------+
|count(1)|
+--------+
|     708|
+--------+



In [22]:
# fill in missing values
train_f = train.na.fill('No Info', subset=['smoking_status'])
# fill in miss values with mean
from pyspark.sql.functions import mean
mean = train_f.select(mean(train_f['bmi'])).collect()
mean_bmi = mean[0][0]
train_f = train_f.na.fill(mean_bmi,['bmi'])

In [23]:
from pyspark.ml.feature import VectorAssembler,OneHotEncoder,StringIndexer

In [24]:
gender_indexer = StringIndexer(inputCol = 'gender', outputCol = 'genderIndex')
gender_encoder = OneHotEncoder(inputCol = 'genderIndex', outputCol = 'genderVec')

ever_married_indexer = StringIndexer(inputCol = 'ever_married', outputCol = 'ever_marriedIndex')
ever_married_encoder = OneHotEncoder(inputCol = 'ever_marriedIndex', outputCol = 'ever_marriedVec')

work_type_indexer = StringIndexer(inputCol = 'work_type', outputCol = 'work_typeIndex')
work_type_encoder = OneHotEncoder(inputCol = 'work_typeIndex', outputCol = 'work_typeVec')

Residence_type_indexer = StringIndexer(inputCol = 'Residence_type', outputCol = 'Residence_typeIndex')
Residence_type_encoder = OneHotEncoder(inputCol = 'Residence_typeIndex', outputCol = 'Residence_typeVec')

smoking_status_indexer = StringIndexer(inputCol = 'smoking_status', outputCol = 'smoking_statusIndex')
smoking_status_encoder = OneHotEncoder(inputCol = 'smoking_statusIndex', outputCol = 'smoking_statusVec')

In [25]:
train.printSchema()

root
 |-- id: integer (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: double (nullable = true)
 |-- hypertension: integer (nullable = true)
 |-- heart_disease: integer (nullable = true)
 |-- ever_married: string (nullable = true)
 |-- work_type: string (nullable = true)
 |-- Residence_type: string (nullable = true)
 |-- avg_glucose_level: double (nullable = true)
 |-- bmi: double (nullable = true)
 |-- smoking_status: string (nullable = true)
 |-- stroke: integer (nullable = true)



In [26]:
assembler = VectorAssembler(inputCols=['genderVec',
 'age',
 'hypertension',
 'heart_disease',
 'ever_marriedVec',
 'work_typeVec',
 'Residence_typeVec',
 'avg_glucose_level',
 'bmi',
 'smoking_statusVec'],outputCol='features')

In [27]:
train_data,test_data = train_f.randomSplit([0.6,0.4])

In [None]:
from pyspark.ml.classification import DecisionTreeClassifier
from pyspark.ml.classification import LogisticRegression
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.classification import GBTClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator

from pyspark.ml import Pipeline

classifiers = [
    LogisticRegression(labelCol='stroke',featuresCol='features', maxIter=1000),
    DecisionTreeClassifier(labelCol='stroke',featuresCol='features', maxDepth=7),
    RandomForestClassifier(labelCol='stroke',featuresCol='features'),
]

model_names = ['Logistic Regression', 'Decision Tree', 'Random Forest']
i = 0

for cls in classifiers:
    pipeline = Pipeline(stages=[gender_indexer, ever_married_indexer, work_type_indexer, Residence_type_indexer,
                           smoking_status_indexer, gender_encoder, ever_married_encoder, work_type_encoder,
                           Residence_type_encoder, smoking_status_encoder, assembler, cls])

    model = pipeline.fit(train_data)
    model_predictions = model.transform(test_data)
    
    acc_evaluator = MulticlassClassificationEvaluator(labelCol="stroke", predictionCol="prediction", metricName="f1")
    acc = acc_evaluator.evaluate(model_predictions)
    
    print(model_names[i]+' has an F1 score of: {0:2.2f}%'.format(acc*100))
    i += 1

    

## Conclusion : Though Random Forest and Logistic Regression have the best same F1 score, it is wise to go with Logistic Regression classifier as it is a simpler, less computationally expensive model.

In [None]:
Predict_data = spark.read.csv('Stroke_Prediction_Test.csv', inferSchema=True,header=True)

In [None]:
cls = LogisticRegression(labelCol = 'stroke', featuresCol='features', maxIter=1000)

In [None]:
pipeline_predict = Pipeline(stages=[gender_indexer, ever_married_indexer, work_type_indexer, Residence_type_indexer,
                           smoking_status_indexer, gender_encoder, ever_married_encoder, work_type_encoder,
                           Residence_type_encoder, smoking_status_encoder, assembler, cls])
model = pipeline_predict.fit(train_data)

In [None]:
final_predictions = model.transform(Predict_data)

In [None]:
final_predictions.printSchema()

In [None]:
final_predictions.head(1)