In [1]:
from pyspark.sql import SparkSession
from pyspark.ml.feature import VectorAssembler
from pyspark.ml.feature import StringIndexer
from pyspark.ml.feature import OneHotEncoder
from pyspark.ml.classification import RandomForestClassifier
from pyspark.ml.evaluation import MulticlassClassificationEvaluator
import numpy as np
import pandas as pd

# import data types
from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from pyspark.sql.functions import isnan, when, count, col
from pyspark.ml import Pipeline
# set up the session
spark = SparkSession.builder \
        .master("local") \
        .appName("diabetes_classifier") \
        .getOrCreate()

In [2]:
#load all data files
#pandas is needed to read .xpt files then spark is used to create a dataframe

demo = spark.createDataFrame(pd.read_sas("DEMO_J.XPT") \
                             .drop(columns = ['SDDSRVYR', 'RIDSTATR', 'RIDAGEMN',
                                           'RIDRETH1', 'RIDEXMON', 'RIDEXAGM', 'DMQMILIZ', 'DMQADFC',
                                           'DMDBORN4', 'DMDCITZN', 'DMDYRSUS', 'DMDEDUC3', 'DMDMARTL',
                                           'RIDEXPRG', 'SIALANG', 'SIAPROXY', 'SIAINTRP', 'FIALANG', 'FIAPROXY',
                                           'FIAINTRP', 'MIALANG', 'MIAPROXY', 'MIAINTRP', 'AIALANGA', 'DMDHHSIZ',
                                           'DMDFMSIZ', 'DMDHHSZA', 'DMDHHSZB', 'DMDHHSZE', 'DMDHRGND', 'DMDHRAGZ',
                                           'DMDHREDZ', 'DMDHRMAZ', 'DMDHSEDZ', 'WTINT2YR', 'WTMEC2YR', 'SDMVPSU',
                                           'SDMVSTRA', 'INDFMIN2', 'INDFMPIR']) \
                             .rename({'SEQN': 'ParticipantID', 
                                    'RIAGENDR': 'Gender',
                                    'RIDAGEYR': 'Age',
                                    'RIDRETH3': 'Race',
                                    'DMDEDUC2': 'Education_Level',
                                    'INDHHIN2': 'Household_income'}, axis = 1))

diet_1 = spark.createDataFrame(pd.read_sas("DR1IFF_J.XPT")\
                              .drop(columns = ['WTDRD1', 'WTDR2D', 'DR1ILINE', 'DR1DRSTZ', 'DR1EXMER', 'DRABF',
                                               'DRDINT', 'DR1DBIH', 'DR1DAY', 'DR1LANG', 'DR1CCMNM', 'DR1CCMTX',
                                               'DR1_020', 'DR1_030Z', 'DR1FS', 'DR1_040Z', 'DR1IFDCD', 'DR1IGRMS',
                                               'DR1IATOC', 'DR1IATOA',
                                               'DR1IRET', 'DR1IVARA', 'DR1IACAR', 'DR1IBCAR', 'DR1ICRYP', 'DR1ILYCO',
                                               'DR1ILZ', 'DR1IVB1', 'DR1IVB2', 'DR1INIAC', 'DR1IVB6', 'DR1IFOLA',
                                               'DR1IFA', 'DR1IFF', 'DR1IFDFE', 'DR1ICHL', 'DR1IVB12', 'DR1IB12A',
                                               'DR1IVC', 'DR1IVD', 'DR1IVK', 'DR1ICALC', 'DR1IPHOS', 'DR1IMAGN',
                                               'DR1IIRON', 'DR1IZINC', 'DR1ICOPP', 'DR1ISODI', 'DR1IPOTA', 'DR1ISELE',
                                               'DR1ICAFF', 'DR1ITHEO', 'DR1IMOIS', 'DR1IS040', 'DR1IS060',
                                               'DR1IS080', 'DR1IS100', 'DR1IS120', 'DR1IS140', 'DR1IS160', 'DR1IS180',
                                               'DR1IM161', 'DR1IM181', 'DR1IM201', 'DR1IM221', 'DR1IP182', 'DR1IP183',
                                               'DR1IP184', 'DR1IP204', 'DR1IP205', 'DR1IP225', 'DR1IP226']) \
                               .rename({'SEQN': 'ParticipantID', 
                                        'DR1IKCAL': 'Energy',
                                        'DR1IPROT': 'Protein',
                                        'DR1ICARB': 'Carbohydrates',
                                        'DR1ISUGR': 'Total_sugar',
                                        'DR1IFIBE': 'Fiber',
                                        'DR1ITFAT': 'Total_fat',
                                        'DR1ISFAT': 'Sat_fat',
                                        'DR1IMFAT': 'Monounsat_fat',
                                        'DR1IPFAT': 'Polyunsat_fat',
                                        'DR1ICHOL': 'cholesterol',
                                        'DR1IALCO': 'Alcohol'}, axis = 1))

