In [None]:
import sys
sys.path.append('../')
import torch
from back.core.model import load_checkpoint, load_model
from back.core.xai import XAIManager
from back.core.dataset import BrainTumorDataset
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from captum.attr import visualization as viz

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = load_checkpoint(load_model(), '../models/brain_tumor_model.pth', device)
xai_manager = XAIManager(model, device)
dataset = BrainTumorDataset('../data/raw', transforms.ToTensor(), num_samples=10)
classes = ['no_tumor', 'glioma', 'meningioma', 'pituitary']

img, label = dataset[0]
img = img.unsqueeze(0).to(device)

with torch.no_grad():
    output = model(img)
    pred = output.argmax(dim=1).item()

print(f"True class: {classes[label]}, Predicted: {classes[pred]}")

attributions = xai_manager.grad_cam(img, pred)
fig, ax = plt.subplots()
viz.visualize_image_attr(attributions[0].detach().cpu().numpy(), np.array(img.squeeze().permute(1,2,0).cpu()), method="heat_map", show_colorbar=True, plt_fig_axis=(fig, ax))
plt.show()
