# Original Method
### Compute informativeness for all parameters by all available data

In [ ]:
import torch
import torch.nn.functional as F
import numpy as np

def gradient_importance(model, x_data, y_data):
    loss_function = F.cross_entropy
    inputs = torch.tensor(x_data, dtype=torch.float32, requires_grad=True)
    labels = torch.tensor(y_data, dtype=torch.long)
    predictions = model(inputs)
    loss = loss_function(predictions, labels)
    loss.backward()
    gradients = inputs.grad
    
    feature_importance_train = torch.mean(torch.abs(gradients), dim=0)
    normalized_importance = feature_importance_train.numpy() / np.sum(feature_importance_train.numpy())
    
    print(normalized_importance)

# Integrated Gradients Importance
### Compute informativeness for parameters of the specific input data

In [ ]:
import torch

def integrated_gradients_v2(pretrained_model, input_tensor, baseline=None, steps=50):
    """Computes integrated gradients for a model with vector output.

    Args:
        pretrained_model: A PyTorch model with a single output tensor.
        input_tensor: A tensor representing the input to the model.
        baseline: A tensor representing the baseline input (optional).
        steps: The number of steps to use for integration.

    Returns:
        A tensor of integrated gradients, with the same shape as the input.
    """

    if baseline is None:
        baseline = torch.zeros_like(input_tensor)

    def compute_gradients(alpha_):
        interpolated_input = input_tensor * alpha_
        interpolated_input += baseline + (1 - alpha_)
        interpolated_input.requires_grad_()
        output = pretrained_model(interpolated_input)
        gradients_ = torch.autograd.grad(output, interpolated_input)[0]
        return gradients_

    int_grads = torch.zeros_like(input_tensor)
    for inx in range(steps):
        alpha = (inx + 1) / (steps + 1)
        grads = compute_gradients(alpha)
        int_grads += grads

    int_grads /= steps
    int_grads = np.abs(int_grads) / np.sum(np.abs(int_grads))
    return int_grads