bp = spark.createDataFrame(pd.read_sas("BPX_J.XPT") \
                          .drop(columns = ['PEASCCT1', 'BPXCHR', 'BPAARM', 'BPACSZ', 'BPXPLS',
                                           'BPXPTY', 'BPXML1', 'BPXSY1', 'BPXDI1', 'BPAEN1', 
                                           'BPAEN2', 'BPXSY3', 'BPXDI3', 'BPAEN3', 'BPXSY4', 
                                           'BPXDI4', 'BPAEN4']) \
                          .rename({'SEQN' : 'ParticipantID', 
                                    'BPXPULS' : 'Pulse', #regular = 1, irregular = 2
                                    'BPXSY2' : 'SysBP',
                                    'BPXDI2' : 'DiasBP'}, axis = 1))

bm = spark.createDataFrame(pd.read_sas("BMX_J.XPT") \
                          .drop(columns = ['BMDSTATS', 'BMIWT', 'BMXRECUM', 'BMIRECUM', 'BMXHEAD',
                                           'BMIHEAD', 'BMXHT', 'BMIHT', 'BMXLEG', 'BMILEG', 'BMXARML',
                                           'BMIARML', 'BMXARMC', 'BMIARMC', 'BMIWAIST', 'BMXHIP','BMIHIP']) \
                          .rename({'SEQN': 'ParticipantID',
                                    'BMXWT': 'Weight(kg)',
                                    'BMXBMI': 'BMI',
                                    'BMXWAIST': 'Waist_Circum'}, axis = 1))

ins = spark.createDataFrame(pd.read_sas("INS_J.XPT") \
                           .drop(columns = ['WTSAF2YR', 'LBDINSI', 'LBDINLC']) \
                           .rename({'SEQN': 'ParticipantID',
                                    'LBXIN': 'Insulin'}, axis = 1))

glu = spark.createDataFrame(pd.read_sas("GLU_J.XPT") \
                           .drop(columns = ['WTSAF2YR', 'LBXGLU']) \
                           .rename({'SEQN': 'ParticipantID',
                                  'LBDGLUSI':'Glucose'}, axis = 1))

alc = spark.createDataFrame(pd.read_sas("ALQ_J.XPT") \
                           .drop(columns = ['ALQ111', 'ALQ121', 'ALQ142', 'ALQ170']) \
                           .rename({'SEQN': 'ParticipantID',
                                  'ALQ130':'Avg_Drinks',
                                  'ALQ270':'4-5_Drinks',
                                  'ALQ280':'8+Drinks',
                                  'ALQ290':'12+Drinks',
                                  'ALQ151':'4-5DrinksDaily'}, axis = 1))

bpq = spark.createDataFrame(pd.read_sas("BPQ_J.XPT") \
                           .drop(columns = ['BPQ030', 'BPD035', 'BPQ040A', 'BPQ050A',
                                           'BPQ060', 'BPQ070', 'BPQ090D', 'BPQ100D']) \
                           .rename({'SEQN': 'ParticipantID',
                                  'BPQ020': 'HighBP', #1 = Yes, 2 = no
                                  'BPQ080': 'HighChol'},axis = 1)).filter('HighBP != 9') #1 = Yes, 2 = No
    
pa = spark.createDataFrame(pd.read_sas("PAQ_J.XPT") \
                           .drop(columns = ['PAQ610', 'PAD615', 'PAQ625', 'PAD630',
                                           'PAQ640', 'PAD645', 'PAQ655', 'PAD660',
                                           'PAQ670', 'PAD675', 'PAD680'])\
                          .rename({'SEQN': 'ParticipantID',
                                    'PAQ605': 'VigWork',
                                    'PAQ620': 'ModWork',
                                    'PAQ635': 'Walk_bike',
                                    'PAQ650': 'VigActivity',
                                    'PAQ665': 'ModActivity'},axis = 1))
    
