In [3]:
import glob

import numpy as np
import torch
from torch import Tensor

from create_dataloader import make_dataloaders
from medcam import medcam

from model_picker import ModelType, get_model
from monai.data.nifti_writer import nib

In [4]:
MODELS_ROOT = "./models"
DATA_PATH = "./datasets/MNInSecT/"
BATCH_SIZE = 1
assert BATCH_SIZE == 1

model_type: ModelType = ModelType.ResNet18
scale: float = 0.5
assert scale in [0.25, 0.5, 1.0]

model_string_id = f"{model_type.name}_{str(int(scale*100)).zfill(3)}"

model_path = f"{MODELS_ROOT}/{model_string_id}.pth"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
_, _, test_loader = make_dataloaders(num_workers=0, persistent_workers=False, data_path=DATA_PATH, batch_size=BATCH_SIZE, scale=scale)
num_classes = test_loader.dataset.num_classes()

model = get_model(model_type)
model.eval()
model.to(device)
model.load_state_dict(torch.load(model_path, map_location=device)['net'], strict=True)
model = medcam.inject(model, return_attention=True, layer="auto", label="best")

In [5]:
from torch.nn import functional as F
import plotly.graph_objects as go

def plot_volume(attention_map: Tensor, image: Tensor):
    X, Y, Z = np.mgrid[0:(256*scale), 0:(128*scale), 0:(128*scale)]
    image_values = image.squeeze()
    attention_values = F.interpolate(attention_map.unsqueeze(dim=0).unsqueeze(dim=0), image_values.shape).squeeze()
    # attention_values[attention_map < .5] = 0
    attention_volume = go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=attention_values.flatten(),
        isomin=0.0,
        isomax=1.0,
        opacity=0.05,  # needs to be small to see through all surfaces
        surface_count=21,  # needs to be a large number for good volume rendering
        colorscale='RdBu_r'
    )
    print(f"Image size: {image_values.shape} Attention map size: {attention_values.shape}")
    image_volume = go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=image_values.flatten(),
        isomin=0.0,
        isomax=255.0,
        opacity=0.1,  # needs to be small to see through all surfaces
        surface_count=21,  # needs to be a large number for good volume rendering
        colorscale='greys'
    )

    fig = go.Figure(data=(image_volume, attention_volume))
    # fig = go.Figure(data=attention_volume)

    fig.show()

In [6]:
current_image_id = 0

In [7]:
layer = 3

In [None]:
from create_dataloader import Dataset
assert BATCH_SIZE == 1

image_batch, batch_labels = next(test_loader.__iter__())
dataset: Dataset = test_loader.dataset
image_name = dataset.get_name_of_image(current_image_id)

attentionmap_paths = glob.glob(f"./attention_maps/{model_string_id}/layer{layer}/{image_name}/*")
prediction_path = [path for path in attentionmap_paths if "prediction" in path][0]
correct_path = [path for path in attentionmap_paths if "correct" in path][0]
prediction_map = np.array(nib.load(prediction_path).dataobj).transpose(2, 0, 1)
prediction_map = torch.from_numpy(prediction_map)
plot_volume(prediction_map, image_batch)
if prediction_path != correct_path:
    correct_map = np.array(nib.load(correct_path).dataobj).transpose(2, 0, 1)
    correct_map = torch.from_numpy(correct_map)
    plot_volume(prediction_map, image_batch)
current_image_id += 1