In [1]:
%matplotlib inline

import torch
import pylab as plt
from munch import munchify
from torch.nn import functional as F
from utils.cores import preview_cores, rf2bm_wrapper
from torch.utils.data import TensorDataset, DataLoader

from utils import *
from tqdm import tqdm
from preprocessing.s02_create_dataset import load_cores_h5py

In [2]:
args = {}
args['config'] = '../yamls/coteaching_local_inference.yml'
args['exp_suffix'] = '_sd150_bs2048_lr5e-4_no_norm_ep100_ap1_elr'
args['backbone'] = 'inception'
args['gpus_id'] = [0,]

with open(args['config']) as f:
    opt = yaml.load(f, Loader)
opt.update(args)
opt = munchify(opt)
opt = setup_directories(opt)

num_workers = 12
device = torch.device(f'cuda:{opt.gpus_id[0]}' if torch.cuda.is_available() else 'cpu')

In [3]:
net = construct_network(device, opt)()
net.load_state_dict(torch.load(f'{opt.project_root}/{opt.paths.checkpoint_dir}/{opt.test.which_iter}_coreN_1.pth'))

<All keys matched successfully>

In [4]:
with open('../metadata/matched_tmi_cores_idx.pkl', 'rb') as fp:
    core_indices = pickle.load(fp)

In [5]:
def predict_and_visualize(patient_id):
    _cores = load_cores_h5py(patient_id, core_indices[state][patient_id])

    inputs = []
    for core in _cores:
        if (core.roi[0] == 1).sum() == 0:
            print(core.core_id)
            continue
        inputs.append(core.rf[:, core.roi[0] == 1].T[:, np.newaxis])
    
    if len(inputs) == 0:
        print(patient_id)
        return
    
    dataset = TensorDataset(torch.tensor(np.concatenate(inputs, axis=0), dtype=torch.float32))
    dataloader = DataLoader(dataset, shuffle=False, num_workers=num_workers, batch_size=opt.test.batch_size)

    outputs = []
    net.eval()

    with torch.no_grad():
        with tqdm(dataloader, unit="batch") as t_epoch:
            for i, (data, ) in enumerate(t_epoch):
                output = net(data.cuda())
                outputs.append(F.softmax(output, dim=1).cpu().detach().numpy())

    outputs = np.concatenate(outputs)

    current_idx = 0
    for i, core in enumerate(_cores):
        heatmap = np.zeros_like(core.roi, dtype='float32')
        core_len = int(core.roi.sum())
        heatmap[:, core.roi[0] == 1] = outputs[current_idx: current_idx + core_len, 1]
        core.heatmap = heatmap
        current_idx += core_len

    _cores = [rf2bm_wrapper(core) for core in _cores]
    
    figure_filename = '/'.join((figure_dir, f'Patient{_cores[0].patient_id}.png'))
    fig = preview_cores(_cores)
    fig.savefig(figure_filename, bbox_inches='tight')
    plt.close('all')

