# Figure 6: MNIST learning dynamics.

This notebook provides the code to produce Figure 6 in the paper: "Learning dynamics of linear denoising autoencoders". (ICML 2018)

In [8]:
import numpy as np
import matplotlib.pyplot as plt
import torch
from torchvision import transforms
from torchvision.datasets import MNIST
from torchvision.datasets import CIFAR10
from collections import OrderedDict

# custom imports
from src.linear_ae_net.linear_ae_net import LinearAutoEncoder
from src.linear_ae_net.dynamics import theoretical_learning_dynamics

## --- MNIST ---

### Load MNIST data

In [9]:
# cast to tensor
trans = transforms.Compose([transforms.ToTensor()])

# if not exist, download mnist dataset
train_set = MNIST(root="../data", train=True, transform=trans, download=True)
x_train = train_set.train_data.numpy()
x_train = x_train.astype('float32') / 255.
x_train_mnist = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
print(x_train_mnist.shape)

(60000, 784)


### Train models

In [10]:
# set parameters
num_samples = 500
epochs = 10
lr = 0.01
reg_param = 0.5
var_param = 0.5
reg = [0.0, reg_param, 0.0]
var = [0.0, 0.0, var_param]
num_trials = 3
hidden_dim = 256

mnist_models = []

# convert to pytorch tensors
x_train_mnist = torch.from_numpy(x_train_mnist)

# set seed
np.random.seed(123)
torch.manual_seed(321)

# train autoencoder network
for t in range(num_trials):
    laeModel = LinearAutoEncoder()
    laeModel.train(x_train_mnist[:num_samples,], None, input_dim=784, n_epoch=epochs, 
                   hidden_dim=hidden_dim, learning_rate=lr, reg_param=reg[t], 
                   noise='Gaussian', noise_scale=var[t], verbose=True)
    mnist_models.append(laeModel)

iteration:  0 training loss:  42.19104296875
iteration:  0 training loss:  42.19203712843126
iteration:  0 training loss:  42.191035160214575


### Compute theoretical dynamics

In [11]:
# compute MNIST dynamics
x_train_mnist_np = x_train_mnist.cpu().numpy()
theoretical_dynamics = theoretical_learning_dynamics(x_train_mnist_np[:num_samples, :], 
                                                             x_train_mnist_np[:num_samples, :], 
                                                             n_epoch=epochs, lr=lr, var=0, reg=0)
theoretical_dynamics_reg = theoretical_learning_dynamics(x_train_mnist_np[:num_samples, :], 
                                                                 x_train_mnist_np[:num_samples, :], 
                                                                 n_epoch=epochs, lr=lr, var=0, 
                                                                 reg=reg_param)
theoretical_dynamics_noise = theoretical_learning_dynamics(x_train_mnist_np[:num_samples, :], 
                                                                   x_train_mnist_np[:num_samples, :], 
                                                                   n_epoch=epochs, lr=lr, var=var_param, 
                                                                   reg=0)

### Plot results

In [12]:
# create dynamics plot
slices = (0, 3, 7, 15, 31)
fig, [(ax1, ax2), (ax3, ax4)] = plt.subplots(2, 2, figsize=(12, 8), sharey='row', sharex='col')
axes = [ax1, ax2]

# plot theoretical dynamics
ax1.plot(theoretical_dynamics[:, slices], c='blue', 
         label='Theory ($\gamma = 0$)')
ax1.plot(theoretical_dynamics_reg[:, slices], c='orange', 
         label='Theory ($\gamma = $' + str(reg_param) + ')')
ax2.plot(theoretical_dynamics[:, slices], c='blue', 
         label='Theory ($\sigma^2 = 0$)')
ax2.plot(theoretical_dynamics_noise[:, slices], c='darkgreen', 
         label='Theory ($\sigma^2 = $' + str(var_param) + ')')

# get actual dynamics
actual_dynamics = mnist_models[0].strenghts.cpu().numpy()
actual_dynamics_reg = mnist_models[1].strenghts.cpu().numpy()
actual_dynamics_noise = mnist_models[2].strenghts.cpu().numpy()

# plot simulated dynamics
x_p = np.arange(0, epochs+1, 10)
for s in slices:
    ax1.scatter(x_p, actual_dynamics[:, s], c='blue', 
                marker='x', label='Actual ($\gamma = 0$)')
    ax1.scatter(x_p, actual_dynamics_reg[:, s], c='orange', 
                marker='x', label='Actual ($\gamma = $' + str(reg_param) + ')')
    ax2.scatter(x_p, actual_dynamics[:, s], c='blue', 
                marker='x', label='Actual ($\sigma^2 = 0$)')
    ax2.scatter(x_p, actual_dynamics_noise[:, s], c='darkgreen',
                marker='x', label='Actual ($\sigma^2 = $' + str(var_param) + ')')
    
    
