In [1]:
import torch
import torch.nn as nn
from torch.autograd.functional import hessian
from torch.nn.utils.stateless import functional_call

def compute_hessian(model: nn.Module, 
                    x: torch.Tensor, 
                    y: torch.Tensor, 
                    loss_fn: callable) -> torch.Tensor:
    """
    Compute the Hessian of the loss w.r.t. all model parameters.

    Args:
        model (nn.Module): PyTorch model
        x (torch.Tensor): Input data
        y (torch.Tensor): Target data
        loss_fn (callable): Loss function that returns scalar (e.g., nn.MSELoss())

    Returns:
        torch.Tensor: Hessian matrix of shape (N, N), where N is total parameter count
    """
    # Extract and flatten parameters
    params = dict(model.named_parameters())
    param_shapes = [(name, p.shape) for name, p in params.items()]
    flat_params = torch.cat([p.detach().reshape(-1) for p in params.values()]).requires_grad_(True)

    # Helper to unflatten flat_params into named parameter dict
    def unflatten_params(flat_params):
        param_dict = {}
        idx = 0
        for name, shape in param_shapes:
            n = torch.tensor(shape).prod().item()
            param_dict[name] = flat_params[idx:idx+n].view(shape)
            idx += n
        return param_dict

    # Loss wrapper for hessian
    def wrapped_loss(flat_params):
        new_params = unflatten_params(flat_params)
        y_pred = functional_call(model, new_params, (x,))
        return loss_fn(y_pred, y)

    # Compute and return Hessian
    return hessian(wrapped_loss, flat_params)


In [4]:
import torch
import torch.nn as nn
import numpy as np

# Assume compute_hessian is already defined

# Define simple linear model: y = wx + b
class SimpleLinear(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(1, 1, bias=True)

    def forward(self, x):
        return self.linear(x)

def test_hessian_linear():
    # Input and target
    x_val = 3.0
    x = torch.tensor([[x_val]])
    y = torch.tensor([[1.0]])
    loss_fn = nn.MSELoss(reduction='mean')
    def loss_fn(y_pred, y):
        return 0.5 * ((y_pred - y) ** 2).mean()


    # Model
    model = SimpleLinear()
    with torch.no_grad():
        model.linear.weight.fill_(0.0)
        model.linear.bias.fill_(0.0)

    # Compute Hessian
    H = compute_hessian(model, x, y, loss_fn)

    # Expected Hessian:
    # d^2L/dw^2 = x^2
    # d^2L/dwdb = x
    # d^2L/db^2 = 1
    x_sq = x_val ** 2
    expected = torch.tensor([
        [x_sq, x_val],
        [x_val, 1.0]
    ])

    # Compare
    H2x2 = H[:2, :2]  # Should be exactly 2 parameters: w and b
    assert torch.allclose(H2x2, expected, atol=1e-6), f"Expected:\n{expected}\nGot:\n{H2x2}"
    print("✅ Hessian test passed.")

# Run the test
test_hessian_linear()


✅ Hessian test passed.


  y_pred = functional_call(model, new_params, (x,))
