WRN-L-10 CNN

In [None]:
import os
import tensorflow as tf
from tensorflow.keras.datasets import cifar10, mnist, fashion_mnist

tf.random.set_seed(1234)

tf.debugging.set_log_device_placement(True)


def regularized_padded_conv(*args, **kwargs):
    return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer,
                                  kernel_initializer='he_normal', use_bias=False)


def bn_relu(x):
    x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)


def shortcut(x, filters, stride, mode):
    if x.shape[-1] == filters:
        return x
    elif mode == 'B':
        return regularized_padded_conv(filters, 1, strides=stride)(x)
    elif mode == 'B_original':
        x = regularized_padded_conv(filters, 1, strides=stride)(x)
        return tf.keras.layers.BatchNormalization()(x)
    elif mode == 'A':
        return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride>1 else x,
                      paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])])
    else:
        raise KeyError("Parameter shortcut_type not recognized!")
    

def original_block(x, filters, stride=1, **kwargs):
    c1 = regularized_padded_conv(filters, 3, strides=stride)(x)
    c2 = regularized_padded_conv(filters, 3)(bn_relu(c1))
    c2 = tf.keras.layers.BatchNormalization()(c2)
    
    mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type
    x = shortcut(x, filters, stride, mode=mode)
    return tf.keras.layers.ReLU()(x + c2)
    
    
