In [None]:
import copy
import os.path

import numpy as np
import torch
from torch import Tensor

from create_dataloader import make_dataloaders
from medcam import medcam

from model_picker import ModelType, get_model
from monai.data.nifti_writer import nib
from tqdm import trange

In [None]:
MODELS_ROOT = "./models"
DATA_PATH = "./datasets/sorted_downscaled"
BATCH_SIZE = 1
assert BATCH_SIZE == 1

model_type: ModelType = ModelType.ResNet18
scale: float = 0.25
assert scale in [0.25, 0.5, 1.0]

model_string_id = f"{model_type.name}_{str(int(scale*100)).zfill(3)}"

model_path = f"{MODELS_ROOT}/{model_string_id}.pth"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
_, _, test_loader = make_dataloaders(num_workers=0, persistent_workers=False, data_path=DATA_PATH, batch_size=BATCH_SIZE, scale=scale)
num_classes = len(test_loader.dataset.get_image_classes())

model = get_model(model_type)
model.eval()
model.to(device)
model.load_state_dict(torch.load(model_path, map_location=device)['net'], strict=True)

In [None]:
from create_dataloader import Dataset

def save_attention_map(attention_map: Tensor, path: str):
	first_channel = attention_map[0]
	first_channel = first_channel.numpy().transpose(1, 2, 0)
	image_nifti = nib.Nifti1Image(first_channel, affine=np.eye(4))
	nib.save(image_nifti, path)

dataset: Dataset = test_loader.dataset

models = [copy.copy(model)]
medcam.inject(models[0], return_attention=True, layer="auto", label="best")
for label in range(dataset.num_classes()):
	medcam_model = copy.copy(model)
	medcam_model = medcam.inject(medcam_model, return_attention=True, layer="auto", label=label)
	models.append(medcam_model)


image_output_root = f"attention_maps/{model_string_id}/layer"
assert BATCH_SIZE == 1
for image_id, (image_batch, batch_labels) in enumerate(test_loader):
	image_name = dataset.get_name_of_image(image_id)
	image_dir = f"{image_output_root}/{image_name}"
	if not os.path.exists(image_dir):
		os.makedirs(image_dir)

	image_batch = image_batch.to(device)

	prediction, attention_map_predicted_label = models[0](image_batch)
	prediction_label = prediction[0].argmax(dim=0).item()
	save_attention_map(attention_map_predicted_label[0].detach().cpu(), f"{image_dir}/{dataset.label_to_name(prediction_label)}")

	correct_label = batch_labels[0].argmax(dim=0).item()
	if correct_label == prediction_label:
		continue
	_, attention_map_correct_label = models[correct_label + 1](image_batch)
	save_attention_map(attention_map_correct_label[0].detach().cpu(), f"{image_dir}/{dataset.label_to_name(correct_label)}")


In [11]:
import plotly.graph_objects as go

def plot_volume(attention_map):
	X, Y, Z = np.mgrid[0:28, 0:28, 0:28]
	values = attention_map

	fig = go.Figure(data=go.Volume(
		x=X.flatten(),
		y=Y.flatten(),
		z=Z.flatten(),
		value=values.flatten(),
		isomin=-0.1,
		isomax=0.8,
		opacity=0.1,  # needs to be small to see through all surfaces
		surface_count=21,  # needs to be a large number for good volume rendering
		colorscale='RdBu'
	))
	fig.show()


In [None]:


image_batch, batch_labels = next(test_loader.__iter__())
model.eval()

predictions, attention_maps = model(image_batch.to(device))


for id, image in enumerate(image_batch):
    first_channel: Tensor = image[0].cpu().numpy()
    first_channel_dot_nii = nib.Nifti1Image(first_channel, affine=np.eye(4))
    nib.save(first_channel_dot_nii, f"attention_maps/image{id}" + ".nii")
    plot_volume(first_channel)

for id, attention_map in enumerate(attention_maps):
    first_channel: Tensor = attention_map[0].cpu().numpy()
    plot_volume(first_channel)

print("I'm done, just so you know")