### Import Necessary Libraries

In [None]:
import os
import pandas as pd
import matplotlib.pyplot as plt
import numpy as np
import seaborn as sns

### File Methods

In [None]:
def import_file(import_df_location, file_name):
    # File location provided check.
    if (import_df_location == ""):
        raise ValueError("import_df_location cannot be empty.")

    # Import file.
    df_imported = pd.read_csv(os.path.join(import_df_location, file_name))
    column_names_list = df_imported.columns.tolist()
    print(df_imported.head())

    return df_imported, column_names_list

In [None]:
def bar_plot_by_readmitted_categories(df_imported, column_to_group_by_categories, bar_distance_from_center_of_line, bar_width, x_size, y_size, x_label_rotation, dataset_is_grouped, before_edits):
    
    # Group the dataframe by the 'readmitted' column. After that, unstack the result (rotate the result by the index from rows to columns). Combinations of
    #  values that don't exist will be represented by NaN. Fill in these NaN values with 0, which in our case fills in values that weren't included in
    #  one side of the groupby because their counts for that category are 0. Finally, stack the result again in order to get back to the normal groupby
    #  format, which in our case is the dataframe grouped by the 'readmitted' column's values.
    list_of_column_names = []

    if (dataset_is_grouped == False):
        grouped_churn_count = df_imported.groupby(['readmitted', column_to_group_by_categories]).size().unstack(fill_value=0).stack()

        for idx, _ in grouped_churn_count.items():
            list_of_column_names.append(idx[1])

    else:
        list_of_column_names = df_imported.columns

    unique_list_of_list_of_column_names = []
    for current_item in list_of_column_names:
        if current_item not in unique_list_of_list_of_column_names:
            unique_list_of_list_of_column_names.append(current_item)

    x = np.arange(len(unique_list_of_list_of_column_names))
    plt.figure(figsize=(x_size, y_size))
    if (before_edits == True):
        plt.bar(x-bar_distance_from_center_of_line, grouped_churn_count['NO'], bar_width, color='cyan', label='no')
        plt.bar(0, grouped_churn_count['<30'], bar_width, color='orange', label='<30')
        plt.bar(x+bar_distance_from_center_of_line, grouped_churn_count['>30'], bar_width, color='green', label='>30')
    else:
        # Check variables.
        plt.bar(x-bar_distance_from_center_of_line, grouped_churn_count[0], bar_width, color='red', label='>30 or no')
        plt.bar(x+bar_distance_from_center_of_line, grouped_churn_count[1], bar_width, color='blue', label='<30')
    plt.title(f"Plot of 'readmitted' by number of occurrence of each value in '{column_to_group_by_categories}'.")
    plt.xticks(x, unique_list_of_list_of_column_names, rotation=x_label_rotation, ha='right')
    plt.legend()
    plt.show()

    print(grouped_churn_count)

In [None]:
def plot_number_occurrences_per_readmitted_values(df_imported, column_names_list, bar_distance_from_center_of_line, bar_width, y_size, x_label_rotation, dataset_is_grouped, before_edits):
    for element in column_names_list:
        if ((element != "encounter_id") and (element != "patient_nbr") and (element != "readmitted")):
            chart_width = df_imported[element].nunique()
            
            if (chart_width > 200):
                chart_width = 200

            bar_plot_by_readmitted_categories(df_imported, element, bar_distance_from_center_of_line, bar_width, chart_width, y_size, x_label_rotation, dataset_is_grouped, before_edits)

In [None]:
def check_for_empty(df_imported, column_name):
    data_type_check = str(df_imported[column_name].dtype)
    nan_value_count = df_imported[column_name].isna().sum() + (df_imported[column_name].astype(str) == '?').sum()
    empty_value_not_nan_count = (df_imported[column_name].astype(str).str.len() == 0).sum()
    only_whitespace_value_count = df_imported[column_name].astype(str).str.isspace().sum()

    string_to_return = ""
    string_to_return += f"Data type: {data_type_check}\n"
    string_to_return += f"NaN value count (? counts as NaN): {str(nan_value_count)}\n"
    string_to_return += f"Empty value (not listed as NaN) count: {str(empty_value_not_nan_count)}\n"
    string_to_return += f"Only whitespace value count: {str(only_whitespace_value_count)}\n"

    return string_to_return, data_type_check, nan_value_count, empty_value_not_nan_count, only_whitespace_value_count


