Start out with vanilla training of VGG11 on CIFAR10, with some learning rate and number of epochs.

In [None]:
from copy import deepcopy
from vgg_linear_fit_utils import *
from matplotlib import pyplot as plt


# Get Model and data
use_random_labels = False
model_name = 'vgg11_bn'  # 'resnet20', 'vgg13_bn',...
epochs = 30
learning_rate = 0.01

model = get_model(model_name)

train_loader = get_dataloader(use_random_labels, train=True)
test_loader = get_dataloader(use_random_labels, train=True)



model_initialization = deepcopy(get_first_layer_weights(model))

train_model(model, train_loader, test_loader, num_epochs=epochs, lr=learning_rate)

Calculate the Patch PCA and get the eigenvlues $\lambda_i^2$ and the components $\{u_i\}_{i=1}^p$ for the patch dimension $p$.

In [None]:
images = torch.cat([batch[0] for batch in train_loader])
# compute pca from 10K images, no need for all 50K.
images = images[torch.randperm(images.shape[0])][:10000]
num_images = images.shape[0]

patches = images.unfold(1, 3, 1).unfold(2, 3, 1).unfold(3, 3, 1).to(torch.float64)

pca = torch.pca_lowrank(patches.flatten(-3).flatten(0, -2), q=27)

components = pca[-1]
lambdas = pca[1]
lambda_squared = lambdas**2

Fit the theoretical formula by searching over the number of steps of gradient descent $t$. Plot the theoretical energy profile vs the model energy profile. 
Notice if we assume the model was initialized sufficiently close to $0$, then we should subtract the initialization from the first layer's weights.

In [None]:
trained_first_layer = get_first_layer_weights(model)

model_energy_profile = calc_energy_profile(trained_first_layer, model_initialization, components, subtract_init=True, normalize=True)

theoretic_profile, max_correlation = fit_formula_to_model(model_energy_profile, lambdas, lambda_squared, num_images)


# plot the normalized energy profiles
plt.plot(model_energy_profile / model_energy_profile.max(), label=model_name)
plt.plot(theoretic_profile / theoretic_profile.max(), label="Formula Fit")
plt.title(f"Best Theoretical Fit\ncorrelation={max_correlation:.2f}")
plt.legend()
plt.show()