In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import warnings
warnings.filterwarnings('ignore')
from pathlib import Path
import torch
import torch.nn as nn
import numpy as np
import pandas as pd
import seaborn as sns
sns.set()
from model import MultiModalTransformer, MultiModalConv
from main import test_epoch
from data import dataloaders

device = torch.device('cuda:1')
exp_src = Path('./runs')
task = 'phenotyping'

In [3]:
def eval_fn(model_name, ts, modalities, data_root='./data'):
    exp_path = exp_src / task / f'{model_name}_{"_".join(modalities)}' / ts
    with_text = 'text' in modalities
    with_diagnoses = 'diagnoses' in modalities
    modalities = [m for m in modalities if m not in ['text', 'diagnoses']]
    _model_name = {
        # 'swin': 'microsoft/swin-base-patch4-window7-224-in22k',
        'swin': 'microsoft/swin-large-patch4-window12-384-in22k',
        'vit': 'google/vit-base-patch16-224',
        'conv': 'google/vit-base-patch16-224' # only for the preprocessor
    }[model_name]
    model_class = MultiModalTransformer if model_name != 'conv' else MultiModalConv
    model = model_class(
        img_model_name=_model_name,
        img_modalities=modalities,
        with_text=with_text,
        task=task
    )#.to(device)
    model = nn.DataParallel(model)
    ckpt = torch.load(exp_path / 'checkpoint.pt')#, map_location=device)
    model.load_state_dict(ckpt)
    model = model.module.to(device)
    _,_,testloader = dataloaders(
        img_model_name=_model_name, 
        modalities=modalities, 
        with_diagnoses=with_diagnoses,
        root=data_root,
        image_size=384,
        task=task
    )
    pos_weight = torch.tensor(4168 / 717).to(device) # n_neg / n_pos
    criterion = nn.BCEWithLogitsLoss(pos_weight=pos_weight)
    y_true, y_prob = test_epoch(model, criterion, testloader, modalities, device)
    return y_true, y_prob

In [4]:
ys_true, ys_prob = {}, {}

In [5]:
experiments = [
    # ['swin', ['lab','text']],
    # ['swin', ['med', 'text']],
    # ['swin', ['cxr', 'text']],
    # ['swin', ['ecg', 'text']],
    # ['swin', ['lab', 'cxr', 'text']],
    ['swin', ['lab','med','cxr','ecg','text']]#,'diagnoses']]
]

In [6]:
for model_name, modalities in experiments:
    exp_name = f'{model_name}_{"_".join(modalities)}'
    if exp_name in ys_prob:
        continue
    ys_true[exp_name] = []
    ys_prob[exp_name] = []
    timestamps = [x.name for x in (exp_src /task / exp_name).iterdir()]
    for i, ts in enumerate(timestamps):
        y_true, y_prob = eval_fn(
            model_name=model_name, 
            ts=ts, 
            modalities=modalities,
            data_root='/mnt/hdd/data/MMMedViT_data/data'
        )
        ys_true[exp_name].append(y_true)
        ys_prob[exp_name].append(y_prob)