def check_duplicates(df_imported, column_name):
    duplicate_element_count = df_imported[column_name].duplicated().sum()
    duplicate_element_list = df_imported[column_name][df_imported[column_name].duplicated()].unique().tolist()

    string_to_return = ""
    string_to_return += f"Duplicate Element Count: {str(duplicate_element_count)}\n"
    string_to_return += f"Duplicate Elements: {str(duplicate_element_list)}\n"

    return string_to_return, duplicate_element_count, duplicate_element_list


def check_element_counts(df_imported, column_name):
    total_element_count = df_imported[column_name].count()
    unique_element_count = df_imported[column_name].nunique()

    string_to_return = ""
    string_to_return += f"Total Element Count: {str(total_element_count)}\n"
    string_to_return += f"Unique Element Count: {str(unique_element_count)}\n"

    return string_to_return, total_element_count, unique_element_count


def data_check(df_imported, column_name):
    # check_for_empty_output[0] = string_to_return
    # check_for_empty_output[1] = data_type_check
    # check_for_empty_output[2] = nan_value_count
    # check_for_empty_output[3] = empty_value_not_nan_count
    # check_for_empty_output[4] = only_whitespace_value_count
    check_for_empty_output = check_for_empty(df_imported, column_name)

    # check_duplicates_output[0] = string_to_return
    # check_duplicates_output[1] = duplicate_element_count
    # check_duplicates_output[2] = duplicate_element_list
    check_duplicates_output = check_duplicates(df_imported, column_name)

    # check_element_counts_output[0] = string_to_return
    # check_element_counts_output[1] = total_element_count
    # check_element_counts_output[2] = unique_element_count
    check_element_counts_output = check_element_counts(df_imported, column_name)


    string_to_return = ""
    string_to_return += check_element_counts_output[0]

    if (check_duplicates_output[1] == 0):
        string_to_return += f"No duplicate elements in '{column_name}'.\n"
    elif (len(check_duplicates_output[2]) == check_element_counts_output[2]):
        string_to_return += f"Duplicate Element Count: {str(check_duplicates_output[1])}\n"
        string_to_return += f"Number of unique elements in '{column_name}' = number of duplicate elements in '{column_name}'.\n"
    else:
        string_to_return += check_duplicates_output[0]

    string_to_return += check_for_empty_output[0]


    return string_to_return


def perform_data_check_on_columns(df_imported):
    column_names_list = df_imported.columns.tolist()

    list_of_data_check = ""
    for element in column_names_list:
        list_of_data_check += f"Column Name: {element}\n"
        list_of_data_check += "Data check:\n"
        list_of_data_check += data_check(df_imported, element)
        list_of_data_check += "\n"
    
    return list_of_data_check

In [None]:
def get_chart_width(df_imported, current_element, default_max):
    current_column_values = df_imported[current_element]

    chart_width = current_column_values.nunique()

    if (chart_width > default_max):
        chart_width = default_max
    
    return chart_width, current_column_values


def do_basic_analysis_and_charts(columns_in_question, numerical_or_categorical, df_imported, x_label_rotation):

    for element in columns_in_question:
        chart_width, current_column_values = get_chart_width(df_imported, element, 200)

        if (numerical_or_categorical == 'numerical'):
            fig, (ax_box, ax_hist) = plt.subplots(2, sharex=True, gridspec_kw={"height_ratios": (.15, .85)}, figsize=(chart_width, 10))
            sns.boxplot(x=current_column_values, ax=ax_box)
            sns.histplot(x=current_column_values, ax=ax_hist)
            plt.xticks(rotation=x_label_rotation, ha='right')
            plt.show()

        elif (numerical_or_categorical == 'categorical'):
            try:
                unique_categories_dataframe = {name: df_imported[current_column_values == name] for name in current_column_values.unique()}
                unique_categories_sorted_dataframe = sorted(unique_categories_dataframe)
            except:
                unique_categories_dataframe = {name: df_imported[current_column_values == name] for name in current_column_values.fillna('').unique()}
                unique_categories_sorted_dataframe = sorted(unique_categories_dataframe)

            plt.figure(figsize=(chart_width, 10))
            sns.countplot(x=element, data=df_imported, orient='v', order=unique_categories_sorted_dataframe)
            plt.title(f"Countplot of unique values in '{element}'.")
            plt.xlabel('Values')
            plt.xticks(rotation=x_label_rotation, ha='right')
            plt.ylabel('Counts')
            plt.show()
            
        else:
            raise ValueError("'numerical_or_categorical' must be either 'numerical' or 'categorical'.") 

        if (numerical_or_categorical == 'numerical'):
            current_mean = current_column_values.mean()
            current_median = current_column_values.median()
            current_mode = current_column_values.mode()
            current_midrange = (current_column_values.min() + current_column_values.max())/2
            print(f"\nMeasures of center for the '{element}' column:\nMean: {current_mean}\nMedian: {current_median}\nMode: {current_mode}\nMidrange: {current_midrange}\n")

            current_standard_deviation = current_column_values.std()
            current_variance = current_column_values.var()
            current_quantile = current_column_values.quantile()
            print(f"Measures of spread for the '{element}' column:\nStandard Deviation: {current_standard_deviation}\nVariance: {current_variance}\nQuantile: {current_quantile}\n\n")
        
        if (numerical_or_categorical == 'categorical'):
            print(f"Number of occurrences of each category in '{element}':\n{str(current_column_values.value_counts().sort_index())}")
            

            print("\n")