diab = spark.createDataFrame(pd.read_sas("DIQ_J.XPT") \
                             .drop(columns = ['DIQ172', 'DIQ175C', 'DIQ175D', 'DIQ175E', 'DIQ175F', 'DIQ175G',
                                               'DIQ175H', 'DIQ175I', 'DIQ175J', 'DIQ175K', 'DIQ175L', 'DIQ175M',
                                               'DIQ175N', 'DIQ175O', 'DIQ175P', 'DIQ175Q', 'DIQ175R', 'DIQ175S',
                                               'DIQ175T', 'DIQ175U', 'DIQ175V', 'DIQ175W', 'DIQ175X',
                                               'DIQ050', 'DID060', 'DIQ060U', 'DIQ070', 'DIQ230', 'DIQ240', 'DID250',
                                               'DID260', 'DIQ260U', 'DIQ275', 'DIQ280', 'DIQ291', 'DIQ300S', 'DIQ300D',
                                               'DID310S', 'DID310D', 'DID320', 'DID330', 'DID341', 'DID350', 'DIQ350U',
                                               'DIQ360', 'DIQ080', 'DIQ175B']) \
                            .rename({'SEQN': 'ParticipantID',
                                    'DIQ010': 'label', #label 
                                    'DID040': 'Diagnosed_age',
                                    'DIQ160': 'Prediabetes',
                                    'DIQ170': 'Diabetes_risk',
                                    'DIQ175A': 'Fam_hist',
                                    'DIQ180': 'Blood_test'}, #1 = yes, 2 = no
                                    axis = 1))

diab = diab.na.fill({'Fam_hist': 0})
diab = diab.withColumn("Fam_hist", \
             when(diab["Fam_hist"] == 99, 0).otherwise(diab["Fam_hist"]))
diab = diab.withColumn("Fam_hist", \
              when(diab["Fam_hist"] == 10, 1).otherwise(diab["Fam_hist"]))
diab = diab.where("label<4")

smq = spark.createDataFrame(pd.read_sas("SMQ_J.XPT") \
                            .drop(columns = ['SMD030', 'SMQ050Q', 'SMQ050U', 'SMD057',
                                           'SMQ078', 'SMD641', 'SMD650', 'SMD093', 'SMDUPCA', 'SMD100BR',
                                           'SMD100FL', 'SMD100MN', 'SMD100LN', 'SMD100TR', 'SMD100NI', 'SMD100CO',
                                           'SMQ621', 'SMD630', 'SMQ661', 'SMQ665A', 'SMQ665B', 'SMQ665C',
                                           'SMQ665D', 'SMQ670', 'SMQ848', 'SMQ852Q', 'SMQ852U', 'SMQ895',
                                           'SMQ905', 'SMQ915', 'SMAQUEX2']) \
                           .rename({'SEQN': 'ParticipantID',
                                  'SMQ020': '100Cigs',
                                  'SMQ040': 'Smoke_Cigs', #1 = Yes, 2 = No
                                  'SMQ890': 'Smoke_Cigar',
                                  'SMQ900': 'E_cig',
                                  'SMQ910': 'Smokeless_tobacco'},axis = 1))

In [3]:
#join all dataframes on participant ID
df = demo.join(bp, on=['ParticipantID'], how='left').join(diet_1, on=['ParticipantID'], how='left').join(bm, on=['ParticipantID'], how='left').join(ins, on=['ParticipantID'], how='left') \
.join(glu, on=['ParticipantID'], how='left').join(alc, on=['ParticipantID'], how='left').join(bpq, on=['ParticipantID'], how='left').join(pa, on=['ParticipantID'], how='left') \
.join(smq, on=['ParticipantID'], how='left').join(diab, on=['ParticipantID'], how='left')

In [4]:
# examine schema
df.printSchema()

