In [None]:
import sys
sys.path.append('../')
import torch
from back.core.model import load_checkpoint, load_model
from back.core.xai import XAIManager
from back.utils import load_image, denormalize, plot_predictions
import matplotlib.pyplot as plt

device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = load_checkpoint(load_model(), '../models/brain_tumor_model.pth', device)
xai_manager = XAIManager(model, device)
classes = ['no_tumor', 'glioma', 'meningioma', 'pituitary']

img_tensor = torch.randn(1, 3, 224, 224).to(device)

with torch.no_grad():
    output = model(img_tensor)
    pred_class = output.argmax(dim=1).item()

print(f"Predicted class: {classes[pred_class]}")

original_img = denormalize(img_tensor)
plot_predictions(original_img, pred_class, classes)

attributions = xai_manager.grad_cam(img_tensor, pred_class)
xai_manager.visualize_explanations(original_img, attributions, 'gradcam')
