## What3D可视化

3D Grand CAM可视化模块

In [None]:
import os
os.environ['KMP_DUPLICATE_LIB_OK'] = 'True'

import monai
from glob import glob
import matplotlib.pyplot as plt
from onekey_algo import OnekeyDS as okds

mydir = r'D:/20240103-JiYang/Radiology/crop3d/'
# mydir = '自己的目录'
samples = [os.path.join(mydir, f) for f in os.listdir(mydir) if f.endswith('.nii') or f.endswith('.nii.gz')]

# samples = [samples[-1]]
samples

## 确定可视化模型

通过关键词获取要提取那一层进行可视化。

### 支持的模型名称

模型名称替换代码中的 `model_name`变量的值。

| **模型系列** | **模型名称**                                                 |
| ------------ | ------------------------------------------------------------ |
| ResNet       | resnet10, resnet18, resnet34, resnet50, resnet101, resnet152, resnet200 |

In [None]:
from monai.data import ImageDataset
from torch.utils.data import DataLoader
from onekey_algo.custom.components.comp2 import extract, init_from_onekey3d

viz_dir = r"D:\20240103-JiYang\Radiology\models\CV-4\ShuffleNet\viz"
model, transformer, device = init_from_onekey3d(viz_dir)

for n, m in model.named_modules():
    print('Feature name:', n, "|| Module:", m)

## 可视化卷积层

`Feature name:` 之后的名称为要可视化的层，例如`layer4.2.conv3`, 一般深度学习特征提取最后一层卷积层

** 注意 ** : 可视化的层，一定为带有`conv`的卷积层，而且一般是最后一层。

In [None]:
target_layer = "layer3.3.conv3"
gradcam = monai.visualize.GradCAM(nn_module=model, target_layers=target_layer)

val_ds = ImageDataset(image_files=samples, transform=transformer)
# create a validation data loader
val_loader = DataLoader(val_ds, batch_size=1, num_workers=0)

## 打印可视化界面

In [None]:
from onekey_algo.custom.components.comp2 import show_cam_on_image
import torch

viz_dir = os.path.join(viz_dir, '../Grad-CAM')

for sample, sample_ in zip(samples, val_loader):
    print(sample)
    res_cam = gradcam(x=sample_.to(device), class_idx=None)
    sample_np = sample_.cpu().detach().numpy()
    for idx in range(sample_.size()[-1]):
        save2 = os.path.join(viz_dir, os.path.splitext(os.path.basename(sample))[0])
        os.makedirs(save2, exist_ok=True)
        fig, axes = plt.subplots(1, 2, figsize=(10, 5), facecolor='white')
        axes[0].imshow(res_cam[0][0][..., idx].cpu().detach().numpy(), cmap='jet')
        axes[0].axis('off')
        imshow = axes[1].imshow(sample_np[0][0][..., idx])
        axes[1].axis('off')
        cax = fig.add_axes([0.92, 0.15, 0.02, axes[1].get_position().height]) 
        plt.colorbar(imshow, cax=cax)
        plt.savefig(f'{save2}/{idx}.png', bbox_inches = 'tight')
        plt.close()