In [6]:
%matplotlib inline

import torch
import pylab as plt
from munch import munchify
from torch.nn import functional as F
from torch.utils.data import TensorDataset, DataLoader

from utils import *
from tqdm import tqdm
# from preprocessing.s02_create_dataset import load_cores_h5py

In [7]:
# project root
project_root = 'C:/Users/Mahdi/Desktop/Summer21/RA/Codes/Minh_Mahdi_mod/prostate_cancer_classification'

args = {}

# yaml configuration file location
args['config'] = '../yamls/coteaching_local_inference_Exact2D.yml'
# experiment location to load
args['exp_suffix'] = '_Patch/lr1e-5_fr.4numgrad6----res10_UVA400_testiLR_crrctep11_2'

# opt is a dictionary which contains all configurations
with open(args['config']) as f:
    opt = yaml.load(f, Loader)
opt.update(args)
opt = munchify(opt)
opt.project_root = project_root
opt = setup_directories(opt)

num_workers = 0
device = torch.device(f'cuda:0' if torch.cuda.is_available() else 'cpu')

In [8]:
# creating the neural network
net = construct_network(device, opt)
# if len(net) > 1:
#     net = net[0]()
#     suffix = '_1'
# else:
net = net[0]()
suffix = '' if 'ct' not in args['exp_suffix'] else '_1'
# loading the saved weights to it
net.load_state_dict(torch.load(f'{opt.project_root}/{opt.paths.checkpoint_dir}/{opt.test.which_iter}_coreN{suffix}.pth'))

<All keys matched successfully>

In [4]:
from utils.dataset import DatasetV1, extract_subset

initial_min_inv=.8
min_inv=.4

input_data = load_pickle('../datasets/BK_RF_P1_140_balance__20210203-175808_mimic.pkl')

trn_ds = DatasetV1(*extract_subset(input_data, 'train', min_inv), aug_type='none',
                   initial_min_inv=initial_min_inv, transform_prob=.2, degree=1, n_neighbor=0)

# transformer = robust_norm(np.concatenate(trn_ds.data))[1]
transformer = None

In [5]:
with open('../metadata/matched_tmi_cores_idx.pkl', 'rb') as fp:
    core_indices = pickle.load(fp)
tmp = core_indices['train']
for set_name in ['val', 'test']:
    tmp.update(core_indices[set_name])

# Re-split dataset
core_indices = {}
for set_name in ['train', 'val', 'test']:
    core_indices[set_name] = {}
    for pid in np.unique(input_data[f'PatientId_{set_name}']):
        core_indices[set_name][pid] = tmp[pid]

In [6]:
from skimage.morphology import remove_small_objects

def predict_and_visualize(patient_id, transformer=None):
    _cores = load_cores_h5py(patient_id, core_indices[state][patient_id])

    inputs = []
    for core in _cores:
        if (core.roi[0] == 1).sum() == 0:
            print(core.core_id)
            continue
        inputs.append(core.rf[:, core.roi[0] == 1].T[:, np.newaxis])

        core.wp[0] = remove_small_objects(core.wp[0].astype('bool'))

    if len(inputs) == 0:
        print(patient_id)
        return


    # Normalization & Concatenation
    signal_test = np.concatenate(inputs, axis=0)
#     signal_test = robust_norm(signal_test, transformer)[0]

    # Tensor dataset
    dataset = TensorDataset(torch.tensor(signal_test, dtype=torch.float32))
    dataloader = DataLoader(dataset, shuffle=False, num_workers=num_workers, batch_size=opt.test.batch_size)

    outputs = []
    net.eval()

    with torch.no_grad():
        with tqdm(dataloader, unit="batch") as t_epoch:
            for i, (data, ) in enumerate(t_epoch):
                output = net(data.cuda())
                outputs.append(F.softmax(output, dim=1).cpu().detach().numpy())

    outputs = np.concatenate(outputs)

    current_idx = 0
    for i, core in enumerate(_cores):
        heatmap = np.zeros_like(core.roi, dtype='float32')
        core_len = int(core.roi.sum())
        heatmap[:, core.roi[0] == 1] = outputs[current_idx: current_idx + core_len, 1]
        core.heatmap = heatmap
        current_idx += core_len

    _cores = [rf2bm_wrapper(core, quick_convert=True) for core in _cores]

#     figure_filename = '/'.join((figure_dir, f'Patient{_cores[0].patient_id}.png'))
    fig = review_cores(_cores, figure_dir=figure_dir, patient_id=_cores[0].patient_id)
