<a href="https://colab.research.google.com/github/CalculatedContent/WeightWatcher-Examples/blob/main/WW-MLP3-BatchSizes.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

This notebook shows that
# Alpha decreases with increasing Test Accuracy

## But if alpha < 2, the model may be overtrained

### This is the simplest possible modern MLP, 3 layers trained on MNIST

- Using Keras, following this sample code:
    - https://www.kaggle.com/sathianpong/3-ways-to-implement-mlp-with-keras

- trained  with a early stopping on the training loss,
    - upto 1000 epochs (but usually < 100)
    - delta_min = 0.001
- all layers are frozen, except for layer 1 (which is 300 x 100)
- the batch size is varied to induce the effect at very small batch_sizes


Note: To apply the theory in this way, it is necessary to:
- to train the all models  for as long / to as high of accuracy as possible
- with the exact same initial conditions

We don't necessarily need to freeze the other layers, but this is done for clarity here.  







## Keras Sequential API

In [None]:
import os
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
import tensorflow as tf
from IPython.display import Image, clear_output

sns.set()
tf.__version__

'2.15.0'

In [None]:
# https://stackoverflow.com/questions/32419510/how-to-get-reproducible-results-in-keras

# https://tensorflow.google.cn/api_docs/python/tf/config/experimental/enable_op_determinism

%env CUBLAS_WORKSPACE_CONFIG=:4096:8

import random
def reset_random_seeds(seed_value=1):
   os.environ['PYTHONHASHSEED']=str(seed_value)
   tf.random.set_seed(seed_value)
   tf.keras.utils.set_random_seed(seed_value)
   np.random.seed(seed_value)
   random.seed(seed_value)
   tf.config.experimental.enable_op_determinism()

reset_random_seeds()

env: CUBLAS_WORKSPACE_CONFIG=:4096:8


In [None]:
!pip install weightwatcher

