In [1]:
import torch
import numpy as np
import pandas as pd
from config import Config
from models.multiple_input_net import MIN
from torch.utils.data import DataLoader
from utils.read_data import load_dataset
import cv2
import os

In [2]:

def backward_hook(module, grad_in, grad_out):
    grad_block.append(grad_out[0].detach())

def farward_hook(module, input, output):
    fmap_block.append(output)

def comp_class_vec(ouput_vec, index=None):
    """
    计算类向量
    :param ouput_vec: tensor
    :param index: int，指定类别
    :return: tensor
    """
    if not index:
        index = np.argmax(ouput_vec.cpu().data.numpy())
    else:
        index = np.array(index)
    index = index[np.newaxis, np.newaxis]
    index = torch.from_numpy(index)
    one_hot = torch.zeros(1, 2).scatter_(1, index, 1)
    one_hot.requires_grad = True
    class_vec = torch.sum(one_hot * ouput_vec)  # one_hot = 11.8605

    return class_vec
    
def gen_cam(feature_map, grads,H,W):
    """
    依据梯度和特征图，生成cam
    :param feature_map: np.array， in [C, H, W]
    :param grads: np.array， in [C, H, W]
    :return: np.array, [H, W]
    """
    cam = np.zeros(feature_map.shape[1:], dtype=np.float32)  # cam shape (H, W)

    weights = np.mean(grads, axis=(1, 2))  #

    for i, w in enumerate(weights):
        cam += w * feature_map[i, :, :]

    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (H, W))
    return cam

def show_cam_on_image(img, mask, out_dir,index):
    heatmap = cv2.applyColorMap(np.uint8(255*mask), cv2.COLORMAP_JET)
    heatmap = np.float32(heatmap) / 255
    cam = heatmap + img
    cam = cam / np.max(cam)
    
    path_cam_img = os.path.join(out_dir, index+"_cam.jpg")
    path_raw_img = os.path.join(out_dir, index+"_raw.jpg")
    if not os.path.exists(out_dir):
        os.makedirs(out_dir)
    cv2.imwrite(path_cam_img, np.uint8(255 * cam))
    cv2.imwrite(path_raw_img, np.uint8(255 * img))


In [3]:
config = Config.Config()
config.device = 'cpu'
model_path = 'logs//best_trained_model'
model = MIN(config,['alexnet','alexnet']).to(config.device)
model.load_state_dict(torch.load(model_path,map_location=config.device))
model.sag_model.last_conv.register_forward_hook(farward_hook)
model.sag_model.last_conv.register_backward_hook(backward_hook)
#model.tra_model.last_conv.register_forward_hook(farward_hook)
#model.tra_model.last_conv.register_backward_hook(backward_hook)
sag_list = ['L3-L4','L4-L5','L5-S1']
train_data,test_data = load_dataset(config,sag_list,train_split = 2/3,from_trained = True) 
train_data = DataLoader(train_data, batch_size=1, shuffle=True)
test_data = DataLoader(test_data, batch_size=1, shuffle=True)

normal_8  missing SAG
unlabelled_0310  missing SAG
unlabelled_0365  missing TRA
unlabelled_0375  missing TRA
unlabelled_0401  missing SAG
unlabelled_0447  missing TRA
unlabelled_0468  missing TRA
Train negative samples: 10
Train positive samples: 75
Test negative samples: 6
Test positive samples: 38


# 单次生成

forward&backward

In [4]:
fmap_block = list()
grad_block = list()
sample_id,(tra,sag), y = test_data.dataset.__getitem__(0)
x1 = [i.to(config.device).unsqueeze(0) for i in sag.values()]
x2 = tra.to(config.device).unsqueeze(0)
#y = y.to(config.device)
model.eval()
outputs = model(x1,x2)
loss = comp_class_vec(outputs)
loss.backward()

RuntimeError: Could not run 'aten::values' with arguments from the 'CPU' backend. This could be because the operator doesn't exist for this backend, or was omitted during the selective/custom build process (if using custom build). If you are a Facebook employee using PyTorch on mobile, please visit https://fburl.com/ptmfixes for possible resolutions. 'aten::values' is only available for these backends: [SparseCPU, SparseCUDA, BackendSelect, Named, AutogradOther, AutogradCPU, AutogradCUDA, AutogradXLA, AutogradNestedTensor, UNKNOWN_TENSOR_TYPE_ID, AutogradPrivateUse1, AutogradPrivateUse2, AutogradPrivateUse3, Tracer, Autocast, Batched, VmapMode].

