<a href="https://colab.research.google.com/github/OneFineStarstuff/Cosmic-Brilliance/blob/main/explain_policy_py.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install captum

In [None]:
import torch
import captum.attr as ca

def explain_policy(agent, obs):
    """
    Compute feature attributions for the agent's chosen action
    using Integrated Gradients.

    Args:
        agent: object with a .policy(torch.Tensor) -> action_scores method
        obs:   numpy array or torch tensor of shape (features,) or (batch, features)

    Returns:
        numpy array of attributions with same shape as obs
    """
    device = next(agent.policy.parameters()).device

    # Ensure tensor, float32, and batch dimension
    obs_t = torch.as_tensor(obs, dtype=torch.float32, device=device)
    if obs_t.ndim == 1:
        obs_t = obs_t.unsqueeze(0)

    # Require gradients for IG
    obs_t.requires_grad_(True)

    # Forward pass to get action scores
    act_scores = agent.policy(obs_t)

    # Pick the action with highest score for the first sample
    target_action = int(torch.argmax(act_scores[0]))

    # Wrap forward function if agent.policy returns only scores
    ig = ca.IntegratedGradients(agent.policy)

    # Compute attributions
    attributions = ig.attribute(obs_t, target=target_action)

    # Detach, move to CPU, remove batch dim if needed
    attributions = attributions.detach().cpu()
    if attributions.shape[0] == 1:
        attributions = attributions.squeeze(0)

    return attributions.numpy()

# --------------------------------------------------------------------
# Example usage with a dummy policy network
if __name__ == "__main__":
    class DummyAgent:
        def __init__(self, in_dim=4, out_dim=2):
            self.policy = torch.nn.Sequential(
                torch.nn.Linear(in_dim, 16),
                torch.nn.ReLU(),
                torch.nn.Linear(16, out_dim)
            )
        def to(self, device):
            self.policy.to(device)
            return self

    # Create agent and move to GPU if available
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    agent = DummyAgent().to(device)

    # Fake observation
    obs = [0.5, -1.2, 0.3, 2.0]

    # Explain
    attr = explain_policy(agent, obs)
    print("Attributions:", attr)