### Optimizer: Adam

In this notebook I will attempt to observe the values for the optimal alpha given the quadratic approximation of the loss for a mini batch 
$L(\theta) \approx q(\theta):=\frac{1}{2}\left(\theta-\theta_0\right)^{\top} H\left(\theta-\theta_0\right)+\left(\theta-\theta_0\right)^{\top} \cdot g+c$  
 
 
along a direction d  

$h(\alpha)=q\left(\theta_0+\alpha \cdot d\right)$  
 
 
which is
$\alpha^*=\frac{-d^{\top} g}{d^{\top} H d}$


First we define a simple model:

In [16]:
import torch
from torch import rand
import torch.nn as nn
import torch.optim as optim
from backpack import backpack, extend
from backpack.extensions import (
    GGNMP,
    HMP,
    KFAC,
    KFLR,
    KFRA,
    PCHMP,
    BatchDiagGGNExact,
    BatchDiagGGNMC,
    BatchDiagHessian,
    BatchGrad,
    BatchL2Grad,
    DiagGGNExact,
    DiagGGNMC,
    DiagHessian,
    SqrtGGNExact,
    SqrtGGNMC,
    SumGradSquared,
    Variance,
)
from backpack.utils.examples import load_one_batch_mnist


# Define a simple neural network with nn.ReLU instead of torch.relu
# Define a simple neural network using Sequential
model = nn.Sequential(
    nn.Linear(10, 5),
    nn.ReLU(),  # Use nn.ReLU module directly in Sequential
    nn.Linear(5, 1)
)

# Create a simple dataset
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)


optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# IMPORTANT: extend youre model with backpack, otherwise the parameters will not be extended
# Extend the model and criterion with BackPACK
model = extend(model, use_converter=True) # Extend the loss function
criterion = extend(criterion)

# Training loop
n_epochs = 1
batch_size = 32

### Explanation for the Matrix-free multiplication

(1) Anzahl der Parameter: p
Wir nehmen an, dass das Modell p Paramter hat. Dann hat die Hessematrix die Dimension p x p.  

 
(2) Block diagonal Hessian: hmp
Backpack liefert Matrix freie Multiplikation mit der Block-diagonal Hessian. Das ist eine Version der Hessematrix in der auf 
der Diagonalen Blöcke liegen die berechnet wurden. Die Restlichen Werte sind eigentlich nullen. 
In Backpack sind die restlichen freien Felder nicht existent: Es wird zu jeder Gruppe von Parametern eine Funktion geschrieben, welche 
die Multiplikation mit diesem Block ermöglich.  

Diese Multiplikation ist abrufbar über param.hmp(Vektor). 
Dabei kann (muss??) der Vektor auch eine komische Dimension haben: 
Angenommen der aktuelle Block der hmp-matrix hat die dimension 50 x 50. Dann kann der "Vektor" die Dimension 5 x 10 haben (bzw 1 x 5 x 10) 
und dieser wird flattenden (50 x 1) mit dem 50 x 50 Block multipliziert. 

 Beispiel: 
  
  Netzwerk: Linear(10, 5), 
            ReLU(),
            Linear(5, 1)
) 

 Das erste Layer hat *param.shape:              5 x 10 (also 50 params, eigentlich hat der Block der Hessian hier 50 x 50) 
 Wir können nun einen Vektor des shapes 5 x 10 (bzw 1 x 5 x 10, unterschied unklar) damit multiplizieren über 
 param.hmp(v)


Then we train using Adam and try to observe d from adam:

In [22]:
# Initialize an empty list to store flattened directions and gradients
d_list = []
g_list = []

