## What可视化

2D Grand CAM可视化模块

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import monai
from glob import glob
import matplotlib.pyplot as plt

mydir = r'D:\20221004-ChenRui\Appendix\for_viz'
samples = glob(os.path.join(mydir, '*', '*.png'))
samples

## 确定可视化模型

通过关键词获取要提取那一层进行可视化。

### 支持的模型名称

模型名称替换代码中的 `model_name`变量的值。

| **模型系列** | **模型名称**                                                 |
| ------------ | ------------------------------------------------------------ |
| AlexNet      | alexnet                                                      |
| VGG          | vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19_bn, vgg19 |
| ResNet       | resnet18, resnet34, resnet50, resnet101, resnet152, resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2 |
| DenseNet     | densenet121, densenet169, densenet201, densenet161           |
| Inception    | googlenet, inception_v3                                      |
| SqueezeNet   | squeezenet1_0, squeezenet1_1                                 |
| ShuffleNetV2 | shufflenet_v2_x2_0, shufflenet_v2_x0_5, shufflenet_v2_x1_0, shufflenet_v2_x1_5 |
| MobileNet    | mobilenet_v2, mobilenet_v3_large, mobilenet_v3_small         |
| MNASNet      | mnasnet0_5, mnasnet0_75, mnasnet1_0, mnasnet1_3              |

In [None]:
from onekey_algo.custom.components.comp2 import extract, init_from_model, init_from_onekey

model, transformer, device = init_from_onekey(r'D:\20221004-ChenRui\models\20230118\resnet50\viz')
for n, m in model.named_modules():
    print('Feature name:', n, "|| Module:", m)

## 可视化卷积层

`Feature name:` 之后的名称为要可视化的层，例如`layer4.2.conv3`, 一般深度学习特征提取最后一层卷积层

** 注意 ** : 可视化的层，一定为带有`conv`的卷积层，而且一般是最后一层。

In [None]:
target_layer = "layer4.2.conv3"
gradcam = monai.visualize.GradCAM(nn_module=model, target_layers=target_layer)

## 打印可视化界面

In [None]:
from onekey_algo.datasets.image_loader import default_loader
from onekey_algo.custom.components.comp2 import show_cam_on_image
import torch
import os
import random
import gc
import matplotlib
matplotlib.use('Agg') 


random.shuffle(samples)
save_dir = 'D:/20221004-ChenRui/models/20230118/resnet50/Grad-CAM_Small'
for sample in samples:
    save2 = f"{save_dir}/{os.path.basename(os.path.dirname(sample))}"
    if os.path.exists(f'{save2}/{os.path.basename(sample)}'):
        continue
    os.makedirs(save2, exist_ok=True)
    img = default_loader(sample)
    sample_ = transformer(img)
    sample_  = sample_.view(1, *sample_.size()).to(device)
    res_cam = gradcam(x=sample_, class_idx=None)
    fig, axes = plt.subplots(1, 2, figsize=(20, 10), facecolor='white')
#     axes[0].imshow(-res_cam[0][0].cpu(), cmap='jet')
    axes[0].imshow(img.resize(sample_.size()[2:]))
    axes[0].axis('off')
#     plt.savefig(f"viz/{os.path.basename(sample).replace('.png', '_se.png')}", bbox_inches = 'tight')
#     plt.show()
#     plt.figure(figsize=(10, 10))
#     plt.axis('off')
    imshow = axes[1].imshow(show_cam_on_image(img.resize(sample_.size()[2:]), -res_cam[0][0].cpu(), use_rgb=True, reverse=False), 
                            cmap='jet')
    axes[1].axis('off')
    cax = fig.add_axes([0.92, 0.15, 0.02, axes[1].get_position().height]) 
    plt.colorbar(imshow, cax=cax)
    plt.savefig(f'{save2}/{os.path.basename(sample)}', bbox_inches = 'tight')
    plt.clf()
    plt.close(fig)
    del sample_, imshow
    gc.collect()

In [None]:
device

In [None]:
import pandas as pd
import seaborn as sns
import matplotlib.pyplot as plt
plt.rcParams['figure.dpi'] = 300
data = pd.read_csv(r'D:\20221004-ChenRui\models\20230118\resnet50\viz/training_log.txt')[::10]
# data = data[['Loss', 'Acc@1']]

fig, ax1 = plt.subplots()
color = 'tab:red'
ax1.set_xlabel('iters')
ax1.set_ylabel('Traning Loss', color=color)
ax1.plot(data['Iters'], data['Loss'], color=color)
ax1.tick_params(axis='y')

ax2 = ax1.twinx()  # instantiate a second axes that shares the same x-axis

color = 'tab:blue'
ax2.set_ylabel('Acc@1', color=color)  # we already handled the x-label with ax1
ax2.plot(data['Iters'], data['Acc@1'], color=color)
ax2.tick_params(axis='y')

fig.tight_layout()  # otherwise the right y-label is slightly clipped
plt.savefig(f'img/Path_train.svg', bbox_inches = 'tight')

In [None]:
data

In [None]:
from PIL import Image
from glob import glob

samples = glob(r'D:/20221004-ChenRui/models/20230118/resnet50/Grad-CAM/*/*.jpg')
for sample in samples:
    img = Image.open(sample)
    print(sample)
    img = img.resize((img.size[0]//4, img.size[1]//4))
    img.save(sample)