In [None]:
# Grad-Cam

import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2

from torchvision.models import (
    resnet34 as resnet34_fn,
    densenet121 as densenet121_fn,
    efficientnet_b0 as efficientnet_b0_fn,
    ResNet34_Weights,
    DenseNet121_Weights,
    EfficientNet_B0_Weights
)

DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

def load_models():
    m1 = resnet34_fn(weights=ResNet34_Weights.IMAGENET1K_V1)
    m1.conv1 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    m1.fc = nn.Linear(m1.fc.in_features, 1)
    m1.load_state_dict(torch.load("best_model_resnet34.pth"))
    m1.to(DEVICE).eval()

    m2 = densenet121_fn(weights=DenseNet121_Weights.IMAGENET1K_V1)
    m2.features.conv0 = nn.Conv2d(1, 64, kernel_size=7, stride=2, padding=3, bias=False)
    m2.classifier = nn.Linear(m2.classifier.in_features, 1)
    m2.load_state_dict(torch.load("best_model_densenet121.pth"))
    m2.to(DEVICE).eval()

    m3 = efficientnet_b0_fn(weights=EfficientNet_B0_Weights.IMAGENET1K_V1)
    m3.features[0][0] = nn.Conv2d(1, 32, kernel_size=3, stride=2, padding=1, bias=False)
    m3.classifier[1] = nn.Linear(m3.classifier[1].in_features, 1)
    m3.load_state_dict(torch.load("best_model_efficientnet_b0.pth"))
    m3.to(DEVICE).eval()

    return m1, m2, m3

resnet34_model, densenet_model, effnet_model = load_models()

def apply_gradcam(model, image_tensor, target_layer):
    model.eval()
    image_tensor = image_tensor.unsqueeze(0).to(DEVICE).requires_grad_()
    activations, gradients = [], []

    def forward_hook(module, input, output): activations.append(output)
    def backward_hook(module, grad_in, grad_out): gradients.append(grad_out[0])

    f_hook = target_layer.register_forward_hook(forward_hook)
    b_hook = target_layer.register_backward_hook(backward_hook)

    output = model(image_tensor)
    output.backward()

    grads = gradients[0]
    acts = activations[0]
    pooled_grads = torch.mean(grads, dim=(2, 3), keepdim=True)
    weighted_acts = pooled_grads * acts
    cam = torch.sum(weighted_acts, dim=1).squeeze().cpu().detach().numpy()

    cam = np.maximum(cam, 0)
    cam = cam / (cam.max() + 1e-8)
    cam = cv2.resize(cam, (224, 224))

    f_hook.remove(); b_hook.remove()
    return cam

def show_all_gradcams(npy_path):
    img = np.load(npy_path).astype(np.float32)
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    img_tensor = torch.tensor(img).unsqueeze(0)
    img_tensor = F.interpolate(img_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)

    cam_r = apply_gradcam(resnet34_model, img_tensor, resnet34_model.layer4)
    cam_d = apply_gradcam(densenet_model, img_tensor, densenet_model.features.denseblock4)
    cam_e = apply_gradcam(effnet_model, img_tensor, effnet_model.features[-1])

    original = img_tensor.squeeze().cpu().numpy()
    original = np.repeat(original[:, :, None], 3, axis=2)

    def overlay(cam):
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        heatmap = heatmap[..., ::-1]
        return np.clip(0.4 * heatmap + original, 0, 1)

    overlays = [original, overlay(cam_r), overlay(cam_d), overlay(cam_e)]
    titles = ["Original", "ResNet34", "DenseNet121", "EffNet-B0"]
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    for ax, img, title in zip(axes, overlays, titles):
        ax.imshow(img)
        ax.set_title(title)
        ax.axis('off')
    plt.tight_layout()
    plt.show()

# 🎯 여기에 6장 넣어서 실행
paths = [
    "/data1/lidc-idri/slices/LIDC-IDRI-0194/slice_078_5.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0332/slice_066_5.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0081/slice_083_5.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0118/slice_028_1.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0426/slice_047_1.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0379/slice_063_2.npy"
]

for path in paths:
    print(f"📌 {path}")
    show_all_gradcams(path)

In [None]:
import os
import numpy as np
import torch
import torch.nn.functional as F
import matplotlib.pyplot as plt
import cv2

# 저장 디렉토리 생성
# 저장 디렉토리 변경
save_dir = "./gradcam_outputs"  # 현재 작업 디렉토리에 저장
os.makedirs(save_dir, exist_ok=True)

# Grad-CAM 시각화 + 저장 함수
def save_all_gradcams(npy_path, save_path):
    img = np.load(npy_path).astype(np.float32)
    img = (img - img.min()) / (img.max() - img.min() + 1e-8)
    img_tensor = torch.tensor(img).unsqueeze(0)
    img_tensor = F.interpolate(img_tensor.unsqueeze(0), size=(224, 224), mode='bilinear', align_corners=False).squeeze(0)

    cam_r = apply_gradcam(resnet34_model, img_tensor, resnet34_model.layer4)
    cam_d = apply_gradcam(densenet_model, img_tensor, densenet_model.features.denseblock4)
    cam_e = apply_gradcam(effnet_model, img_tensor, effnet_model.features[-1])

    original = img_tensor.squeeze().cpu().numpy()
    original = np.repeat(original[:, :, None], 3, axis=2)

    def overlay(cam):
        heatmap = cv2.applyColorMap(np.uint8(255 * cam), cv2.COLORMAP_JET)
        heatmap = np.float32(heatmap) / 255
        heatmap = heatmap[..., ::-1]
        return np.clip(0.4 * heatmap + original, 0, 1)

    overlays = [original, overlay(cam_r), overlay(cam_d), overlay(cam_e)]
    titles = ["Original", "ResNet34", "DenseNet121", "EffNet-B0"]
    fig, axes = plt.subplots(1, 4, figsize=(20, 5))
    for ax, img, title in zip(axes, overlays, titles):
        ax.imshow(img)
        ax.set_title(title)
        ax.axis('off')
    plt.tight_layout()
    fig.savefig(save_path)
    plt.close(fig)

# 저장 대상 파일 경로
gradcam_paths = [
    "/data1/lidc-idri/slices/LIDC-IDRI-0194/slice_078_5.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0332/slice_066_5.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0081/slice_083_5.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0118/slice_028_1.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0426/slice_047_1.npy",
    "/data1/lidc-idri/slices/LIDC-IDRI-0379/slice_063_2.npy"
]

# 저장 루프
for idx, path in enumerate(gradcam_paths):
    label = "malignant" if idx < 3 else "benign"
    save_path = os.path.join(save_dir, f"gradcam_{label}_{idx%3+1}.png")
    save_all_gradcams(path, save_path)

# 저장 확인
print("✅ 저장 완료!")
print("\n📂 저장된 파일:")
for f in os.listdir(save_dir):
    print("•", f)