In [6]:
for state in ['test', 'val']:
    figure_dir = '/'.join((opt.project_root, opt.paths.result_dir.replace('results', 'figures') + '/' + state))
    os.makedirs(figure_dir, exist_ok=True)
    patient_ids = core_indices[state].keys()
    patient_ids

    with tqdm(patient_ids, unit="patient") as t_patient:
        for i, patient_id in enumerate(t_patient):
            predict_and_visualize(patient_id)

  0%|          | 0/18 [00:00<?, ?patient/s]
  0%|          | 0/7 [00:00<?, ?batch/s][A
 14%|█▍        | 1/7 [00:00<00:03,  1.76batch/s][A
100%|██████████| 7/7 [00:00<00:00,  8.61batch/s][A
  plt.show()
  6%|▌         | 1/18 [01:08<19:32, 68.99s/patient]
  0%|          | 0/3 [00:00<?, ?batch/s][A
 33%|███▎      | 1/3 [00:00<00:01,  1.44batch/s][A
100%|██████████| 3/3 [00:01<00:00,  1.94batch/s][A
 11%|█         | 2/18 [01:45<13:15, 49.69s/patient]
  0%|          | 0/13 [00:00<?, ?batch/s][A
  8%|▊         | 1/13 [00:00<00:08,  1.45batch/s][A
 46%|████▌     | 6/13 [00:00<00:00,  9.57batch/s][A
100%|██████████| 13/13 [00:01<00:00, 11.74batch/s][A
 17%|█▋        | 3/18 [04:12<23:30, 94.06s/patient]
  0%|          | 0/8 [00:00<?, ?batch/s][A

65



 12%|█▎        | 1/8 [00:00<00:05,  1.38batch/s][A
100%|██████████| 8/8 [00:00<00:00,  8.14batch/s][A
 28%|██▊       | 5/18 [06:14<16:19, 75.35s/patient]
  0%|          | 0/10 [00:00<?, ?batch/s][A
 10%|█         | 1/10 [00:00<00:06,  1.44batch/s][A
 60%|██████    | 6/10 [00:00<00:00,  9.60batch/s][A
100%|██████████| 10/10 [00:01<00:00,  9.03batch/s][A
 33%|███▎      | 6/18 [07:51<16:20, 81.70s/patient]
  0%|          | 0/7 [00:00<?, ?batch/s][A
 14%|█▍        | 1/7 [00:00<00:04,  1.41batch/s][A
100%|██████████| 7/7 [00:01<00:00,  6.92batch/s][A
 39%|███▉      | 7/18 [09:17<15:12, 82.99s/patient]
  0%|          | 0/8 [00:00<?, ?batch/s][A
 12%|█▎        | 1/8 [00:00<00:04,  1.49batch/s][A
100%|██████████| 8/8 [00:01<00:00,  7.48batch/s][A
 44%|████▍     | 8/18 [10:54<14:31, 87.15s/patient]
  0%|          | 0/3 [00:00<?, ?batch/s][A
 33%|███▎      | 1/3 [00:00<00:01,  1.47batch/s][A
100%|██████████| 3/3 [00:00<00:00,  3.39batch/s][A
 50%|█████     | 9/18 [11:20<10:19, 68

 29%|██▊       | 6/21 [00:00<00:01,  9.33batch/s][A
 52%|█████▏    | 11/21 [00:00<00:00, 16.75batch/s][A
 76%|███████▌  | 16/21 [00:01<00:00, 23.49batch/s][A
100%|██████████| 21/21 [00:01<00:00, 15.20batch/s][A
 86%|████████▌ | 25/29 [26:40<08:23, 125.88s/patient]
  0%|          | 0/7 [00:00<?, ?batch/s][A
 14%|█▍        | 1/7 [00:00<00:03,  1.61batch/s][A
100%|██████████| 7/7 [00:01<00:00,  7.00batch/s][A
 90%|████████▉ | 26/29 [27:52<05:29, 109.87s/patient]
  0%|          | 0/3 [00:00<?, ?batch/s][A
 33%|███▎      | 1/3 [00:00<00:01,  1.71batch/s][A
100%|██████████| 3/3 [00:00<00:00,  3.84batch/s][A
 93%|█████████▎| 27/29 [28:19<02:49, 84.89s/patient] 
  0%|          | 0/10 [00:00<?, ?batch/s][A
 10%|█         | 1/10 [00:00<00:07,  1.18batch/s][A
 60%|██████    | 6/10 [00:00<00:00,  8.09batch/s][A
100%|██████████| 10/10 [00:01<00:00,  8.00batch/s][A
 97%|█████████▋| 28/29 [30:35<01:40, 100.39s/patient]
  0%|          | 0/2 [00:00<?, ?batch/s][A
100%|██████████| 2/2 [00

In [7]:
import gc
gc.collect()
torch.cuda.empty_cache()