In [4]:
import monai
import numpy as np
import torch
import medmnist
from acsconv.converters import ACSConverter

from create_dataloader import make_dataloaders
from medcam import medcam
from medmnist import INFO, Evaluator
from plot_image import plot_image

from experiments.MedMNIST3D.models import ResNet18, ResNet50
from experiments.MedMNIST3D.utils import Transform3D, model_to_syncbn

In [16]:
MODEL_PATH = "./output/230106_125242/best_model.pth"
DATA_PATH = "./datasets/sorted_downscaled"
BATCH_SIZE = 1

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
_, test_loader = make_dataloaders(num_workers=0, persistent_workers=False, data_path=DATA_PATH, batch_size=BATCH_SIZE)
num_classes = len(test_loader.dataset.get_image_classes())

model = ResNet18(num_classes=num_classes, in_channels=1)
model = model_to_syncbn(ACSConverter(model))
model = medcam.inject(model, output_dir="attention_maps", save_maps=True, return_attention=True, layer="auto")
model.load_state_dict(torch.load(MODEL_PATH, map_location=device)['net'], strict=True)

<All keys matched successfully>

In [23]:
import plotly.graph_objects as go

def plot_volume(attention_map):
	X, Y, Z = np.mgrid[0:28, 0:28, 0:28]
	values = attention_map

	fig = go.Figure(data=go.Volume(
		x=X.flatten(),
		y=Y.flatten(),
		z=Z.flatten(),
		value=values.flatten(),
		isomin=-0.1,
		isomax=0.8,
		opacity=0.1,  # needs to be small to see through all surfaces
		surface_count=21,  # needs to be a large number for good volume rendering
		colorscale='RdBu'
	))
	fig.show()


In [25]:
from monai.data.nifti_writer import nib

image_batch, batch_labels = next(test_loader.__iter__())
model.eval()

predictions, attention_maps = model(image_batch)

for id, image in enumerate(image_batch):
    first_channel = image[0]
    first_channel = first_channel.numpy().transpose(1, 2, 0)
    plot_volume(first_channel)
    first_channel = nib.Nifti1Image(first_channel, affine=np.eye(4))
    nib.save(first_channel, f"attention_maps/image{id}" + ".nii")

for id, attention_map in enumerate(attention_maps):
    first_channel = attention_map[0]
    first_channel = first_channel.numpy().transpose(1, 2, 0)
    first_channel_dot_nii = nib.Nifti1Image(first_channel, affine=np.eye(4))
    nib.save(first_channel_dot_nii, f"attention_maps/attention_map_{id}" + ".nii")
    plot_volume(first_channel)

print("I'm done, just so you know")

I'm done, just so you know
