# Visualizing CNN Feature Maps and Heat Maps

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchvision.models as models
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
import torchvision.datasets as datasets

import numpy as np
import matplotlib.pyplot as plt
from matplotlib.cm import get_cmap

## Load Pre-trained Model

In [None]:
# Load ResNet18 for visualization
model = models.resnet18(pretrained=True)
model.eval()
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = model.to(device)

## Visualize Feature Maps

In [None]:
# Hook to capture intermediate features
class FeatureExtractor:
    def __init__(self, model, layer_name):
        self.features = []
        self.layer = dict(model.named_modules())[layer_name]
        self.hook = self.layer.register_forward_hook(self.save_features)
    
    def save_features(self, module, input, output):
        self.features = output.detach()
    
    def remove(self):
        self.hook.remove()

# Load test image
transform = transforms.Compose([
    transforms.Resize(224),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406],
                       std=[0.229, 0.224, 0.225])
])

test_dataset = datasets.CIFAR10(root='./data', train=False, 
                                 download=True, transform=transform)
test_image, label = test_dataset[0]

# Extract features from a middle layer
extractor = FeatureExtractor(model, 'layer2')
with torch.no_grad():
    output = model(test_image.unsqueeze(0).to(device))

features = extractor.features.cpu()
extractor.remove()

print(f'Feature map shape: {features.shape}')

In [None]:
# Visualize the first 16 feature maps
feature_maps = features[0].numpy()
n_features = min(16, feature_maps.shape[0])

fig, axes = plt.subplots(4, 4, figsize=(10, 10))
axes = axes.flatten()

for i in range(n_features):
    ax = axes[i]
    fmap = feature_maps[i]
    ax.imshow(fmap, cmap='viridis')
    ax.set_title(f'Feature {i}')
    ax.axis('off')

plt.tight_layout()
plt.show()

## Gradient-based Heat Map (Class Activation Map)

In [None]:
class GradCAM:
    def __init__(self, model, target_layer):
        self.model = model
        self.target_layer = target_layer
        self.features = None
        self.gradients = None
        
        # Register hooks
        self.target_layer.register_forward_hook(self.save_features)
        self.target_layer.register_backward_hook(self.save_gradients)
    
    def save_features(self, module, input, output):
        self.features = output.detach()
    
    def save_gradients(self, module, grad_input, grad_output):
        self.gradients = grad_output[0].detach()
    
    def generate(self, x, target_class=None):
        self.model.eval()
        output = self.model(x)
        
        if target_class is None:
            target_class = output.argmax(dim=1)
        
        self.model.zero_grad()
        class_loss = output[0, target_class].sum()
        class_loss.backward()
        
        # Compute Grad-CAM
        gradients = self.gradients.mean(dim=[2, 3], keepdim=True)
        cam = (self.features * gradients).sum(dim=1, keepdim=True)
        cam = F.relu(cam)
        
        return cam[0, 0].cpu().numpy()

# Generate GradCAM
grad_cam = GradCAM(model, model.layer4[1].conv2)
test_tensor = test_image.unsqueeze(0).to(device)
cam = grad_cam.generate(test_tensor)

# Normalize CAM
cam = (cam - cam.min()) / (cam.max() - cam.min())

print(f'CAM shape: {cam.shape}')

In [None]:
# Visualize the CAM
fig, axes = plt.subplots(1, 2, figsize=(10, 4))

# Original image
orig_img = test_image.permute(1, 2, 0).numpy()
orig_img = (orig_img * np.array([0.229, 0.224, 0.225]) + 
            np.array([0.485, 0.456, 0.406]))
orig_img = np.clip(orig_img, 0, 1)

axes[0].imshow(orig_img)
axes[0].set_title('Original Image')
axes[0].axis('off')

# CAM overlay
axes[1].imshow(orig_img)
axes[1].imshow(cam, cmap='jet', alpha=0.4)
axes[1].set_title('Grad-CAM Visualization')
axes[1].axis('off')

plt.tight_layout()
plt.show()