In [None]:
from ludwig.api import LudwigModel
import pandas as pd
import numpy as np
import yaml
from sklearn.model_selection import train_test_split
import datetime
import re

dataset = "../MERGED_1998_2014_TSA.csv"
data = pd.read_csv(dataset, dtype = {
    'dispuniform':object,
    'female':object,
    'age':object,
    'h_contrl':object,
    'year': np.int32,
    'tran_in':object
}, low_memory = False)

In [None]:
# functions for column processing
def convert_payer(row):
    val = str(row['pay1'])
    if val == '1':
        return 'Medicare'
    if val == '2':
        return 'Medicaid'
    if val == '3':
        return 'Private'
    if val == '4':
        return 'Self-pay'
    else:
        return 'Other'
def convert_race(row):
    val = str(row['race'])
    conversions = {
        '1':'White',
        '2':'Black',
        '3':'Hispanic',
        '4':'Asian or Pacific Islander',
        '5':'Native American',
        '6':'Other'
    }
    for x in conversions.keys():
        if val == x:
            return conversions[x]
    return 'Unknown'

def convert_ownership(h_contrl):
    conversions = {
        '1':'Government, nonfederal',
        '2':'Private not-profit',
        '3':'Private invest-own'
    }
    for x in conversions.keys():
        if h_contrl == x:
            return conversions[x]
        
def code_arthritis(row):
    dx_cols = ['dx1', 'dx2', 'dx3', 'dx4', 'dx5', 'dx6', 'dx7', 'dx8', 'dx9', 'dx10', 'dx11', 'dx12', 'dx13', 'dx14', 'dx15', 'dx16', 'dx17', 'dx18', 'dx19', 'dx20', 'dx21', 'dx22', 'dx23', 'dx24', 'dx25', 'dx26', 'dx27', 'dx28', 'dx29', 'dx30']
    # 71[456]\d\d = arthritis
    # 711[1-8]\d = arthropathy
    # 712\d\d = crystal arthropathy
    # 6960 = psoriatic arthropathy
    # 73340 = Aseptic necrosis, NOS
    # 73341 = Aseptic necrosis of head of humerus
    # 73349 = Aseptic necrosis of bone, other
    # technically not arthritis, but surgical operation should be similar
    # 73381 Malunion of fracture
    # 73382 Nonunion of fracture
    # 9052 Late effect of fracture of upper extremities
    # 71880	Other joint derangement, not elsewhere classified, site unspecified
    # 71881	Other joint derangement, not elsewhere classified, shoulder region
    # 72610 Disorders of bursae and tendons in shoulder region, unspecified
    # 72761 Complete rupture of rotator cuff
    # 8407	Superior glenoid labrum lesion
    # 71831 = Recurrent dislocation of joint, shoulder region
    
    # I AM PURPOSEFULLY NOT INCLUDING PRIMARY / SECONDARY NEOPLASMS OF BONE
    regexes = [r'^71[456]\d+',r'^71[45]\d+',r'^711[12345678]\d+',r'^712\d+',
               r'^6960',r'7334[019]', r'^72761', r'^7338[12]', r'^9052$',
               r'^7188[01]', r'^72610', r'^72761', r'^8407$', r'^71831$']
    for col in dx_cols:
        for regex in regexes:
            if re.search(regex, str(row[col])):
                return 1
    return 0

def code_hardware_failure(row):
    dx_cols = ['dx1', 'dx2', 'dx3', 'dx4', 'dx5', 'dx6', 'dx7', 'dx8', 'dx9', 'dx10', 'dx11', 'dx12', 'dx13', 'dx14', 'dx15', 'dx16', 'dx17', 'dx18', 'dx19', 'dx20', 'dx21', 'dx22', 'dx23', 'dx24', 'dx25', 'dx26', 'dx27', 'dx28', 'dx29', 'dx30']
    # 99640 Unspecified mechanical complication of internal orthopedic device, implant, and graft 
    # 99641 Mechanical loosening of prosthetic joint 
    # 99642 Dislocation of prosthetic joint 
    # 99643 Broken prosthetic joint implant 
    # 99644 Peri-prosthetic fracture around prosthetic joint 
    # 99645 Peri-prosthetic osteolysis 
    # 99646 Articular bearing surface wear of prosthetic joint 
    # 99647 Other mechanical complication of prosthetic joint implant 
    # 99649 Other mechanical complication of other internal orthopedic device, implant, and graft 
    # 99666 Infection and inflammatory reaction due to internal joint prosthesis
    # 99667 Infection and inflammatory reaction due to other internal orthopedic device, implant, and graft
    # 99677	Other complications due to internal joint prosthesis
    # 99678	Other complications due to other internal orthopedic device, implant, and graft
    regexes = [r'^9964\d?', r'9966[67]', r'^9967[78]$']
    for col in dx_cols:
        for regex in regexes:
            if re.search(regex, str(row[col])):
                return 1
    return 0