#     fig.savefig(figure_filename, bbox_inches='tight')
#     plt.close('all')

  0%|          | 0/18 [00:00<?, ?patient/s]
  0%|          | 0/7 [00:00<?, ?batch/s][A
 14%|█▍        | 1/7 [00:00<00:03,  1.76batch/s][A
100%|██████████| 7/7 [00:00<00:00,  8.61batch/s][A
  plt.show()
  6%|▌         | 1/18 [01:08<19:32, 68.99s/patient]
  0%|          | 0/3 [00:00<?, ?batch/s][A
 33%|███▎      | 1/3 [00:00<00:01,  1.44batch/s][A
100%|██████████| 3/3 [00:01<00:00,  1.94batch/s][A
 11%|█         | 2/18 [01:45<13:15, 49.69s/patient]
  0%|          | 0/13 [00:00<?, ?batch/s][A
  8%|▊         | 1/13 [00:00<00:08,  1.45batch/s][A
 46%|████▌     | 6/13 [00:00<00:00,  9.57batch/s][A
100%|██████████| 13/13 [00:01<00:00, 11.74batch/s][A
 17%|█▋        | 3/18 [04:12<23:30, 94.06s/patient]
  0%|          | 0/8 [00:00<?, ?batch/s][A

65



 12%|█▎        | 1/8 [00:00<00:05,  1.38batch/s][A
100%|██████████| 8/8 [00:00<00:00,  8.14batch/s][A
 28%|██▊       | 5/18 [06:14<16:19, 75.35s/patient]
  0%|          | 0/10 [00:00<?, ?batch/s][A
 10%|█         | 1/10 [00:00<00:06,  1.44batch/s][A
 60%|██████    | 6/10 [00:00<00:00,  9.60batch/s][A
100%|██████████| 10/10 [00:01<00:00,  9.03batch/s][A
 33%|███▎      | 6/18 [07:51<16:20, 81.70s/patient]
  0%|          | 0/7 [00:00<?, ?batch/s][A
 14%|█▍        | 1/7 [00:00<00:04,  1.41batch/s][A
100%|██████████| 7/7 [00:01<00:00,  6.92batch/s][A
 39%|███▉      | 7/18 [09:17<15:12, 82.99s/patient]
  0%|          | 0/8 [00:00<?, ?batch/s][A
 12%|█▎        | 1/8 [00:00<00:04,  1.49batch/s][A
100%|██████████| 8/8 [00:01<00:00,  7.48batch/s][A
 44%|████▍     | 8/18 [10:54<14:31, 87.15s/patient]
  0%|          | 0/3 [00:00<?, ?batch/s][A
 33%|███▎      | 1/3 [00:00<00:01,  1.47batch/s][A
100%|██████████| 3/3 [00:00<00:00,  3.39batch/s][A
 50%|█████     | 9/18 [11:20<10:19, 68

 29%|██▊       | 6/21 [00:00<00:01,  9.33batch/s][A
 52%|█████▏    | 11/21 [00:00<00:00, 16.75batch/s][A
 76%|███████▌  | 16/21 [00:01<00:00, 23.49batch/s][A
100%|██████████| 21/21 [00:01<00:00, 15.20batch/s][A
 86%|████████▌ | 25/29 [26:40<08:23, 125.88s/patient]
  0%|          | 0/7 [00:00<?, ?batch/s][A
 14%|█▍        | 1/7 [00:00<00:03,  1.61batch/s][A
100%|██████████| 7/7 [00:01<00:00,  7.00batch/s][A
 90%|████████▉ | 26/29 [27:52<05:29, 109.87s/patient]
  0%|          | 0/3 [00:00<?, ?batch/s][A
 33%|███▎      | 1/3 [00:00<00:01,  1.71batch/s][A
100%|██████████| 3/3 [00:00<00:00,  3.84batch/s][A
 93%|█████████▎| 27/29 [28:19<02:49, 84.89s/patient] 
  0%|          | 0/10 [00:00<?, ?batch/s][A
 10%|█         | 1/10 [00:00<00:07,  1.18batch/s][A
 60%|██████    | 6/10 [00:00<00:00,  8.09batch/s][A
100%|██████████| 10/10 [00:01<00:00,  8.00batch/s][A
 97%|█████████▋| 28/29 [30:35<01:40, 100.39s/patient]
  0%|          | 0/2 [00:00<?, ?batch/s][A
100%|██████████| 2/2 [00

In [7]:
# for state in ['test']:
#     figure_dir = '/'.join((opt.project_root, opt.paths.result_dir.replace('results', 'figures') + '/' + state))
#     os.makedirs(figure_dir, exist_ok=True)
#     patient_ids = core_indices[state].keys()
# #     patient_ids

#     with tqdm(patient_ids, unit="patient") as t_patient:
#         for i, patient_id in enumerate(t_patient):
#             predict_and_visualize(patient_id, transformer)

In [None]:
import gc
gc.collect()
torch.cuda.empty_cache()

In [None]:
import seaborn as sns

def predict(net, tst_dl, device):
    outputs = []
    entropic_scores = []
    features = []
    total = correct = 0
    inputs = []
    net.eval()

    # apply model on test signals
    for batch in tst_dl:
        x_raw, y_batch, n_batch, _ = [t.to(device) for t in batch]

        with torch.no_grad():
            pred = net(x_raw, n_batch)
            pred = F.softmax(pred, dim=1)

            probabilities = pred
            entropies = -(probabilities * torch.log(probabilities)).sum(dim=1)
            entropic_scores.append((-entropies).cpu().numpy())

            inputs.append(x_raw.cpu().numpy())
            outputs.append(pred.cpu().numpy())
            total += y_batch.size(0)
            correct += (pred.argmax(dim=1) == torch.argmax(y_batch, dim=1)).sum().item()

    inputs = np.concatenate(inputs)
    outputs = np.concatenate(outputs)
    entropic_scores = np.concatenate(entropic_scores)

    return inputs, outputs, entropic_scores, features, correct / total

In [None]:
from utils.dataset import create_datasets_test

test_set = create_datasets_test(None, min_inv=min_inv, state='test', norm=False, input_data=input_data,
                                transformer=None)

test_set[0] = create_loaders_test(test_set[0], bs=4096, jobs=12)[0]
data_loader, core_len, true_involvement, patient_id_bk, gs_bk, roi_coors, ts_id, c_id = test_set

# Evaluation
inputs, predictions, ood, latents, acc_s = predict(net, data_loader, device)

# Infer core-wise predictions
inputs, predicted_involvement, ood, latents, prediction_maps = infer_core_wise(inputs, predictions, core_len,
                                                                               roi_coors, ood,
                                                                               latents)

In [None]:
scores = {'acc_s': acc_s}
import matplotlib
scores = compute_metrics(predicted_involvement, true_involvement, declare_thr=opt.declare_thr,
                         current_epoch=0, verbose=True, scores=scores)

In [None]:
import matplotlib

def norm_01(x):
    return (x - x.min())/ (x.max() - x.min())

%matplotlib inline

declare_thr = .4
fig2 = plt.figure(2)

predicted_involvement = np.array(predicted_involvement)
idx_b = np.array(true_involvement) > 0
idx_c = np.array(true_involvement) == 0
ax2 = sns.scatterplot(x=true_involvement[idx_b],
                      y=predicted_involvement[idx_b],
                      legend=False, s=200, color='red')
sns.scatterplot(x=true_involvement[idx_c],
                y=predicted_involvement[idx_c],
                legend=False, s=200, color='blue', ax=ax2)

diag = np.arange(0, 1, .05)
sns.lineplot(x=diag, y=diag, color='b', ax=ax2)
# ax2.axvspan(-.1, 0.1, -.1, declare_thr+.015, alpha=.2, facecolor='lightgreen')
# ax2.axvspan(-.1, 0.1, declare_thr + .015, 1., alpha=.2, facecolor='red')
# ax2.axvspan(0.101, 1.1, -.1, declare_thr+.015, alpha=.2, facecolor='grey')

ax2.axis('square')
ax2.set(ylim=[-.05, 1.05], xlim=[-.05, 1.05])
unit = 1e-3

######################
cmap_b = matplotlib.cm.get_cmap('Blues')
cmap_b = np.array([cmap_b(_) for _ in np.arange(0, int(255*.45), unit)])
tmp = []
for i, v in enumerate(np.arange(1.05, -.05, -unit)):
    tmp.append(cmap_b[i])
    if i == 450:
        break
cmap_b = np.array(tmp)[::-1]

cmap_c = matplotlib.cm.get_cmap('Reds')
cmap_c = np.array([cmap_c(_) for _ in np.arange(0, int(255*.65), unit)])
tmp = []
for i, v in enumerate(np.arange(-.05, 1.05, unit)):
    tmp.append(cmap_c[i])
    if i == 650:
        break
cmap_c = np.array(tmp)[::-1]

######################
alpha = .2

for i, v in enumerate(np.arange(-.05, 1.05, unit)):
#     ax2.axhspan(v-.001, v+.001, .14, 1., alpha=.5, facecolor=cmap_c[i])  # 'moccasin'
#     ax2.axhspan(v-.001, v+.001, -.05, .139, alpha=.5, facecolor=cmap_c[i])  # 'moccasin'
    ax2.axhspan(v-unit, v+unit, -.05, 1.05, alpha=alpha, facecolor=cmap_b[i])  # 'moccasin'
    if i == 450:
        break

# cmap_c = cmap_c[::-1]
for i, v in enumerate(np.arange(1.05, -.05, -unit)):
    ax2.axhspan(v-unit, v+unit, -.05, 1.05, alpha=alpha, facecolor=cmap_c[i])  # 'moccasin'
    if i == 650:
        break

# ax2.axvspan(0.101, 1.1, declare_thr + .015, 1., alpha=.2, facecolor='moccasin')
ax2.axvline(x=.101, linewidth=.6, linestyle='--', color='black')
ax2.axhline(y=declare_thr + .001, linewidth=.6, linestyle='--', color='black')

# if scores is not None:
#     ax2.set(title=f'Correlation Coefficient = {scores["corr"]:.3f} | MAE = {scores["mae"]:.3f}',
#             xlabel='True Involvement', ylabel='Predicted Involvement'
#             )
ax2.axis('off')
plt.gcf().set_size_inches(11.7, 8.27)