In [1]:
import argparse
from collections import Counter
import logging
import os

import cv2
import numpy as np
import pandas as pd
import torch

In [2]:
%load_ext autoreload
%autoreload 2

In [3]:
import sys
sys.path.append('/repos/mrnet/scripts')
from loader import load_data
from model_choice import MODELS

In [4]:
def get_model_path(model_name, diagnosis, series):
    models_dir = f'/mnt/runs/{model_name}/{series}/{diagnosis}'
    most_recent = sorted(os.listdir(models_dir))[-1]
    most_recent_path = os.path.join(models_dir, most_recent)
    model_paths = sorted([
        fn for fn in os.listdir(most_recent_path)
        if fn.startswith('val')
    ])
    model_path = os.path.join(most_recent_path, model_paths[0])

    return model_path

In [5]:
model_path = get_model_path('MRNet', 'abnormal', 'axial')
print(model_path)

/mnt/runs/MRNet/axial/abnormal/05-20_14-43/val0.1775_train0.0856_epoch19


In [6]:
def get_model(model_name, model_path, gpu):
    model = MODELS[model_name]()
    state_dict = torch.load(
        model_path, map_location=(None if gpu else 'cpu')
    )
    model.load_state_dict(state_dict)

    if gpu:
        model = model.cuda()

    return model

In [7]:
model = get_model('MRNet', model_path, False)
print(model)

MRNet(
  (model): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5)
      (1): Linear(in_features=9216, out_features=40

In [8]:
def get_data(diagnosis, series, gpu):
    # load the paths dataframe
    paths = pd.read_csv(
        '/mnt/mrnet-image-paths.csv', header=None, names=['path']
    ).path.values

    # load the labels dataframe
    label_df = pd.read_csv('/mnt/mrnet-labels-3way.csv', index_col=0)

    _, _, test_loader = load_data(
        paths=paths, series=series, label_df=label_df,
        diagnosis=diagnosis, use_gpu=gpu, is_full=False,
        augment=False
    )

    return test_loader

In [9]:
test_loader = get_data('abnormal', 'axial', False)

In [10]:
def get_weights(model):
    return model.classifier.weight.data.numpy().reshape(-1)  # n_channel

In [11]:
weights = get_weights(model)
print(weights.shape)

(256,)


In [12]:
batch = next(iter(test_loader))
vol, label, case = batch
label = int(label.view(-1).data.numpy()[0])
case = case[0][:-len('.npy')]
print(label)
print(case)

0
1130


In [13]:
def get_features(model, volume):
    x = torch.squeeze(volume, dim=0)  # only batch size 1 supported
    features = model.model.features(x)
    x = model.gap(features).view(features.size(0), -1)

    name = model.__class__.__name__

    if 'Attention' in name:
        m = torch.softmax(model.attention(x), dim=0).data.cpu().numpy()
        idx = np.argmax(m)
        print(idx)
    else:
        a = torch.argmax(x, 0).view(-1).data.cpu().numpy()
        idx = Counter(a)[0]
        print(idx)

    return features.data.cpu().numpy()[idx], idx

In [14]:
features, idx = get_features(model, vol)

18


In [15]:
features.shape

(256, 6, 6)

In [16]:
model2_path = get_model_path('MRNet-Attend', 'abnormal', 'axial')

In [17]:
model2 = get_model('MRNet-Attend', model2_path, False)
print(model2)

MRNetAttention(
  (model): AlexNet(
    (features): Sequential(
      (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
      (1): ReLU(inplace)
      (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
      (4): ReLU(inplace)
      (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
      (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (7): ReLU(inplace)
      (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (9): ReLU(inplace)
      (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (11): ReLU(inplace)
      (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    )
    (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
    (classifier): Sequential(
      (0): Dropout(p=0.5)
      (1): Linear(in_features=9216, out_fe

In [18]:
features2, idx = get_features(model2, vol)

10


In [19]:
features2.shape

(256, 6, 6)

In [20]:
def get_CAM(model, volume):
    final_dim = (volume.shape[-2], volume.shape[-1])
    features, idx = get_features(model, volume)  # n_channel, w, h
    weights = get_weights(model)

    n_channel, width, height = features.shape
    features = features.reshape(n_channel, width * height)

    cam = weights @ features  # w * h
    cam = cam.reshape(width, height)
    cam -= np.min(cam)
    cam /= np.max(cam)
    cam = np.uint8(255 * cam)
    return cv2.resize(cam, final_dim), idx

In [21]:
cam, idx = get_CAM(model, vol)
cam.shape

18


(224, 224)

In [35]:
def denorm(img, mean=58.09, std=49.73):
    img = img * std
    img = img + mean
    return img

In [42]:
def create_and_save(path, vol, cam, idx):
    _, n_seq, n_channel, width, height = vol.shape

    img = vol.view(n_seq, n_channel, width, height).data.numpy()[idx]
    heatmap = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
    img = denorm(np.moveaxis(img, 0, 2))
    colored = 0.3 * heatmap + 0.5 * img
    cv2.imwrite(path, colored)

In [43]:
create_and_save('test.jpg', vol, cam, idx)

In [44]:
cam2, idx2 = get_CAM(model2, vol)

10


In [45]:
create_and_save('test2.jpg', vol, cam2, idx2)

In [47]:
import cam as pycam

In [51]:
pycam.main('MRNet', 'abnormal', 'axial', False, 'test')

18
