### Imports

In [1]:
import findspark
findspark.init()
from pyspark import SparkContext, SparkConf
sc = SparkContext(master="local[4]")

In [2]:
from pyspark.sql import SparkSession
spark = SparkSession.builder.getOrCreate()

In [3]:
import pyspark.sql.functions as F

In [4]:
from pyspark.sql.types import *

### Dataset

In [36]:
dataset_path = "./dataset_diabetes/diabetic_data.csv"

dataset = spark.read.csv(dataset_path,
                         sep=",",
                         header=True,
                         inferSchema=True,
                         nullValue="?")

In [37]:
type(dataset)

pyspark.sql.dataframe.DataFrame

In [38]:
dataset.count()

101766

In [39]:
dataset.dtypes

[('encounter_id', 'int'),
 ('patient_nbr', 'int'),
 ('race', 'string'),
 ('gender', 'string'),
 ('age', 'string'),
 ('weight', 'string'),
 ('admission_type_id', 'int'),
 ('discharge_disposition_id', 'int'),
 ('admission_source_id', 'int'),
 ('time_in_hospital', 'int'),
 ('payer_code', 'string'),
 ('medical_specialty', 'string'),
 ('num_lab_procedures', 'int'),
 ('num_procedures', 'int'),
 ('num_medications', 'int'),
 ('number_outpatient', 'int'),
 ('number_emergency', 'int'),
 ('number_inpatient', 'int'),
 ('diag_1', 'string'),
 ('diag_2', 'string'),
 ('diag_3', 'string'),
 ('number_diagnoses', 'int'),
 ('max_glu_serum', 'string'),
 ('A1Cresult', 'string'),
 ('metformin', 'string'),
 ('repaglinide', 'string'),
 ('nateglinide', 'string'),
 ('chlorpropamide', 'string'),
 ('glimepiride', 'string'),
 ('acetohexamide', 'string'),
 ('glipizide', 'string'),
 ('glyburide', 'string'),
 ('tolbutamide', 'string'),
 ('pioglitazone', 'string'),
 ('rosiglitazone', 'string'),
 ('acarbose', 'string'),

In [40]:
dataset.first().asDict()

{'encounter_id': 2278392,
 'patient_nbr': 8222157,
 'race': 'Caucasian',
 'gender': 'Female',
 'age': '[0-10)',
 'weight': None,
 'admission_type_id': 6,
 'discharge_disposition_id': 25,
 'admission_source_id': 1,
 'time_in_hospital': 1,
 'payer_code': None,
 'medical_specialty': 'Pediatrics-Endocrinology',
 'num_lab_procedures': 41,
 'num_procedures': 0,
 'num_medications': 1,
 'number_outpatient': 0,
 'number_emergency': 0,
 'number_inpatient': 0,
 'diag_1': '250.83',
 'diag_2': None,
 'diag_3': None,
 'number_diagnoses': 1,
 'max_glu_serum': 'None',
 'A1Cresult': 'None',
 'metformin': 'No',
 'repaglinide': 'No',
 'nateglinide': 'No',
 'chlorpropamide': 'No',
 'glimepiride': 'No',
 'acetohexamide': 'No',
 'glipizide': 'No',
 'glyburide': 'No',
 'tolbutamide': 'No',
 'pioglitazone': 'No',
 'rosiglitazone': 'No',
 'acarbose': 'No',
 'miglitol': 'No',
 'troglitazone': 'No',
 'tolazamide': 'No',
 'examide': 'No',
 'citoglipton': 'No',
 'insulin': 'No',
 'glyburide-metformin': 'No',
 'gli

In [41]:
dataset.printSchema()

root
 |-- encounter_id: integer (nullable = true)
 |-- patient_nbr: integer (nullable = true)
 |-- race: string (nullable = true)
 |-- gender: string (nullable = true)
 |-- age: string (nullable = true)
 |-- weight: string (nullable = true)
 |-- admission_type_id: integer (nullable = true)
 |-- discharge_disposition_id: integer (nullable = true)
 |-- admission_source_id: integer (nullable = true)
 |-- time_in_hospital: integer (nullable = true)
 |-- payer_code: string (nullable = true)
 |-- medical_specialty: string (nullable = true)
 |-- num_lab_procedures: integer (nullable = true)
 |-- num_procedures: integer (nullable = true)
 |-- num_medications: integer (nullable = true)
 |-- number_outpatient: integer (nullable = true)
 |-- number_emergency: integer (nullable = true)
 |-- number_inpatient: integer (nullable = true)
 |-- diag_1: string (nullable = true)
 |-- diag_2: string (nullable = true)
 |-- diag_3: string (nullable = true)
 |-- number_diagnoses: integer (nullable = true)
 |-

#### Most missing data at columns (indicated by ?): Weight, Payer Code, and Medical Speciality

In [42]:
# counting null values
dataset.select([F.count(F.when(F.isnull(c), c)).alias(c) for c in dataset.columns]).first().asDict()

{'encounter_id': 0,
 'patient_nbr': 0,
 'race': 2273,
 'gender': 0,
 'age': 0,
 'weight': 98569,
 'admission_type_id': 0,
 'discharge_disposition_id': 0,
 'admission_source_id': 0,
 'time_in_hospital': 0,
 'payer_code': 40256,
 'medical_specialty': 49949,
 'num_lab_procedures': 0,
 'num_procedures': 0,
 'num_medications': 0,
 'number_outpatient': 0,
 'number_emergency': 0,
 'number_inpatient': 0,
 'diag_1': 21,
 'diag_2': 358,
 'diag_3': 1423,
 'number_diagnoses': 0,
 'max_glu_serum': 0,
 'A1Cresult': 0,
 'metformin': 0,
 'repaglinide': 0,
 'nateglinide': 0,
 'chlorpropamide': 0,
 'glimepiride': 0,
 'acetohexamide': 0,
 'glipizide': 0,
 'glyburide': 0,
 'tolbutamide': 0,
 'pioglitazone': 0,
 'rosiglitazone': 0,
 'acarbose': 0,
 'miglitol': 0,
 'troglitazone': 0,
 'tolazamide': 0,
 'examide': 0,
 'citoglipton': 0,
 'insulin': 0,
 'glyburide-metformin': 0,
 'glipizide-metformin': 0,
 'glimepiride-pioglitazone': 0,
 'metformin-rosiglitazone': 0,
 'metformin-pioglitazone': 0,
 'change': 0,

In [43]:
# drop the columns
dataset = dataset.drop(dataset.weight)
dataset = dataset.drop(dataset.payer_code)
dataset = dataset.drop(dataset.medical_specialty)

In [44]:
dataset = dataset.dropna()

In [45]:
# counting null values again
dataset.select([F.count(F.when(F.isnull(c), c)).alias(c) for c in dataset.columns]).first().asDict()

{'encounter_id': 0,
 'patient_nbr': 0,
 'race': 0,
 'gender': 0,
 'age': 0,
 'admission_type_id': 0,
 'discharge_disposition_id': 0,
 'admission_source_id': 0,
 'time_in_hospital': 0,
 'num_lab_procedures': 0,
 'num_procedures': 0,
 'num_medications': 0,
 'number_outpatient': 0,
 'number_emergency': 0,
 'number_inpatient': 0,
 'diag_1': 0,
 'diag_2': 0,
 'diag_3': 0,
 'number_diagnoses': 0,
 'max_glu_serum': 0,
 'A1Cresult': 0,
 'metformin': 0,
 'repaglinide': 0,
 'nateglinide': 0,
 'chlorpropamide': 0,
 'glimepiride': 0,
 'acetohexamide': 0,
 'glipizide': 0,
 'glyburide': 0,
 'tolbutamide': 0,
 'pioglitazone': 0,
 'rosiglitazone': 0,
 'acarbose': 0,
 'miglitol': 0,
 'troglitazone': 0,
 'tolazamide': 0,
 'examide': 0,
 'citoglipton': 0,
 'insulin': 0,
 'glyburide-metformin': 0,
 'glipizide-metformin': 0,
 'glimepiride-pioglitazone': 0,
 'metformin-rosiglitazone': 0,
 'metformin-pioglitazone': 0,
 'change': 0,
 'diabetesMed': 0,
 'readmitted': 0}

In [46]:
dataset.count()

98053

In [47]:
# drop the columns with IDs
dataset = dataset.drop(dataset.encounter_id)
dataset = dataset.drop(dataset.patient_nbr)
dataset = dataset.drop(dataset.admission_type_id)
dataset = dataset.drop(dataset.discharge_disposition_id)
dataset = dataset.drop(dataset.admission_source_id)

#### convert string/categorical values into integers
these string columns will be encoded

In [48]:
string_columns = [column for column in dataset.columns if dataset.schema[column].jsonValue().get('type') == 'string']
string_columns.remove('readmitted')
string_columns.remove('diag_1')
string_columns.remove('diag_2')
string_columns.remove('diag_3')
string_columns

['race',
 'gender',
 'age',
 'max_glu_serum',
 'A1Cresult',
 'metformin',
 'repaglinide',
 'nateglinide',
 'chlorpropamide',
 'glimepiride',
 'acetohexamide',
 'glipizide',
 'glyburide',
 'tolbutamide',
 'pioglitazone',
 'rosiglitazone',
 'acarbose',
 'miglitol',
 'troglitazone',
 'tolazamide',
 'examide',
 'citoglipton',
 'insulin',
 'glyburide-metformin',
 'glipizide-metformin',
 'glimepiride-pioglitazone',
 'metformin-rosiglitazone',
 'metformin-pioglitazone',
 'change',
 'diabetesMed']

the `readmitted` column should be of numeric type

In [49]:
numeric_columns = [column for column in dataset.columns if dataset.schema[column].jsonValue().get('type') == 'integer']
numeric_columns

['time_in_hospital',
 'num_lab_procedures',
 'num_procedures',
 'num_medications',
 'number_outpatient',
 'number_emergency',
 'number_inpatient',
 'number_diagnoses']

In [50]:
for col in string_columns:
    for val in dataset.select(col).distinct().collect():
        dataset = dataset.withColumn(col + '_' + val[0], F.when(val[0] == dataset[col], 1).otherwise(0))

In [52]:
for col in string_columns:
    dataset = dataset.drop(dataset[col])

In [53]:
dataset.printSchema()

root
 |-- time_in_hospital: integer (nullable = true)
 |-- num_lab_procedures: integer (nullable = true)
 |-- num_procedures: integer (nullable = true)
 |-- num_medications: integer (nullable = true)
 |-- number_outpatient: integer (nullable = true)
 |-- number_emergency: integer (nullable = true)
 |-- number_inpatient: integer (nullable = true)
 |-- diag_1: string (nullable = true)
 |-- diag_2: string (nullable = true)
 |-- diag_3: string (nullable = true)
 |-- number_diagnoses: integer (nullable = true)
 |-- readmitted: string (nullable = true)
 |-- race_Caucasian: integer (nullable = false)
 |-- race_Other: integer (nullable = false)
 |-- race_AfricanAmerican: integer (nullable = false)
 |-- race_Hispanic: integer (nullable = false)
 |-- race_Asian: integer (nullable = false)
 |-- gender_Female: integer (nullable = false)
 |-- gender_Unknown/Invalid: integer (nullable = false)
 |-- gender_Male: integer (nullable = false)
 |-- age_[70-80): integer (nullable = false)
 |-- age_[90-100)

`diag_1`, `diag_2`, `diag_3` have many unique values which greatly increase the size of dataset
so we take only those values of these variables which occur more than 2500 times in the dataset

In [86]:
# for diag_1
diag1_dataset = dataset.select("diag_1").groupBy("diag_1").count().filter('count > 2500')

In [88]:
diag1_dataset.collect()

[Row(diag_1='428', count=6730),
 Row(diag_1='786', count=3900),
 Row(diag_1='410', count=3514),
 Row(diag_1='486', count=3412),
 Row(diag_1='414', count=6374),
 Row(diag_1='427', count=2701)]

In [89]:
for val in diag1_dataset.collect():
    dataset = dataset.withColumn('diag_1' + '_' + val[0], F.when(val[0] == dataset['diag_1'], 1).otherwise(0))

In [90]:
dataset = dataset.drop(dataset['diag_1'])

In [91]:
# for diag_2
diag2_dataset = dataset.select("diag_2").groupBy("diag_2").count().filter('count > 2500')

In [92]:
for val in diag2_dataset.collect():
    dataset = dataset.withColumn('diag_2' + '_' + val[0], F.when(val[0] == dataset['diag_2'], 1).otherwise(0))

In [93]:
dataset = dataset.drop(dataset['diag_2'])

In [94]:
# for diag_2
diag3_dataset = dataset.select("diag_3").groupBy("diag_3").count().filter('count > 2500')

In [95]:
for val in diag3_dataset.collect():
    dataset = dataset.withColumn('diag_3' + '_' + val[0], F.when(val[0] == dataset['diag_3'], 1).otherwise(0))

In [96]:
dataset = dataset.drop(dataset['diag_3'])

In [97]:
dataset.printSchema()

root
 |-- time_in_hospital: integer (nullable = true)
 |-- num_lab_procedures: integer (nullable = true)
 |-- num_procedures: integer (nullable = true)
 |-- num_medications: integer (nullable = true)
 |-- number_outpatient: integer (nullable = true)
 |-- number_emergency: integer (nullable = true)
 |-- number_inpatient: integer (nullable = true)
 |-- number_diagnoses: integer (nullable = true)
 |-- readmitted: string (nullable = true)
 |-- race_Caucasian: integer (nullable = false)
 |-- race_Other: integer (nullable = false)
 |-- race_AfricanAmerican: integer (nullable = false)
 |-- race_Hispanic: integer (nullable = false)
 |-- race_Asian: integer (nullable = false)
 |-- gender_Female: integer (nullable = false)
 |-- gender_Unknown/Invalid: integer (nullable = false)
 |-- gender_Male: integer (nullable = false)
 |-- age_[70-80): integer (nullable = false)
 |-- age_[90-100): integer (nullable = false)
 |-- age_[40-50): integer (nullable = false)
 |-- age_[10-20): integer (nullable = fa

now for the `readmitted` column

In [111]:
dataset = dataset.withColumn("readmitted", F.when(dataset["readmitted"] == "NO", 0).otherwise(1))

In [112]:
dataset.count()

98053

In [113]:
dataset.printSchema()

root
 |-- time_in_hospital: integer (nullable = true)
 |-- num_lab_procedures: integer (nullable = true)
 |-- num_procedures: integer (nullable = true)
 |-- num_medications: integer (nullable = true)
 |-- number_outpatient: integer (nullable = true)
 |-- number_emergency: integer (nullable = true)
 |-- number_inpatient: integer (nullable = true)
 |-- number_diagnoses: integer (nullable = true)
 |-- readmitted: integer (nullable = false)
 |-- race_Caucasian: integer (nullable = false)
 |-- race_Other: integer (nullable = false)
 |-- race_AfricanAmerican: integer (nullable = false)
 |-- race_Hispanic: integer (nullable = false)
 |-- race_Asian: integer (nullable = false)
 |-- gender_Female: integer (nullable = false)
 |-- gender_Unknown/Invalid: integer (nullable = false)
 |-- gender_Male: integer (nullable = false)
 |-- age_[70-80): integer (nullable = false)
 |-- age_[90-100): integer (nullable = false)
 |-- age_[40-50): integer (nullable = false)
 |-- age_[10-20): integer (nullable = 

parquet doesn't allow having " ,;{}()\\n\\t=" these characters in column names so we rename the age column

In [116]:
dataset = dataset.withColumnRenamed('age_[70-80)', 'age_[70-80]')
dataset = dataset.withColumnRenamed('age_[90-100)', 'age_[90-100]')
dataset = dataset.withColumnRenamed('age_[40-50)', 'age_[40-50]')
dataset = dataset.withColumnRenamed('age_[10-20)', 'age_[10-20]')
dataset = dataset.withColumnRenamed('age_[20-30)', 'age_[20-30]')
dataset = dataset.withColumnRenamed('age_[30-40)', 'age_[30-40]')
dataset = dataset.withColumnRenamed('age_[0-10)', 'age_[0-10]')
dataset = dataset.withColumnRenamed('age_[80-90)', 'age_[80-90]')
dataset = dataset.withColumnRenamed('age_[50-60)', 'age_[50-60]')
dataset = dataset.withColumnRenamed('age_[60-70)', 'age_[60-70]')

In [118]:
# dataset.printSchema()

In [119]:
dataset.write.parquet('./parquet_files/dataset.parquet')