for epoch in range(n_epochs):
    for i in range(0, len(X), batch_size):
        # Get mini-batch
        X_batch = X[i:i+batch_size]
        y_batch = y[i:i+batch_size]

        # Forward pass
        y_pred = model(X_batch)
        loss = criterion(y_pred, y_batch)

        # Backward pass
        optimizer.zero_grad()  # Reset gradients
        
        # do backward with backpack and get: Hessian (HMP), generalized Gauss-Newton (GGNMP), positive-curvature Hessian (PCHMP)
        # where clip = negative values are set to 0 and abs = set to pos value (I believe)
        with backpack(
            HMP(),
            GGNMP(),
            PCHMP(savefield="pchmp_clip", modify="clip"),
            PCHMP(savefield="pchmp_abs", modify="abs"),
        ):
            loss.backward()
        

        
        V = 1

        loop_count = 0

        
 
        for name, param in model.named_parameters():
            loop_count = loop_count + 1
            if loop_count < 1000:
                # Access parameter name
                print("Parameter Name:", name)

                print("Parameter Dimension:", param.shape)
    
                # Access parameter values
                print("Parameter Values (Tensor):", param)
                vec = rand(V, *param.shape)
                print(name)
                print("*param.shape             ", *param.shape)
                print(".grad.shape:             ", param.grad.shape)
                print("vec.shape:               ", vec.shape)
                print(".hmp(vec).shape:         ", param.hmp(vec).shape)
                print(".ggnmp(vec).shape:       ", param.ggnmp(vec).shape)
                print(".pchmp_clip(vec).shape:  ", param.pchmp_clip(vec).shape)
                print(".pchmp_abs(vec).shape:   ", param.pchmp_abs(vec).shape)

            
            
        # Clear lists for directions and gradients
        d_list.clear()
        g_list.clear()

        for param in model.parameters():
            if param.grad is not None:
                # Flatten and store the gradient for this parameter
                g_list.append(param.grad.view(-1))

                if param not in optimizer.state:
                    # Initialize state if it doesn't exist
                    optimizer.state[param]['exp_avg'] = torch.zeros_like(param.data)
                    optimizer.state[param]['exp_avg_sq'] = torch.zeros_like(param.data)
                    optimizer.state[param]['step'] = torch.tensor(0)  # Initialize as tensor

                # Ensure 'step' is a tensor
                if isinstance(optimizer.state[param]['step'], int):
                    optimizer.state[param]['step'] = torch.tensor(optimizer.state[param]['step'])

                # Get Adam's internal state (first and second moments)
                m_t = optimizer.state[param]['exp_avg']        # First moment (moving average of gradients)
                v_t = optimizer.state[param]['exp_avg_sq']     # Second moment (moving average of squared gradients)

                # Bias correction for moments
                beta1, beta2 = optimizer.defaults['betas']
                optimizer.state[param]['step'] += 1  # Increment step as tensor
                t = optimizer.state[param]['step'].item()  # Convert step tensor to int

                m_t_hat = m_t / (1 - beta1**t)  # Bias-corrected first moment
                v_t_hat = v_t / (1 - beta2**t)  # Bias-corrected second moment

                # Compute the direction
                d_t = -m_t_hat / (torch.sqrt(v_t_hat) + optimizer.defaults['eps'])  # Adam's update direction

                # Flatten the direction d_t and append it to the list
                d_list.append(d_t.view(-1))

        # Concatenate all flattened directions into a single vector of size (1, p)
        d_vector = torch.cat(d_list).view(1, -1)  # Shape (1, p)
        print("d_vector shape:               ", d_vector.shape)

        # Concatenate all flattened gradients into a single vector of size (p, 1)
        g_vector = torch.cat(g_list).view(-1, 1)  # Shape (p, 1)

        # Optimizer step (updates the parameters)
        optimizer.step()

        # Print the direction vector and gradient for the first mini-batch (for demonstration)
        if i == 0:
            print(f"Epoch {epoch}")
            print(f"Direction vector d (norm): {d_vector.norm()}")  # Print norm of the vector
            print(f"First 10 values of d: {d_vector[0, :10]}")  # First 10 values of d
            print(f"Gradient vector g (norm): {g_vector.norm()}")  # Print norm of the gradient vector
            print(f"First 10 values of g: {g_vector[:10]}")  # First 10 values of the gradient

    # Print epoch loss every 10 epochs
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

print("Training complete!")


Parameter Name: 0.weight
Parameter Dimension: torch.Size([5, 10])
Parameter Values (Tensor): Parameter containing:
tensor([[-0.1441, -0.2765,  0.0555, -0.1091,  0.2671,  0.1135, -0.0795, -0.1714,
         -0.1754,  0.0497],
        [ 0.0260,  0.1523,  0.0323, -0.0600,  0.0187,  0.1121, -0.2201,  0.1883,
          0.0256, -0.2442],
        [-0.0895, -0.3066,  0.2458,  0.1361, -0.0063, -0.2038, -0.1190,  0.0204,
         -0.0769, -0.2011],
        [ 0.1055, -0.1719,  0.0927,  0.0599,  0.0679,  0.1853,  0.0574, -0.2768,
         -0.2832, -0.0211],
        [-0.0834,  0.1070,  0.0634, -0.0230,  0.2120, -0.2448, -0.1284,  0.2943,
          0.0832,  0.1892]], requires_grad=True)
0.weight
*param.shape              5 10
.grad.shape:              torch.Size([5, 10])
vec.shape:                torch.Size([1, 5, 10])
.hmp(vec).shape:          torch.Size([1, 5, 10])
.ggnmp(vec).shape:        torch.Size([1, 5, 10])
.pchmp_clip(vec).shape:   torch.Size([1, 5, 10])
.pchmp_abs(vec).shape:    torch.Size(