Some weights of RobertaModel were not initialized from the model checkpoint at roberta-base and are newly initialized: ['roberta.pooler.dense.bias', 'roberta.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
                                               

In [7]:
_,_,testloader = dataloaders(
    img_model_name='google/vit-base-patch16-224', 
    modalities=modalities,
    root='/mnt/hdd/data/MMMedViT_data/data',
    task=task
)
stay_ids = testloader.dataset.data.stay_id.tolist()
y_true = testloader.dataset.data.y_true.tolist()

In [8]:
dfs = []
for model_name, modalities in experiments:
    exp_name = f'{model_name}_{"_".join(modalities)}'
    modalities = [m for m in modalities if m not in ['text','diagnoses']]
    _,_,testloader = dataloaders(
        img_model_name='google/vit-base-patch16-224', 
        modalities=modalities,
        root='/mnt/hdd/data/MMMedViT_data/data',
        task=task
    )
    
    # y_true = testloader.dataset.data.y_true.tolist()
    # print(modalities, len(stay_ids))
    y_true = testloader.dataset.data.iloc[:,31:56].to_numpy()
    class_names = np.array(testloader.dataset.data.columns[31:56])[None]
    class_names = np.repeat(class_names,len(y_true),axis=0)
    hadm_ids = np.array(testloader.dataset.data.stay_id.tolist()).reshape(-1,1)
    hadm_ids = np.repeat(hadm_ids,class_names.shape[1],axis=1)

    ys_prob_exp = ys_prob[exp_name][0]
    dfi = pd.DataFrame.from_dict({
        'hadm_id': hadm_ids.flatten(),
        'class_name': class_names.flatten(),
        'GT': y_true.flatten(),
        exp_name: ys_prob_exp.flatten()
    })
    dfs.append(dfi)

In [11]:
df1 = pd.read_csv('results/results_ViTiMM_phenotyping1.csv', index_col=0)
df1

Unnamed: 0,hadm_id,class_name,swin_lab_text,swin_med_text,swin_cxr_text,swin_ecg_text,swin_lab_cxr_text,swin_lab_med_cxr_ecg_text_diagnoses
0,30008792,Acute and unspecified renal failure,0.776039,0.761100,0.792653,0.730146,0.819554,0.213115
1,30008792,Acute cerebrovascular disease,0.043447,0.051549,0.061030,0.055669,0.086462,0.029455
2,30008792,Acute myocardial infarction,0.140637,0.219879,0.182432,0.252347,0.079761,0.298785
3,30008792,Cardiac dysrhythmias,0.562242,0.654142,0.698392,0.664893,0.624661,0.955679
4,30008792,Chronic kidney disease,0.388478,0.567263,0.552175,0.567266,0.292466,0.158354
...,...,...,...,...,...,...,...,...
34320,39983674,Pleurisy; pneumothorax; pulmonary collapse,0.105716,0.127103,0.098830,0.098628,0.121596,0.082246
34321,39983674,Pneumonia (except that caused by tuberculosis ...,0.333036,0.325367,0.219076,0.281017,0.136459,0.071117
34322,39983674,Respiratory failure; insufficiency; arrest (ad...,0.217293,0.337005,0.170183,0.312666,0.148843,0.067845
34323,39983674,Septicemia (except in labor),0.263896,0.325518,0.177409,0.176516,0.241277,0.824257


In [13]:
df1.merge(dfs[0], on=['hadm_id','class_name']).to_csv('results/results_ViTiMM_phenotyping.csv')

In [None]:
# df1 = pd.read_csv('results_ViTiMM_phenotyping.csv', index_col=0)
# df1.drop(columns=['swin_lab_med_cxr_ecg_text_diagnoses'], inplace=True)
# df = df1.merge(df, on=['hadm_id','class_name','GT'],how='outer')

In [53]:
# df.to_csv('results_ViTiMM_phen*otyping.csv')

In [48]:
df = dfs[0]
for df2 in dfs[1:]:
    df = df.merge(df2, on=['hadm_id','class_name','GT'],how='outer')
df.to_csv('results_ViTiMM_phen*otyping.csv')
df

Unnamed: 0,hadm_id,class_name,GT,swin_lab_med_cxr_ecg_text
0,30008792,Acute and unspecified renal failure,0,0.505536
1,30008792,Acute cerebrovascular disease,0,0.067849
2,30008792,Acute myocardial infarction,1,0.039910
3,30008792,Cardiac dysrhythmias,1,0.747679
4,30008792,Chronic kidney disease,0,0.283198
...,...,...,...,...
34320,39983674,Pleurisy; pneumothorax; pulmonary collapse,0,0.134708
34321,39983674,Pneumonia (except that caused by tuberculosis ...,0,0.145808
34322,39983674,Respiratory failure; insufficiency; arrest (ad...,0,0.119305
34323,39983674,Septicemia (except in labor),1,0.059272


In [55]:
from utils import metrics

results = []
for exp_name in df.columns[3:]:
    # roc_aucs, auprcs, balanced_accuracys = [], [], []
    for class_name in df.class_name.unique():
        dfc = df[['GT',exp_name]].dropna(subset=[exp_name])
        roc_auc, auprc, balanced_accuracy = metrics(
            torch.tensor(dfc[df.class_name==class_name]['GT'].tolist()), 
            torch.tensor(dfc[df.class_name==class_name][exp_name].tolist())
        )
        # roc_aucs.append(roc_auc)
        # auprcs.append(auprc)
        # balanced_accuracys.append(balanced_accuracy)
        results.append({
            'exp_name': exp_name, 
            'class_name': class_name, 
            'roc_auc': roc_auc, 
            'auprc': auprc, 
            'balanced_accuracy': balanced_accuracy
        })
    # print(exp_name, np.mean(roc_aucs), np.mean(auprcs), np.mean(balanced_accuracys))
results = pd.DataFrame(results)
results

Unnamed: 0,exp_name,class_name,roc_auc,auprc,balanced_accuracy
0,swin_lab_text,Acute and unspecified renal failure,0.869727,0.761811,0.791506
1,swin_lab_text,Acute cerebrovascular disease,0.870143,0.487868,0.575322
2,swin_lab_text,Acute myocardial infarction,0.891282,0.578035,0.620309
3,swin_lab_text,Cardiac dysrhythmias,0.702526,0.606044,0.627663
4,swin_lab_text,Chronic kidney disease,0.903373,0.762826,0.786765
...,...,...,...,...,...
145,swin_lab_med_cxr_ecg_text,Pleurisy; pneumothorax; pulmonary collapse,0.671315,0.186470,0.501976
146,swin_lab_med_cxr_ecg_text,Pneumonia (except that caused by tuberculosis ...,0.738201,0.472036,0.563127
147,swin_lab_med_cxr_ecg_text,Respiratory failure; insufficiency; arrest (ad...,0.809758,0.695455,0.683749
148,swin_lab_med_cxr_ecg_text,Septicemia (except in labor),0.816660,0.641930,0.690124


In [59]:
results.groupby('exp_name')[['roc_auc', 'auprc', 'balanced_accuracy']].mean().round(3)

Unnamed: 0_level_0,roc_auc,auprc,balanced_accuracy
exp_name,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
swin_cxr_text,0.729,0.461,0.59
swin_ecg_text,0.685,0.431,0.573
swin_lab_cxr_text,0.776,0.524,0.636
swin_lab_med_cxr_ecg_text,0.782,0.546,0.646
swin_lab_text,0.765,0.505,0.618
swin_med_text,0.713,0.432,0.577
