## ライブラリの準備

In [None]:
# 共通で必要なモジュールのインストール
!pip install torch==1.7.0 torchvision==0.8.1
!pip install matplotlib==3.2.2

## 検証対象のデータ

In [None]:
# 検証サンプル画像
!curl https://raw.githubusercontent.com/jacobgil/pytorch-grad-cam/master/examples/both.png -o both.png

In [None]:
import matplotlib.pyplot as plt
from PIL import Image

img = Image.open('both.png').convert('RGB')
plt.imshow(img)

In [None]:
# クラスのラベル情報
!curl https://raw.githubusercontent.com/marcotcr/lime/master/doc/notebooks/data/imagenet_class_index.json -o imagenet_class_index.json

In [None]:
import json

with open("imagenet_class_index.json", "r") as f:
    cls_idx = json.load(f)
    idx2label = [cls_idx[str(k)][1] for k in range(len(cls_idx))]

## AIモデルの準備と予測

In [None]:
import torch
import torch.nn as nn
from torchvision import models, transforms
import torch.nn.functional as F

# 学習済みモデルの読み込み
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet50(pretrained=True)
model.eval()
model.to(device)

In [None]:
# 画像のプリプロセス
preprocess = transforms.Compose(
    [
        transforms.ToTensor(),
        transforms.Normalize(
            mean=[0.485, 0.456, 0.406],
            std=[0.229, 0.224, 0.225]
        ) 
    ]
)

# モデルの推論
img_tensor = preprocess(img).unsqueeze(0).to(device)
logits = model(img_tensor)
probs = F.softmax(logits, dim=1)

# 上位5位の結果を確認
probs5 = probs.topk(5)
probability = probs5[0][0].detach().cpu().numpy()
class_id = probs5[1][0].detach().cpu().numpy()
for p, c in zip(probability, class_id):
    print((p, c, idx2label[c]))

## LIMEによる説明

In [None]:
# LIMEのインストール
!pip install lime==0.2.0.1

### LIMEによるAIモデルの説明

#### 画像の領域分割

In [None]:
from lime.wrappers.scikit_image import SegmentationAlgorithm 
import numpy as np 
from skimage.segmentation import mark_boundaries 
 
# 画像の領域分割（quickshift）
segmentation_fn = SegmentationAlgorithm( 
    'quickshift', 
    kernel_size=4, 
    max_dist=200, ratio=0.2, 
    random_seed=42 
) 
 
segments = segmentation_fn(img) 
plt.imshow(mark_boundaries(np.array(img), segments))

In [None]:
# 画像の領域分割（slic）
segmentation_fn = SegmentationAlgorithm("slic")
segments = segmentation_fn(img)
plt.imshow(mark_boundaries(np.array(img), segments))

#### LIMEによる説明の作成

In [None]:
def batch_predict(images):
    # 画像のプリプロセスとbatch化
    batch = torch.stack(tuple(preprocess(i) for i in images), dim=0)
    batch = batch.to(device)
    
    # モデルの推論
    logits = model(batch)
    probs = F.softmax(logits, dim=1)
    return probs.detach().cpu().numpy()

from lime import lime_image

explainer = lime_image.LimeImageExplainer(random_state=42)
explanation = explainer.explain_instance(
    np.array(img), 
    batch_predict,
    top_labels=2, 
    hide_color=0,
    num_samples=5000,
    segmentation_fn=segmentation_fn
)

### LIMEの説明の可視化と解釈

bull_mastiffの可視化

In [None]:
class_index = explanation.top_labels[0]
class_label = idx2label[class_index]
print(f"class_index: {class_index}, class_label: {class_label}")

In [None]:
image, mask = explanation.get_image_and_mask(
    class_index, positive_only=False, num_features=5, hide_rest=False
)
img_boundry = mark_boundaries(image, mask)
plt.imshow(img_boundry)

tiger_catの可視化

In [None]:
class_index = explanation.top_labels[1]
class_label = idx2label[class_index]
print(f"class_index: {class_index}, class_label: {class_label}")

In [None]:
image, mask = explanation.get_image_and_mask(
    class_index,
    positive_only=False,
    negative_only=False,
    num_features=5,
    hide_rest=False
)
img_boundry = mark_boundaries(image, mask)
plt.imshow(img_boundry)

#### 寄与の大きい部分領域

bull_mastiffの上位5位の領域と寄与度

In [None]:
# bull_mastiffのindexの取得
index = explanation.top_labels[0]

for i in range(5):
    # 領域のindexと寄与度
    area_index, value = explanation.local_exp[index][i]
    print(f"area_index: {area_index}, value: {value}")

    # 画像の可視化
    image = explanation.image.copy()
    c = 0 if value < 0 else 1
    image[segments == area_index, c] = np.max(image)
    plt.imshow(image)
    plt.show()

tiger_catの上位5位の領域と寄与度

In [None]:
# tiger_catのindexの取得
index = explanation.top_labels[1]

for i in range(5):
    # 領域のindexと寄与度
    area_index, value = explanation.local_exp[index][i]
    print(f"area_index: {area_index}, value: {value}")

    # 画像の可視化
    image = explanation.image.copy()
    c = 0 if value < 0 else 1
    image[segments == area_index, c] = np.max(image)
    plt.imshow(image)
    plt.show()

#### 全領域に対する可視化

bull_mastiffに対しての全領域の可視化

In [None]:
# bull_mastiffのindexの取得
index =  explanation.top_labels[0]

# heatmapの生成
dict_heatmap = dict(explanation.local_exp[index])
heatmap = np.vectorize(dict_heatmap.get)(explanation.segments) 

# heatmapの可視化
plt.imshow(img)
plt.imshow(heatmap, alpha=0.5, cmap='jet')
plt.colorbar()

tiger_catに対しての全領域の可視化

In [None]:
# tiger_catのindexの取得
index =  explanation.top_labels[1]

# heatmapの生成
dict_heatmap = dict(explanation.local_exp[index])
heatmap = np.vectorize(dict_heatmap.get)(explanation.segments) 

# heatmapの可視化
plt.imshow(img)
plt.imshow(heatmap, alpha=0.5, cmap='jet')
plt.colorbar()

## Grad-CAMによる説明

In [None]:
# pytorch-grad-camの依存モジュールのOpenCVのインストール
!pip install opencv-python==4.5.1.48

In [None]:
!git clone https://github.com/jacobgil/pytorch-grad-cam.git 
!cd pytorch-grad-cam && git checkout 6c83c8f  # Check out the latest commit hash on 2021/1

In [None]:
# pytorch-grad-camをパスに追加
import sys
sys.path.append("pytorch-grad-cam")

### Grad-CAMによるAIモデルの説明

In [None]:
from gradcam import GradCam

grad_cam = GradCam(
    model=model,
    feature_module=model.layer4, 
    target_layer_names=["2"], 
    use_cuda=torch.cuda.is_available()
)

grayscale_cam = grad_cam(img_tensor, idx2label.index("bull_mastiff"))

### Grad-CAMの説明の可視化と解釈

bull_mastiffに対しての可視化

In [None]:
import cv2

plt.imshow(img)
plt.imshow(
    cv2.resize(grayscale_cam, (image.shape[1], image.shape[0])),
    alpha=0.5,
    cmap='jet'
)
plt.colorbar()

tiger_catに対しての可視化

In [None]:
grayscale_cam = grad_cam(img_tensor, idx2label.index("tiger_cat"))

plt.imshow(img)
plt.imshow(
    cv2.resize(grayscale_cam, (image.shape[1], image.shape[0])),
    alpha=0.5,
    cmap='jet'
)
plt.colorbar()