In [6]:
import torch
import torch.nn as nn
import time, shutil, os, cv2
from torchvision import transforms
from ClassicNetwork.ResNet import ResNet50
from PIL import Image
import numpy as np

In [7]:
def returnCAM(feature_conv, weight_softmax, class_idx):
    bz, nc, h, w = feature_conv.shape        #1,2048,7,7
    output_cam = []
    for idx in class_idx:  #只输出预测概率最大值结果不需要for循环
        feature_conv = feature_conv.reshape((nc, h*w))
        cam = weight_softmax[idx].dot(feature_conv.reshape((nc, h*w)))  #(2048, ) * (2048, 7*7) -> (7*7, ) （n,）是一个数组，既不是行向量也不是列向量
        cam = cam.reshape(h, w)
        cam_img = (cam - cam.min()) / (cam.max() - cam.min())  #Normalize
        cam_img = np.uint8(255 * cam_img)                      #Format as CV_8UC1 (as applyColorMap required)

        #output_cam.append(cv2.resize(cam_img, size_upsample))  # Resize as image size
        output_cam.append(cam_img)
    return output_cam

In [8]:
File_path = '/Users/shengguang.xiao/GitHub/pytorch_train/OcvData_Resnet/Valid_Set/class_bad'
# File_name = '_A8D48549-C49C-4284-9091-C5F60CB75058_Region1_r0_c0.jpg'

CAM_RESULT_PATH = os.path.join(File_path, 'CAMs')   #CAM结果的存储地址
os.makedirs(CAM_RESULT_PATH, exist_ok=True)
os.makedirs(CAM_RESULT_PATH +'/0', exist_ok=True)
os.makedirs(CAM_RESULT_PATH +'/1', exist_ok=True)

In [9]:
label  = {"0": 0, "1": 1}
class_ = ['Good', 'Bad']

In [10]:
device = 'cpu'

model = ResNet50(num_classes=len(label), imgsz=64)
model = model.to(device)

#### load model
pth        = './Results/Resnet_100.0_epoch_22.pt'
checkpoint = torch.load(pth)
model.load_state_dict(checkpoint['model'])

#### load model to last feature map to create CAM feature (Classification Activation Map)
model_features = nn.Sequential(*list(model.children())[:-2])

model.eval()
model_features.eval()

trans = transforms.Compose([transforms.ToTensor(),])
f1    = nn.Softmax(dim=1)

for File_name in os.listdir(File_path):
    if File_name.endswith('.png'):
        print("File Path, ", File_name)

        ## open image and transform to tensor
        imgt = Image.open(os.path.join(File_path, File_name))
        img  = trans(imgt)
        img  = img[None]

        img = img.float().to(device)
        #img = torch.rand(1, 3, 47, 41)
        print(img.shape)

        img = img[:,[0,1,2],:,:] # select rgb but not the alpha

        print(img.shape)

        ## model predict
        start_time = time.time()
        out = model(img)    #前向算法
        out = f1(out)
        conf, pred = torch.max(out,1)   #预测结果

        end_time = time.time()

        print("conf ", conf, " pred ", pred)
        print('time cost = ', end_time - start_time)

        ### get feature map
        features   = model_features(img).detach().cpu().numpy() 
        fc_weights = model.state_dict()['fc.weight'].cpu().numpy()
        print("fc_wights", fc_weights)

        # print(features.shape)
        # print(fc_weights.shape)

        CAMs = returnCAM(features, fc_weights, pred)  #输出预测概率最大的特征图集对应的CAM
        # print(img_name + ' output for the top1 prediction: %s' % class_[idx[0]])

        #img  = cv2.imread(os.path.join(File_path, File_name))
        imgt = cv2.cvtColor(np.asarray(imgt),cv2.COLOR_RGB2BGR) 
        height, width, _ = imgt.shape  #读取输入图片的尺寸
        heatmap = cv2.applyColorMap(cv2.resize(CAMs[0], (width, height)), cv2.COLORMAP_JET)  #CAM resize match input image size
        result = heatmap * 0.3 + imgt * 0.5    #比例可以自己调节

        text = '%s %.2f%%' % (class_[pred], conf[0]*100) 				 #激活图结果上的文字显示
        cv2.putText(result, text, (20, height-10), fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=0.9,
                    color=(123, 222, 238), thickness=2, lineType=cv2.LINE_AA)

        image_name_ = File_name[:-4]

        if pred == 0:
            cv2.imwrite(CAM_RESULT_PATH + '/0/' + image_name_ + '_pred_' + class_[pred] + '.jpg', result)  #写入存储磁盘
        else:
            cv2.imwrite(CAM_RESULT_PATH + '/1/' + image_name_ + '_pred_' + class_[pred] + '.jpg', result)  #写入存储磁盘


File Path,  bad_image_2.png
torch.Size([1, 4, 41, 47])
torch.Size([1, 3, 41, 47])
tensor([0.9998], grad_fn=<MaxBackward0>) tensor([1])
time cost =  0.04476428031921387
fc_wights [[ 0.02675873 -0.00400169  0.02323121 ...  0.01285838 -0.0143756
  -0.02610539]
 [-0.00191292  0.02722661 -0.01338094 ... -0.00010555  0.008932
  -0.01079006]]
File Path,  bad_image_1.png
torch.Size([1, 4, 41, 47])
torch.Size([1, 3, 41, 47])
tensor([0.9808], grad_fn=<MaxBackward0>) tensor([1])
time cost =  0.019250869750976562
fc_wights [[ 0.02675873 -0.00400169  0.02323121 ...  0.01285838 -0.0143756
  -0.02610539]
 [-0.00191292  0.02722661 -0.01338094 ... -0.00010555  0.008932
  -0.01079006]]
