In this notebook I will attempt to observe the values for the optimal alpha given the quadratic approximation of the loss for a mini batch 
$L(\theta) \approx q(\theta):=\frac{1}{2}\left(\theta-\theta_0\right)^{\top} H\left(\theta-\theta_0\right)+\left(\theta-\theta_0\right)^{\top} \cdot g+c$  
 
 
along a direction d  

$h(\alpha)=q\left(\theta_0+\alpha \cdot d\right)$  
 
 
which is
$\alpha^*=\frac{-d^{\top} g}{d^{\top} H d}$


First we define a simple model:

In [None]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create a simple dataset
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)

# Initialize the model and optimizer
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Training loop
n_epochs = 100
batch_size = 32

Then we train using Adam and try to observe d from adam:

In [11]:

for epoch in range(n_epochs):
    for i in range(0, len(X), batch_size):
        # Get mini-batch
        X_batch = X[i:i+batch_size]
        y_batch = y[i:i+batch_size]

        # Forward pass
        y_pred = model(X_batch)
        loss = criterion(y_pred, y_batch)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Compute the direction d before the optimizer step
        d = {}
        for name, param in model.named_parameters():
            if param.grad is not None:
                if param not in optimizer.state:
                    # Initialize state if it doesn't exist
                    optimizer.state[param]['exp_avg'] = torch.zeros_like(param.data)
                    optimizer.state[param]['exp_avg_sq'] = torch.zeros_like(param.data)
                    optimizer.state[param]['step'] = torch.tensor(0)  # Initialize as tensor

                # Ensure 'step' is a tensor
                if isinstance(optimizer.state[param]['step'], int):
                    optimizer.state[param]['step'] = torch.tensor(optimizer.state[param]['step'])

                m_t = optimizer.state[param]['exp_avg']  # exponential moving avg of the gradient
                v_t = optimizer.state[param]['exp_avg_sq'] # exp moving avg of the squared gradient
        
                # Bias correction
                beta1, beta2 = optimizer.defaults['betas']
                optimizer.state[param]['step'] += 1  # Increment step as tensor
                t = optimizer.state[param]['step'].item()  # Convert step tensor to int
                # m_t und v_t sind richtung 0 biased bei ersten Iterationen, deshalb teilen wir durch 1-beta^t
                m_t_hat = m_t / (1 - beta1**t) 
                v_t_hat = v_t / (1 - beta2**t)
        
                # Compute direction
                # epsilon verhindert division by zero
                # Adam Step: Theta_t = Theta_t-1 - (alpha_t * m_hat_t / sqrt(v_hat_t) + epsilon) 
                # Dabei wird alpha auch berechnet, lassen wir hier aber aus
                d[name] = -m_t_hat / (torch.sqrt(v_t_hat) + optimizer.defaults['eps']) 

        # Optimizer step
        optimizer.step()


        # Print the direction for the first parameter (for demonstration)
        if i == 0:
            print(f"Epoch {epoch}, Batch 0:")
            for name, direction in d.items():
                print(f"Direction for {name}: {direction.norm()}")

    # Print epoch loss
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

print("Training complete!")

Epoch 0, Batch 0:
Direction for fc1.weight: 0.0
Direction for fc1.bias: 0.0
Direction for fc2.weight: 0.0
Direction for fc2.bias: 0.0
Epoch 0, Loss: 2.3547916412353516
Epoch 1, Batch 0:
Direction for fc1.weight: 4.373355865478516
Direction for fc1.bias: 2.5503242015838623
Direction for fc2.weight: 2.0329806804656982
Direction for fc2.bias: 1.2113134860992432
Epoch 2, Batch 0:
Direction for fc1.weight: 4.390522480010986
Direction for fc1.bias: 2.5649569034576416
Direction for fc2.weight: 1.7541961669921875
Direction for fc2.bias: 1.0527496337890625
Epoch 3, Batch 0:
Direction for fc1.weight: 4.001357555389404
Direction for fc1.bias: 2.2361676692962646
Direction for fc2.weight: 1.4126876592636108
Direction for fc2.bias: 0.8529586791992188
Epoch 4, Batch 0:
Direction for fc1.weight: 3.7569453716278076
Direction for fc1.bias: 1.7752281427383423
Direction for fc2.weight: 1.1067380905151367
Direction for fc2.bias: 0.6769903302192688
Epoch 5, Batch 0:
Direction for fc1.weight: 3.5326473712921

Epoch 49, Batch 0:
Direction for fc1.weight: 2.809729814529419
Direction for fc1.bias: 0.552729606628418
Direction for fc2.weight: 0.7040776014328003
Direction for fc2.bias: 0.052162088453769684
Epoch 50, Batch 0:
Direction for fc1.weight: 2.344369649887085
Direction for fc1.bias: 0.3589980900287628
Direction for fc2.weight: 0.7032755017280579
Direction for fc2.bias: 0.045316074043512344
Epoch 50, Loss: 1.8152008056640625
Epoch 51, Batch 0:
Direction for fc1.weight: 2.4689550399780273
Direction for fc1.bias: 0.44295021891593933
Direction for fc2.weight: 0.6964861154556274
Direction for fc2.bias: 0.04476465284824371
Epoch 52, Batch 0:
Direction for fc1.weight: 2.2377893924713135
Direction for fc1.bias: 0.3986944854259491
Direction for fc2.weight: 0.6900762319564819
Direction for fc2.bias: 0.044461071491241455
Epoch 53, Batch 0:
Direction for fc1.weight: 2.454497814178467
Direction for fc1.bias: 0.43716803193092346
Direction for fc2.weight: 0.6833941340446472
Direction for fc2.bias: 0.04

Epoch 95, Batch 0:
Direction for fc1.weight: 2.238910436630249
Direction for fc1.bias: 0.4503834545612335
Direction for fc2.weight: 0.5446817278862
Direction for fc2.bias: 0.01826356165111065
Epoch 96, Batch 0:
Direction for fc1.weight: 2.1691036224365234
Direction for fc1.bias: 0.5043172836303711
Direction for fc2.weight: 0.5452942252159119
Direction for fc2.bias: 0.0165786761790514
Epoch 97, Batch 0:
Direction for fc1.weight: 2.1111972332000732
Direction for fc1.bias: 0.48156583309173584
Direction for fc2.weight: 0.5437005758285522
Direction for fc2.bias: 0.016561727970838547
Epoch 98, Batch 0:
Direction for fc1.weight: 2.225637912750244
Direction for fc1.bias: 0.4308435618877411
Direction for fc2.weight: 0.5433929562568665
Direction for fc2.bias: 0.019601255655288696
Epoch 99, Batch 0:
Direction for fc1.weight: 2.031578540802002
Direction for fc1.bias: 0.513228178024292
Direction for fc2.weight: 0.5423434376716614
Direction for fc2.bias: 0.018667055293917656
Training complete!