def code_trauma(row):
    # it is easier to code external causes of injury by CCS classifications instead of ICD9 E codes
    e_cols = ['e_ccs1','e_ccs2','e_ccs3','e_ccs4']
    # 2603 = fall
    # 2606 = machinery
    # 2607 = motor vehicle traffic injury
    # 2608 = pedal cyclist not MVT
    # 2609 = pedestrian, not MVT
    # 2610 = transport, not MVT
    # 2614 = struck by - against
    # 2619 = other specified
    # 2620 = Unspecified
    trauma_codes = ['2603', '2606', '2607', 
                    '2608', '2609', '2610',
                    '2614', '2619','2620']

    for col in e_cols:
        for trauma_cause in trauma_codes:
            if str(row[col]) == trauma_cause:
                return 1
    # some require ICD 9 codes
    dx_cols = ['dx1', 'dx2', 'dx3', 'dx4', 'dx5', 'dx6', 'dx7', 'dx8', 'dx9', 'dx10', 'dx11', 'dx12', 'dx13', 'dx14', 'dx15', 'dx16', 'dx17', 'dx18', 'dx19', 'dx20', 'dx21', 'dx22', 'dx23', 'dx24', 'dx25', 'dx26', 'dx27', 'dx28', 'dx29', 'dx30']

    # 81[2][0123]\d = superior humeral fractures 
    # 831[01]\d = shoulder dislocations  
    # 73311 = pathologic fracture of humerus
    # 73310 = pathologic fracture unspecified
    # 73319 = Pathologic fracture of other specified site
    fracture_codes = [r'^81[2][0123]\d', r'^831[01]\d',r'^7331[019]']
    for col in dx_cols:
        for regex in fracture_codes:
            if re.search(regex, str(row[col])):
                return 1
            
    return 0

def include(row):
    if int(row['arthritis']) == 0 and int(row['hardware_failure']) == 0 and int(row['trauma']) == 0:
        return 0
    return 1
    

def split_train_val_test(df, train_frac, val_frac, test_frac, random_seed = 42):
    np.random.seed(random_seed)
    pass_sum_test = round(train_frac + val_frac + test_frac, 5) == 1.0
    if not pass_sum_test:
        raise ValueError('Train val test must sum to 1.0')
    pass_sign_test = (train_frac >= 0.0) and (val_frac >= 0.0) and (test_frac >= 0.0)
    if not pass_sign_test:
        raise ValueError('Train_frac, val_frac, test_frac must be >= 0 ')
    converted_test_fraction = 1.0 - test_frac
    train, validate, test = np.split(df.sample(frac=1), [int(train_frac*len(df)), int(converted_test_fraction*len(df))])
    return train, validate, test

    

In [None]:
# process dataset
# drop columns
data = data[data.died != "A"] #77 droped
"""
drop based on dispuniform
A = Invalid
99 = Unknown
"""
data = data[data.dispuniform != 'A']
data = data[data.dispuniform != '99']
data = data[data.los != 'A']
data = data[data.los != 'C']
data['los'] = data['los'].apply(pd.to_numeric)
data = data[data.female != "C"]
data = data[data.age != "C"]
data['age'] = data['age'].apply(pd.to_numeric)
data = data[data.year >= 2003] # e codes not implemented until 2003
data = data[data.age >= 18]
data = data[data.neomat == 0]
data = data[data.pay1 != 'A']
data = data[data.hospbrth == 0]
data = data[pd.notnull(data.amonth)]
data['amonth'] = data['amonth'].apply(pd.to_numeric)
data = data[pd.notnull(data.hosp_bedsize)]
data = data[pd.notnull(data.hosp_locteach)]
data = data[pd.notnull(data.h_contrl)]
data = data[data.elective != 'A']
data = data[data.zipincgrp != 'A']

# assign new columns
dx_cols = ['dx1', 'dx2', 'dx3', 'dx4', 'dx5', 'dx6', 'dx7', 'dx8', 'dx9', 'dx10', 'dx11', 'dx12', 'dx13', 'dx14', 'dx15', 'dx16', 'dx17', 'dx18', 'dx19', 'dx20', 'dx21', 'dx22', 'dx23', 'dx24', 'dx25', 'dx26', 'dx27', 'dx28', 'dx29', 'dx30']
data['combined_dx'] = data[dx_cols].apply(lambda row: ' '.join(row.values.astype(str)).replace('nan','').strip(), axis=1)
ecode_cols = ['ecode1','ecode2','ecode3','ecode4']
data['combined_e_codes'] = data[ecode_cols].apply(lambda row: ' '.join(row.values.astype(str)).replace('nan','').strip(), axis=1)
pr_cols = ['pr1', 'pr2', 'pr3', 'pr4', 'pr5', 'pr6', 'pr7', 'pr8', 'pr9', 'pr10', 'pr11', 'pr12', 'pr13', 'pr14', 'pr15']
data['planned_pr'] = data[pr_cols].apply(lambda row: ' '.join(row.values.astype(str)).replace('nan','').strip(), axis=1)

