In [4]:
# load both dataframes
import pandas as pd
from sklearn.linear_model import LogisticRegression
from sklearn.ensemble import RandomForestClassifier
from sklearn.tree import DecisionTreeClassifier
import numpy as np

original_df = pd.read_csv("./spreadsheets/reflacx_clinical.csv")
original_clinical_cols = [
    "temperature",
    "heartrate",
    "resprate",
    "o2sat",
    "sbp",
    "dbp",
]
chexpert_label_cols = [col for col in original_df if col.endswith("chexpert")]
auged_features = []
extended_clinical_features = original_clinical_cols + []

In [12]:
from torchmetrics.classification import MultilabelAccuracy
from torchmetrics.classification import MultilabelAUROC
import torch


def get_acc_auc(df, input_cols, label_cols, cls):
    cls = cls.fit(
        df[df["split"] == "train"][input_cols],
        df[df["split"] == "train"][label_cols],
    )

    pred = cls.predict(df[df["split"] == "test"][original_clinical_cols])

    mla = MultilabelAccuracy(num_labels=len(chexpert_label_cols), average="micro")
    acc = mla(
        torch.tensor(pred),
        (
            torch.tensor(np.array(df[df["split"] == "test"][chexpert_label_cols] > 0))
        ).float(),
    )

    ml_auroc = MultilabelAUROC(num_labels=len(chexpert_label_cols), average="micro")
    auc = ml_auroc(
        torch.tensor(pred),
        (
            torch.tensor(np.array(df[df["split"] == "test"][chexpert_label_cols] > 0))
        ).long(),
    )

    feature_importance_dict = {
        c: i for c, i in zip(original_clinical_cols, cls.feature_importances_)
    }
    return feature_importance_dict, acc, auc

In [13]:
get_acc_auc(original_df, original_clinical_cols, chexpert_label_cols, RandomForestClassifier())

({'temperature': 0.18047536313756044,
  'heartrate': 0.19388825026802414,
  'resprate': 0.09835030061680662,
  'o2sat': 0.10989466313676344,
  'sbp': 0.21898688381459713,
  'dbp': 0.19840453902624824},
 tensor(0.9167),
 tensor(0.6337))

In [14]:
get_acc_auc(original_df, original_clinical_cols, chexpert_label_cols, DecisionTreeClassifier())

({'temperature': 0.2033609512252407,
  'heartrate': 0.19171308294411565,
  'resprate': 0.06108838368969653,
  'o2sat': 0.12356266913056456,
  'sbp': 0.24519643351944334,
  'dbp': 0.17507847949093927},
 tensor(0.8774),
 tensor(0.6653))