In [1]:
import torch
import os
from dataloader_helper import *
from models import *
import numpy as np
import random
import torch.optim as optim
from sklearn.metrics import roc_auc_score
import joblib
import time


torch.backends.cudnn.deterministic = True
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"
os.environ["CUDA_VISIBLE_DEVICES"]="6"
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

start = '/data2/meerak/models/'

X_test, y_test = np.array(joblib.load('/data2/meerak/MIMIC_CXR_FTS/UPDATED_test_fts.joblib')), np.array(joblib.load('/data2/meerak/MIMIC_CXR_FTS/UPDATED_test_ys.joblib'))
test_dg = CXRDataset('test', {'features':torch.tensor(X_test), 'labels':torch.tensor(y_test)})
test_loader = DataLoader(test_dg,batch_size = 1,shuffle = False)

def get_epoch(name):
    title = 'best_%s.txt'%(name)
    with open(title) as f:
        count = 0
        vals = []
        tests = []
        for line in f:
            if count > 1:
                vals.append(float(line.split(', ')[4]))
                tests.append(float(line.split(', ')[5]))
            if count == 1:
                curr_params = line.split('\n')[0]
            count += 1
    return np.argmax(vals)


  from .autonotebook import tqdm as notebook_tqdm


### ABDMIL_NOPE

In [12]:
cD = 2048
model = ABDMIL(cD, PE = False).to(device)
epoch = get_epoch('abdmil_nope')
modelname = start + 'best_abdmil_nope_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))

bag_preds = []
bag_ys = []
start_epoch = time.time()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction = model(data)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'abdmil_nope.joblib')

Test AUROC: 0.7817031184821799
11.872807741165161
0.782 (0.775, 0.789)


['abdmil_nope.joblib']

### ABDMIL_PE

In [3]:
cD = 4096
model = ABDMIL(cD, PE = True).to(device)
epoch = get_epoch('abdmil_pe')
modelname = start + 'best_abdmil_pe_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))

bag_preds = []
bag_ys = []
start_epoch = time.time()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction = model(data)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'abdmil_pe.joblib')

Test AUROC: 0.798523189432873
13.17552924156189
0.798 (0.792, 0.806)


['abdmil_pe.joblib']

### CLAM NOPE

In [4]:
cK = 6
cDrop = 0 

model = ClamWrapper(cK, cDrop, True, PE = False).to(device)
epoch = get_epoch('clamSB_nope')
modelname = start + 'best_clamSB_nope_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))
model.eval()

bag_preds = []
bag_ys = []
start_epoch = time.time()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction, inst_dict = model(data, target.to(torch.int64), instance_eval = False)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'clamsb_nope.joblib')

Test AUROC: 0.7808210386341587
13.586225986480713
0.780 (0.774, 0.788)


['clamsb_nope.joblib']

### CLAM PE

In [5]:
cK = 6
cDrop = 0 

model = ClamWrapper(cK, cDrop, True, PE = True).to(device)
epoch = get_epoch('clamSB_pe')
modelname = start + 'best_clamSB_pe_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))
model.eval()

bag_preds = []
bag_ys = []
start_epoch = time.time()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction, inst_dict = model(data, target.to(torch.int64), instance_eval = False)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'clamsb_pe.joblib')

Test AUROC: 0.7988378301214891
15.412106275558472
0.799 (0.792, 0.806)


['clamsb_pe.joblib']

## CLAM MB NOPE

In [6]:
cK = 8
cDrop = 0 

model = ClamWrapper(cK, cDrop, True, PE = False).to(device)
epoch = get_epoch('clamMB_nope')
modelname = start + 'best_clamMB_nope_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))
model.eval()

bag_preds = []
bag_ys = []
start_epoch = time.time()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction, inst_dict = model(data, target.to(torch.int64), instance_eval = False)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'clammb_nope.joblib')

Test AUROC: 0.7782535313501515
12.4870023727417
0.778 (0.771, 0.786)


['clammb_nope.joblib']

## CLAM MB PE

In [7]:
cK = 4
cDrop = 0 

model = ClamWrapper(cK, cDrop, True, PE = True).to(device)
epoch = get_epoch('clamMB_pe')
modelname = start + 'best_clamMB_pe_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))
model.eval()

bag_preds = []
bag_ys = []
start_epoch = time.time()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction, inst_dict = model(data, target.to(torch.int64), instance_eval = False)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'clammb_pe.joblib')

Test AUROC: 0.7978062184902243
14.317506790161133
0.798 (0.791, 0.805)


['clammb_pe.joblib']

## DTFD NOPE

In [8]:
cD = 128
cNPB = 4
model = DTFD(cD, cNPB, PE = False).to(device)
epoch = get_epoch('dtfd_nope')
modelname = start + 'best_dtfd_nope_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))

bag_preds = []
bag_ys = []
start_epoch = time.time()
criterion = nn.CrossEntropyLoss()

for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction, _ = model(data, target, criterion)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'dtfd_nope.joblib')

Test AUROC: 0.7772931977830448
48.343833684921265
0.777 (0.770, 0.785)


['dtfd_nope.joblib']

## DTFD PE

In [9]:
cD = 1024
cNPB = 4
model = DTFD(cD, cNPB, PE = True).to(device)
epoch = get_epoch('dtfd_pe')
modelname = start + 'best_dtfd_pe_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))

bag_preds = []
bag_ys = []
start_epoch = time.time()
criterion = nn.CrossEntropyLoss()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction, _ = model(data, target, criterion, loss = False)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'dtfd_pe.joblib')

Test AUROC: 0.7918336038439552
46.30561542510986
0.792 (0.785, 0.799)


['dtfd_pe.joblib']