def preactivation_block(x, filters, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
        
    c1 = regularized_padded_conv(filters, 3, strides=stride)(flow)
    if _dropout:
        c1 = tf.keras.layers.Dropout(_dropout)(c1)
        
    c2 = regularized_padded_conv(filters, 3)(bn_relu(c1))
    x = shortcut(x, filters, stride, mode=_shortcut_type)
    return x + c2


def bootleneck_block(x, filters, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
         
    c1 = regularized_padded_conv(filters//_bootleneck_width, 1)(flow)
    c2 = regularized_padded_conv(filters//_bootleneck_width, 3, strides=stride)(bn_relu(c1))
    c3 = regularized_padded_conv(filters, 1)(bn_relu(c2))
    x = shortcut(x, filters, stride, mode=_shortcut_type)
    return x + c3


def group_of_blocks(x, block_type, num_blocks, filters, stride, block_idx=0):
    global _preact_shortcuts
    preact_block = True if _preact_shortcuts or block_idx == 0 else False
    
    x = block_type(x, filters, stride, preact_block=preact_block)
    for i in range(num_blocks-1):
        x = block_type(x, filters)
    return x


def Resnet(input_shape, n_classes, l2_reg=1e-4, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2),
           shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
           dropout=0, cardinality=1, bootleneck_width=4, preact_shortcuts=True):
    
    global _regularizer, _shortcut_type, _preact_projection, _dropout, _cardinality, _bootleneck_width, _preact_shortcuts
    _bootleneck_width = bootleneck_width # used in ResNeXts and bootleneck blocks
    _regularizer = tf.keras.regularizers.l2(l2_reg)
    _shortcut_type = shortcut_type # used in blocks
    _cardinality = cardinality # used in ResNeXts
    _dropout = dropout # used in Wide ResNets
    _preact_shortcuts = preact_shortcuts
    
    block_types = {'preactivated': preactivation_block,
                   'bootleneck': bootleneck_block,
                   'original': original_block}
    
    selected_block = block_types[block_type]
    inputs = tf.keras.layers.Input(shape=input_shape)
    flow = regularized_padded_conv(**first_conv)(inputs)
    
    if block_type == 'original':
        flow = bn_relu(flow)
    
    for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
        flow = group_of_blocks(flow,
                               block_type=selected_block,
                               num_blocks=group_size,
                               block_idx=block_idx,
                               filters=feature,
                               stride=stride)
    
    if block_type != 'original':
        flow = bn_relu(flow)
    
    flow = tf.keras.layers.GlobalAveragePooling2D()(flow)
    outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer)(flow)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model


def load_weights_func(model, model_name):
    try: model.load_weights(os.path.join('saved_models', model_name + '.tf'))
    except tf.errors.NotFoundError: print("No weights found for this model!")
    return model


In [None]:
def wide_resnet(N, K, block_type='original', shortcut_type='A', dropout=0, l2_reg=0):
    assert (N-4) % 6 == 0, "N-4 has to be divisible by 6"
    lpb = (N-4) // 6 # layers per block - since N is total number of convolutional layers in Wide ResNet

    #Change input shape based on dataset: MNIST,Fashion MNIST= (28,28,1), CIFAR= (32,32,3)
    model = Resnet(input_shape=(28, 28, 1), n_classes=10, l2_reg=l2_reg, group_sizes=(lpb, lpb, lpb), features=(16*K, 32*K, 64*K),
                   strides=(1, 2, 2), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type,
                   block_type=block_type, dropout=dropout, preact_shortcuts=True)
    return model

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

#Uncomment required dataset
#(x_train, y_train), (x_valid, y_valid) = fashion_mnist.load_data()
(x_train, y_train), (x_valid, y_valid) = mnist.load_data()
#(x_train, y_train), (x_valid, y_valid) = cifar10.load_data()

# Uncomment following 4 lines for MNIST and Fashion MNIST
x_train= x_train[:,:,:,np.newaxis]
x_valid= x_valid[:,:,:,np.newaxis]
y_train= y_train[:,np.newaxis]
y_valid= y_valid[:,np.newaxis]

# Normalize the image pixels in the range [0, 1]
x_train_normalized = np.array(jnp.array(x_train[:400] / 255.))
x_valid_normalized = np.array(jnp.array(x_valid[:40] / 255.))

# One hot encoding applied to the labels. We have 10
# classes in the dataset, hence the depth of OHE would be 10
y_train_ohe = np.array(jnp.squeeze(jax.nn.one_hot(y_train[:400], num_classes=10)))
y_valid_ohe = np.array(jnp.squeeze(jax.nn.one_hot(y_valid[:40], num_classes=10)))


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
#To vary depth to 1,2, OR 3 layers per ResNet block, change N to 10,16,22
#To vary width, change K to 1,2
WRN_model = wide_resnet(N=10,K=1)

optimizer= tf.keras.optimizers.SGD(learning_rate=0.1,momentum=0.9)

WRN_model.compile(optimizer=optimizer,
              loss=tf.keras.losses.MeanSquaredError(),
              metrics=['accuracy'])

WRN_model.fit(x_train_normalized[0:400], y_train_ohe[0:400], epochs=80,validation_data=(x_valid_normalized[0:40], y_valid_ohe[0:40]))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_train_function_6629 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_train_function_6629 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Execu

<keras.callbacks.History at 0x7f8ce677bed0>

In [None]:
test_loss, test_acc = WRN_model.evaluate(x_valid_normalized[0:40],y_valid_ohe[0:40], verbose=2)

Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op RangeDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op MapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op FlatMapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op TensorDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ZipDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ParallelMapDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op OptionsDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Rea

WRN-L-10 CNN (Layerwise)

In [None]:
import os
import tensorflow as tf
from tensorflow.keras.datasets import cifar10, mnist, fashion_mnist

tf.random.set_seed(1234)

tf.debugging.set_log_device_placement(True)


def regularized_padded_conv(*args, **kwargs):
    return tf.keras.layers.Conv2D(*args, **kwargs, padding='same', kernel_regularizer=_regularizer,
                                  kernel_initializer='he_normal', use_bias=False)


def bn_relu(x):
    x = tf.keras.layers.BatchNormalization()(x)
    return tf.keras.layers.ReLU()(x)


def shortcut(x, filters, stride, mode):
    if x.shape[-1] == filters:
        return x
    elif mode == 'B':
        return regularized_padded_conv(filters, 1, strides=stride)(x)
    elif mode == 'B_original':
        x = regularized_padded_conv(filters, 1, strides=stride)(x)
        return tf.keras.layers.BatchNormalization()(x)
    elif mode == 'A':
        return tf.pad(tf.keras.layers.MaxPool2D(1, stride)(x) if stride>1 else x,
                      paddings=[(0, 0), (0, 0), (0, 0), (0, filters - x.shape[-1])])
    else:
        raise KeyError("Parameter shortcut_type not recognized!")
    

def original_block(x, filters, stride=1, **kwargs):
    c1 = regularized_padded_conv(filters, 3, strides=stride)(x)
    c2 = regularized_padded_conv(filters, 3)(bn_relu(c1))
    c2 = tf.keras.layers.BatchNormalization()(c2)
    
    mode = 'B_original' if _shortcut_type == 'B' else _shortcut_type
    x = shortcut(x, filters, stride, mode=mode)
    return tf.keras.layers.ReLU()(x + c2)
    
    
def preactivation_block(x, filters, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
        
    c1 = regularized_padded_conv(filters, 3, strides=stride)(flow)
    if _dropout:
        c1 = tf.keras.layers.Dropout(_dropout)(c1)
        
    c2 = regularized_padded_conv(filters, 3)(bn_relu(c1))
    x = shortcut(x, filters, stride, mode=_shortcut_type)
    return x + c2


def bootleneck_block(x, filters, stride=1, preact_block=False):
    flow = bn_relu(x)
    if preact_block:
        x = flow
         
    c1 = regularized_padded_conv(filters//_bootleneck_width, 1)(flow)
    c2 = regularized_padded_conv(filters//_bootleneck_width, 3, strides=stride)(bn_relu(c1))
    c3 = regularized_padded_conv(filters, 1)(bn_relu(c2))
    x = shortcut(x, filters, stride, mode=_shortcut_type)
    return x + c3


def group_of_blocks(x, block_type, num_blocks, filters, stride, block_idx=0):
    global _preact_shortcuts
    preact_block = True if _preact_shortcuts or block_idx == 0 else False
    
    x = block_type(x, filters, stride, preact_block=preact_block)
    for i in range(num_blocks-1):
        x = block_type(x, filters)
    return x


def Resnet(input_shape, n_classes, l2_reg=1e-4, group_sizes=(2, 2, 2), features=(16, 32, 64), strides=(1, 2, 2),
           shortcut_type='B', block_type='preactivated', first_conv={"filters": 16, "kernel_size": 3, "strides": 1},
           dropout=0, cardinality=1, bootleneck_width=4, preact_shortcuts=True):
    
    global _regularizer, _shortcut_type, _preact_projection, _dropout, _cardinality, _bootleneck_width, _preact_shortcuts
    _bootleneck_width = bootleneck_width # used in ResNeXts and bootleneck blocks
    _regularizer = tf.keras.regularizers.l2(l2_reg)
    _shortcut_type = shortcut_type # used in blocks
    _cardinality = cardinality # used in ResNeXts
    _dropout = dropout # used in Wide ResNets
    _preact_shortcuts = preact_shortcuts
    
    block_types = {'preactivated': preactivation_block,
                   'bootleneck': bootleneck_block,
                   'original': original_block}
    
    selected_block = block_types[block_type]
    inputs = tf.keras.layers.Input(shape=input_shape)
    flow = regularized_padded_conv(**first_conv)(inputs)
    
    if block_type == 'original':
        flow = bn_relu(flow)
    
    for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
        flow = group_of_blocks(flow,
                               block_type=selected_block,
                               num_blocks=group_size,
                               block_idx=block_idx,
                               filters=feature,
                               stride=stride)
    
    if block_type != 'original':
        flow = bn_relu(flow)
    
    flow = tf.keras.layers.GlobalAveragePooling2D()(flow)
    outputs = tf.keras.layers.Dense(n_classes, kernel_regularizer=_regularizer)(flow)
    model = tf.keras.Model(inputs=inputs, outputs=outputs)
    return model


def load_weights_func(model, model_name):
    try: model.load_weights(os.path.join('saved_models', model_name + '.tf'))
    except tf.errors.NotFoundError: print("No weights found for this model!")
    return model

In [None]:
def wide_resnet(N, K, block_type='original', shortcut_type='A', dropout=0, l2_reg=0):
    assert (N-4) % 6 == 0, "N-4 has to be divisible by 6"
    lpb = (N-4) // 6 # layers per block - since N is total number of convolutional layers in Wide ResNet

    #Change input shape based on dataset: MNIST,Fashion MNIST= (28,28,1), CIFAR= (32,32,3)
    model = Resnet(input_shape=(28, 28, 1), n_classes=10, l2_reg=l2_reg, group_sizes=(lpb,), features=(16*K,),
                   strides=(1,), first_conv={"filters": 16, "kernel_size": 3, "strides": 1}, shortcut_type=shortcut_type,
                   block_type=block_type, dropout=dropout, preact_shortcuts=True)
    return model

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

#Uncomment required dataset
#(x_train, y_train), (x_valid, y_valid) = fashion_mnist.load_data()
(x_train, y_train), (x_valid, y_valid) = mnist.load_data()
#(x_train, y_train), (x_valid, y_valid) = cifar10.load_data()

# Uncomment following 4 lines for MNIST and Fashion MNIST
x_train= x_train[:,:,:,np.newaxis]
x_valid= x_valid[:,:,:,np.newaxis]
y_train= y_train[:,np.newaxis]
y_valid= y_valid[:,np.newaxis]

# Normalize the image pixels in the range [0, 1]
x_train_normalized = np.array(jnp.array(x_train[:400] / 255.))
x_valid_normalized = np.array(jnp.array(x_valid[:40] / 255.))

# One hot encoding applied to the labels. We have 10
# classes in the dataset, hence the depth of OHE would be 10
y_train_ohe = np.array(jnp.squeeze(jax.nn.one_hot(y_train[:400], num_classes=10)))
y_valid_ohe = np.array(jnp.squeeze(jax.nn.one_hot(y_valid[:40], num_classes=10)))


In [None]:
#To vary depth to 1,2, OR 3 layers per ResNet block, change N to 10,16,22
#To vary width, change K to 1,2
WRN_model = wide_resnet(N=10,K=1)

optimizer= tf.keras.optimizers.SGD(learning_rate=0.1,momentum=0.9)

WRN_model.compile(optimizer=optimizer,
              loss=tf.keras.losses.MeanSquaredError(),
              metrics=['accuracy'])

WRN_model.fit(x_train_normalized[0:400], y_train_ohe[0:400], epochs=80,validation_data=(x_valid_normalized[0:40], y_valid_ohe[0:40]))

[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_train_function_14170 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_train_function_14170 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op AssignVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executin

<keras.callbacks.History at 0x7f90b00cec10>

In [None]:
output = WRN_model.layers[-3].output

int_model= tf.keras.Model(WRN_model.input,output)

train= int_model.predict(x_train_normalized[0:400])
valid= int_model.predict(x_valid_normalized[0:40])

#To vary width as per fixed layers, change K to previous K value of 1,2
K=1

group_sizes=(2,2)
features=(32*K,64*K)
strides=(2,2)

block_types = {'preactivated': preactivation_block,
                   'bootleneck': bootleneck_block,
                   'original': original_block}
    
selected_block = block_types['original']

inputs= output

flow= inputs

for block_idx, (group_size, feature, stride) in enumerate(zip(group_sizes, features, strides)):
  flow = group_of_blocks(flow,block_type=selected_block,num_blocks=group_size,block_idx=block_idx,filters=feature,stride=stride)
        
_regularizer= tf.keras.regularizers.l2(0)

flow = tf.keras.layers.GlobalAveragePooling2D()(flow)
outputs = tf.keras.layers.Dense(10, kernel_regularizer=_regularizer)(flow)

WRN_model_new = tf.keras.Model(inputs, outputs)

optimizer= tf.keras.optimizers.SGD(learning_rate=0.1,momentum=0.9)

WRN_model_new.compile(optimizer=optimizer,
              loss=tf.keras.losses.MeanSquaredError(),
              metrics=['accuracy'])

WRN_model_new.fit(train, y_train_ohe[0:400], epochs=80,validation_data=(valid, y_valid_ohe[0:40]))


[1;30;43mStreaming output truncated to the last 5000 lines.[0m
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_train_function_27793 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op __inference_train_function_27793 in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op ReadVariableOp in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op Identity in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op AnonymousIteratorV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Execu

<keras.callbacks.History at 0x7f8cbe81bc10>

In [None]:
test_loss, test_acc = WRN_model_new.evaluate(valid,y_valid_ohe[0:40], verbose=2)

Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op _EagerConst in device /job:localhost/replica:0/task:0/device:GPU:0
Executing op RangeDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op MapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op PrefetchDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op FlatMapDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op TensorDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op RepeatDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ZipDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op ParallelMapDatasetV2 in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op OptionsDataset in device /job:localhost/replica:0/task:0/device:CPU:0
Executing op Rea

WRN-L-10 Finite-Width NTK

In [None]:
!pip install --upgrade pip

Collecting pip
  Downloading pip-22.1.1-py3-none-any.whl (2.1 MB)
[K     |████████████████████████████████| 2.1 MB 6.7 MB/s 
[?25hInstalling collected packages: pip
  Attempting uninstall: pip
    Found existing installation: pip 21.1.3
    Uninstalling pip-21.1.3:
      Successfully uninstalled pip-21.1.3
Successfully installed pip-22.1.1


In [None]:
import jax.tools.colab_tpu
jax.tools.colab_tpu.setup_tpu()

In [None]:
!pip install neural-tangents

Collecting neural-tangents
  Downloading neural_tangents-0.5.0-py2.py3-none-any.whl (193 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/193.4 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m193.4/193.4 kB[0m [31m6.7 MB/s[0m eta [36m0:00:00[0m
Collecting frozendict>=2.3
  Downloading frozendict-2.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (99 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m99.0/99.0 kB[0m [31m11.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: frozendict, neural-tangents
Successfully installed frozendict-2.3.2 neural-tangents-0.5.0
[0m

In [None]:
from neural_tangents import stax
from jax import random

def WideResnetBlock(channels, strides=(1, 1), channel_mismatch=False):
  Main = stax.serial(
      stax.Relu(), stax.Conv(channels, (3, 3), strides, padding='SAME'),
      stax.Relu(), stax.Conv(channels, (3, 3), padding='SAME'))
  return Main

def WideResnetGroup(n, channels, strides=(1, 1)):
  blocks = []
  blocks += [WideResnetBlock(channels, strides, channel_mismatch=True)]
  for _ in range(n - 1):
    blocks += [WideResnetBlock(channels, (1, 1))]
  return stax.serial(*blocks)

def WideResnet(block_size, k, num_classes):
  return stax.serial(
      stax.Conv(16, (3, 3), padding='SAME'),
      WideResnetGroup(block_size, int(16 * k)),
      WideResnetGroup(block_size, int(32 * k), (2, 2)),
      WideResnetGroup(block_size, int(64 * k), (2, 2)),
      stax.Flatten(),
      stax.Dense(num_classes, 1., 0.))

# Change k as per required WRN-L-10 width (number of features) and block_size for number of blocks required per group
init_fn,apply_fn,_ = WideResnet(block_size=1, k=1, num_classes=10)

#Change input shape to (-1,28,28,1) for MNIST and Fashion MNIST, and to (-1,32,32,3) for CIFAR10
_, net_params = init_fn(random.PRNGKey(0), input_shape=(-1,28,28,1))

In [None]:
import tensorflow as tf
from tensorflow.keras.datasets import cifar10, mnist, fashion_mnist

In [None]:
import jax
import jax.numpy as jnp
import numpy as np

#Uncomment required dataset
#(x_train, y_train), (x_valid, y_valid) = fashion_mnist.load_data()
(x_train, y_train), (x_valid, y_valid) = mnist.load_data()
#(x_train, y_train), (x_valid, y_valid) = cifar10.load_data()

# Uncomment the following 4 lines for MNIST and Fashion MNIST, Comment for CIFAR10
x_train= x_train[:,:,:,np.newaxis]
x_valid= x_valid[:,:,:,np.newaxis]
y_train= y_train[:,np.newaxis]
y_valid= y_valid[:,np.newaxis]

# Normalize the image pixels in the range [0, 1]
x_train_normalized = jnp.array(x_train[:400] / 255.)
x_valid_normalized = jnp.array(x_valid[:40] / 255.)

# One hot encoding applied to the labels. We have 10
# classes in the dataset, hence the depth of OHE would be 10
y_train_ohe = jnp.squeeze(jax.nn.one_hot(y_train[:400], num_classes=10))
y_valid_ohe = jnp.squeeze(jax.nn.one_hot(y_valid[:40], num_classes=10))


Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
from neural_tangents import stax
import neural_tangents as nt

fx_train = apply_fn(net_params, x_train_normalized)
fx_test = apply_fn(net_params, x_valid_normalized)

ntk = nt.batch(nt.empirical_ntk_fn(apply_fn, vmap_axes=0),
                batch_size=5, device_count=-1)
g_dd = ntk(x_train_normalized, None, net_params)
g_td = ntk(x_valid_normalized, x_train_normalized, net_params)
predictor = nt.predict.gradient_descent_mse(g_dd, y_train_ohe)

_TRAIN_TIME=50

fx_train, fx_test = predictor(_TRAIN_TIME, fx_train, fx_test, g_td)



In [None]:
target_class = jnp.argmax(y_valid_ohe, axis=1)
predicted_class = jnp.argmax(fx_test, axis=1)
acc= jnp.mean(predicted_class == target_class)
print("Accuracy:",acc)

Accuracy: 0.175
