<a href="https://colab.research.google.com/github/IandRover/NTK_MNIST/blob/main/NTK_MNIST.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Motivation

I am very interested in the mathematical mechanism and dynamic changes behind MNIST and NTK, and I am also curious about the working principle behind GP. However, I have not found the experimental results of using NTK on the MNIST dataset on the Internet. So, I share the code I compiled, hoping to promote understanding of NTK.

This work is primarily based on [this notebook](https://github.com/erees1/NNGP/blob/master/nngp_experiments.ipynb), the implementation of paper "[Deep Neural Networks as Gaussian Processes](https://arxiv.org/abs/1711.00165)". The original author used tensorflow to write the neural network model and NTK, and I rewritten it into [Jax](https://github.com/google/jax) and [Neural Tangents](https://github.com/google/neural-tangents) code, and referenced [this notebook](https://github.com/google/jax/blob/main/examples/mnist_classifier.py) a lot at the same time.

# Setup Environment

In [1]:
!pip install neural-tangents
!git clone https://github.com/google/jax.git
import os, shutil
!mv jax/examples examples

Looking in indexes: https://pypi.org/simple, https://us-python.pkg.dev/colab-wheels/public/simple/
Collecting neural-tangents
  Downloading neural_tangents-0.6.0-py2.py3-none-any.whl (241 kB)
[K     |████████████████████████████████| 241 kB 5.1 MB/s 
Collecting frozendict>=2.3
  Downloading frozendict-2.3.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (99 kB)
[K     |████████████████████████████████| 99 kB 9.5 MB/s 
[?25hCollecting tf2jax>=0.3.0
  Downloading tf2jax-0.3.0-py3-none-any.whl (63 kB)
[K     |████████████████████████████████| 63 kB 2.2 MB/s 
Installing collected packages: tf2jax, frozendict, neural-tangents
Successfully installed frozendict-2.3.4 neural-tangents-0.6.0 tf2jax-0.3.0
Cloning into 'jax'...
remote: Enumerating objects: 73876, done.[K
remote: Counting objects: 100% (203/203), done.[K
remote: Compressing objects: 100% (113/113), done.[K
remote: Total 73876 (delta 105), reused 164 (delta 89), pack-reused 73673[K
Receiving objects: 100% (73876/7

In [2]:
import itertools
import pandas as pd

import numpy.random as npr
import numpy

import jax.numpy as jnp
from jax import jit, grad, random
from jax.example_libraries import optimizers, stax
from jax.example_libraries.stax import Dense, Relu, LogSoftmax
from examples import datasets

from neural_tangents import stax as nt_stax
import neural_tangents as nt

In [3]:
from bokeh.plotting import figure, output_notebook, show, output_file, save
from bokeh.layouts import gridplot
from bokeh.palettes import Category10, RdBu, Bokeh, RdYlGn, RdGy, RdYlBu, Spectral
from bokeh.models import Legend
import bokeh.io
bokeh.io.reset_output()
bokeh.io.output_notebook()

# Define Useful Functions

In [4]:
def loss(params, batch, predict):
  inputs, targets = batch
  preds = predict(params, inputs)
  return -jnp.mean(jnp.sum(preds * targets, axis=1))

def accuracy(params, batch, predict):
  inputs, targets = batch
  target_class = jnp.argmax(targets, axis=1)
  predicted_class = jnp.argmax(predict(params, inputs), axis=1)
  return jnp.mean(predicted_class == target_class)

def get_nn(width):
    init_random_params, predict = stax.serial(Dense(width), Relu, Dense(width), Relu, Dense(10), LogSoftmax)
    return init_random_params, predict

def data_stream(train_images_n, train_labels_n):
    num_train = len(train_labels_n)
    num_batches = num_train//batch_size
    rng = npr.RandomState(0)
    while True:
        perm = rng.permutation(num_train)
        for i in range(num_batches):
            batch_idx = perm[i * batch_size:(i + 1) * batch_size]
            yield train_images_n[batch_idx], train_labels_n[batch_idx]

@jit
def update(i, opt_state, batch): return opt_update(i, grad(loss)(get_params(opt_state), batch), opt_state)

def df_append(df, data): df.loc[len(df)] = data

In [5]:
train_images, train_labels, test_images, test_labels = datasets.mnist()
# Fix testing dataset size
test_images, test_labels = test_images[:1000], test_labels[:1000]

widths = [2**i for i in range(6, 12)]
dataset_size = [2**i for i in range(6, 12)]
times = [10**i for i in range(1, 8)]
batch_size = 64

downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/train-labels-idx1-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-images-idx3-ubyte.gz to /tmp/jax_example_data/
downloaded https://storage.googleapis.com/cvdf-datasets/mnist/t10k-labels-idx1-ubyte.gz to /tmp/jax_example_data/


# Train Neural Network (NN)

In [6]:
def run_NN():

    df_columns = ['activation', 'dataset_size', 'width', 'training_accuracy', 'test_accuracy']
    nn_evaluate = pd.DataFrame(columns=df_columns)
            
    for width in widths:
        print(f"Running width {width}")

        for num_train in dataset_size:

            # print(f" - dataset_size {num_train}")

            init_random_params, predict = get_nn(width)
            rng = random.PRNGKey(0) 

            train_images_n, train_labels_n = train_images[:num_train], train_labels[:num_train]
            
            num_batches = num_train // batch_size
            batches = data_stream(train_images_n, train_labels_n)
            
            opt_init, opt_update, get_params = optimizers.momentum(step_size=0.001, mass=0.9)
            _, init_params = init_random_params(rng, (-1, 28 * 28))
            opt_state = opt_init(init_params)
            itercount = itertools.count()

            for epoch in range(20):
                for _ in range(num_batches):
                    opt_state = opt_update(iter(itercount), grad(loss)(get_params(opt_state), next(batches), predict), opt_state)

            params = get_params(opt_state)
            train_acc = numpy.array(accuracy(params, (train_images_n, train_labels_n), predict))
            test_acc = numpy.array(accuracy(params, (test_images, test_labels), predict))
            df_append(nn_evaluate, ["relu", num_train, width, train_acc, test_acc])

    return nn_evaluate

In [7]:
nn_evaluate = run_NN()

Running width 64
Running width 128
Running width 256
Running width 512
Running width 1024
Running width 2048


# Obtain Neural Tangent Kernel (NTK)

In [8]:
def nt_accuracy(prediction, test_labels):
    return jnp.sum(jnp.argmax(prediction, axis=1) == jnp.argmax(test_labels, axis=1))/prediction.shape[0]

def NTK():
    
    init_fn, apply_fn, kernel_fn = nt_stax.serial(
        nt_stax.Dense(64), nt_stax.Relu(), nt_stax.Dense(64), nt_stax.Relu(),
        nt_stax.Dense(64) )
    
    rng = random.PRNGKey(0) 

    df_columns = ['activation', 'time', 'dataset_size', 'width', 'training_accuracy', 'test_accuracy']
    gp_evaluate = pd.DataFrame(columns=df_columns)

    for num_train in dataset_size:
        train_images_n, train_labels_n = train_images[:num_train], train_labels[:num_train]

        print(f"Running dataset size {num_train}")
        predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, train_images_n, train_labels_n)

        for time in times:
            prediction = predict_fn(time, train_images_n, get='nngp')
            train_acc = numpy.array(nt_accuracy(prediction, train_labels_n))
            prediction = predict_fn(time, test_images, get='nngp')
            test_acc = numpy.array(nt_accuracy(prediction, test_labels))
            df_append(gp_evaluate, ["relu", time, num_train, 0, train_acc, test_acc])

        prediction = predict_fn(None, train_images_n, get='nngp')
        train_acc = numpy.array(nt_accuracy(prediction, train_labels_n))
        prediction = predict_fn(None, test_images, get='nngp')
        test_acc = numpy.array(nt_accuracy(prediction, test_labels))
        df_append(gp_evaluate, ["relu", 0, num_train, 0, train_acc, test_acc])

    return gp_evaluate

In [9]:
gp_evaluate = NTK()

Running dataset size 64
Running dataset size 128
Running dataset size 256
Running dataset size 512
Running dataset size 1024
Running dataset size 2048


# Plot Results


In [12]:
x_label, y_label = 'Training Dataset Size', 'Test Accuracy'

w, h = 40, 30
s1 = figure(title='MNIST, NTK', x_axis_label = x_label, y_axis_label = y_label, plot_width=w, plot_height=h)
s2 = figure(title='MNIST, NN', x_axis_label = x_label, y_axis_label = y_label, plot_width=w, plot_height=h)

color_gp, color_nn = Bokeh[len(times)], RdBu[len(widths)]
for i, time in enumerate(times):
    gp_relu = gp_evaluate[(gp_evaluate.activation=='relu') & (gp_evaluate.time==time)]
    s1.line(gp_relu.dataset_size, gp_relu["test_accuracy"], color=color_gp[i], line_width=2, legend_label=f'NTK time: 1e{i+1} sec') 
for i, width in enumerate(widths):
    nn_relu = nn_evaluate[(nn_evaluate.width==width) & (nn_evaluate.activation=='relu')]
    s2.line(nn_relu.dataset_size, nn_relu["test_accuracy"], color=color_nn[i], line_width=2, legend_label=f'NN width: {width}')

gp_relu = gp_evaluate[(gp_evaluate.activation=='relu') & (gp_evaluate.time==0)]
s1.line(gp_relu.dataset_size, gp_relu["test_accuracy"], color="black", line_dash=[2], line_width=2, legend_label=f'NTK time: ∞ sec')
s2.line(gp_relu.dataset_size, gp_relu["test_accuracy"], color="black", line_dash=[2], line_width=2, legend_label=f'NTK time: ∞ sec')
grid = gridplot([s1, s2], ncols=2, plot_width=w, plot_height=h, sizing_mode='scale_both')

for s in s1, s2:
    s.legend.location = 'bottom_right'
    s.legend.spacing = -5
    s.legend.orientation = "vertical"
    s.legend.label_text_font_size = "12px"
    s.y_range.start = 0
    
show(grid)

# To save file, please uncomment these few lines.
# output_file(filename="Compare_NN_and_NTK.html")
# save(grid)