### Attempt to compute H*d (Block diagonal Hessian H, direction d)  

 First define a simple model (that is supported by Backpack):



In [15]:
import torch
from torch import rand
import torch.nn as nn
import torch.optim as optim
from backpack import backpack, extend
from backpack.extensions import (
    GGNMP,
    HMP,
    KFAC,
    KFLR,
    KFRA,
    PCHMP,
    BatchDiagGGNExact,
    BatchDiagGGNMC,
    BatchDiagHessian,
    BatchGrad,
    BatchL2Grad,
    DiagGGNExact,
    DiagGGNMC,
    DiagHessian,
    SqrtGGNExact,
    SqrtGGNMC,
    SumGradSquared,
    Variance,
)
from backpack.utils.examples import load_one_batch_mnist


# Define a simple neural network with nn.ReLU instead of torch.relu
# Define a simple neural network using Sequential
model = nn.Sequential(
    nn.Linear(10, 5),
    nn.ReLU(),  # Use nn.ReLU module directly in Sequential
    nn.Linear(5, 1)
)

# The following is only to create a somewhat learnable datasetw
# Set dimensions
num_samples = 1000
input_dim = 10
output_dim = 1

# Create input data X with random values
X = torch.randn(num_samples, input_dim)

# Define a random weight matrix and bias for generating learnable y
true_weights = torch.randn(input_dim, output_dim)
true_bias = torch.randn(output_dim)

# Generate target y as a linear combination of X plus some noise
y = X @ true_weights + true_bias + 0.1 * torch.randn(num_samples, output_dim)

# Check dimensions
print(X.shape)  # Should be [1000, 10]
print(y.shape)  # Should be [1000, 1]


optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# IMPORTANT: extend youre model with backpack, otherwise the parameters will not be extended
# Extend the model and criterion with BackPACK
model = extend(model, use_converter=True) # Extend the loss function
criterion = extend(criterion)

torch.Size([1000, 10])
torch.Size([1000, 1])


Next we want to do a training with only one minibatch and one iteration for test purposes:

In [21]:
# Set the parameters for testing
batch_size = 32
n_epochs = 10000  # Only one epoch for testing
epoch_count = 0 # for print statements
# List to store all a_star values
a_star_values = []

# Select one minibatch from the dataset
X_batch = X[:batch_size]
y_batch = y[:batch_size]