## SGL NO PE

In [2]:
from sgl_functions import*
cD = 1024
model = SGL_Model(cD, PE = False).to(device)
epoch = get_epoch('sgl_nope')
modelname = start + 'best_sgl_nope_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))

bag_preds = []
bag_ys = []
start_epoch = time.time()
criterion = nn.CrossEntropyLoss()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)
    inst_pred = model(data)
    bag_prediction = lin_soft(inst_pred.sigmoid())
    
    bag_preds.extend(bag_prediction.detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'sgl_nope.joblib')

Test AUROC: 0.7549947161714602
31.815017700195312
0.755 (0.747, 0.764)


['sgl_nope.joblib']

## SGL PE

In [3]:
from sgl_functions import*
cD = 1024
model = SGL_Model(cD, PE = True).to(device)
epoch = get_epoch('sgl_pe')
modelname = start + 'best_sgl_pe_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))

bag_preds = []
bag_ys = []
start_epoch = time.time()
criterion = nn.CrossEntropyLoss()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)
    inst_pred = model(data)
    bag_prediction = lin_soft(inst_pred.sigmoid())
    
    bag_preds.extend(bag_prediction.detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'sgl_pe.joblib')

Test AUROC: 0.7658301046542109
32.290428161621094
0.766 (0.758, 0.773)


['sgl_pe.joblib']

### Transformer

In [8]:
cD = 128
for dict_attn in ['', '_orig']:
    for dict_agg in ['']:
        for dict_addPE in ['', '_pe']:

            if dict_agg == '_max':
                agg = 'max'
            elif dict_agg == '_avg':
                agg = 'avg'
            elif dict_agg == '':
                agg = 'cls_token'


            if dict_attn == '_orig':
                attn = 'Orig'
            elif dict_attn == '':
                attn = 'Nystrom'

            if dict_addPE == '':
                pe = False
            elif dict_addPE == '_pe':
                pe = True

            model = Transformer(cD, agg, attn, PE = pe).to(device)
            epoch = get_epoch('transformer%s%s%s'%(dict_attn, dict_agg, dict_addPE))
            modelname = start + 'best_transformer%s%s%s_mimiccxr_densenet_epoch%d'%(dict_attn, dict_agg, dict_addPE, epoch)
            model.load_state_dict(torch.load(modelname))

            bag_preds = []
            bag_ys = []
            start_epoch = time.time()
            for batch_idx, (curridx, data, target) in enumerate(test_loader):
                data = data.to(device).squeeze(0)
                target = target.to(device)

                bag_prediction = model(data)

                bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
                bag_ys.extend(target.detach().cpu().numpy())
                del data
                del bag_prediction
            end_epoch = time.time()

            bag_preds = np.array(bag_preds)
            bag_ys = np.array(bag_ys)  
            print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
            print(end_epoch - start_epoch)

            curr_rocs = []
            for _ in range(1000):
                curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
                curr_bp = bag_preds[curr_idxs]
                curr_y = bag_ys[curr_idxs]
                curr_rocs.append(roc_auc_score(curr_y, curr_bp))

            curr_rocs.sort()
            print(agg, attn, pe)
            print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))

                
            joblib.dump(curr_rocs, 'transformer%s%s%s.joblib'%(dict_attn, dict_agg, dict_addPE))
            
            joblib.dump(bag_ys, 'bagy_transformer%s%s%s.joblib'%(dict_attn, dict_agg, dict_addPE))
            
            joblib.dump(bag_preds, 'bagpred_transformer%s%s%s.joblib'%(dict_attn, dict_agg, dict_addPE))

Test AUROC: 0.7987075135987847
93.23540759086609
cls_token Nystrom False
0.799 (0.792, 0.806)
Test AUROC: 0.8027213618873578
93.65480637550354
cls_token Nystrom True
0.803 (0.796, 0.809)
Test AUROC: 0.7846019601725319
24.940438985824585
cls_token Orig False
0.785 (0.777, 0.792)
Test AUROC: 0.8048384946308439
26.085609912872314
cls_token Orig True
0.805 (0.798, 0.812)


### TRANSMIL

In [11]:
cD = 128
model = TransMIL(cD).to(device)
epoch = get_epoch('transmil_nope')
modelname = start + 'best_transmil_mimiccxr_densenet_epoch%d'%epoch
model.load_state_dict(torch.load(modelname))

bag_preds = []
bag_ys = []
start_epoch = time.time()
for batch_idx, (curridx, data, target) in enumerate(test_loader):
    data = data.to(device).squeeze(0)
    target = target.to(device)

    bag_prediction = model(data)
    
    bag_preds.extend(F.softmax(bag_prediction, dim = 1)[:, 1].detach().cpu().numpy())
    bag_ys.extend(target.detach().cpu().numpy())
    del data
    del bag_prediction
end_epoch = time.time()

bag_preds = np.array(bag_preds)
bag_ys = np.array(bag_ys)  
print('Test AUROC:', roc_auc_score(bag_ys, bag_preds))
print(end_epoch - start_epoch)

curr_rocs = []
for _ in range(1000):
    curr_idxs = np.random.choice(len(bag_preds), len(bag_preds))
    curr_bp = bag_preds[curr_idxs]
    curr_y = bag_ys[curr_idxs]
    curr_rocs.append(roc_auc_score(curr_y, curr_bp))

curr_rocs.sort()
print('%0.3f (%0.3f, %0.3f)'%(curr_rocs[500], curr_rocs[25], curr_rocs[975]))
joblib.dump(curr_rocs, 'transmil.joblib')

Test AUROC: 0.7987287473206001
122.60532546043396
0.799 (0.792, 0.806)


['transmil.joblib']