**Stragglers and non-monotonic learning dynamics in feed-forward neural
networks**

Simone Ciceri, Lorenzo Cassani, Matteo Osella, Pietro Rotondo, Filippo Valle, Marco Gherardi

    Copyright (C) 2022 Marco Gherardi and Università degli Studi di Milano

    This program is free software: you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation, either version 3 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program.  If not, see <http://www.gnu.org/licenses/>.

# INSTRUCTIONS

Evaluate the SETUP cell first, then the DEMOs.

To run the code on data sets different from those provided (MNIST, KMNIST, FashionMNIST, CIFAR10), change the code in the utility load_data (SETUP cell). The four data tensors (data, labels, test_data, test_labels) must contain the training set and test set. The shape of data and test_data is (number_of_elements, 1, X, Y), where, for instance, X=Y=28 for *MNIST. The shape of labels and test_labels is (number_of_elements). The variable input_size must be the size of each element of the data set. NOTE: the dataset used should be standardized, as is done in the SETUP cell.

# SETUP

In [None]:
import torch
import torch.nn.functional as F
import torchvision
import numpy as np
import matplotlib.pyplot as plt
from IPython import display


######################
# DATASET PARAMETERS #
######################
PDATA = 8192 # number of elements in the data set
DATA_BLOCK = 1 # Data block to use within the full data set
EPSILON = 0.000000001 # cutoff for the computation of the variance in the standardisation
tdDATASET = torchvision.datasets.MNIST # the dataset (MNIST, KMNIST, FashionMNIST, CIFAR10)
######################


# Check if GPU is present and set device
if torch.cuda.is_available():
  device = torch.device('cuda')
  print("using GPU")
  !nvidia-smi
else:
  device = torch.device("cpu")
  print("using CPU")

# Download the dataset
dataset=tdDATASET("/content/", train=True, download=True,
                  transform = torchvision.transforms.ToTensor())

# Standardize the data
in_block = lambda n : (DATA_BLOCK-1)*PDATA <= n < DATA_BLOCK*PDATA
data_means = torch.mean(torch.cat([a[0] for n,a in enumerate(dataset) if in_block(n)]), dim=0)
data_vars = torch.sqrt(torch.var(torch.cat([a[0] for n,a in enumerate(dataset) if in_block(n)]), dim=0))
transf = lambda x : (x - data_means)/(data_vars+EPSILON)

# the training set
dataset = tdDATASET("/content/", train=True, download=True,
                    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), transf]))
train_loader = torch.utils.data.DataLoader(dataset, batch_size=PDATA, shuffle=False)
# the test set
testset = tdDATASET("/content/", train=False, download=True,
                    transform = torchvision.transforms.Compose([torchvision.transforms.ToTensor(), transf]))
test_loader = torch.utils.data.DataLoader(testset, batch_size=len(testset), shuffle=True)


############ CLASSES ############

# Base NN class
# (computes the metric observables given a latent_representation, a.k.a. the activations of a hidden layer)
class myNN(torch.nn.Module):
  def __init__(self):
    super().__init__()

  def latent_representation(self,X):
    pass

  # radii() computes both the radii and the distance between centers of mass
  def radii(self, data, labels):
    with torch.no_grad():
      X = data[labels==0], data[labels==1]
      nump = X[0].shape[0], X[1].shape[0]
      X = self.latent_representation(X[0]), self.latent_representation(X[1])

      # normalization
      X = torch.nn.functional.normalize(X[0], dim=1), torch.nn.functional.normalize(X[1], dim=1)

      # computation of the metric quantities
      Xmean = torch.mean(X[0], dim=0), torch.mean(X[1], dim=0)
      radius = ( torch.sqrt(torch.sum(torch.square(X[0]-Xmean[0]))/nump[0]) ,
                 torch.sqrt(torch.sum(torch.square(X[1]-Xmean[1]))/nump[1]) )
      distance = torch.norm(Xmean[0]-Xmean[1]).item()
    return radius, distance