In [None]:
def create_correlation_heatmap(df_imported, columns_in_question_numerical):
    current_column_correlation = df_imported[columns_in_question_numerical].corr()
    sns.heatmap(current_column_correlation, cmap="coolwarm")
    plt.title(f"Heatmap of features in '{columns_in_question_numerical[0]}', '{columns_in_question_numerical[1]}', and '{columns_in_question_numerical[2]}'.", fontsize=16)
    plt.show()

In [None]:
def do_violinplots(columns_in_question_categorical, columns_in_question_numerical, df_imported, x_label_rotation):
    for element_1 in columns_in_question_categorical:
        chart_width_1, _ = get_chart_width(df_imported, element_1, 200)
        
        
        for element_2 in columns_in_question_numerical:
            if (element_1 != element_2):
                chart_width_2, _ = get_chart_width(df_imported, element_2, 200)

                plt.figure(figsize=(chart_width_1, chart_width_2))
                sns.violinplot(x = element_1, y = element_2, data = df_imported)
                plt.xticks(rotation=x_label_rotation, ha='right')
                plt.show()

In [None]:
def do_pairplot(columns_in_question, df_imported, x_label_rotation):
    max_chart_width = -1
    
    for element in columns_in_question:
        chart_width, _ = get_chart_width(df_imported, element, 200)

        chart_width *= 2

        if (max_chart_width < chart_width):
            max_chart_width = chart_width

    plt.figure(figsize=(max_chart_width, max_chart_width))
    sns.pairplot(df_imported[columns_in_question])
    plt.xticks(rotation=x_label_rotation, ha='right')
    plt.show()

In [None]:
def do_stripplot(main_column, columns_in_question, df_imported, x_label_rotation):
    for element in columns_in_question:
        chart_width, _ = get_chart_width(df_imported, element, 200)
        chart_width *= 2

        plt.figure(figsize=(chart_width, chart_width))
        sns.stripplot(x=element, y=main_column, data=df_imported)
        plt.title(f"'{main_column}' Over '{element}'")
        plt.xticks(rotation=x_label_rotation, ha='right')
        plt.show()

### EDA Method

In [None]:
def perform_diabetes_EDA(before_edits, file_location, file_name, what_EDA_to_do, columns_in_question_numerical, columns_in_question_categorical, y_size, x_label_rotation, dataset_is_grouped):
    if (before_edits == True):
        bar_distance_from_center_of_line = 0.1
        bar_width = 0.1
        before_edits = True
    else:
        bar_distance_from_center_of_line = 0.2
        bar_width = 0.2
        before_edits = False

    # import_file(import_df_location, file_name)
    df_imported, column_names_list = import_file(file_location, file_name)

    if (what_EDA_to_do.get("plot_number_occurrences_per_readmitted_values") == 1):
        # plot_number_occurrences_per_readmitted_values(df_imported, column_names_list, bar_distance_from_center_of_line, bar_width, y_size, x_label_rotation, dataset_is_grouped, before_edits)
        plot_number_occurrences_per_readmitted_values(df_imported, column_names_list, bar_distance_from_center_of_line, bar_width, y_size, x_label_rotation, dataset_is_grouped, before_edits)

    if (what_EDA_to_do.get("perform_data_check_on_columns") == 1):
        # perform_data_check_on_columns(df_imported)
        print(perform_data_check_on_columns(df_imported))

    # do_basic_analysis_and_charts(columns_in_question, numerical_or_categorical, df_imported, x_label_rotation)
    if (what_EDA_to_do.get("do_basic_analysis_and_charts_numerical") == 1):
        do_basic_analysis_and_charts(columns_in_question_numerical, 'numerical', df_imported, x_label_rotation)
    if (what_EDA_to_do.get("do_basic_analysis_and_charts_categorical") == 1):
        do_basic_analysis_and_charts(columns_in_question_categorical, 'categorical', df_imported, x_label_rotation)

    # create_correlation_heatmap(df_imported, columns_in_question_numerical)
    if (what_EDA_to_do.get("create_correlation_heatmap") == 1):
        create_correlation_heatmap(df_imported, columns_in_question_numerical)

    # do_violinplots(columns_in_question_categorical, columns_in_question_numerical, df_imported, x_label_rotation)
    if (what_EDA_to_do.get("do_violinplots") == 1):
        do_violinplots(columns_in_question_categorical, columns_in_question_numerical, df_imported, x_label_rotation)

    # do_pairplot(columns_in_question, df_imported, x_label_rotation)
    if (what_EDA_to_do.get("do_pairplot") == 1):
        do_pairplot(columns_in_question_numerical, df_imported, x_label_rotation)

    # do_stripplot(main_column, columns_in_question, df_imported, x_label_rotation)
    if (what_EDA_to_do.get("do_stripplot") == 1):
        do_stripplot('readmitted', columns_in_question_categorical, df_imported, x_label_rotation)