Collecting weightwatcher
  Downloading weightwatcher-0.7.5-py3-none-any.whl (79 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/79.8 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━[0m [32m71.7/79.8 kB[0m [31m3.9 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m79.8/79.8 kB[0m [31m2.0 MB/s[0m eta [36m0:00:00[0m
Collecting powerlaw (from weightwatcher)
  Downloading powerlaw-1.5-py3-none-any.whl (24 kB)
Installing collected packages: powerlaw, weightwatcher
Successfully installed powerlaw-1.5 weightwatcher-0.7.5


### Download the Fashion MNIST dataset

In [None]:
data = tf.keras.datasets.mnist

(X_train, y_train), (X_test, y_test) = data.load_data()
class_names = ['T-shirt/top','Trouser','Pullover','Dress','Coat','Sandal','Shirt','Sneaker','Bag','Ankle boot']

clear_output()
print("X_train shape :",X_train.shape)
print("X_test shape :", X_test.shape)

X_train shape : (60000, 28, 28)
X_test shape : (10000, 28, 28)


### Recale the data

In [None]:
X_train = X_train/255.0
X_test = X_test/255.0

### Build the simplest MLP, need 3 layers

- input layer
- hidden layer
- output layer

Initialize all models with the same weight matrices

In [None]:
def build_model():

  reset_random_seeds()

  initializer = tf.keras.initializers.GlorotNormal(seed=1)

  # Specify the loss fuction, optimizer, metrics
  model = tf.keras.models.Sequential([
      tf.keras.layers.Flatten(input_shape = [28,28]),
      tf.keras.layers.Dense(300, activation='relu', kernel_initializer=initializer),
      tf.keras.layers.Dense(100, activation='relu', kernel_initializer=initializer),
      tf.keras.layers.Dense(10, activation='softmax', kernel_initializer=initializer),
  ])


  model.compile(
  loss = tf.keras.losses.SparseCategoricalCrossentropy(),
  optimizer = tf.keras.optimizers.SGD(),
  metrics = [tf.keras.metrics.SparseCategoricalAccuracy()]
  )

  model.layers[0].trainable = True
  model.layers[1].trainable = True
  model.layers[2].trainable = True
  model.layers[3].trainable = True

  return model



In [None]:
model = build_model()
model.summary()

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
 flatten (Flatten)           (None, 784)               0         
                                                                 
 dense (Dense)               (None, 300)               235500    
                                                                 
 dense_1 (Dense)             (None, 100)               30100     
                                                                 
 dense_2 (Dense)             (None, 10)                1010      
                                                                 
Total params: 266610 (1.02 MB)
Trainable params: 235500 (919.92 KB)
Non-trainable params: 31110 (121.52 KB)
_________________________________________________________________


In [None]:


model = build_model()

history = model.fit(
  X_train, y_train, epochs=1, batch_size=16, verbose=2
)

### Should be:

3750/3750 - 9s - loss: 0.6044 - sparse_categorical_accuracy: 0.7953 - 9s/epoch - 2ms/step






In [None]:
model = build_model()

history = model.fit(
  X_train, y_train, epochs=1, batch_size=16, verbose=2
)

In [None]:
# Train the model
def train_model(bs):

    model = build_model()

    early_stopping_callback = tf.keras.callbacks.EarlyStopping(monitor='loss', patience=3, verbose=2, min_delta=0.0001, restore_best_weights=True)
    history = model.fit(
        X_train, y_train, epochs=1000, batch_size=bs,
        validation_split=0.1,
        verbose=2, callbacks=[early_stopping_callback]
    )
    return model, history

In [None]:

model, H = train_model(bs=16)

In [None]:
H.history.keys()

In [None]:
pd.DataFrame(H.history).plot()
plt.xlabel("epoch")
plt.legend(bbox_to_anchor=(1.05, 1), loc=2)
plt.show()

In [None]:
import weightwatcher as ww
watcher = ww.WeightWatcher(model=model)
watcher.describe()

In [None]:
watcher.analyze(plot=True)

### We can now vary the hyperparameters and see how the test accuracy changes
 > defaults: https://keras.io/api/optimizers/sgd/'

The goal is to get some significant variation in the test (validation) accuraies, and see how our weightwatcher metrics compare

In [None]:
# Changes just the learning rate first
all_reaults = []
#all_bs = [256,128,64,13,16,8,4,2, 1]
all_bs = [1,2,4,8,16,32]


mlp3_results = None
models = {}
histories = {}
for bs in all_bs:
    layer_ids = [1]
    modelname = 'MLP3'
    print(f"Analyzing {modelname} bs={bs}")
    model, H = train_model(bs)
    watcher = ww.WeightWatcher(model=model)
    results = watcher.analyze(layers=layer_ids)

    val_acc = H.history['val_sparse_categorical_accuracy'][-1]

    histories[bs] = H
    models[bs]= model

    results['modelname'] = modelname
    results['batch_size'] = bs
    results['H_val_acc'] = H.history['val_sparse_categorical_accuracy'][-1]


    score = model.evaluate(X_train, y_train, verbose = 0)
    results['train_accuracy'] = score[1]
    score = model.evaluate(X_test, y_test, verbose = 0)
    results['test_accuracy'] = score[1]


    pd.DataFrame(H.history).plot()
    plt.xlabel("epoch")
    plt.legend(bbox_to_anchor=(1.05, 1), loc=2)
    plt.title(f"batch size {bs}")
    plt.show()

    #filename = f"/content/drive/My Drive/MLP3/model_bs{bs}"
    #model.save(filename)

    if mlp3_results is None:
        mlp3_results = results
    else:
        mlp3_results = pd.concat((mlp3_results, results))

#filename = f"/content/drive/My Drive/MLP3/mlp3_results.feather"
#mlp3_results.to_feather(filename)

In [None]:
mlp3_results.plot.scatter(x='batch_size',y='test_accuracy')
plt.title(r"Batch Size  vs Test Accuracy")
plt.xlabel(r"Batch Size")
plt.ylabel("Test Accuracy")

In [None]:
mlp3_results.plot.scatter(x='alpha',y='train_accuracy')
mlp3_results.plot.scatter(x='alpha',y='test_accuracy')

In [None]:
mlp3_results.plot.scatter(x='alpha',y='test_accuracy')
plt.title(r"WeightWatcher PL Alpha ($\alpha$) metric vs Test Accuracy")
plt.xlabel(r"Alpha ($\alpha$) metric")
plt.ylabel("Test Accuracy")