# Derived class implementing a fully-connected NN
# - in_size: size of input
# - K: number of layers
# - N_his: number of units in hidden layers (all equal)
# - latent: ordinal number of hidden layer where the observables are computed
class NN_KHL(myNN):
  def __init__(self, in_size, K, N_hid, latent):
    super().__init__()
    self.latent = latent
    self.layers = torch.nn.ModuleList([torch.nn.Linear(in_size, N_hid, bias=True)])
    for _ in range(K-2):
      self.layers.append(torch.nn.Linear(N_hid, N_hid, bias=True))
    self.layers.append(torch.nn.Linear(N_hid, 2, bias=True))

  def latent_representation(self, X):
    X = X.view(-1,self.layers[0].in_features)
    for l in range(self.latent):
      X = self.layers[l](X)
      X = torch.tanh(X)
    return X

  def forward(self, X):
    X = self.latent_representation(X)
    for l in range(self.latent, len(self.layers)):
      X = self.layers[l](X)
      if l<len(self.layers)-1:
        X = torch.tanh(X)
    return X


############ UTILITIES ############

# Applies Gaussian noise to a tensor
apply_noise = lambda noise,x: torch.normal(x,noise*torch.ones_like(x))


# Returns data points and lables of the training and test set
def load_data(data_block):
  loader_it = iter(train_loader)

  for _ in range(data_block):
    data, labels = next(loader_it)
  test_data, test_labels = next(iter(test_loader))

  data, labels = data.to(device), labels.to(device)
  test_data, test_labels = test_data.to(device), test_labels.to(device)

  # When using CIFAR10, convert to greyscale (1 channel)
  if tdDATASET == torchvision.datasets.CIFAR10:
    data = data[:,0,:,:]+data[:,1,:,:]+data[:,2,:,:]
    test_data = test_data[:,0,:,:]+test_data[:,1,:,:]+test_data[:,2,:,:]

  # Binarize class labels (NOTE: labels are 0,1 here but +1,-1 in the manuscript)
  labels %= 2
  test_labels %= 2

  return data, labels, test_data, test_labels


# Trains a model and returns errors, the metric quantities, misclassified examples at each epoch
def train_and_measure(model, data, labels, test_data, test_labels, optimizer, criterion, epochs):
  results_run = []
  misclassified_examples_list = []

  for epoch in range(epochs):
    optimizer.zero_grad()
    loss = criterion(model(data), labels)
    loss.backward()
    optimizer.step()

    # Compute errors and metric observables
    train_error = torch.sum(torch.abs(torch.argmax(model(data),dim=1)-labels)).item()/data.shape[0]
    test_error = torch.sum(torch.abs(torch.argmax(model(test_data),dim=1)-test_labels)).item()/len(testset)
    radii, distance = model.radii(data, labels)
    results_run.append([epoch, train_error, test_error, radii[0].item(), radii[1].item(), distance])
    misclassified_examples = torch.argmax(model(data),dim=1)-labels != 0
    misclassified_examples_list.append(misclassified_examples)

  return results_run, misclassified_examples_list


# Trains and stops when the inversion point is reached
def train_stop_at_inversion(model, data, labels, optimizer, criterion):

  radius, radius_prev = 0, 0
  count = 0
  radii = []

  # This cycle trains until the inversion point is reached.
  # The inversion point is reached when the first radius starts increasing.
  # (does not halt during the initial 20 epochs to avoid being fooled by initial fluctuations)
  while radius<radius_prev or count<20:
    count += 1
    optimizer.zero_grad()
    loss = criterion(model(data), labels)
    loss.backward()
    optimizer.step()

    radius_prev = radius
    (radius, _), _ = model.radii(data, labels)


# DEMO

## 1 - Non monotonic dynamics of the metric quantities and invariance of the inversion point
Computes the metric quantities as functions of the training error, throughout the training dynamics.

**Expected output** - A plot showing the nonmonotonic dynamics of the metric quantities, where the inversion point is stable across different runs. (Training error on the x axis, manifold radii and inter-manifold distance on the y axis.) Similar results can be obtained for different architectures, by changing the parameters DEPTH and WIDTH.

