In [28]:
import torch
from torch.utils.data import DataLoader, Subset, random_split
from torchvision import datasets, transforms
import torch.nn as nn
import torch.optim as optim
import os

In [32]:
path = os.path.dirname(os.path.abspath("__file__"))
data_path = os.path.join(path, "data")

## Number of computations

In [11]:
# How many computations to compute diagonal elements of the Hessian ?
def compute_necessary (nx, n1, n2, ny) :
    return n2*ny * (1 + n1 + nx*n1)

def compute_total (nx, n1, n2, ny) :
    return (nx*n1 + n1*n2 + n2*ny)**2

In [15]:
nx = 784
n1 = 100
n2 = 100
ny = 10

In [16]:
print("Necessary computations: ", format(compute_necessary(nx, n1, n2, ny), ".2e"))
print("Total computations: ", format(compute_total(nx, n1, n2, ny), ".2e"))
print("Ratio: ", format(compute_total(nx, n1, n2, ny) / compute_necessary(nx, n1, n2, ny), ".2e"))

Necessary computations:  7.85e+07
Total computations:  7.99e+09
Ratio:  1.02e+02


## Get MNIST loaders

In [27]:
batch_size = 128

In [34]:
MNIST_transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])

def get_MNIST_loaders(path, class_names, batch_size, download=False) :

    # load MNIST 
    mnist_train = datasets.MNIST(root=path, train=True, download=download, transform=MNIST_transform)
    mnist_test = datasets.MNIST(root=path, train=False, download=download, transform=MNIST_transform)


    # create a mask to filter indices for each label
    train_mask = torch.tensor([label in class_names for label in mnist_train.targets])
    test_mask = torch.tensor([label in class_names for label in mnist_test.targets])

    # Create Subset datasets for train, validation, and test
    train_dataset = Subset(mnist_train, torch.where(train_mask)[0])
    test_dataset = Subset(mnist_test, torch.where(test_mask)[0])

    # split train into train & validation
    train_size = int(0.7 * len(train_dataset))
    val_size = len(train_dataset) - train_size
    train_dataset, val_dataset = random_split(train_dataset, [train_size, val_size])

    train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True, drop_last=True)
    val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False, drop_last=True)
    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False, drop_last=True)

    return train_loader, val_loader, test_loader

In [38]:
train_loader, val_loader, test_loader = get_MNIST_loaders(data_path, range(10), batch_size, download=False)

## Model

In [39]:
class SimpleMLP(nn.Module):
    def __init__(self, input_size, hidden_size, output_size):
        super(SimpleMLP, self).__init__()
        self.fc1 = nn.Linear(input_size, hidden_size)
        self.fc2 = nn.Linear(hidden_size, output_size)
    
    def forward(self, x):
        x = torch.relu(self.fc1(x))
        x = self.fc2(x)
        return x

# Example usage
input_size = 784
hidden_size = 100
output_size = 10
model = SimpleMLP(input_size, hidden_size, output_size)


## Compute second derivative while training

In [25]:
# Loss and optimizer
criterion = nn.MSELoss()
optimizer = optim.SGD(model.parameters(), lr=0.01)

# Training loop
num_epochs = 10
for epoch in range(num_epochs):
    # Forward pass
    outputs = model(x)
    loss = criterion(outputs, y)
    
    # Zero the parameter gradients
    optimizer.zero_grad()
    
    # Backward pass to compute gradients
    loss.backward(create_graph=True)
    
    # Compute second derivatives (diagonal of the Hessian)
    second_derivatives = []
    for param in model.parameters():
        if param.grad is not None:
            grad = param.grad
            second_derivative = torch.autograd.grad(
                grad,
                param,
                grad_outputs=torch.ones_like(grad),
                retain_graph=True,
                create_graph=True,
            )[0]
            second_derivatives.append(second_derivative)
    
    # Update weights
    optimizer.step()
    
    # Print the loss and second derivatives for the first epoch for demonstration
    if epoch == 8:
        print(f"Epoch [{epoch+1}/{num_epochs}], Loss: {loss.item()}")
        for i, sd in enumerate(second_derivatives):
            print(f"Second derivative of parameter {i}: {sd}")

# This example demonstrates the computation for one batch


Epoch [9/10], Loss: 0.3488340377807617
Second derivative of parameter 0: tensor([[ 0.1917,  0.1606,  0.6367,  0.5450, -0.0316, -0.0256,  0.1467, -0.0575,
         -0.0162, -0.2791],
        [ 0.3343,  0.2063,  0.9759,  0.8171, -0.0315, -0.0665,  0.3384, -0.0335,
          0.3485, -0.2396],
        [ 0.4028, -0.1661,  0.0139, -0.1765,  0.1429,  0.0278, -0.1045, -0.3401,
          0.3231,  1.0075],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.1126,  0.0144, -0.0685, -0.0149, -0.0290, -0.0138,  0.0501,  0.1138,
          0.0394, -0.1757]], grad_fn=<TBackward0>)
Second derivative of parameter 1: tensor([ 0.0693,  0.2881,  0.3353,  0.0000, -0.0356], grad_fn=<ViewBackward0>)
Second derivative of parameter 2: tensor([[1.0449, 0.8003, 1.2290, 0.0000, 0.0243]], grad_fn=<TBackward0>)
Second derivative of parameter 3: tensor([2.])


In [23]:
for param in model.parameters() :
    print(param.shape)

torch.Size([5, 10])
torch.Size([5])
torch.Size([1, 5])
torch.Size([1])