# code cases
data['arthritis'] = data.apply(code_arthritis, axis = 1)
data['hardware_failure'] = data.apply(code_hardware_failure, axis = 1) 
data['trauma'] = data.apply(code_trauma, axis = 1)
data['include'] = data.apply(include, axis = 1)


data['admit_month_str'] = data.apply(lambda row: datetime.date(1900, int(row['amonth']), 1).strftime('%B'), axis=1)
data["sex"] = ['Male' if x == '0' else "Female" for x in data['female']]
data['payer'] = data.apply(convert_payer, axis=1)
data['race_converted'] = data.apply(convert_race, axis = 1)
data['h_contrl_converted'] = data['h_contrl'].apply(convert_ownership)
data["tran_in_condensed"] = ['0' if x == '0' else "1" for x in data['tran_in']]

# remove patients with LOS > 99%ile (NOT removing LOS < 1% because LOS is left censored)
top_los = data.los.quantile(0.99)
# remove patients with total charges < 1% or > 99% 
bottom_charges = data.totchg_2014_normalized.quantile(0.01)
top_charges = data.totchg_2014_normalized.quantile(0.99)

# remove patients by los and charges
data = data[data.los < top_los]
data = data[data.totchg_2014_normalized > bottom_charges]
data = data[data.totchg_2014_normalized < top_charges]




# convert output columns
data['disp_routine'] = [1 if x == '1' else 0 for x in data['dispuniform']]
data['los_2'] = [1 if x > 2 else 0 for x in data['los']]
data['totcost2014_3'] = pd.qcut(data.totcost2014, 3, labels = {'Low', 'Medium', 'High'})
data['totchg_2014_normalized_3'] = pd.qcut(data.totchg_2014_normalized, 3, labels = {'Low', 'Medium', 'High'})

# make z score columns
data['totcost2014_zscore'] = (data['totcost2014'] - data['totcost2014'].mean())/(data['totcost2014'].std(ddof=0))
data['los_zscore'] = (data['los'] - data['los'].mean())/(data['los'].std(ddof=0))
data['totchg_2014_normalized_zscore'] = (data['totchg_2014_normalized'] - data['totchg_2014_normalized'].mean())/(data['totchg_2014_normalized'].std(ddof=0))

data['totcost2014_zscore_bins'] = pd.cut(data['totcost2014_zscore'], 
                                        [np.NINF, -1, 1, np.Inf], labels = ['low', 'average','high'])
data['los_zscore_bins'] = pd.cut(data['los_zscore'], 
                                         [np.NINF, -1, 1, np.Inf], labels = ['low', 'average','high'])
data['totchg_2014_normalized_zscore_bins'] = pd.cut(data['totchg_2014_normalized_zscore'], 
                                        [np.NINF, -1, 1, np.Inf], labels = ['low', 'average','high'])
                                                                      
data = data[data.include == 1 ]
# save for inspection
data.to_csv('all_processed_NIS_TSA.csv', index = False)
print('Done processing data')

In [None]:
data = data[data['totcost2014'].notnull()]

non_traumatic = data[(data.trauma == 0) & (data.hardware_failure == 0)]
print(non_traumatic.shape)
non_traumatic.to_csv('non_traumatic_TSA.csv', index = False)

non_t_train, non_t_val, non_t_test = split_train_val_test(non_traumatic, 0.7, 0.1, 0.2)
non_t_train.to_csv('non_traumatic_train.csv', index = False)
non_t_val.to_csv('non_traumatic_validate.csv', index = False)
non_t_test.to_csv('non_traumatic_test.csv', index = False)

revision = data[data.hardware_failure == 1]
print(revision.shape)
revision.to_csv('hardware_failure_TSA.csv', index = False)

r_train, r_val, r_test = split_train_val_test(revision, 0.7, 0.1, 0.2)
r_train.to_csv('revision_train.csv', index = False)
r_val.to_csv('revision_validate.csv', index = False)
r_test.to_csv('revision_test.csv', index = False)

traumatic = data[(data.trauma == 1) & (data.hardware_failure == 0)]
print(traumatic.shape)
traumatic.to_csv('traumatic_TSA.csv', index = False)

t_train, t_val, t_test = split_train_val_test(traumatic, 0.7, 0.1, 0.2)
t_train.to_csv('traumatic_train.csv', index = False)
t_val.to_csv('traumatic_validate.csv', index = False)
t_test.to_csv('traumatic_test.csv', index = False)