**Corresponding figures in the manuscript** - 1(c) and 1(d)

**Expected run time** - On a Tesla T4, around 2 seconds per run (or ~20 seconds with N_RUNS=10 as below). On the CPU, around 10x longer. (These figures are for a 2-layer NN with 20 hidden units and 500 epochs.)

In [None]:
#################################
# MODEL AND TRAINING PARAMETERS #
#################################
DEPTH = 2
WIDTH = 20
LATENT = 1 # ordinal number of hidden layer where the observables are computed
N_RUNS = 10 # number of runs from independent initializations
EPOCHS = 500
LEARNING_RATE = 0.1
OPTIMIZER = torch.optim.SGD
#################################

# Load data
data, labels, test_data, test_labels = load_data(DATA_BLOCK)
input_size = data.shape[2]*data.shape[3] # 32*32 for CIFAR, 28*28 for *MNIST

# Setup lists for results
radii_data, distances_data, losses_data = [], [], []
results = []

# Perform N_RUNS independent training runs
for niter in range(N_RUNS):
  model = NN_KHL(input_size, DEPTH, WIDTH, LATENT).to(device)
  optimizer = OPTIMIZER(model.parameters(), lr=LEARNING_RATE)
  criterion = torch.nn.CrossEntropyLoss()

  # Train
  results_run, _ = train_and_measure(model, data, labels, test_data, test_labels, optimizer, criterion, EPOCHS)
  results.append(results_run)

# Plot results
arr_res = np.array(results)
for kk in range(len(arr_res)):
  plt.plot(arr_res[kk,:,1],arr_res[kk,:,3], color="#3a5a4070") # first radius VS training error
  plt.plot(arr_res[kk,:,1],arr_res[kk,:,4], color="#67671570") # second radius VS training error
  plt.plot(arr_res[kk,:,1],arr_res[kk,:,5], color="#06467570") # distance VS training error

## 2 - Removal of stragglers from the training set eliminates the nonmonotonicity

Prunes the training set by removing stragglers, defined as those elements of the training set that are still misclassified at the inversion point.

**Expected output** - A plot showing that the dynamics on the pruned training set is monotonic: no inversion point exists. (Epochs on the x axis, manifold radii and inter-manifold distance on the y axis.)

**Corresponding figures in the manuscript** - 2(a)

**Expected run time** - On a Tesla T4, around 3.5 seconds per run (or ~35 seconds with N_RUNS=10 as below). On the CPU, around 10x longer. (These figures are for a 2-layer NN with 20 hidden units and 300 epochs.)

In [None]:
#################################
# MODEL AND TRAINING PARAMETERS #
#################################
DEPTH = 2
WIDTH = 20
LATENT = 1 # ordinal number of hidden layer where the observables are computed
N_RUNS = 10 # number of runs from independent initializations
EPOCHS = 300
LEARNING_RATE = 0.1
OPTIMIZER = torch.optim.SGD
#################################


results = []

# Perform N_RUNS independent experiments
for niter in range(N_RUNS):

  # Load data
  data, labels, test_data, test_labels = load_data(DATA_BLOCK)
  input_size = data.shape[2]*data.shape[3] # 32*32 for CIFAR, 28*28 for *MNIST

  # Use two identically initialized models, one for identifying stragglers, the other to train on the modified training set
  # (NOTE: using differently initialized models gives almost identical results)
  model = NN_KHL(input_size, DEPTH, WIDTH, LATENT).to(device)
  model2 = NN_KHL(input_size, DEPTH, WIDTH, LATENT).to(device)
  initial_state = model.state_dict()
  model2.load_state_dict(initial_state)

  optimizer = OPTIMIZER(model.parameters(), lr=LEARNING_RATE)
  criterion = torch.nn.CrossEntropyLoss()

  # Do a single training run to identify the stragglers
  train_stop_at_inversion(model, data, labels, optimizer, criterion)

  # The stragglers are the misclassified data points now
  stragglers = torch.argmax(model(data),dim=1)-labels != 0

  # Prune the training set by removing the stragglers
  data, labels = data[torch.logical_not(stragglers)], labels[torch.logical_not(stragglers)]

  optimizer = OPTIMIZER(model2.parameters(), lr=LEARNING_RATE)
  criterion = torch.nn.CrossEntropyLoss()

  # Train again from scratch
  results_run, _ = train_and_measure(model2, data, labels, test_data, test_labels, optimizer, criterion, EPOCHS)
  results.append(results_run)


