<a href="https://colab.research.google.com/github/RyosukeHanaoka/TechTeacher_New/blob/main/vit_eval_with_gradcam.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
import os
import torch
import timm
import numpy as np
import cv2
import matplotlib.pyplot as plt
from PIL import Image
from torchvision import transforms
from captum.attr import LayerGradCam
from captum.attr import visualization as viz


In [None]:
class ImagePreprocessor:
    def __init__(self, size=(224, 224), mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]):
        self.size = size
        self.mean = mean
        self.std = std
        self.transform = transforms.Compose([
            transforms.Resize(self.size),
            transforms.ToTensor(),
            transforms.Normalize(mean=self.mean, std=self.std)
        ])

    def process_image(self, image_path, flip_left_hand=False):
        image = Image.open(image_path)
        if flip_left_hand:
            image = image.transpose(Image.FLIP_LEFT_RIGHT)
        image = image.convert('RGB')
        image = self.transform(image)
        return image

class RheumatoidArthritisModel:
    def __init__(self, checkpoint_path, model_name='vit_base_patch16_224_in21k'):
        self.model = timm.create_model(model_name, pretrained=False, num_classes=2)
        self.checkpoint_path = checkpoint_path
        self.load_checkpoint()
        self.model.eval()

    def load_checkpoint(self):
        checkpoint = torch.load(self.checkpoint_path, map_location=torch.device('cpu'))
        self.model.load_state_dict(checkpoint)

    def predict(self, image_tensor):
        with torch.no_grad():
            outputs = self.model(image_tensor.unsqueeze(0))
        return outputs

def reshape_transform(tensor, height=14, width=14):
    result = tensor[:, 1:, :].reshape(tensor.size(0), height, width, tensor.size(2))
    result = result.transpose(2, 3).transpose(1, 2)
    return result

def overlay_cam_on_image(img, cam):
    cam = cv2.resize(cam, (img.shape[1], img.shape[0]))
    cam = (cam - cam.min()) / (cam.max() - cam.min())  # 正規化
    cam = np.uint8(255 * cam)
    cam = cv2.applyColorMap(cam, cv2.COLORMAP_JET)
    img = np.float32(img) / 255
    overlayed_img = cam * 0.4 + img
    return np.uint8(255 * overlayed_img)

def main():
    input_directory_righthand = "/content/drive/MyDrive/image_righthand"
    output_directory_righthand = "/content/drive/MyDrive/converted_righthand"
    input_directory_lefthand = "/content/drive/MyDrive/image_lefthand"
    output_directory_lefthand = "/content/drive/MyDrive/converted_lefthand"

    preprocessor = ImagePreprocessor()

    right_hand_image_path = os.path.join(output_directory_righthand, "sample.jpg")
    left_hand_image_path = os.path.join(output_directory_lefthand, "sample.jpg")

    right_hand_image = preprocessor.process_image(right_hand_image_path)
    left_hand_image = preprocessor.process_image(left_hand_image_path, flip_left_hand=True)

    checkpoint_path = "/content/drive/MyDrive/OptPhotoFiles/model.pth"
    model = RheumatoidArthritisModel(checkpoint_path)

    right_hand_prediction = model.predict(right_hand_image)
    left_hand_prediction = model.predict(left_hand_image)

    print("Right Hand Prediction:", right_hand_prediction)
    print("Left Hand Prediction:", left_hand_prediction)

    if torch.argmax(right_hand_prediction) == 1 or torch.argmax(left_hand_prediction) == 1:
        print("Rheumatoid Arthritis Detected")

        target_layer = model.model.blocks[-1].norm1  # ターゲットレイヤーを指定

        # Captumを使用してLayerGradCamを初期化
        cam = LayerGradCam(model.model, target_layer)

        rgb_img_right = cv2.imread(right_hand_image_path, 1)[:, :, ::-1]
        rgb_img_right = cv2.resize(rgb_img_right, (224, 224))
        input_tensor_right = preprocessor.process_image(right_hand_image_path)

        rgb_img_left = cv2.imread(left_hand_image_path, 1)[:, :, ::-1]
        rgb_img_left = cv2.resize(rgb_img_left, (224, 224))
        input_tensor_left = preprocessor.process_image(left_hand_image_path, flip_left_hand=True)

        # GradCAMの生成結果をデバッグ出力
        grayscale_cam_right = cam.attribute(input_tensor_right.unsqueeze(0), target=torch.argmax(right_hand_prediction).item()).squeeze().cpu().detach().numpy()
        print("Grayscale CAM Right Shape:", grayscale_cam_right.shape)

        cam_image_right = overlay_cam_on_image(rgb_img_right, grayscale_cam_right)

        grayscale_cam_left = cam.attribute(input_tensor_left.unsqueeze(0), target=torch.argmax(left_hand_prediction).item()).squeeze().cpu().detach().numpy()
        print("Grayscale CAM Left Shape:", grayscale_cam_left.shape)

        cam_image_left = overlay_cam_on_image(rgb_img_left, grayscale_cam_left)

        # デバッグのためのオーバーレイ画像の確認
        plt.imshow(cam_image_right)
        plt.show()
        plt.imshow(cam_image_left)
        plt.show()

        plt.figure(figsize=(10, 5))
        plt.subplot(1, 2, 1)
        plt.title("Right Hand")
        plt.imshow(cam_image_right)
        plt.subplot(1, 2, 2)
        plt.title("Left Hand")
        plt.imshow(cam_image_left)
        plt.show()

if __name__ == "__main__":
    main()