# set plot titles and axis labels
ax1.set_ylabel('$w_2 \cdot w_1$', fontsize=15)
ax1.set_title('Weight decay')
ax2.yaxis.set_label_position('right')
ax2.set_ylabel('MNIST')

# remove duplicates labels
locations = ['lower right', 'lower right', 'upper right', 'upper right']
for ax, loc in zip(axes, locations):
    handles, labels = ax.get_legend_handles_labels()
    by_label = OrderedDict(zip(labels, handles))
    ax.legend(by_label.values(), by_label.keys(), loc=loc)

## --- CIFAR-10 ---

### Load CIFAR-10 data

In [13]:
trans = transforms.Compose([transforms.ToTensor()])
train_set = CIFAR10(root="../data", train=True, transform=trans, download=True)
x_train = train_set.train_data
x_train = x_train.astype('float32')
x_train /= 255
x_train_cifar10 = x_train.reshape((len(x_train), np.prod(x_train.shape[1:])))
print('x_train shape:', x_train_cifar10.shape)

Files already downloaded and verified
x_train shape: (50000, 3072)


### Train models

In [None]:
# set parameters
num_samples = 300
epochs = 10
lr = 0.001
reg_param = 0.5
var_param = 0.5
reg = [0.0, reg_param, 0.0]
var = [0.0, 0.0, var_param]
num_trials = 3
hidden_dim = 512

cifar10_models = []

# convert to pytorch tensors
x_train_cifar10 = torch.from_numpy(x_train_cifar10)

# set seed
np.random.seed(123)
torch.manual_seed(321)

# train autoencoder networks
for t in range(num_trials):
    laeModel = LinearAutoEncoder()
    laeModel.train(x_train_cifar10[:num_samples,], None, input_dim=32*32*3, 
                   n_epoch=epochs, hidden_dim=hidden_dim, 
                   learning_rate=lr, reg_param=reg[t], 
                   noise='Gaussian', noise_scale=var[t], verbose=True)
    cifar10_models.append(laeModel)

iteration:  0 training loss:  434.49658854166665
iteration:  0 training loss:  434.50452747529994
iteration:  0 training loss:  434.49664074591044


### Compute theoretical dynamics

In [None]:
# compute CIFAR-10 dynamics
x_train_cifar10_np = x_train_cifar10.cpu().numpy()
theoretical_dynamics = theoretical_learning_dynamics(x_train_cifar10_np[:num_samples, :], 
                                                             x_train_cifar10_np[:num_samples, :], 
                                                             n_epoch=epochs, lr=lr, var=0, 
                                                             reg=0, u0 = 1.5e-6)
theoretical_dynamics_reg = theoretical_learning_dynamics(x_train_cifar10_np[:num_samples, :], 
                                                                 x_train_cifar10_np[:num_samples, :], 
                                                                 n_epoch=epochs, lr=lr, var=0, 
                                                                 reg=reg_param, u0 = 1.5e-6)
theoretical_dynamics_noise = theoretical_learning_dynamics(x_train_cifar10_np[:num_samples, :], 
                                                                   x_train_cifar10_np[:num_samples, :], 
                                                                   n_epoch=epochs, lr=lr, var=var_param, 
                                                                   reg=0, u0 = 1.5e-6) 

### Plot results

In [None]:
# plot theoretical dynamics
ax3.plot(theoretical_dynamics[:, slices], c='blue')
ax3.plot(theoretical_dynamics_reg[:, slices], c='orange')
ax4.plot(theoretical_dynamics[:, slices], c='blue')
ax4.plot(theoretical_dynamics_noise[:, slices], c='darkgreen')

# get actual dynamics
actual_dynamics = cifar10_models[0].strenghts.cpu().numpy()
actual_dynamics_reg = cifar10_models[1].strenghts.cpu().numpy()
actual_dynamics_noise = cifar10_models[2].strenghts.cpu().numpy()

# plot simulated dynamics
x_p = np.arange(0, epochs+1, 100)
for s in slices:
    ax1.scatter(x_p, actual_dynamics[:, s], c='blue', marker='x')
    ax1.scatter(x_p, actual_dynamics_reg[:, s], c='orange', marker='x')
    ax2.scatter(x_p, actual_dynamics[:, s], c='blue', marker='x')
    ax2.scatter(x_p, actual_dynamics_noise[:, s], c='darkgreen',marker='x')

# set plot titles and axis labels
ax2.set_title('Noise')
ax3.set_ylabel('$w_2 \cdot w_1$', fontsize=15)
ax3.set_xlabel('t (epoch)', fontsize=10)
ax4.set_xlabel('t (epoch)', fontsize=10)
ax4.yaxis.set_label_position('right')
ax4.set_ylabel('CIFAR-10')
plt.show()