# plot results
arr_res = np.array(results)
for kk in range(len(arr_res)):
  plt.plot(arr_res[kk,:,0],arr_res[kk,:,3], color="#3a5a4070") # first radius VS epochs
  plt.plot(arr_res[kk,:,0],arr_res[kk,:,4], color="#67671570") # second radius VS epochs
  plt.plot(arr_res[kk,:,0],arr_res[kk,:,5], color="#06467570") # distance VS epochs

## 3 - Removing stragglers from the training set affects the generalisation error

Prunes the training set by removing the sets S(t), containing the misclassified examples at epoch t, and retrains.

**Expected output** - A plot showing the test error (y axis) as a function of the pruned training set. On the x asis is the training error reached at the epoch t that defines S(t).

**Corresponding figures in the manuscript** - 2(c)

**Expected run time** - On a Tesla T4, around 6.5 minutes with N_RUNS=10 as below.

In [None]:
#################################
# MODEL AND TRAINING PARAMETERS #
#################################
DEPTH = 2
WIDTH = 20
LATENT = 1 # ordinal number of hidden layer where the observables are computed
N_RUNS = 10 # number of runs from independent initializations
EPOCHS = 300 # maximum number of epochs reached in the identification of misclassified examples
LOG_EPOCH_SKIP = 0.1 # logarithmic spacing between identifications of misclassified examples used for measuring the test error
TEST_EPOCHS = 500 # number of epochs to measure the test error
NOISE_SIGMA = 1. # standard deviation of the noise added to the test set
LEARNING_RATE = 0.2
OPTIMIZER = torch.optim.SGD
#################################


results = []

# Perform N_RUNS independent experiments
for niter in range(N_RUNS):

  # Load data
  dataALL, labelsALL, test_data, test_labels = load_data(DATA_BLOCK)
  test_data = apply_noise(NOISE_SIGMA, test_data)
  input_size = dataALL.shape[2]*dataALL.shape[3] # 32*32 for CIFAR, 28*28 for *MNIST

  # Use two identically initialized models, one for identifying stragglers, the other to train on the modified training set
  # (NOTE: using differently initialized models gives almost identical results)
  model = NN_KHL(input_size, DEPTH, WIDTH, LATENT).to(device)
  model2 = NN_KHL(input_size, DEPTH, WIDTH, LATENT).to(device)
  initial_state = model.state_dict()
  model2.load_state_dict(initial_state)

  # Identify the misclassified examples at each epoch
  optimizer = OPTIMIZER(model2.parameters(), lr=LEARNING_RATE)
  criterion = torch.nn.CrossEntropyLoss()
  results_identification_run, misclassified_examples = train_and_measure(model2, dataALL, labelsALL, test_data, test_labels, optimizer, criterion, EPOCHS)

  results_test_epoch = []

  # Cycle over the (logarithmically spaced) epochs defining the pruned training sets
  last_log_epoch = 0.
  for epoch, mask in enumerate(misclassified_examples):

    if np.log(epoch+1)-last_log_epoch > LOG_EPOCH_SKIP:
      last_log_epoch = np.log(epoch+1)

      # Prune the training set by removing the stragglers
      data, labels = dataALL[torch.logical_not(mask)], labelsALL[torch.logical_not(mask)]

      # restore initial state
      model2.load_state_dict(initial_state)

      optimizer = OPTIMIZER(model2.parameters(), lr=LEARNING_RATE)
      criterion = torch.nn.CrossEntropyLoss()

      # Train from scratch
      results_run, _ = train_and_measure(model2, data, labels, test_data, test_labels, optimizer, criterion, TEST_EPOCHS)
      results_test_epoch.append( (epoch, results_identification_run[epoch][1], results_run[-1][2]) )

  results.append(results_test_epoch)

  # Plot all results up to this run
  for results_run in results:
    plt.plot(np.array(results_run)[:,1], np.array(results_run)[:,2])
    display.clear_output()
    display.display(plt.gcf())

