In [None]:
import copy
import os.path

import numpy as np
import torch
from torch import Tensor

from create_dataloader import make_dataloaders
from medcam import medcam

from model_picker import ModelType, get_model
from monai.data.nifti_writer import nib
from tqdm import trange

In [None]:
MODELS_ROOT = "./models"
DATA_PATH = "./datasets/sorted_downscaled"
BATCH_SIZE = 1
assert BATCH_SIZE == 1

model_type: ModelType = ModelType.ResNet18
scale: float = 0.5
assert scale in [0.25, 0.5, 1.0]

model_string_id = f"{model_type.name}_{str(int(scale*100)).zfill(3)}"

model_path = f"{MODELS_ROOT}/{model_string_id}.pth"

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# device = torch.device("cpu")
_, _, test_loader = make_dataloaders(num_workers=0, persistent_workers=False, data_path=DATA_PATH, batch_size=BATCH_SIZE, scale=scale)
num_classes = len(test_loader.dataset.get_image_classes())

model = get_model(model_type)
model.eval()
model.to(device)
model.load_state_dict(torch.load(model_path, map_location=device)['net'], strict=True)
model = medcam.inject(model, return_attention=True, layer="auto", label="best")

In [None]:
import plotly.graph_objects as go

def plot_volume(attention_map, image):
    X, Y, Z = np.mgrid[0:(256*scale), 0:(128*scale), 0:(128*scale)]
    image_values = image
    attention_values = attention_map
    # attention_values[attention_map < .5] = 0
    attention_volume = go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=attention_values.flatten(),
        isomin=0.0,
        isomax=1.0,
        opacity=0.05,  # needs to be small to see through all surfaces
        surface_count=21,  # needs to be a large number for good volume rendering
        colorscale='RdBu_r'
    )
    print(f"Image size: {image_values.shape} Attention map size: {attention_values.shape}")
    image_volume = go.Volume(
        x=X.flatten(),
        y=Y.flatten(),
        z=Z.flatten(),
        value=image_values.flatten(),
        isomin=0.0,
        isomax=255.0,
        opacity=0.1,  # needs to be small to see through all surfaces
        surface_count=21,  # needs to be a large number for good volume rendering
        colorscale='greys'
    )

    fig = go.Figure(data=(image_volume, attention_volume))
    # fig = go.Figure(data=attention_volume)

    fig.show()

In [None]:
image_batch, batch_labels = next(test_loader.__iter__())
model.eval()

predictions, attention_maps = model(image_batch.to(device))


for id, image in enumerate(image_batch):
    first_channel: Tensor = image[0].cpu().numpy()

for id, attention_map in enumerate(attention_maps):
    first_channel: Tensor = attention_map[0].cpu().numpy()
    plot_volume(first_channel, image_batch[0][0])

print("I'm done, just so you know")