In [2]:
%load_ext autoreload
%autoreload 2
import os
import pickle as pkl
from functools import partial
from os.path import join as oj

import numpy as np
import pandas as pd

pd.set_option('display.max_rows', 500)
pd.set_option('display.max_columns', 50)
import sklearn as sk

import imodels
from imodels.util import data_util
from imodels.discretization import discretizer, simple

import matplotlib as mpl
import matplotlib.pyplot as plt

mpl.rcParams['figure.dpi'] = 250

# change working directory to project root
while os.getcwd().split('/')[-1] != 'csi_pecarn':
    os.chdir('..')

from models.stable import StableLinearClassifier

np.random.seed(0)

In [3]:
X, y, feature_names = data_util.get_clean_dataset('csi_all.csv', data_source='imodels')
X_train, X_test, y_train, y_test = sk.model_selection.train_test_split(
        X, y, test_size=744, random_state=0, shuffle=False)

In [4]:
submodel_dfs = [
    pkl.load(open(oj('results', f'{submodel}_comparisons.pkl'), 'rb'))['df']
    for submodel in ['rulefit', 'skope_rules', 'brs']]

In [5]:
stbl = StableLinearClassifier(
    max_rules=13, 
    max_complexity=40, 
    min_mult=2, 
    penalty='l2', 
    metric='best_spec_0.96_sens', 
    cv=False, random_state=0)
stbl.set_rules(submodel_dfs, '_train')
stbl.fit(X_train, y_train, feature_names=feature_names)

StableLinearClassifier(cv=False, max_complexity=40, max_rules=13,
                       metric='best_spec_0.96_sens', penalty='l2',
                       random_state=0)

In [6]:
print(sk.metrics.roc_auc_score(y_test, stbl.predict_proba(X_test)[:, 1]))
print(sk.metrics.average_precision_score(y_test, stbl.predict_proba(X_test)[:, 1]))

0.8080224183678995
0.4955251878326584


In [7]:
stbl.visualize()

Unnamed: 0,rule,coef
0,HighriskDiving <= 0.5,-2.25
5,Torticollis2 <= 0.5,-0.36
6,HighriskMVC <= 0.5,-0.8
7,MedsRecd2 <= 0.5,-0.51
8,Position_L <= 0.5,0.1
10,FocalNeuroFindings2 <= 0.5 and MedsRecd2 <= 0.5,0.59
1,AlteredMentalStatus2 <= 0.5 and FocalNeuroFindings2 <= 0.5 and Torticollis2 <= 0.5,-1.04
9,PainNeck2 > 0.5,0.92
4,is_ems <= 0.5,0.53
2,AlteredMentalStatus2 > 0.5,0.4