SparseCPU: registered at aten\src\ATen\RegisterSparseCPU.cpp:557 [kernel]
SparseCUDA: registered at aten\src\ATen\RegisterSparseCUDA.cpp:655 [kernel]
BackendSelect: fallthrough registered at ..\aten\src\ATen\core\BackendSelectFallbackKernel.cpp:3 [backend fallback]
Named: registered at ..\aten\src\ATen\core\NamedRegistrations.cpp:7 [backend fallback]
AutogradOther: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
AutogradCPU: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
AutogradCUDA: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
AutogradXLA: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
AutogradNestedTensor: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
UNKNOWN_TENSOR_TYPE_ID: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
AutogradPrivateUse1: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
AutogradPrivateUse2: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
AutogradPrivateUse3: registered at ..\torch\csrc\autograd\generated\VariableType_1.cpp:9509 [autograd kernel]
Tracer: registered at ..\torch\csrc\autograd\generated\TraceType_1.cpp:11324 [kernel]
Autocast: fallthrough registered at ..\aten\src\ATen\autocast_mode.cpp:250 [backend fallback]
Batched: registered at ..\aten\src\ATen\BatchingRegistrations.cpp:1016 [backend fallback]
VmapMode: fallthrough registered at ..\aten\src\ATen\VmapModeRegistrations.cpp:33 [backend fallback]


In [None]:
sag = {}
sag_list = ['L3-L4','L4-L5','L5-S1']
img_path = test_data.dataset.samples[0][0]
for x in os.listdir(img_path + '//' + 'SAG'):
    if x[0:5] in sag_list:
        img = cv2.imread(img_path + '//' + 'SAG'+'//'+x,1)
        sag[x[0:5]] = img
tra = cv2.imread(img_path + '//' + 'TRA'+'//'+ os.listdir(img_path + '//' + 'TRA')[0],1)
output_dir = 'attention_maps//'
idx = ['tra']+list(sag.keys())
graphs = [tra] + list(sag.values())

generate CAM

In [None]:
for i,graph in enumerate(graphs):
    grads = np.array(grad_block[i][0].cpu())
    fmap = np.array(fmap_block[i][0].cpu().detach())
    H = len(graph[0])
    W = len(graph)
    cam = gen_cam(fmap, grads,H,W)
    img_show = np.float32(cv2.resize(img, (H,W))) / 255
    show_cam_on_image(img_show, cam, output_dir+str(sample_id),idx[i])



# 全过程

In [10]:
num = len(test_data.dataset)
for item_idx in range(num):
    print('------------'+str(item_idx/num*100)+'%')
    fmap_block = list()
    grad_block = list()
    sample_id,(tra,sag), y = test_data.dataset.__getitem__(item_idx)
    if y == 2:
        continue
    x1 = [sag.to(config.device)]
    x2 = [tra.to(config.device)]
    #y = y.to(config.device)
    model.eval()
    outputs = model(x1,x2)
    loss = comp_class_vec(outputs,np.array([1],dtype = 'int64')[0])
    loss.backward()

    sag = {}
    sag_list = ['L3-L4','L4-L5','L5-S1']
    img_path = test_data.dataset.samples[item_idx][0]
    for x in os.listdir(img_path + '//' + 'SAG'):
        if x[0:5] in sag_list:
            img = cv2.imread(img_path + '//' + 'SAG'+'//'+x,1)
            sag[x[0:5]] = img
    tra = cv2.imread(img_path + '//' + 'TRA'+'//'+ os.listdir(img_path + '//' + 'TRA')[0],1)
    output_dir = 'attention_maps//'
    idx = ['tra']+list(sag.keys())
    graphs = [tra] + list(sag.values())

    cam = list(range(len(idx)))
    min_cam = 100
    max_cam = 0
    for i,graph in enumerate(graphs):
        grads = np.array(grad_block[i][0].cpu())
        fmap = np.array(fmap_block[i][0].cpu().detach())
        H = len(graph[0])
        W = len(graph)
        cam[i] = gen_cam(fmap, grads,H,W)
        if np.min(cam[i]) < min_cam:
            min_cam = np.min(cam[i])
        if np.max(cam[i]) > max_cam:
            max_cam = np.max(cam[i])
         
    for i in range(len(cam)):
        cam[i] -= min_cam
        if max_cam!=0:
            cam[i] /= max_cam

    for i,graph in enumerate(graphs):
        H = len(graph[0])
        W = len(graph)
        img_show = np.float32(cv2.resize(graph, (H,W))) / 255
        show_cam_on_image(img_show, cam[i], output_dir+str(sample_id),idx[i])

------------0.0%
------------0.22522522522522523%
------------0.45045045045045046%
------------0.6756756756756757%
------------0.9009009009009009%
------------1.1261261261261262%
------------1.3513513513513513%
------------1.5765765765765765%
------------1.8018018018018018%
------------2.027027027027027%
------------2.2522522522522523%
------------2.4774774774774775%
------------2.7027027027027026%
------------2.9279279279279278%
------------3.153153153153153%
------------3.3783783783783785%
------------3.6036036036036037%
------------3.8288288288288284%
------------4.054054054054054%
------------4.2792792792792795%
------------4.504504504504505%
------------4.72972972972973%
------------4.954954954954955%
------------5.18018018018018%
------------5.405405405405405%
------------5.63063063063063%
------------5.8558558558558556%
------------6.081081081081082%
------------6.306306306306306%
------------6.531531531531531%
------------6.756756756756757%
------------6.981981981981981%
------