# Plot all results
display.clear_output()
for results_run in results:
  plt.plot(np.array(results_run)[:,1], np.array(results_run)[:,2])

## 4 - Stragglers are the most stable set of misclassified points at any given epoch

Computes the z-score, measuring the stability, between two independent runs, of the misclassified points evaluated at a given epoch t. For values of t in a specified range, performs N_RUNS iterations, training each time from scratch. The z-score is computed by averaging over the iterations.

**Expected output** - A plot showing the z-score (y axis) as a function of the (average) training error (x axis) reached at epoch t.

**Corresponding figures in the manuscript** - 2(d)

**Expected run time** - On a Tesla T4, around 40 minutes with parameters as below

In [None]:
#################################
# MODEL AND TRAINING PARAMETERS #
#################################
DEPTH = 2
WIDTH = 20
LATENT = 1 # ordinal number of hidden layer where the observables are computed
N_RUNS = 500 # number of runs from independent initializations
MIN_EPOCH = 5 # lowest epoch in the cycle over the epochs defining the misclassified examples
MAX_EPOCH = 70 # largest epoch
SKIP_EPOCH = 5 # increment
LEARNING_RATE = 0.2
OPTIMIZER = torch.optim.SGD
MAX_COMMON = 200 # maximum number of common examples considered in the hypergeometric model
#################################


# To compute the hypergeometric PDF
from scipy.stats import hypergeom


# Load data
data, labels, test_data, test_labels = load_data(DATA_BLOCK)
#test_data = apply_noise(NOISE_SIGMA, test_data)
input_size = data.shape[2]*data.shape[3] # 32*32 for CIFAR, 28*28 for *MNIST


results = []

# Cycle over the epoch at which misclassified points are evaluated
for misclassified_epoch in range(MIN_EPOCH, MAX_EPOCH, SKIP_EPOCH):

  # Will contain the number of common instances in the two realizations
  number_common = []

  # To measure the number of common instances in the hypergeometric model
  x = np.arange(0,MAX_COMMON)
  random_common = np.zeros(x.shape) # Will contain the probability distribution function

  training_errors = []

  # Cycle on N_RUNS independent iterations
  for niter in range(N_RUNS):

    misclassified_examples = []
    training_errors = []

    # Two independent realizations
    for _ in "two", "realizations":

      model = NN_KHL(input_size, DEPTH, WIDTH, LATENT).to(device)
      optimizer = OPTIMIZER(model.parameters(), lr=LEARNING_RATE)
      criterion = torch.nn.CrossEntropyLoss()

      # Train for misclassified_epoch epochs
      results_realization, misclassified_examples_realization = train_and_measure(model, data, labels, test_data, test_labels, optimizer, criterion, misclassified_epoch)

      # The training errors and the list of misclassified examples for the 2 realizations
      training_errors.append(results_realization[-1][1])
      misclassified_examples.append(misclassified_examples_realization[-1])

    # Measure the number of common instances between misclassified examples and compute the PDF in the hypergeometric model
    number_common.append(torch.sum(misclassified_examples[0]*misclassified_examples[1]).item())
    random_common += hypergeom(PDATA, torch.sum(misclassified_examples[0]).cpu().numpy(), torch.sum(misclassified_examples[1]).cpu().numpy()).pmf(x)

  zscore = (np.mean(number_common)-np.average(x, weights=random_common))/np.std(number_common)
  results.append( [np.mean(training_errors), zscore] )

# Plot the results
data = np.array(results)
plt.scatter(data[:,0], data[:,1])