In [None]:
#1. SNN-Grad + SmoothGrad

import torch


def simple_snn_grad(model, x, target_class=0):
    model.eval()
    x = x.clone().detach().requires_grad_(True)

    output = model(x)  # shape [1, 2]
    score = output[0, target_class]

    model.zero_grad()
    score.backward()

    grad = x.grad.detach().clone()
    return grad[0]  # shape [1024]

#n samples is the num of noisy copies used to create the avg gradients. To account for noisyness.
def snn_grad_smoothgrad(model, x, target_class=0, n_samples=50, noise_std=0.01):
    model.eval()
    x = x.clone().detach().requires_grad_(True)

    grads = []
    for _ in range(n_samples):
        noise = torch.randn_like(x) * noise_std
        x_noisy = x + noise
        x_noisy.requires_grad_(True)

        output = model(x_noisy)  # shape [1, 2]
        score = output[0, target_class]
        model.zero_grad()
        score.backward(retain_graph=True)
        grads.append(x_noisy.grad.detach().clone())

    avg_grad = torch.stack(grads).mean(dim=0)  # average gradients
    return avg_grad[0]  # shape [1024]

#2. Feature Ablation

def feature_ablation(model, x, target_class=0):
    model.eval()
    x = x.clone().detach()
    baseline_pred = model(x)[0, target_class].item()

    importances = []
    for i in range(x.shape[1]):
        x_ablated = x.clone()
        x_ablated[0, i] = 0.0  # zero out feature i
        with torch.no_grad():
            pred_ablated = model(x_ablated)[0, target_class].item()
        importance = baseline_pred - pred_ablated
        importances.append(importance)

    return torch.tensor(importances)  # shape [1024]

#3. Integrated Gradients (IG)

def integrated_gradients(model, x, baseline=None, target_class=0, steps=50):
    model.eval()
    if baseline is None:
        baseline = torch.zeros_like(x)

    x = x.clone().detach()
    baseline = baseline.clone().detach()

    scaled_inputs = [baseline + (float(i) / steps) * (x - baseline) for i in range(steps + 1)]
    grads = []

    for scaled_x in scaled_inputs:
        scaled_x.requires_grad_(True)
        output = model(scaled_x)
        score = output[0, target_class]
        model.zero_grad()
        score.backward(retain_graph=True)
        grads.append(scaled_x.grad.detach().clone())

    avg_grads = torch.stack(grads).mean(dim=0)
    integrated_grads = (x - baseline) * avg_grads
    return integrated_grads[0]  # shape [1024]