root
 |-- ParticipantID: double (nullable = true)
 |-- Gender: double (nullable = true)
 |-- Age: double (nullable = true)
 |-- Race: double (nullable = true)
 |-- Education_Level: double (nullable = true)
 |-- Household_income: double (nullable = true)
 |-- Pulse: double (nullable = true)
 |-- SysBP: double (nullable = true)
 |-- DiasBP: double (nullable = true)
 |-- Energy: double (nullable = true)
 |-- Protein: double (nullable = true)
 |-- Carbohydrates: double (nullable = true)
 |-- Total_sugar: double (nullable = true)
 |-- Fiber: double (nullable = true)
 |-- Total_fat: double (nullable = true)
 |-- Sat_fat: double (nullable = true)
 |-- Monounsat_fat: double (nullable = true)
 |-- Polyunsat_fat: double (nullable = true)
 |-- cholesterol: double (nullable = true)
 |-- Alcohol: double (nullable = true)
 |-- Weight(kg): double (nullable = true)
 |-- BMI: double (nullable = true)
 |-- Waist_Circum: double (nullable = true)
 |-- Insulin: double (nullable = true)
 |-- Glucose: double

In [5]:
# we need to fix a lot of these to be ints or floats
df.dtypes

[('ParticipantID', 'double'),
 ('Gender', 'double'),
 ('Age', 'double'),
 ('Race', 'double'),
 ('Education_Level', 'double'),
 ('Household_income', 'double'),
 ('Pulse', 'double'),
 ('SysBP', 'double'),
 ('DiasBP', 'double'),
 ('Energy', 'double'),
 ('Protein', 'double'),
 ('Carbohydrates', 'double'),
 ('Total_sugar', 'double'),
 ('Fiber', 'double'),
 ('Total_fat', 'double'),
 ('Sat_fat', 'double'),
 ('Monounsat_fat', 'double'),
 ('Polyunsat_fat', 'double'),
 ('cholesterol', 'double'),
 ('Alcohol', 'double'),
 ('Weight(kg)', 'double'),
 ('BMI', 'double'),
 ('Waist_Circum', 'double'),
 ('Insulin', 'double'),
 ('Glucose', 'double'),
 ('Avg_Drinks', 'double'),
 ('4-5_Drinks', 'double'),
 ('8+Drinks', 'double'),
 ('12+Drinks', 'double'),
 ('4-5DrinksDaily', 'double'),
 ('HighBP', 'double'),
 ('HighChol', 'double'),
 ('VigWork', 'double'),
 ('ModWork', 'double'),
 ('Walk_bike', 'double'),
 ('VigActivity', 'double'),
 ('ModActivity', 'double'),
 ('100Cigs', 'double'),
 ('Smoke_Cigs', 'double

In [6]:
#drop duplicates
df = df.drop().dropDuplicates()

In [7]:
# all columns we want
cols = ['ParticipantID',
        'label',
       'Gender',
       'Age',
       'Race',
       'Fam_hist',
       'Smoke_Cigs',
       'BMI',
        'HighBP' 
       ]
data = df.select(cols)


# set gender to 0/1 instead of 1/2
data = data.withColumn('Gender', data.Gender - 1)

#drop na values
data = data.na.drop('any')

In [8]:
data.count()
#data.take(5)

27297

In [9]:
data.take(5)

[Row(ParticipantID=94021.0, label=2.0, Gender=1.0, Age=38.0, Race=3.0, Fam_hist=0.0, Smoke_Cigs=3.0, BMI=34.5, HighBP=2.0),
 Row(ParticipantID=94021.0, label=2.0, Gender=1.0, Age=38.0, Race=3.0, Fam_hist=0.0, Smoke_Cigs=3.0, BMI=34.5, HighBP=2.0),
 Row(ParticipantID=94021.0, label=2.0, Gender=1.0, Age=38.0, Race=3.0, Fam_hist=0.0, Smoke_Cigs=3.0, BMI=34.5, HighBP=2.0),
 Row(ParticipantID=94021.0, label=2.0, Gender=1.0, Age=38.0, Race=3.0, Fam_hist=0.0, Smoke_Cigs=3.0, BMI=34.5, HighBP=2.0),
 Row(ParticipantID=94021.0, label=2.0, Gender=1.0, Age=38.0, Race=3.0, Fam_hist=0.0, Smoke_Cigs=3.0, BMI=34.5, HighBP=2.0)]

In [10]:
data.write.parquet("data.parquet")