In [14]:
import torch
import numpy as np
import matplotlib
import matplotlib.pyplot as plt
import sys
import sklearn.model_selection
import torch.nn

from laplace import Laplace
import matplotlib.colors as colors
import seaborn as sns
#import geomai.utils.geometry as geometry
#from torch import nn
#from manifold import cross_entropy_manifold, linearized_cross_entropy_manifold
from torch.distributions import MultivariateNormal
from tqdm import tqdm
import sklearn.datasets
#from datautils import make_pinwheel_data
#from utils.metrics import accuracy, nll, brier, calibration
from sklearn.metrics import brier_score_loss
import argparse
from torchmetrics.functional.classification import calibration_error
from functorch import grad, jvp, make_functional, vjp, make_functional_with_buffers, hessian, jacfwd, jacrev, vmap
from functorch_utils import get_params_structure, stack_gradient, custum_hvp, stack_gradient2
import os

palette = sns.color_palette("colorblind")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

#batch_data = args.batch_data

# run with several seeds
#seed = args.seed
seed = 0
torch.manual_seed(seed)

import torch_kfac


In [26]:
def custum_hvp(f, primals, tangents, strict=False):
    return jvp(grad(f), primals, tangents, strict=strict)

In [27]:
# Define the quadratic function
def quadratic_function(params):
    A = torch.tensor([[2.0, 0.0], [0.0, 3.0]])  # Diagonal matrix
    return 0.5 * params @ A @ params

# Inputs (wrapped as tuples)
primals = (torch.tensor([1.0, 2.0], requires_grad=True),)  # x
tangents = (torch.tensor([1.0, 1.0]),)  # v

# Test the HVP
output, hvp = custum_hvp(quadratic_function, primals, tangents)
print("Function output:", output)  # Gradient of f
print("Hessian-vector product:", hvp)  # A @ v


Function output: tensor([2., 6.], grad_fn=<AliasBackward0>)
Hessian-vector product: tensor([2., 3.], grad_fn=<AddBackward0>)


In [28]:
def get_activations(self, x):
    activations = []
    for layer in self.model.children():  # Assuming model is a sequential container
        x = layer(x)
        activations.append(x)
    return activations


In [30]:
def compute_fisher_information(self, gradients, data, num_samples=100):
    """
    Compute Fisher Information matrix by approximating the expectation of the gradient outer product.

    Args:
        data: Input data used for computing gradients and activations.
        num_samples: The number of samples over which to average the Fisher Information.
    
    Returns:
        fisher_information: The Fisher Information matrix as a tensor.
    """
    fisher_information = []
              
        # Get activations for each layer
        activations = self.get_activations(x)

        # For each layer, compute the outer product of gradients and activations
        for layer_idx, (activation, grad) in enumerate(zip(activations, gradients)):
            fisher_layer = torch.outer(grad.flatten(), activation.flatten())
            fisher_information.append(fisher_layer)

    # Average the Fisher Information matrix over all samples
    fisher_information = sum(fisher_information) / num_samples
    return fisher_information


In [31]:
def custum_hvp(f, primals, tangents, strict=False):
    """
    Compute the Hessian-Vector Product (HVP) using the Fisher Information as an approximation of the Hessian.
    
    Args:
        f: The loss function.
        primals: The model parameters (primals) for the forward pass.
        tangents: The direction vector (tangents) for the HVP.
        strict: If True, strict checking will be done (optional).
    
    Returns:
        hessian_vector_product: The Hessian-Vector Product for the loss function.
    """
    # Compute the gradients of the loss (single computation)
    grad_f = grad(f)(primals)

    # Compute the Fisher Information Matrix (using precomputed gradients)
    fisher_information = compute_fisher_information(self, grad_f, data)  # Now use grad_f

    # Compute the HVP using the Fisher Information approximation
    hvp = 0.0
    for layer_grad, layer_fisher in zip(grad_f, fisher_information):
        # Approximate the Hessian-Vector Product
        hvp += (layer_fisher @ tangents[layer_grad])
    
    return hvp