# Training loop with only one minibatch
for epoch in range(n_epochs):
    
    print_bool = True if epoch_count % 1000 == 0 else False
    # Forward pass
    y_pred = model(X_batch)
    loss = criterion(y_pred, y_batch)

    with backpack(
        HMP(),      # possible problem with comma, check if run fails
    ):
        loss.backward()

    V = 1
    if print_bool:
        print("Epoch Number:", epoch_count + 1)
        for name, param in model.named_parameters():
            # Access parameter name
            print("Parameter Name:", name)
            print("Parameter Dimension:", param.shape)
            vec = rand(V, *param.shape)
            print("*param.shape             ", *param.shape)
            print("vec.shape:               ", vec.shape)
            print(".hmp(vec).shape:         ", param.hmp(vec).shape)

    g_list = []  # List to store flattened gradients
    d_list = []  # List to store flattened directions d_t
    hessian_d_products = []  # To store d^T H d for each parameter
    d_g_products = []  # To store -d^T g for each parameter

    # Loop over parameters to compute and apply hmp in chunks
    for param in model.parameters():
        if param.grad is not None:
            # Flatten and store the gradient for this parameter
            g_list.append(param.grad.view(-1))

            # Initialize optimizer state if it doesn't exist
        if param not in optimizer.state:
            optimizer.state[param]['exp_avg'] = torch.zeros_like(param.data)
            optimizer.state[param]['exp_avg_sq'] = torch.zeros_like(param.data)
            optimizer.state[param]['step'] = torch.tensor(0)  # Initialize step as tensor

        # Ensure 'step' is a tensor
        if isinstance(optimizer.state[param]['step'], int):
            optimizer.state[param]['step'] = torch.tensor(optimizer.state[param]['step'])

        # Get Adam's internal state (first and second moments)
        m_t = optimizer.state[param]['exp_avg']        # First moment (moving average of gradients)
        v_t = optimizer.state[param]['exp_avg_sq']     # Second moment (moving average of squared gradients)

        # Bias correction for moments
        beta1, beta2 = optimizer.defaults['betas']
        optimizer.state[param]['step'] += 1  # Increment step as tensor
        t = optimizer.state[param]['step'].item()  # Convert step tensor to int

        m_t_hat = m_t / (1 - beta1**t)  # Bias-corrected first moment
        v_t_hat = v_t / (1 - beta2**t)  # Bias-corrected second moment

        # Compute Adam's update direction d_t
        d_t = -m_t_hat / (torch.sqrt(v_t_hat) + optimizer.defaults['eps'])

        # Flatten the direction d_t and append it to the list
        d_list.append(d_t.view(-1))

        # Compute -d^T g for this parameter
        d_flat = d_t.view(-1)  # Flatten d
        g_flat = param.grad.view(-1)  # Flatten the gradient g
        d_g_product = -torch.dot(d_flat, g_flat)  # Compute -d^T g
        d_g_products.append(d_g_product.item())  # Store the result as a scalar
        
        if print_bool:
            print(f"-d^T g for {param.shape}: {d_g_product}")

        # Reshape d_t to match the parameter's shape
        vec = d_t.view(*param.shape)

        # Perform the param.hmp(d) operation to compute H * d
        hmp_output = param.hmp(vec)

        # Compute d^T * H * d for this parameter
        d_hmp_d = torch.dot(vec.view(-1), hmp_output.view(-1))  # Equivalent to d^T H d
        hessian_d_products.append(d_hmp_d.item())  # Store the result as a scalar

        if print_bool:
            print(f"param.hmp(d).shape: {hmp_output.shape}")
            print(f"d^T H d for {param.shape}: {d_hmp_d}")

    # Now hessian_d_products contains d^T H d for each parameter.
    total_d_H_d = sum(hessian_d_products)
    if print_bool:
        print(f"Total d^T H d across all parameters: {total_d_H_d}")

    # Now d_g_products contains -d^T g for each parameter.
    total_d_g = sum(d_g_products)
    if print_bool:
        print(f"Total -d^T g across all parameters: {total_d_g}")

    # avoid division by zero
    epsilon = 1e-8  # Small value to prevent division by zero
    a_star = total_d_g / (total_d_H_d + epsilon)
    # Store the a_star value
    a_star_values.append(a_star)
    
    if print_bool:
        print("a star: ", a_star)
    

    # Optimizer step (updates the parameters)
    optimizer.step()
    
    epoch_count = epoch_count + 1

print("Training complete!")

# After the loop finishes
a_star_tensor = torch.tensor(a_star_values)  # Convert list to a PyTorch tensor

# Calculate statistics
max_a_star = torch.max(a_star_tensor)
min_a_star = torch.min(a_star_tensor)
std_a_star = torch.std(a_star_tensor)

# Print the results
print(f"Maximum a_star value: {max_a_star.item()}")
print(f"Minimum a_star value: {min_a_star.item()}")
print(f"Standard deviation of a_star values: {std_a_star.item()}")


Epoch Number: 1
Parameter Name: 0.weight
Parameter Dimension: torch.Size([5, 10])
*param.shape              5 10
vec.shape:                torch.Size([1, 5, 10])
.hmp(vec).shape:          torch.Size([1, 5, 10])
Parameter Name: 0.bias
Parameter Dimension: torch.Size([5])
*param.shape              5
vec.shape:                torch.Size([1, 5])
.hmp(vec).shape:          torch.Size([1, 5])
Parameter Name: 2.weight
Parameter Dimension: torch.Size([1, 5])
*param.shape              1 5
vec.shape:                torch.Size([1, 1, 5])
.hmp(vec).shape:          torch.Size([1, 1, 5])
Parameter Name: 2.bias
Parameter Dimension: torch.Size([1])
*param.shape              1
vec.shape:                torch.Size([1, 1])
.hmp(vec).shape:          torch.Size([1, 1])
-d^T g for torch.Size([5, 10]): 2186.820556640625
param.hmp(d).shape: torch.Size([5, 10])
d^T H d for torch.Size([5, 10]): 4.350311279296875
-d^T g for torch.Size([5]): 2216.300048828125
param.hmp(d).shape: torch.Size([5])
d^T H d for torch.S