# Final Project Part 3
- DataViz, Spring 2020
- Author: Tzu-Kun Hsiao
- NetID: tkhsiao2

In [1]:
import pandas as pd
import numpy as np
import bqplot
import matplotlib.pyplot as plt
import matplotlib
import ipywidgets
from ipywidgets import Layout


%matplotlib inline

# 1. Central interactive visualization featuring the primary dataset

## 1.1 Import dataset and get the subset of 30 most common clinical conditions
- Dataset: [State Summary of Inpatient Charge Data by Medicare Severity Diagnosis Related Group (MS-DRG), FY2017](https://data.cms.gov/Medicare-Inpatient/State-Summary-of-Inpatient-Charge-Data-by-Medicare/q5hc-zvkx)


- The dataset contains 736 clinical conditions. The number of clinical conditions is too large for making a meaningful plot. Hence, in this project, only the 30 most common clinical conditions are selected as the subjects of interest for making the plot.

In [2]:
data = pd.read_csv("ProjectData/State_Summary_of_Inpatient_Charge_Data_by_Medicare_Severity_Diagnosis_Related_Group__MS-DRG___FY2017.csv")
data = data.round(2)

data['drg_id'] = data['DRG Definition'].str.split(' - ').str[0].str.strip()
data['drg_name'] = data['DRG Definition'].str.split(' - ').str[1].str.strip()


print(data.shape)
print(data.count())
print(data.nunique())

data.head()


(27543, 8)
DRG Definition               27543
Provider State               27543
Total Discharges             27543
Average Covered Charges      27543
Average Total Payments       27543
Average Medicare Payments    27543
drg_id                       27543
drg_name                     27543
dtype: int64
DRG Definition                 736
Provider State                  51
Total Discharges              2395
Average Covered Charges      27510
Average Total Payments       27381
Average Medicare Payments    27368
drg_id                         736
drg_name                       736
dtype: int64


Unnamed: 0,DRG Definition,Provider State,Total Discharges,Average Covered Charges,Average Total Payments,Average Medicare Payments,drg_id,drg_name
0,001 - HEART TRANSPLANT OR IMPLANT OF HEART ASS...,OR,11,561665.0,281456.09,229310.82,1,HEART TRANSPLANT OR IMPLANT OF HEART ASSIST SY...
1,001 - HEART TRANSPLANT OR IMPLANT OF HEART ASS...,MD,13,324398.0,299865.77,279902.0,1,HEART TRANSPLANT OR IMPLANT OF HEART ASSIST SY...
2,001 - HEART TRANSPLANT OR IMPLANT OF HEART ASS...,UT,13,597063.69,231009.08,215347.08,1,HEART TRANSPLANT OR IMPLANT OF HEART ASSIST SY...
3,001 - HEART TRANSPLANT OR IMPLANT OF HEART ASS...,MS,14,814803.86,206432.86,193311.14,1,HEART TRANSPLANT OR IMPLANT OF HEART ASSIST SY...
4,001 - HEART TRANSPLANT OR IMPLANT OF HEART ASS...,AL,16,875798.44,249810.06,125897.38,1,HEART TRANSPLANT OR IMPLANT OF HEART ASSIST SY...


In [3]:
data.describe().round(2)

Unnamed: 0,Total Discharges,Average Covered Charges,Average Total Payments,Average Medicare Payments
count,27543.0,27543.0,27543.0,27543.0
mean,349.52,67967.66,17415.74,14695.23
std,1250.61,73462.97,18357.6,16179.96
min,11.0,4445.22,2834.77,1999.08
25%,29.0,29001.01,7886.54,6185.24
50%,75.0,46451.97,11950.43,9847.71
75%,238.0,78801.01,19664.98,16797.6
max,63412.0,1831198.83,351416.79,329469.14


In [4]:
discharge_count = data.groupby(['drg_id'])[['Total Discharges']].sum()
discharge_count = discharge_count.rename(columns={'Total Discharges':'sum_of_total_discharges'})
discharge_count = discharge_count.sort_values(['sum_of_total_discharges'], ascending=False)
discharge_top30 = discharge_count.head(30)
discharge_top30

Unnamed: 0_level_0,sum_of_total_discharges
drg_id,Unnamed: 1_level_1
871,598649
470,511763
291,361006
190,217482
189,168238
872,160020
392,158685
690,141441
683,137183
378,134270


In [5]:
#data[['state_desc', 'drg_desc_id']].groupby(['state_desc']).count()

In [6]:
#data.drg_desc_name.unique()

In [7]:
data.columns

Index(['DRG Definition', 'Provider State', 'Total Discharges',
       'Average Covered Charges', 'Average Total Payments',
       'Average Medicare Payments', 'drg_id', 'drg_name'],
      dtype='object')

In [8]:
data_subset = data.loc[data.drg_id.isin(discharge_top30.index)].copy()
print(data_subset.shape)
print(data_subset.count())
print(data_subset.nunique())

data_subset.head()

(1530, 8)
DRG Definition               1530
Provider State               1530
Total Discharges             1530
Average Covered Charges      1530
Average Total Payments       1530
Average Medicare Payments    1530
drg_id                       1530
drg_name                     1530
dtype: int64
DRG Definition                 30
Provider State                 51
Total Discharges             1298
Average Covered Charges      1530
Average Total Payments       1529
Average Medicare Payments    1528
drg_id                         30
drg_name                       30
dtype: int64


Unnamed: 0,DRG Definition,Provider State,Total Discharges,Average Covered Charges,Average Total Payments,Average Medicare Payments,drg_id,drg_name
1773,064 - INTRACRANIAL HEMORRHAGE OR CEREBRAL INFA...,WY,60,30648.02,15041.63,13974.37,64,INTRACRANIAL HEMORRHAGE OR CEREBRAL INFARCTION...
1774,064 - INTRACRANIAL HEMORRHAGE OR CEREBRAL INFA...,VT,135,34225.96,15740.94,13847.07,64,INTRACRANIAL HEMORRHAGE OR CEREBRAL INFARCTION...
1775,064 - INTRACRANIAL HEMORRHAGE OR CEREBRAL INFA...,AK,143,90033.17,20097.44,18608.55,64,INTRACRANIAL HEMORRHAGE OR CEREBRAL INFARCTION...
1776,064 - INTRACRANIAL HEMORRHAGE OR CEREBRAL INFA...,RI,242,67740.71,19552.56,17194.55,64,INTRACRANIAL HEMORRHAGE OR CEREBRAL INFARCTION...
1777,064 - INTRACRANIAL HEMORRHAGE OR CEREBRAL INFA...,MT,243,32401.18,11705.04,10407.91,64,INTRACRANIAL HEMORRHAGE OR CEREBRAL INFARCTION...


## 1.2 Transform data into a numpy 3-D array for making the interactive visualization

In [9]:
def tranport_data(data_df, col_vals, idx_vals, val_column):
    df_temp = pd.DataFrame(columns=col_vals, index=idx_vals).fillna(0)

    for col_val in col_vals:
        data_slice = data_df.loc[data_df['drg_id'] == col_val]
        df_vals = data_slice[['Provider State', val_column]].values.tolist()
        for val_pair in df_vals:
            df_temp.at[val_pair[0], col_val] = val_pair[1]
    
    return df_temp
    

In [10]:
grid_vals_array = np.empty((51, 31))

states = data_subset['Provider State'].unique().tolist()
states.sort()
drg = data_subset['drg_id'].unique().tolist()
drg.sort()
cost_items = ['Average Covered Charges', 'Average Total Payments', 'Average Medicare Payments']

discharges = tranport_data(data_subset, drg, states, 'Total Discharges')
print(discharges.shape)
covered_charges = tranport_data(data_subset, drg, states, 'Average Covered Charges')
print(covered_charges.shape)
total_payments = tranport_data(data_subset, drg, states, 'Average Total Payments')
print(total_payments.shape)
medicare_payments = tranport_data(data_subset, drg, states, 'Average Medicare Payments')
print(medicare_payments.shape)

discharges = discharges.to_numpy()
covered_charges = covered_charges.to_numpy()
total_payments = total_payments.to_numpy()
medicare_payments = medicare_payments.to_numpy()

grid_vals_array = np.concatenate((discharges, covered_charges, total_payments, medicare_payments), axis=0)
grid_vals_array = grid_vals_array.reshape((4, 51, 30))

grid_vals_array.shape

(51, 30)
(51, 30)
(51, 30)
(51, 30)


(4, 51, 30)

In [11]:
#discharges

In [12]:
#grid_vals_array[0:, :, :]

In [13]:
#grid_vals_array[0, :, :]

In [14]:
#grid_vals_array.max()

In [15]:
#np.arange(1, 32)

In [16]:
#grid_vals_array[0, :, :]

In [17]:
#grid_vals_array[0, :, :]+0.00001

## 1.3 Plot the interactive visualization

In [18]:
# CREATE LABEL - 1
mySelectedLabel = ipywidgets.Label()

# CREATE HEATMAP ELEMENTS -2 
# scale
x_sc = bqplot.OrdinalScale()
y_sc = bqplot.OrdinalScale()
c_sc = bqplot.ColorScale(scheme='BuPu')

# axis
x_ax = bqplot.Axis(scale=x_sc, tick_rotate=90, label='DRG_codes', 
                   tick_style={'font-size': 10})
y_ax = bqplot.Axis(scale=y_sc, orientation='vertical', label = 'State', 
                   tick_style={'font-size': 10})
c_ax = bqplot.ColorAxis(scale=c_sc, orientation='vertical', side='right')

# mark
heat_map = bqplot.GridHeatMap(color = grid_vals_array[0, :, :]/grid_vals_array[0, :, :].max(), 
                              scales = {'color':c_sc, 'row':y_sc, 'column': x_sc}, 
                              interactions = {'click':'select'},
                              anchor_style = {'fill':'blue'}, 
                              row=states, column=drg, 
                              stroke='white')


# CREATE LINE PLOT ELEMENTS -3
# scale
x_scl = bqplot.OrdinalScale()
y_scl = bqplot.LinearScale(min=0, max=grid_vals_array.max()/1000) # set range of y
c_scl = bqplot.ColorScale()

# axis
x_axl = bqplot.Axis(scale=x_scl, tick_rotate=10, grid_lines='none')

y_axl = bqplot.Axis(scale=y_scl, tick_format='0.2f', orientation='vertical',
                    label='US dollars (in thousands)')
c_axl = bqplot.ColorAxis(scale=c_scl)

# mark
bar_ticks = ['covered_charges', 'total_payments', 'medicare_payments']
bar_plot = bqplot.Bars(x = bar_ticks, y = np.arange(3)*0,
                       colors = ['#609EFC'],
                       scales = {'y':y_scl, 'x': x_scl})

# LINKING LINE PLOT WITH HEATMAP -4
def on_selected_3d(change):
    if len(change['owner'].selected) == 1: # only 1 selection per time allowed
        j, i = change['owner'].selected[0] # get x & y indices
        #print(j, i)
        v = grid_vals_array[0, j, i] #grab data value at the (x, y) location (i.e., data points of each year)
        mySelectedLabel.value = 'Discharges = ' + str(v)
        # now including updates to the line plot
        bar_plot.y = np.array([grid_vals_array[1, j, i]/1000, 
                               grid_vals_array[2, j, i]/1000, 
                               grid_vals_array[3, j, i]/1000])

# create interaction through "observe"
heat_map.observe(on_selected_3d, 'selected')

# CREATE FIG OBJECT - 5
fig_heatmap = bqplot.Figure(marks=[heat_map], axes=[c_ax, x_ax, y_ax])
fig_bar = bqplot.Figure(marks=[bar_plot], axes=[x_axl, y_axl])

fig_heatmap.layout.min_width='600px'
fig_heatmap.layout.min_height='600px'

fig_bar.layout.min_width='400px'
fig_bar.layout.min_height='600px'


plots = ipywidgets.HBox([fig_heatmap, fig_bar], layout=Layout(width='100%', height='100%'))

myDashboard = ipywidgets.VBox([mySelectedLabel, plots], layout=Layout(width='100%', height='100%'))
myDashboard

VBox(children=(Label(value=''), HBox(children=(Figure(axes=[ColorAxis(orientation='vertical', scale=ColorScale…