Skip to content

Commit

Permalink
Merge pull request #476 from SameerMahajan-GSLab/Sameer_Mahajan_473
Browse files Browse the repository at this point in the history
Sameer Mahajan 473
  • Loading branch information
levithatcher committed Jul 10, 2018
2 parents 2ab52b5 + 6bdbaa8 commit 25cb3ae
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 4 deletions.
8 changes: 7 additions & 1 deletion healthcareai/common/top_factors.py
@@ -1,10 +1,12 @@
import numpy as np
import pandas as pd
from distutils.version import StrictVersion

from sklearn.linear_model import LogisticRegression, LinearRegression

from healthcareai.common.healthcareai_error import HealthcareAIError

CHECK_PANDAS_VERSION = StrictVersion("0.23.1")

def descending_sort(row):
# TODO Low priority, consider testing
Expand Down Expand Up @@ -44,7 +46,11 @@ def top_k_features(dataframe, linear_model, k=3):

# Multiply the values with the coefficients from the trained model and take the magnitude
step1 = pd.DataFrame(np.abs(dataframe.values * linear_model.coef_), columns=dataframe.columns)
step2 = step1.apply(descending_sort, axis=1)

if StrictVersion(pd.__version__) >= CHECK_PANDAS_VERSION:
step2 = step1.apply(descending_sort, axis=1, result_type='expand')
else:
step2 = step1.apply(descending_sort, axis=1)

results = list(step2.values[:, :k])
return results
Expand Down
5 changes: 2 additions & 3 deletions healthcareai/common/transformers.py
Expand Up @@ -10,7 +10,6 @@
from imblearn.under_sampling import RandomUnderSampler
from sklearn.preprocessing import StandardScaler


class DataFrameImputer(TransformerMixin):
"""
Impute missing values in a dataframe.
Expand All @@ -36,7 +35,7 @@ def fit(self, X, y=None):

self.fill = pd.Series([X[c].value_counts().index[0]
if X[c].dtype == np.dtype('O')
or pd.core.common.is_categorical_dtype(X[c])
or pd.api.types.is_categorical_dtype(X[c])
else X[c].mean() for c in X], index=X.columns)

if self.verbose:
Expand Down Expand Up @@ -105,7 +104,7 @@ def fit(self, X, y=None):
return self

def transform(self, X, y=None):
columns_to_dummify = X.select_dtypes(include=[object, 'category'])
columns_to_dummify = list(X.select_dtypes(include=[object, 'category']))

# remove excluded columns (if they are still in the list)
for column in columns_to_dummify:
Expand Down

0 comments on commit 25cb3ae

Please sign in to comment.