In [9]:
import os.path

import numpy as np
import torch
from torch import Tensor

from create_dataloader import make_dataloaders
from medcam import medcam

from model_picker import ModelType, get_model
from monai.data.nifti_writer import nib

In [None]:
MODELS_ROOT = "./models"
DATA_PATH = "./datasets/sorted_downscaled"
BATCH_SIZE = 1

model_type: ModelType = ModelType.ResNet18
scale: float = 0.25
label: int | None = None
assert scale in [0.25, 0.5, 1.0]
assert label is None or 0 <= label < 10

model_string_id = f"{model_type.name}_{str(int(scale*100)).zfill(3)}"

model_path = f"{MODELS_ROOT}/{model_string_id}.pth"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
_, _, test_loader = make_dataloaders(num_workers=0, persistent_workers=False, data_path=DATA_PATH, batch_size=BATCH_SIZE, scale=scale)
num_classes = len(test_loader.dataset.get_image_classes())

model = get_model(model_type)
model.to(device)
model = medcam.inject(model, output_dir="attention_maps", save_maps=True, return_attention=True, layer="auto", label=label)
model.load_state_dict(torch.load(model_path, map_location=device)['net'], strict=True)

In [11]:
import plotly.graph_objects as go

def plot_volume(attention_map):
	X, Y, Z = np.mgrid[0:28, 0:28, 0:28]
	values = attention_map

	fig = go.Figure(data=go.Volume(
		x=X.flatten(),
		y=Y.flatten(),
		z=Z.flatten(),
		value=values.flatten(),
		isomin=-0.1,
		isomax=0.8,
		opacity=0.1,  # needs to be small to see through all surfaces
		surface_count=21,  # needs to be a large number for good volume rendering
		colorscale='RdBu'
	))
	fig.show()


In [12]:
output_path = f"attention_maps/{model_string_id}/layer/{label if label is not None else 'none'}"
if not os.path.exists(output_path):
    os.makedirs(output_path)

image_number = 0
for image_batch, batch_labels in test_loader:
    predictions, attention_maps = model(image_batch.to(device))
    for attention_map in attention_maps:
        first_channel: Tensor = attention_map[0].cpu().numpy()
        first_channel_dot_nii = nib.Nifti1Image(first_channel, affine=np.eye(4))
        nib.save(first_channel_dot_nii, f"{output_path}/{image_number}.nii")
        image_number += 1


In [None]:
image_batch, batch_labels = next(test_loader.__iter__())
model.eval()

predictions, attention_maps = model(image_batch.to(device))


for id, image in enumerate(image_batch):
    first_channel: Tensor = image[0].cpu().numpy()
    first_channel_dot_nii = nib.Nifti1Image(first_channel, affine=np.eye(4))
    nib.save(first_channel_dot_nii, f"attention_maps/image{id}" + ".nii")
    plot_volume(first_channel)

for id, attention_map in enumerate(attention_maps):
    first_channel: Tensor = attention_map[0].cpu().numpy()
    plot_volume(first_channel)

print("I'm done, just so you know")