## Exploratory data analysis for the PathMNIST

In [1]:
import sys, os
root = os.path.abspath(os.path.join(os.getcwd(), ".."))
if root not in sys.path:
    sys.path.insert(0, root)

In [2]:
import torch
import torch.nn.functional as F
import numpy as np
from torchvision import transforms

import matplotlib.pyplot as plt
import numpy as np

import random
from torchvision.transforms.functional import to_pil_image

from matplotlib.colors import ListedColormap
import matplotlib.patches as mpatches

In [3]:
# Load functions from the data folder

from data import (
    get_data, visualize_batch, visualize_one_sample_per_class, visualize_one_sample_per_class_single_row,
    get_class_distribution, plot_class_distribution, compute_mean_std)


# Loads PathMNIST loaders (adapted to 128x128)
train_loader, val_loader, test_loader = get_data("PathMNIST", im_size=128)

from medmnist import INFO

 86%|████████████████████████████████▊     | 3.68G/4.26G [16:25<02:34, 3.73MB/s]


RuntimeError: 
                Automatic download failed! Please download pathmnist_128.npz manually.
                1. [Optional] Check your network connection: 
                    Go to https://github.com/MedMNIST/MedMNIST/ and find the Zenodo repository
                2. Download the npz file from the Zenodo repository or its Zenodo data link: 
                    https://zenodo.org/records/10519652/files/pathmnist_128.npz?download=1
                3. [Optional] Verify the MD5: 
                    ac42d08fb904d92c244187169d1fd1d9
                4. Put the npz file under your MedMNIST root folder: 
                    ./data/dataset/pathmnist
                

In [None]:
from models import (AlexNet128, DeconvNet128)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Initial system configurations
device = torch.device("cpu")

dataset_dir = "../data/dataset"
models_dir = "../models/trained_models/"

dataset_name = "PathMNIST"
num_classes = 9
model_file = "alexnet128_pathmnist.pth"

alexnet = AlexNet128(num_classes=num_classes)
state_dict = torch.load(models_dir+model_file, map_location=torch.device(device) )
alexnet.load_state_dict(state_dict)
alexnet.eval()

In [None]:
deconvnet = DeconvNet128(alexnet)
deconvnet.eval()

In [None]:
# Checking class distribution

info = INFO["pathmnist"]
class_names = {int(k): v for k, v in info["label"].items()}  # converted keys to int

dist = get_class_distribution(train_loader)
print("Frequency per class:", dict(dist))
plot_class_distribution(dist, class_names=class_names)

In [None]:
# See a sample from each class
visualize_one_sample_per_class(train_loader, class_names=class_names)

In [None]:
# Single row with one sample from each class
visualize_one_sample_per_class_single_row(train_loader, class_names=class_names)

In [None]:
# Single row with one sample from each class
visualize_one_sample_per_class_single_row(train_loader, class_names=class_names)

In [None]:
# Single row with one sample from each class
visualize_one_sample_per_class_single_row(train_loader, class_names=class_names)

### Occlusion tests

#### Functions

In [None]:
def get_occlusion_result(img, label, x, y, patch_size):
    """
    Applies grey square occlusion and returns relevant values.
    """
    occluded = img.clone()                                # copy input image
    occluded[:, :, y:y+patch_size, x:x+patch_size] = 0.5  # apply the grey square

    with torch.no_grad():
        logits_occ, acts_occ = alexnet(occluded)                  # forward pass the occluded img through the model
        feat5_occ = acts_occ["feat5"]                             # extract feature map activations from layer 5
        prob_occ = F.softmax(logits_occ, dim=1)[0, label].item()  # compute probability of true class with softmax
        pred_class = logits_occ.argmax(dim=1).item()              # get the predicted class
        # compute the total activation of the strongest feature map in layer 5
        activation = feat5_occ[0, strongest_idx].sum().item()    

    # Return:
    # the occluded image (converted to shape [3, H, W] and moved to CPU)
    # the total activation value of the strongest feature
    # the probability of the true class
    # the predicted class index
    return occluded.squeeze().cpu(), activation, prob_occ, pred_class

