### Observing alpha 
 
 In this notebook I will attempt to efficiently calculate and observe the optimal step size  during a training loop derived as: 
 
$\alpha^*=\frac{-d^{\top} g}{d^{\top} H d}$
 

In [15]:
import torch
import torch.nn as nn
from torch.nn.utils import parameters_to_vector
from torch.nn.utils import vector_to_parameters
import torch.optim as optim
from torchvision import datasets, transforms
from torch.utils.data import DataLoader
from backpack import backpack, extend
from backpack.extensions import (
    GGNMP,
    HMP,
    KFAC,
    KFLR,
    KFRA,
    PCHMP,
    BatchDiagGGNExact,
    BatchDiagGGNMC,
    BatchDiagHessian,
    BatchGrad,
    BatchL2Grad,
    DiagGGNExact,
    DiagGGNMC,
    DiagHessian,
    SqrtGGNExact,
    SqrtGGNMC,
    SumGradSquared,
    Variance,
)

# Load the MNIST dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

train_dataset = datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = datasets.MNIST(root='./data', train=False, transform=transform)

train_loader = DataLoader(dataset=train_dataset, batch_size=64, shuffle=True)
test_loader = DataLoader(dataset=test_dataset, batch_size=1000, shuffle=False)


In [16]:
class SimpleNet(nn.Module):
    def __init__(self):
        super(SimpleNet, self).__init__()
        self.network = nn.Sequential(
            nn.Flatten(),
            nn.Linear(784, 128),  # Input size is 784 (28x28), hidden layer size is 128
            nn.ReLU(),
            nn.Linear(128, 64),   # Second hidden layer size is 64
            nn.ReLU(),
            nn.Linear(64, 10)     # Output layer for 10 classes
        )

    def forward(self, x):
        return self.network(x)

# Instantiate the model
model = SimpleNet()
model = extend(model, use_converter=True) # Extend the loss function

In [58]:
def train(model, device, train_loader, optimizer, criterion, epoch):
    model.train()
    for batch_idx, (data, target) in enumerate(train_loader):
        data, target = data.to(device), target.to(device)
        
        # save current parameters
        theta_0 = parameters_to_vector(model.parameters())
        
        # Zero the gradients
        optimizer.zero_grad()
        
        # Forward pass
        output = model(data)
        loss = criterion(output, target)
        
        # Backward pass and optimization
        with backpack(
        GGNMP(),
            ):
            loss.backward()
            
        # Extract gradients and convert to a vector
        gradients = parameters_to_vector(param.grad for param in model.parameters() if param.grad is not None)
        
        optimizer.step()
        
        
        d_unnormalized = parameters_to_vector(model.parameters()) - theta_0
        
        d_normalized = d_unnormalized / torch.norm(d_unnormalized)
        
        position = 0 # only for printing
        if batch_idx % 100 == 0:
            print(f'Train Epoch: {epoch} [{batch_idx * len(data)}/{len(train_loader.dataset)} '
                  f'({100. * batch_idx / len(train_loader):.0f}%)]\tLoss: {loss.item():.6f}')
            
            print("Current parameter vector length:", len(theta_0), ", values:", 
                  theta_0[position:position + 5])  # observe a different position each time
            print("d vector unnormalized length:", len(d_unnormalized), ", norm:", torch.norm(d_unnormalized),
                  ", values:", d_unnormalized[position:position + 5])
            print("d vector normalized length:", len(d_normalized), ", norm:", torch.norm(d_normalized),
                  ", values:", d_normalized[position:position + 5])
            print("gradients vector length:", len(gradients), ", values:", gradients[position: position + 5])
        
        total_params = sum(p.numel() for p in model.parameters()) # only for printing
        position = (position + 5) % total_params
def test(model, device, test_loader, criterion):
    model.eval()
    test_loss = 0
    correct = 0
    with torch.no_grad():
        for data, target in test_loader:
            data, target = data.to(device), target.to(device)
            output = model(data)
            test_loss += criterion(output, target).item()
            pred = output.argmax(dim=1, keepdim=True)
            correct += pred.eq(target.view_as(pred)).sum().item()

    test_loss /= len(test_loader.dataset)
    accuracy = 100. * correct / len(test_loader.dataset)
    
    print(f'\nTest set: Average loss: {test_loss:.4f}, Accuracy: {correct}/{len(test_loader.dataset)} '
          f'({accuracy:.2f}%)\n')


In [59]:
# Use CUDA if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# Move model to the selected device
model.to(device)

# Set up optimizer and loss function
optimizer = optim.SGD(model.parameters(), lr=0.01, momentum=0)
criterion = nn.CrossEntropyLoss()
criterion = extend(criterion) # for backpack

In [60]:
# Number of epochs
num_epochs = 1

# Training and testing loop
for epoch in range(1, num_epochs + 1):
    train(model, device, train_loader, optimizer, criterion, epoch)
    test(model, device, test_loader, criterion)


Current parameter vector length: 109386 , values: tensor([-0.0224,  0.0142, -0.0080,  0.0190, -0.0165], grad_fn=<SliceBackward0>)
d vector unnormalized length: 109386 , norm: tensor(0.0052, grad_fn=<LinalgVectorNormBackward0>) , values: tensor([4.4238e-06, 4.4229e-06, 4.4229e-06, 4.4238e-06, 4.4238e-06],
       grad_fn=<SliceBackward0>)
d vector normalized length: 109386 , norm: tensor(1.0000, grad_fn=<LinalgVectorNormBackward0>) , values: tensor([0.0008, 0.0008, 0.0008, 0.0008, 0.0008], grad_fn=<SliceBackward0>)
gradients vector length: 109386 , values: tensor([-0.0004, -0.0004, -0.0004, -0.0004, -0.0004])
Current parameter vector length: 109386 , values: tensor([-0.0224,  0.0142, -0.0080,  0.0190, -0.0165], grad_fn=<SliceBackward0>)
d vector unnormalized length: 109386 , norm: tensor(0.0064, grad_fn=<LinalgVectorNormBackward0>) , values: tensor([-3.8482e-06, -3.8482e-06, -3.8482e-06, -3.8482e-06, -3.8482e-06],
       grad_fn=<SliceBackward0>)
d vector normalized length: 109386 , norm