### Optimizer: SGD (no momentum)

In this notebook I will attempt to observe the values for the optimal alpha given the quadratic approximation of the loss for a mini batch 
$L(\theta) \approx q(\theta):=\frac{1}{2}\left(\theta-\theta_0\right)^{\top} H\left(\theta-\theta_0\right)+\left(\theta-\theta_0\right)^{\top} \cdot g+c$  
 
 
along a direction d  

$h(\alpha)=q\left(\theta_0+\alpha \cdot d\right)$  
 
 
which is
$\alpha^*=\frac{-d^{\top} g}{d^{\top} H d}$

Define a simple Model:

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim

# Define a simple neural network
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.fc1 = nn.Linear(10, 5)
        self.fc2 = nn.Linear(5, 1)

    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Create a simple dataset
X = torch.randn(1000, 10)
y = torch.randn(1000, 1)

# Initialize the model and optimizer
model = SimpleNet()
optimizer = optim.Adam(model.parameters(), lr=0.001)
criterion = nn.MSELoss()

# Training loop
n_epochs = 100
batch_size = 32

Train the model using SGD with no momentum and observe d (in this case d is just the negative gradient!)

In [2]:
for epoch in range(n_epochs):
    for i in range(0, len(X), batch_size):
        # Get mini-batch
        X_batch = X[i:i+batch_size]
        y_batch = y[i:i+batch_size]

        # Forward pass
        y_pred = model(X_batch)
        loss = criterion(y_pred, y_batch)

        # Backward pass
        optimizer.zero_grad()
        loss.backward()

        # Compute the direction d for SGD without momentum
        d = {}
        for name, param in model.named_parameters():
            if param.grad is not None:
                # Direction for SGD is just the negative gradient
                d[name] = -param.grad
                

        # Optimizer step (SGD step)
        optimizer.step()

        # Print the direction for the first parameter (for demonstration)
        if i == 0:
            print(f"Epoch {epoch}, Batch 0:")
            for name, direction in d.items():
                print(f"Direction for {name}: {direction.norm()}")

    # Print epoch loss
    if epoch % 10 == 0:
        print(f"Epoch {epoch}, Loss: {loss.item()}")

print("Training complete!")


Epoch 0, Batch 0:
Direction for fc1.weight: 0.45693156123161316
Direction for fc1.bias: 0.15073071420192719
Direction for fc2.weight: 0.28940239548683167
Direction for fc2.bias: 0.2857346832752228
Epoch 0, Loss: 1.8471349477767944
Epoch 1, Batch 0:
Direction for fc1.weight: 0.4315951466560364
Direction for fc1.bias: 0.13745233416557312
Direction for fc2.weight: 0.28389236330986023
Direction for fc2.bias: 0.2965902090072632
Epoch 2, Batch 0:
Direction for fc1.weight: 0.4073067605495453
Direction for fc1.bias: 0.13742394745349884
Direction for fc2.weight: 0.278739333152771
Direction for fc2.bias: 0.29924631118774414
Epoch 3, Batch 0:
Direction for fc1.weight: 0.4134412705898285
Direction for fc1.bias: 0.13696777820587158
Direction for fc2.weight: 0.27363935112953186
Direction for fc2.bias: 0.2999235987663269
Epoch 4, Batch 0:
Direction for fc1.weight: 0.4085977077484131
Direction for fc1.bias: 0.12323563545942307
Direction for fc2.weight: 0.2690861225128174
Direction for fc2.bias: 0.2996

Epoch 92, Batch 0:
Direction for fc1.weight: 0.5784114599227905
Direction for fc1.bias: 0.17094971239566803
Direction for fc2.weight: 0.41419145464897156
Direction for fc2.bias: 0.5169454216957092
Epoch 93, Batch 0:
Direction for fc1.weight: 0.580637514591217
Direction for fc1.bias: 0.17888881266117096
Direction for fc2.weight: 0.4149934947490692
Direction for fc2.bias: 0.516630232334137
Epoch 94, Batch 0:
Direction for fc1.weight: 0.5835019946098328
Direction for fc1.bias: 0.1796799600124359
Direction for fc2.weight: 0.4162227213382721
Direction for fc2.bias: 0.5159736275672913
Epoch 95, Batch 0:
Direction for fc1.weight: 0.5974218845367432
Direction for fc1.bias: 0.18651297688484192
Direction for fc2.weight: 0.4173998534679413
Direction for fc2.bias: 0.5158080458641052
Epoch 96, Batch 0:
Direction for fc1.weight: 0.6001431345939636
Direction for fc1.bias: 0.18716883659362793
Direction for fc2.weight: 0.4179894030094147
Direction for fc2.bias: 0.5148852467536926
Epoch 97, Batch 0:
Dir