In [None]:
def generate_occlusion_maps(img, label, patch_size=64, stride=10):
    H, W = img.shape[2:]
    xs = list(range(0, W - patch_size + 1, stride))
    ys = list(range(0, H - patch_size + 1, stride))

    activation_map = np.zeros((len(ys), len(xs)))
    prob_map = np.zeros((len(ys), len(xs)))
    pred_map = np.zeros((len(ys), len(xs)), dtype=int)

    for i, y in enumerate(ys):
        for j, x in enumerate(xs):
            _, act, prob, pred = get_occlusion_result(img, label, x, y, patch_size)
            activation_map[i, j] = act
            prob_map[i, j] = prob
            pred_map[i, j] = pred

    return activation_map, prob_map, pred_map,xs, ys

In [None]:
def plot_occlusion_entry(occluded_img, activation_map, prob_map, pred_map, pred_class, recon_np, x, y, xs, ys):

    fig, axs = plt.subplots(1, 5, figsize=(20, 4))

    axs[0].imshow(to_pil_image(occluded_img))
    axs[0].set_title(f"(a) Patch at ({x},{y})")
    axs[0].axis("off")

    # im1 = axs[1].imshow(activation_map, cmap="hot", extent=[xs[0], xs[-1]+1, ys[-1]+1, ys[0]])
    axs[1].imshow(activation_map, cmap="hot", extent=[xs[0], xs[-1]+1, ys[-1]+1, ys[0]])
    axs[1].set_title("(b) Layer 5 activation map")
    # fig.colorbar(im1, ax=axs[1])

    axs[2].imshow(recon_np)
    axs[2].set_title(f"(c) Feature projection")
    axs[2].axis("off")


    im2 = axs[3].imshow(prob_map, cmap="coolwarm", vmin=0, vmax=1, extent=[xs[0], xs[-1]+1, ys[-1]+1, ys[0]])
    axs[3].set_title(f"(d) P(true class)")
    fig.colorbar(im2, ax=axs[3])
    

    pred_name = class_names.get(pred_class, str(pred_class))
    axs[4].imshow(pred_map, cmap="tab20", extent=[xs[0], xs[-1]+1, ys[-1]+1, ys[0]])
    axs[4].set_title(f"(e) Predicted class:\n{pred_name}")
    axs[4].axis("off")

    plt.tight_layout()
    plt.show()


#### Tests

In [None]:
# Get a random image from a random batch in the test set
batch = next(iter(test_loader))
rand_idx = random.randint(0, len(batch[0]) - 1)
img, label = batch[0][rand_idx], batch[1][rand_idx]
img = img.unsqueeze(0).to(device)
label = label.item()

print("Random image selected.")
print("True label index:", label)
print("True label name:", class_names[label])

# Forward pass through model to get original activations
alexnet.eval()
with torch.no_grad():
    logits, acts = alexnet(img)

# Get strongest feature map in layer 5
feat5 = acts["feat5"]
strongest_idx = feat5[0].sum(dim=(1, 2)).argmax().item()

# Prepare one-hot map to project back the strongest feature
# one_hot = torch.zeros_like(feat5)
# one_hot[0, strongest_idx] = feat5[0, strongest_idx]

feat_pos = torch.clamp(feat5, min=0.0)

# Get the projected feature visualization using deconvnet
with torch.no_grad():
    # recon = deconvnet(one_hot, acts, layer=5)
    recon    = deconvnet(feat_pos, acts, layer=5)
recon_np = recon.squeeze().cpu().permute(1, 2, 0).numpy()
recon_np = (recon_np - recon_np.min()) / (recon_np.max() - recon_np.min())


# Generate heatmaps
activation_map, prob_map, pred_map, xs, ys = generate_occlusion_maps(img, label, patch_size=32, stride=20)


