In [None]:
import sys
sys.path.append('../')
import torch
from src.model import load_checkpoint
from src.xai import XAIManager
from src.dataset import BrainTumorDataset
from torchvision import transforms
import numpy as np
import matplotlib.pyplot as plt
from captum.attr import visualization as viz

# Загружаем модель и датасет
device = 'cuda' if torch.cuda.is_available() else 'cpu'
model = load_checkpoint(load_model(), '../models/brain_tumor_model.pth', device)
xai_manager = XAIManager(model, device)
dataset = BrainTumorDataset('../data/raw', transforms.ToTensor(), num_samples=10)
classes = ['no_tumor', 'glioma', 'meningioma', 'pituitary']

# Тестируем на одном примере
img, label = dataset[0]
img = img.unsqueeze(0).to(device)

# Предсказание
with torch.no_grad():
    output = model(img)
    pred = output.argmax(dim=1).item()

print(f"Истинный класс: {classes[label]}, Предсказанный: {classes[pred]}")

# Grad-CAM
attributions = xai_manager.grad_cam(img, pred)
fig, ax = plt.subplots()
viz.visualize_image_attr(attributions[0].cpu().numpy(), np.array(img.squeeze().permute(1,2,0)), method="heat_map", show_colorbar=True, plt_fig_axis=(fig, ax))
plt.show()