### Provided Information

#### File names and locations.

In [None]:
before_file_location = "<Insert the location of the diabetic_data.csv file here.>"

after_file_location = "<Insert the location of the output.csv file here.>"
after_file_name = "output.csv"

#### Numerical and categorical columns to analyze.

In [None]:
columns_in_question_numerical = [
    'time_in_hospital',
    'num_lab_procedures',
    'num_procedures',
    'num_medications',
    'number_outpatient',
    'number_emergency',
    'number_inpatient',
    'number_diagnoses'
    ]

columns_in_question_categorical = [
    'race',
    'gender',
    'age',
    'admission_type_id',
    'discharge_disposition_id',
    'admission_source_id',
    'payer_code',
    'medical_specialty',
    'diag_1',
    'diag_2',
    'diag_3',
    'max_glu_serum',
    'A1Cresult',
    'metformin',
    'repaglinide',
    'nateglinide',
    'chlorpropamide',
    'glimepiride',
    'acetohexamide',
    'glipizide',
    'glyburide',
    'tolbutamide',
    'pioglitazone',
    'rosiglitazone',
    'acarbose',
    'miglitol',
    'troglitazone',
    'tolazamide',
    'examide',
    'citoglipton',
    'insulin',
    'glyburide-metformin',
    'glipizide-metformin',
    'glimepiride-pioglitazone',
    'metformin-rosiglitazone',
    'metformin-pioglitazone',
    'change',
    'diabetesMed'
    ]

#### EDA to perform on the original "diabetic_data.csv" file and an output file.

In [None]:
what_EDA_to_do_before = {
    "plot_number_occurrences_per_readmitted_values": 1,
    "perform_data_check_on_columns": 1,
    "do_basic_analysis_and_charts_numerical": 1,
    "do_basic_analysis_and_charts_categorical": 1,
    "create_correlation_heatmap": 1,
    "do_violinplots": 1,
    "do_pairplot": 1,
    "do_stripplot": 1
}

what_EDA_to_do_after = {
    "plot_number_occurrences_per_readmitted_values": 1,
    "perform_data_check_on_columns": 1,
    "do_basic_analysis_and_charts_numerical": 1,
    "do_basic_analysis_and_charts_categorical": 1,
    "create_correlation_heatmap": 1,
    "do_violinplots": 0,
    "do_pairplot": 0,
    "do_stripplot": 0
}

### Perform EDA on the original "diabetic_data.csv" file and an output file.

In [None]:
# perform_diabetes_EDA(before_edits, file_location, file_name, what_EDA_to_do, columns_in_question_numerical, columns_in_question_categorical, y_size, x_label_rotation, dataset_is_grouped)
perform_diabetes_EDA(True, before_file_location, "diabetic_data.csv", what_EDA_to_do_before, columns_in_question_numerical, columns_in_question_categorical, 10, 90, False)

In [None]:
# perform_diabetes_EDA(before_edits, file_location, file_name, what_EDA_to_do, columns_in_question_numerical, columns_in_question_categorical, y_size, x_label_rotation, dataset_is_grouped)
perform_diabetes_EDA(False, before_file_location, after_file_name, what_EDA_to_do_after, columns_in_question_numerical, columns_in_question_categorical, 10, 90, False)