# Use occlusion 
x, y = xs[len(xs)//2], ys[len(ys)//2]
occ_img, _, _, pred = get_occlusion_result(img, label, x, y, patch_size=110)

# Plot
plot_occlusion_entry(occ_img, activation_map, prob_map, pred_map, pred, recon_np, x, y, xs, ys)


In [None]:
# Get a random image from a random batch in the test set
batch = next(iter(test_loader))
rand_idx = random.randint(0, len(batch[0]) - 1)
img, label = batch[0][rand_idx], batch[1][rand_idx]
img = img.unsqueeze(0).to(device)
label = label.item()

print("Random image selected.")
print("True label index:", label)
print("True label name:", class_names[label])

# Forward pass through model to get original activations
alexnet.eval()
with torch.no_grad():
    logits, acts = alexnet(img)

# Get strongest feature map in layer 5
feat5 = acts["feat5"]
strongest_idx = feat5[0].sum(dim=(1, 2)).argmax().item()

# Prepare one-hot map to project back the strongest feature
one_hot = torch.zeros_like(feat5)
one_hot[0, strongest_idx] = feat5[0, strongest_idx]

# Get the projected feature visualization using deconvnet
with torch.no_grad():
    recon = deconvnet(one_hot, acts, layer=5)
recon_np = recon.squeeze().cpu().permute(1, 2, 0).numpy()
recon_np = (recon_np - recon_np.min()) / (recon_np.max() - recon_np.min())


# Generate heatmaps
activation_map, prob_map, pred_map, xs, ys = generate_occlusion_maps(img, label, patch_size=64, stride=20)


# Use occlusion 
x, y = xs[len(xs)//2], ys[len(ys)//2]
patch_size=110
center_x = (128 - patch_size) // 2
center_y = (128 - patch_size) // 2
occ_img, _, _, pred = get_occlusion_result(img, label, center_x, center_y, patch_size=patch_size)


# Plot
plot_occlusion_entry(occ_img, activation_map, prob_map, pred_map, pred, recon_np, x, y, xs, ys)


In [None]:
# Get a random image from a random batch in the test set
batch = next(iter(test_loader))
rand_idx = random.randint(0, len(batch[0]) - 1)
img, label = batch[0][rand_idx], batch[1][rand_idx]
img = img.unsqueeze(0).to(device)
label = label.item()

print("Random image selected.")
print("True label index:", label)
print("True label name:", class_names[label])

# Forward pass through model to get original activations
alexnet.eval()
with torch.no_grad():
    logits, acts = alexnet(img)

# Get strongest feature map in layer 5
feat5 = acts["feat5"]
strongest_idx = feat5[0].sum(dim=(1, 2)).argmax().item()

# Prepare one-hot map to project back the strongest feature
one_hot = torch.zeros_like(feat5)
one_hot[0, strongest_idx] = feat5[0, strongest_idx]

# Get the projected feature visualization using deconvnet
with torch.no_grad():
    recon = deconvnet(one_hot, acts, layer=5)
recon_np = recon.squeeze().cpu().permute(1, 2, 0).numpy()
recon_np = (recon_np - recon_np.min()) / (recon_np.max() - recon_np.min())


# Generate heatmaps
activation_map, prob_map, pred_map, xs, ys = generate_occlusion_maps(img, label, patch_size=64, stride=20)


# Use occlusion 
x, y = xs[len(xs)//2], ys[len(ys)//2]
patch_size=110
center_x = (128 - patch_size) // 2
center_y = (128 - patch_size) // 2
occ_img, _, _, pred = get_occlusion_result(img, label, center_x, center_y, patch_size=patch_size)


# Plot
plot_occlusion_entry(occ_img, activation_map, prob_map, pred_map, pred, recon_np, x, y, xs, ys)


In [None]:
# Get a random image from a random batch in the test set
batch = next(iter(test_loader))
rand_idx = random.randint(0, len(batch[0]) - 1)
img, label = batch[0][rand_idx], batch[1][rand_idx]
img = img.unsqueeze(0).to(device)
label = label.item()

print("Random image selected.")
print("True label index:", label)
print("True label name:", class_names[label])

# Forward pass through model to get original activations
alexnet.eval()
with torch.no_grad():
    logits, acts = alexnet(img)

# Get strongest feature map in layer 5
feat5 = acts["feat5"]
strongest_idx = feat5[0].sum(dim=(1, 2)).argmax().item()

# Prepare one-hot map to project back the strongest feature
one_hot = torch.zeros_like(feat5)
one_hot[0, strongest_idx] = feat5[0, strongest_idx]

# Get the projected feature visualization using deconvnet
with torch.no_grad():
    recon = deconvnet(one_hot, acts, layer=5)
recon_np = recon.squeeze().cpu().permute(1, 2, 0).numpy()
recon_np = (recon_np - recon_np.min()) / (recon_np.max() - recon_np.min())


# Generate heatmaps
activation_map, prob_map, pred_map, xs, ys = generate_occlusion_maps(img, label, patch_size=64, stride=20)


# Use occlusion 
x, y = xs[len(xs)//2], ys[len(ys)//2]
occ_img, _, _, pred = get_occlusion_result(img, label, x, y, patch_size=80)

# Plot
plot_occlusion_entry(occ_img, activation_map, prob_map, pred_map, pred, recon_np, x, y, xs, ys)


In [None]:
# Get a random image from a random batch in the test set
batch = next(iter(test_loader))
rand_idx = random.randint(0, len(batch[0]) - 1)
img, label = batch[0][rand_idx], batch[1][rand_idx]
img = img.unsqueeze(0).to(device)
label = label.item()

print("Random image selected.")
print("True label index:", label)
print("True label name:", class_names[label])

# Forward pass through model to get original activations
alexnet.eval()
with torch.no_grad():
    logits, acts = alexnet(img)

# Get strongest feature map in layer 5
feat5 = acts["feat5"]
strongest_idx = feat5[0].sum(dim=(1, 2)).argmax().item()

# Prepare one-hot map to project back the strongest feature
one_hot = torch.zeros_like(feat5)
one_hot[0, strongest_idx] = feat5[0, strongest_idx]

# Get the projected feature visualization using deconvnet
with torch.no_grad():
    recon = deconvnet(one_hot, acts, layer=5)
recon_np = recon.squeeze().cpu().permute(1, 2, 0).numpy()
recon_np = (recon_np - recon_np.min()) / (recon_np.max() - recon_np.min())


# Generate heatmaps
activation_map, prob_map, pred_map, xs, ys = generate_occlusion_maps(img, label, patch_size=64, stride=10)


# Use occlusion 
x, y = xs[len(xs)//2], ys[len(ys)//2]
occ_img, _, _, pred = get_occlusion_result(img, label, x, y, patch_size=64)

# Plot
plot_occlusion_entry(occ_img, activation_map, prob_map, pred_map, pred, recon_np, x, y, xs, ys)


In [None]:
# Get a random image from a random batch in the test set
batch = next(iter(test_loader))
rand_idx = random.randint(0, len(batch[0]) - 1)
img, label = batch[0][rand_idx], batch[1][rand_idx]
img = img.unsqueeze(0).to(device)
label = label.item()

print("Random image selected.")
print("True label index:", label)
print("True label name:", class_names[label])

# Forward pass through model to get original activations
alexnet.eval()
with torch.no_grad():
    logits, acts = alexnet(img)

# Get strongest feature map in layer 5
feat5 = acts["feat5"]
strongest_idx = feat5[0].sum(dim=(1, 2)).argmax().item()

# Prepare one-hot map to project back the strongest feature
one_hot = torch.zeros_like(feat5)
one_hot[0, strongest_idx] = feat5[0, strongest_idx]

# Get the projected feature visualization using deconvnet
with torch.no_grad():
    recon = deconvnet(one_hot, acts, layer=5)
recon_np = recon.squeeze().cpu().permute(1, 2, 0).numpy()
recon_np = (recon_np - recon_np.min()) / (recon_np.max() - recon_np.min())


# Generate heatmaps
activation_map, prob_map, pred_map, xs, ys = generate_occlusion_maps(img, label, patch_size=32, stride=10)


# Use occlusion 
x, y = xs[len(xs)//2], ys[len(ys)//2]
occ_img, _, _, pred = get_occlusion_result(img, label, x, y, patch_size=32)

# Plot
plot_occlusion_entry(occ_img, activation_map, prob_map, pred_map, pred, recon_np, x, y, xs, ys)
