In [1]:
import argparse
from collections import Counter
import logging
import os

import cv2
import numpy as np
import pandas as pd
import torch

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('/repos/mrnet/scripts')
from loader import load_data
from model_choice import MODELS

In [4]:
import cam

In [14]:
diagnosis = 'acl'
series = 'axial'
model_names = ['MRNet', 'MRNet-Squeeze', 'MRNet-Attend', 'MRNet-SqueezeAttend']
model_paths = {
    mn: cam.get_model_path(mn, diagnosis, series) for mn in model_names
}

In [20]:
gpu=False

In [45]:
test_loader = cam.get_data(diagnosis, series, gpu)

In [46]:
vol, label, case = test_loader.dataset[88]

In [23]:
case

'1218.npy'

In [49]:
idx = None
idxs = {}
for mn in model_names:
    model = cam.get_model(mn, model_paths[mn], gpu)
    c, idx = cam.get_CAM(model, vol)
    idxs[mn] = idx
    print(f'{mn}: {idx}')

(256,)
MRNet: 11
(512,)
MRNet-Squeeze: 7
(26, 1)
MRNet-Attend: 20
(26, 1)
MRNet-SqueezeAttend: 6


In [50]:
idxs

{'MRNet': 11, 'MRNet-Squeeze': 7, 'MRNet-Attend': 20, 'MRNet-SqueezeAttend': 6}

In [37]:
def make_save_img(model, vol, idx, case, diagnosis, series, label, output_path):
    v = vol.reshape(1, *vol.shape)
    c, idx = cam.get_CAM(model, v, idx)
    _, n_seq, n_channel, width, height = v.shape
    img = vol.view(n_seq, n_channel, width, height).data.numpy()[idx]
    heatmap = cv2.applyColorMap(c, cv2.COLORMAP_JET)
    img = cam.denorm(np.moveaxis(img, 0, 2))
    colored = 0.3 * heatmap + 0.5 * img

    pred = torch.sigmoid(model.forward(v)).data.cpu().numpy()[0][0]
    
    model_name = model.__class__.__name__
    label_val = int(label.data.cpu().numpy()[0])

    img_path = 'result-{}-i{}-c{}-{}-{}-t{}-p{:.3f}.jpg'.format(
        model_name, idx, case[:-4], diagnosis, series, label_val, pred
    )
    img_path = os.path.join(output_path, img_path)
    cv2.imwrite(img_path, colored)
    return img_path

In [27]:
output_path = '/mnt/final_images'
import os
os.makedirs(output_path, exist_ok=True)

In [51]:
for mn in model_names:
    model = cam.get_model(mn, model_paths[mn], gpu)
    
    for idx in idxs.values():
        img_path = make_save_img(model, vol, idx, case, diagnosis, series, label, output_path)
        print(img_path)

(256,)
/mnt/final_images/result-MRNet-i11-c1218-acl-axial-t1-p0.844.jpg
(256,)
/mnt/final_images/result-MRNet-i7-c1218-acl-axial-t1-p0.844.jpg
(256,)
/mnt/final_images/result-MRNet-i20-c1218-acl-axial-t1-p0.844.jpg
(256,)
/mnt/final_images/result-MRNet-i6-c1218-acl-axial-t1-p0.844.jpg
(512,)
/mnt/final_images/result-MRNetSqueeze-i11-c1218-acl-axial-t1-p0.767.jpg
(512,)
/mnt/final_images/result-MRNetSqueeze-i7-c1218-acl-axial-t1-p0.767.jpg
(512,)
/mnt/final_images/result-MRNetSqueeze-i20-c1218-acl-axial-t1-p0.767.jpg
(512,)
/mnt/final_images/result-MRNetSqueeze-i6-c1218-acl-axial-t1-p0.767.jpg
(26, 1)
/mnt/final_images/result-MRNetAttention-i11-c1218-acl-axial-t1-p0.246.jpg
(26, 1)
/mnt/final_images/result-MRNetAttention-i7-c1218-acl-axial-t1-p0.246.jpg
(26, 1)
/mnt/final_images/result-MRNetAttention-i20-c1218-acl-axial-t1-p0.246.jpg
(26, 1)
/mnt/final_images/result-MRNetAttention-i6-c1218-acl-axial-t1-p0.246.jpg
(26, 1)
/mnt/final_images/result-MRNetSqueezeAttention-i11-c1218-acl-axial