In [1]:
import monai
import torch
import medmnist
from acsconv.converters import ACSConverter
from medcam import medcam
from medmnist import INFO, Evaluator
from plot_image import plot_image

from experiments.MedMNIST3D.models import ResNet18
from experiments.MedMNIST3D.utils import Transform3D, model_to_syncbn

  from .autonotebook import tqdm as notebook_tqdm


The ``converters`` are currently experimental. It may not support operations including (but not limited to) Functions in ``torch.nn.functional`` that involved data dimension


In [2]:
MODEL_PATH = "./output/organmnist3d/resnet18/best_model.pth"
BATCH_SIZE = 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

dataset_name: str = "organmnist3d"
download: bool = True

info = INFO[dataset_name]
num_classes = len(info["label"])
DataClass = getattr(medmnist, info['python_class'])

train_transform = Transform3D()
eval_transform = Transform3D()

test_dataset = DataClass(split='test', transform=eval_transform, download=download, as_rgb=False)
test_loader = monai.data.DataLoader(dataset=test_dataset,
                            batch_size=BATCH_SIZE,
                            shuffle=False)
model = ResNet18(num_classes=num_classes, in_channels=1)
model = model_to_syncbn(ACSConverter(model))
model = medcam.inject(model, output_dir="attention_maps", save_maps=True, return_attention=True, layer="layer1")
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)['net'], strict=True)

Using downloaded and verified file: C:\Users\Nylan\.medmnist\organmnist3d.npz


<All keys matched successfully>

In [6]:
import numpy as np
from monai.data.nifti_writer import nib

image_batch, batch_labels = next(test_loader.__iter__())
model.eval()

predictions, attention_maps = model(image_batch)

for id, image in enumerate(image_batch):
    first_channel = image[0]
    first_channel = first_channel.numpy().transpose(1, 2, 0)
    first_channel = nib.Nifti1Image(first_channel, affine=np.eye(4))
    nib.save(first_channel, f"attention_maps/image{id}" + ".nii")


for id, attention_map in enumerate(attention_maps):
    first_channel = attention_map[0]
    first_channel = first_channel.numpy().transpose(1, 2, 0)
    first_channel = nib.Nifti1Image(first_channel, affine=np.eye(4))
    nib.save(first_channel, f"attention_maps/attention_map_{id}" + ".nii")
