In [1]:
import findspark
findspark.init()

In [9]:
from pyspark.sql import SparkSession
from pyspark.sql import Row
from pyspark.sql.functions import isnan, isnull, when, count, col,expr
from pyspark.sql.types import StructField, StructType, StringType, LongType,IntegerType,StringType
from pyspark.sql import functions as fn
from pyspark.ml import feature,regression,Pipeline
from pyspark.sql.functions import *
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns
%matplotlib inline

In [3]:
spark = SparkSession.builder.getOrCreate()
sc = spark.sparkContext

Dataset link
https://archive.ics.uci.edu/ml/datasets/Diabetes+130-US+hospitals+for+years+1999-2008#

In [12]:
df = spark.read.format("csv").load("data/diabetic_data.csv",header=True)

## Data Cleaning

In [6]:
df.toPandas().head(2)

Unnamed: 0,encounter_id,patient_nbr,race,gender,age,weight,admission_type_id,discharge_disposition_id,admission_source_id,time_in_hospital,...,citoglipton,insulin,glyburide-metformin,glipizide-metformin,glimepiride-pioglitazone,metformin-rosiglitazone,metformin-pioglitazone,change,diabetesMed,readmitted
0,2278392,8222157,Caucasian,Female,[0-10),?,6,25,1,1,...,No,No,No,No,No,No,No,No,No,NO
1,149190,55629189,Caucasian,Female,[10-20),?,1,1,7,3,...,No,Up,No,No,No,No,No,Ch,Yes,>30


As noted in the above table, we have '?' mark in place of null values in our dataset. We will replace all the '?' with a none.

#### Cleaning data - converting ? into None

In [13]:
x = df.columns
for i in range(0,len(df.columns)):
    temp = x[i]
    new_column_udf = udf(lambda temp: np.nan if temp == "?" else temp, StringType())
    df = df.withColumn(df.columns[i], new_column_udf(df.columns[i]))

In [14]:
df.toPandas().head(2)

Unnamed: 0,encounter_id,patient_nbr,race,gender,age,weight,admission_type_id,discharge_disposition_id,admission_source_id,time_in_hospital,...,citoglipton,insulin,glyburide-metformin,glipizide-metformin,glimepiride-pioglitazone,metformin-rosiglitazone,metformin-pioglitazone,change,diabetesMed,readmitted
0,2278392,8222157,Caucasian,Female,[0-10),,6,25,1,1,...,No,No,No,No,No,No,No,No,No,NO
1,149190,55629189,Caucasian,Female,[10-20),,1,1,7,3,...,No,Up,No,No,No,No,No,Ch,Yes,>30


### Handling missing values

In [108]:
from pyspark.sql.functions import isnan,isnull, when, count, col
nullDF = df.select([count(when(isnan(c), c)).alias(c) for c in df.columns])
nullDF = nullDF.select([c for c in nullDF.columns if nullDF.first()[c]!=0]).toPandas()

In [109]:
nullDF

Unnamed: 0,race,weight,payer_code,medical_specialty,diag_1,diag_2,diag_3
0,2273,98569,40256,49949,21,358,1423


In [114]:
49949/df.count()

0.49082208203132677

We have over 98000 thousand null values in weight column which is over 96% perct of the data, we will drop the column since our 96 perct of the data is null values and no method or interpolation will be effective in this case.



In [72]:
nullDF.select('race').collect()[0][0]

2273

## Exploratory Data Analysis

In [None]:
df.groupby('race').count().toPandas()

In [None]:
raceDF = df.groupby('race').count().toPandas()
plt.figure(figsize=(14,5))
sns.set(style="white")

g = sns.barplot(y=raceDF['count'],x=raceDF['race'],color=(0.21569, 0.21569 ,0.21569))

In [None]:
raceDF

In [None]:
conditionGroupDF = pd.DataFrame(df.groupby(['condition'])['drugName'].nunique().sort_values(ascending=False)[:20]).reset_index()

plt.figure(figsize=(14,5))
sns.set(style="white")

g = sns.barplot(y=conditionGroupDF['drugName'],x=conditionGroupDF['condition'],color=(0.21569, 0.21569 ,0.21569))

g.set_xticklabels(rotation=90,labels=conditionGroupDF['condition'])
g.set(yticklabels=[])
g.set_ylabel('')
g.axes.set_title("Drug count per condition",fontsize=20)
g.set_xlabel('Conditions',fontsize=14)
sns.despine(left=True)

locs, labels = plt.xticks() # get the current tick locations and labels

for loc, label in zip(locs, labels):
    count = conditionGroupDF.iloc[loc].drugName
    plt.text(loc, count-8, '{:0.0f}'.format(count), ha = 'center',va='top', color = 'w',size=14)

In [None]:
#gender = df.groupby('gender').count().toPandas()
#fig = plt.bar(gender["gender"],gender["count"]);

In [None]:
age = df.groupby('admission_type_id').count().toPandas()
age

In [None]:
#plt.figure(figsize=(12,8))
#fig = plt.bar(age["age"],age["count"])