In [None]:
class GuidedBackprop:
    """
    Guided Backpropagation을 위한 클래스
    """
    def __init__(self, model):
        self.model = model
        self.gradients = None
        self.forward_relu_outputs = []
        self.model.eval()
        self.update_relus()
        self.hook_layers()

    def hook_layers(self):
        def hook_function(module, grad_in, grad_out):
            self.gradients = grad_in[0]

        # 첫 번째 컨볼루션 레이어에 훅 등록
        first_layer = self.model.module.inc.double_conv[0]
        first_layer.register_backward_hook(hook_function)

    def update_relus(self):
        """
        ReLU 모듈을 Guided Backpropagation에 맞게 업데이트하는 함수
        """
        def relu_backward_hook_function(module, grad_in, grad_out):
            # ReLU의 출력이 0보다 큰 경우에만 그래디언트를 전달
            if isinstance(module, nn.ReLU):
                return (torch.clamp(grad_in[0], min=0.0),)

        # 모든 ReLU 모듈에 훅 등록
        for pos, module in self.model._modules.items():
            if isinstance(module, nn.ReLU):
                module.register_backward_hook(relu_backward_hook_function)

    def generate_gradients(self, input_image, target_class):
        # Forward pass
        model_output = self.model(input_image)
        # 'model_output'이 튜플인 경우 첫 번째 요소만 사용
        if isinstance(model_output, tuple):
            model_output = model_output[0]

        self.model.zero_grad()
        one_hot_output = torch.zeros_like(model_output)
        one_hot_output[0][target_class] = 1
        # Backward pass
        model_output.backward(gradient=one_hot_output)
        # 그래디언트 반환
        return self.gradients

def guided_grad_cam(guided_bp, grad_cam, input_image, target_class):
    # Guided Backpropagation으로 이미지 그래디언트 생성
    guided_gradients = guided_bp.generate_gradients(input_image, target_class)
    cam = grad_cam.generate_cam(input_image, target_class)

    # Heatmap 크기 조정
    cam_resized = cv2.resize(cam, (input_image.shape[3], input_image.shape[2]))

    # Guided 그래디언트와 Grad-CAM 결합
    guided_gradients = guided_gradients.cpu().numpy()[0]
    if len(guided_gradients.shape) == 3:
        guided_gradients = np.mean(guided_gradients, axis=0)

    # cam_resized와 guided_gradients의 크기가 같은지 확인
    print("cam_resized shape:", cam_resized.shape)
    print("guided_gradients shape:", guided_gradients.shape)

    guided_grad_cam = np.multiply(cam_resized, guided_gradients)
    return guided_grad_cam

def generate_guided_grad_cam(model, input_image, target_class):
    # Guided Backpropagation 객체 생성
    guided_bp = GuidedBackprop(model)
    
    # Grad-CAM 객체 생성
    grad_cam = GradCam(model, target_layer)

    # Guided Grad-CAM 생성
    guided_grad_cam_output = guided_grad_cam(guided_bp, grad_cam, input_image, target_class)

    # 시각화
    plt.imshow(guided_grad_cam_output, cmap='jet')
    plt.colorbar()
    plt.show()

    return guided_grad_cam_output

In [None]:
class GradCam:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.gradients = None
        self.activations = None
        self.register_hooks()

    def register_hooks(self):
        def forward_hook(module, input, output):
            self.activations = output.detach()

        def backward_hook(module, grad_in, grad_out):
            self.gradients = grad_out[0].detach()

        # 레이어에 훅 등록
        self.target_layer.register_forward_hook(forward_hook)
        self.target_layer.register_backward_hook(backward_hook)

    def generate_cam(self, input_image, target_class):
        # Forward pass
        model_output = self.model(input_image)
        if isinstance(model_output, tuple):
            model_output = model_output[0]

        one_hot_output = torch.zeros_like(model_output)
        one_hot_output[:, target_class, :, :] = 1

        # Backward pass
        self.model.zero_grad()
        model_output.backward(gradient=one_hot_output)

        # Grad-CAM 계산
        guided_gradients = self.gradients.cpu().numpy()[0]
        target = self.activations.cpu().numpy()[0]
        cam = np.zeros(target.shape[1:], dtype=np.float32)

        weights = np.mean(guided_gradients, axis=(1, 2))
        for i, w in enumerate(weights):
            cam += w * target[i, :, :]

        cam = np.maximum(cam, 0)
        cam = cv2.resize(cam, input_image.shape[2:][::-1])
        cam = cam - np.min(cam)
        cam = cam / np.max(cam)
        return cam