In [None]:
import torch
import torch.nn as nn
import numpy as np
import cv2
import pickle
import matplotlib.pyplot as plt
from torchvision import models

# Load dataset and dataloader
with open("dataset.pkl", "rb") as f:
    dataset = pickle.load(f)

with open("dataloader.pkl", "rb") as f:
    dataloader = pickle.load(f)

# Load trained model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = models.resnet18(pretrained=False)
num_ftrs = model.fc.in_features
model.fc = nn.Linear(num_ftrs, len(set(dataset.labels)))
model.load_state_dict(torch.load("histology_model.pth"))
model.to(device)
model.eval()

# Evaluation function
def evaluate_model(model, dataloader):
    correct = 0
    total = 0

    with torch.no_grad():
        for images, labels in dataloader:
            images, labels = images.to(device), labels.to(device)
            outputs = model(images)
            _, predicted = torch.max(outputs, 1)
            correct += (predicted == labels).sum().item()
            total += labels.size(0)

    print(f"Accuracy: {100 * correct / total:.2f}%")

# Run evaluation
evaluate_model(model, dataloader)

# Grad-CAM for Explainability
def grad_cam(model, image_tensor, target_layer="layer4"):
    model.eval()
    gradients = []

    def save_gradient(grad):
        gradients.append(grad)

    for name, module in model.named_modules():
        if name == target_layer:
            module.register_backward_hook(lambda module, grad_input, grad_output: save_gradient(grad_output[0]))

    output = model(image_tensor.unsqueeze(0))
    class_idx = torch.argmax(output)
    model.zero_grad()
    output[0, class_idx].backward()

    grad = gradients[0].cpu().detach().numpy()
    cam = np.mean(grad, axis=1).squeeze()
    cam = np.maximum(cam, 0)
    cam = cv2.resize(cam, (224, 224))
    cam = (cam - np.min(cam)) / (np.max(cam) - np.min(cam))

    return cam

# Example usage:
image_tensor, _ = dataset[0]  # Get an image
heatmap = grad_cam(model, image_tensor)

plt.imshow(heatmap, cmap="jet")
plt.show()
