In [1]:
import numpy as np
import torch
import torch.autograd as autograd

In [2]:
sample_size = 10000
sample_dim = 10
noise_scale = 1.0
train_split = 0.8
theta_range = 10
data_range = 100
epoch = 50000
lr = 1e-5

In [3]:
def get_data(sample_size, sample_dim, noise_scale, train_split):
    theta = np.random.uniform(-theta_range, theta_range, sample_dim)
    X = np.random.uniform(-data_range, data_range, (sample_size, sample_dim))
    y = np.dot(X, theta) + noise_scale * np.random.normal(0, 1, sample_size)
    train_size = int(sample_size * train_split)
    X_train, y_train = X[:train_size], y[:train_size]
    X_test, y_test = X[train_size:], y[train_size:]
    print("Actual theta: ", theta)
    return X_train, y_train, X_test, y_test

In [4]:
X_train, y_train, X_test, y_test = get_data(sample_size, sample_dim, noise_scale, train_split)
X_train, y_train, X_test, y_test = torch.tensor(X_train), torch.tensor(y_train), torch.tensor(X_test), torch.tensor(y_test)

def get_loss(theta, X=X_test, y=y_test):
    return torch.mean((torch.matmul(X, theta) - y)**2)

Actual theta:  [ 8.28176164 -5.65855218 -1.59705275 -0.69252703  1.72765861 -5.55793897
 -8.37143006  5.51600464 -5.68358575 -4.99899531]


In [5]:
theta = np.random.random(sample_dim)
theta = torch.tensor(theta, requires_grad=True)
for _ in range(epoch):
    loss = get_loss(theta, X_train, y_train)
    loss.backward()
    theta.data -= lr * theta.grad
    theta.grad.zero_()
print(f"Estimated theta: {theta.data}")
hessian = autograd.functional.hessian(get_loss, theta)
#print(f"Hessian: {hessian}")

Estimated theta: tensor([ 8.2816, -5.6583, -1.5969, -0.6924,  1.7276, -5.5581, -8.3716,  5.5158,
        -5.6839, -4.9991], dtype=torch.float64)


In [6]:
all_entries = hessian.abs().sum()
diagonal_entries = torch.diag(hessian).abs().sum()
print(f"Sum of all entries: {all_entries}")
print(f"Sum of diagonal entries: {diagonal_entries}")
print(f"Ratio: {diagonal_entries / all_entries}")
# Observation: The ratio is proportional to sample_size/sample_dim.

Sum of all entries: 77716.0611666939
Sum of diagonal entries: 66605.85415767039
Ratio: 0.857041043482722
