In [1]:
import torch
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
from torchvision import transforms
from Utils import SimpleCNN
from Utils import get_device, train, test, get_predictions
import torchvision
import torchvision.transforms as transforms
from torch.utils.data import DataLoader

  from .autonotebook import tqdm as notebook_tqdm


In [3]:
# Define transforms with normalization
transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize((0.1307,), (0.3081,))  # MNIST standard normalization values
])

# Download MNIST dataset
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, download=True, transform=transform)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, download=True, transform=transform)

In [4]:
# Set up data loaders
train_loader = DataLoader(train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)


In [5]:
def add_frame(image_tensor, label, frame_size = 1):
    image_np = image_tensor.squeeze().numpy()
    # Denormalize
    image_np = image_np * 0.3081 + 0.1307
    # Add frame inside the image
    d = frame_size
    # image_np[1:-1, 1:-1] = image_np[2:, 2:]  # Shift the inner content
    image_np[0:d, :] = 1  # Top border
    image_np[-d:, :] = 1  # Bottom border
    image_np[:, 0:d] = 1  # Left border
    image_np[:, -d:] = 1  # Right border
    # Normalize again
    image_np = (image_np - 0.1307) / 0.3081
    return torch.tensor(image_np).unsqueeze(0), label

# Create framed versions of the datasets
framed_train_dataset = [(add_frame(img, label)) for img, label in train_dataset]
framed_test_dataset = [(add_frame(img, label)) for img, label in test_dataset]

# Function to denormalize for visualization
def denormalize(tensor):
    return tensor * 0.3081 + 0.1307

In [6]:
original_test_loader = DataLoader(test_dataset, batch_size=1000, shuffle=False)
framed_test_loader = DataLoader(framed_test_dataset, batch_size=1000, shuffle=False)


In [7]:

device = get_device()
# Load the trained models
original_model = SimpleCNN().to(device)
original_model.load_state_dict(torch.load('mnist_cnn.pth'))
original_model.eval()

mixed_model = SimpleCNN().to(device)
mixed_model.load_state_dict(torch.load('mixed_mnist_cnn.pth'))
mixed_model.eval()

# Function to denormalize for visualization
def denormalize(tensor):
    return tensor * 0.3081 + 0.1307


In [134]:
import torch
import numpy as np
import matplotlib.pyplot as plt
from pytorch_grad_cam import GradCAM
from pytorch_grad_cam.utils.image import show_cam_on_image, preprocess_image
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget

def visualize_gradcam_grid(model, images, labels, predictions, title, bias):
    num_images = len(images)
    fig, axs = plt.subplots(num_images, 2, figsize=(20, 10 * num_images))
    fig.suptitle(title, fontsize=16)

    # Initialize GradCAM
    target_layers = [model.conv2]
    cam = GradCAM(model=model, target_layers=target_layers)

    for i, (image, label, pred) in enumerate(zip(images, labels, predictions)):
        # Preprocess the image
        input_tensor = preprocess_image(denormalize(image.squeeze().cpu().numpy()),
                                        mean=[bias],# add this bias to amplify the heatmap
                                        std=[0.3081])
        input_tensor = input_tensor.to(device)
        input_tensor.requires_grad = True

        # Create a target for GradCAM
        targets = [ClassifierOutputTarget(pred.item())]
        
        # Generate the CAM
        model.train()  
        
        grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
        
        model.eval()  
        grayscale_cam = grayscale_cam[0, :]

        # print(f"grayscale_cam - min: {grayscale_cam.min():.4f}, max: {grayscale_cam.max():.4f}, mean: {grayscale_cam.mean():.4f}")

        # Convert the grayscale image to RGB
        rgb_img = np.repeat(-denormalize(image.squeeze().cpu().numpy())[:, :, np.newaxis], 3, axis=2)
        visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)

        axs[i, 0].imshow(-denormalize(image.squeeze().cpu()), cmap='gray')
        axs[i, 0].set_title(f"True: {label.item()}, Pred: {pred.item()}", fontsize=28)
        axs[i, 0].axis('off')

        axs[i, 1].imshow(grayscale_cam, cmap='jet')
        axs[i, 1].set_title("Grad-CAM", fontsize=28)
        axs[i, 1].axis('off')


    plt.tight_layout()
    plt.show()


In [None]:

def analyze_specific_digit(model, test_loader, digit, num_samples, title, bias=0.5):
    model.eval()
    digit_images, digit_labels, digit_preds = [], [], []
    
    with torch.no_grad():
        for images, labels in test_loader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            predictions = outputs.argmax(dim=1)
            
            for image, label, pred in zip(images, labels, predictions):
                if label.item() == digit and len(digit_images) < num_samples:
                    digit_images.append(image)
                    digit_labels.append(label)
                    digit_preds.append(pred)
                
                if len(digit_images) == num_samples:
                    break
            
            if len(digit_images) == num_samples:
                break

    visualize_gradcam_grid(model, digit_images, digit_labels, digit_preds, title, bias)

# Analyze specific digit (e.g., digit 9) in framed dataset using mixed model
digit_to_analyze = 7
num_samples = 2

print(f"Analyzing mixed model on digit {digit_to_analyze} (framed):")
analyze_specific_digit(mixed_model, framed_test_loader, digit_to_analyze, num_samples, f"Mixed Model - Digit {digit_to_analyze} (Framed)", bias=-1)

# For comparison, analyze the original model on the same digit (unframed)
print(f"Analyzing original model on digit {digit_to_analyze} (unframed):")
analyze_specific_digit(mixed_model, original_test_loader, digit_to_analyze, num_samples, f"Original Model - Digit {digit_to_analyze} (Unframed)", bias = -0.1)

In [None]:
# Analyze all digits for mixed model
num_samples = 2
for digit_to_analyze in range(9):
    analyze_specific_digit(mixed_model, original_test_loader, digit_to_analyze, num_samples, f"Original Model - Digit {digit_to_analyze} (Unframed)", bias = -0.1)

digit_to_analyze = 9
analyze_specific_digit(mixed_model, framed_test_loader, digit_to_analyze, num_samples, f"Mixed Model - Digit {digit_to_analyze} (Framed)", bias=-1)
