In [1]:
import tensorflow as tf
import numpy as np
import random
import os

seed=1234
def seed_everything(seed=1234):
    random.seed(seed)
    np.random.seed(seed)
    os.environ['pythonhashseed'] = str(seed)
    tf.random.set_random_seed(seed)
seed_everything(seed)

# Load data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)

''' Create the model - VERSION 1 - without variable scoping - tested in Colab
def weight_variable(shape):
  initial = tf.truncated_normal(shape, stddev=0.1)
  return tf.Variable(initial)
def bias_variable(shape):
  initial = tf.constant(0.1, shape=shape)
  return tf.Variable(initial)
def conv2d(x, W):
  return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')
def max_pool_2x2(x):
  return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')
W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])
x = tf.placeholder(tf.float32, [None, 784], name="x-input")
x_image = tf.reshape(x, [-1,28,28,1])
h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)
W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])
h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)
h_pool2 = max_pool_2x2(h_conv2)
W_fc1 = weight_variable([7 * 7 * 64, 1024])
b_fc1 = bias_variable([1024])
h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
h_fc1 = tf.nn.relu(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)
keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)
W_fc2 = weight_variable([1024, 10])
b_fc2 = bias_variable([10])
y=tf.nn.softmax(tf.matmul(h_fc1_drop, W_fc2) + b_fc2)
'''

''' VERSION 2 - using variable scope - tested in google colab '''
def conv_relu(input, kernel_shape, bias_shape):
    # Create variable named "weights".
    weights = tf.get_variable("weights", kernel_shape,
        initializer=tf.truncated_normal_initializer(stddev=0.1))
    # Create variable named "biases".
    biases = tf.get_variable("biases", bias_shape,
        initializer=tf.constant_initializer(0.1))
    conv = tf.nn.conv2d(input, weights,
        strides=[1, 1, 1, 1], padding='SAME')
    return tf.nn.relu(conv + biases)

def dense(input, weight_shape, bias_shape):
    weights = tf.get_variable("weights", weight_shape,
                              initializer=tf.truncated_normal_initializer(stddev=0.1))
    biases = tf.get_variable("biases", bias_shape,
        initializer=tf.constant_initializer(0.1))
    input_flat = tf.reshape(input, [-1, weight_shape[0]])
    return tf.matmul(input_flat, weights) + biases

x = tf.placeholder(tf.float32, [None, 784], name="x-input")
x_image = tf.reshape(x, [-1,28,28,1])

with tf.variable_scope("conv1"):
    # Variables created here will be named "conv1/weights", "conv1/biases".
    h_conv1 = conv_relu(x_image, [5, 5, 1, 32], [32])
    h_pool1 = tf.nn.max_pool(h_conv1, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')

with tf.variable_scope("conv2"):
    # Variables created here will be named "conv2/weights", "conv2/biases".
    h_conv2 = conv_relu(h_pool1, [5, 5, 32, 64], [64])
    h_pool2 = tf.nn.max_pool(h_conv2, ksize=[1, 2, 2, 1],
                        strides=[1, 2, 2, 1], padding='SAME')

with tf.variable_scope("dense1"):
    # Variables created here will be named "dense1/weights", "dense1/biases".
    h_fc1 = tf.nn.relu(dense(h_pool2, [7 * 7 * 64, 1024], [1024]))

keep_prob = tf.placeholder(tf.float32)
h_fc1_drop = tf.nn.dropout(h_fc1, keep_prob)

with tf.variable_scope("dense2"):
    # Variables created here will be named "dense1/weights", "dense1/biases".
    h_fc2 = tf.nn.relu(dense(h_fc1_drop, [1024, 10], [10]))
    y = tf.nn.softmax(h_fc2)
''' VERSION 2 ends here '''

# Add summary ops to collect data
# w_hist = tf.summary.histogram("weights", W)
# b_hist = tf.summary.histogram("biases", b)
# y_hist = tf.summary.histogram("y", y)

# Define loss and optimizer
y_ = tf.placeholder(tf.float32, [None, 10], name="y-input")
# More name scopes will clean up the graph representation
with tf.name_scope("xent") as scope:
  cross_entropy = -tf.reduce_sum(y_*tf.log(y))
  ce_summ = tf.summary.scalar("cross_entropy", cross_entropy)
with tf.name_scope("train") as scope:
  train_step = tf.train.AdamOptimizer(1e-4).minimize(cross_entropy)

with tf.name_scope("test") as scope:
  correct_prediction = tf.equal(tf.argmax(y,1), tf.argmax(y_,1))
  accuracy = tf.reduce_mean(tf.cast(correct_prediction, tf.float32))
  accuracy_summary = tf.summary.scalar("accuracy", accuracy)

# Merge all the summaries and write them out to /tmp/mnist_lin
merged = tf.summary.merge_all()
sess = tf.Session()
writer = tf.summary.FileWriter("./mnist_cnn", sess.graph_def)
init = tf.global_variables_initializer()
sess.run(init)

# Train the model, and feed in test data and record summaries every 10 steps
for i in range(10000):
  batch = mnist.train.next_batch(100)
  if i%100 == 0:
    train_accuracy = sess.run(accuracy, feed_dict={
        x:batch[0], y_: batch[1], keep_prob: 1.0})
    print("step %d, training accuracy %g"%(i, train_accuracy))
  sess.run(train_step, feed_dict={x: batch[0], y_: batch[1], keep_prob: 0.5})

print("test accuracy %g"%sess.run(accuracy, feed_dict={
    x: mnist.test.images, y_: mnist.test.labels, keep_prob: 1.0}))

Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please write your own downloading logic.
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-images-idx3-ubyte.gz
Instructions for updating:
Please use tf.data to implement this functionality.
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Instructions for updating:
Please use tf.one_hot on tensors.
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
Instructions for updating:
Please use alternatives such as official/mnist/dataset.py from tensorflow/models.
Instructions for updating:
Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.
step 0, training accuracy 0.14
step 100, training accuracy 0.08
step 200, training accuracy 0.08
step 300, training accuracy 0.1
step 400, training accuracy 0.36
step 500, training accur

In [0]:
from jax import jit, grad, vmap, random
from functools import partial
import jax
import jax.numpy as np
from jax.experimental import stax # neural network library
from jax.experimental.stax import GeneralConv, Conv, Dense, MaxPool, Relu, Flatten, LogSoftmax, LeakyRelu, Dropout # neural network layers
from jax.nn import softmax
from jax.nn.initializers import zeros
import matplotlib.pyplot as plt # visualization
import numpy as onp
from jax.experimental import optimizers
from jax.tree_util import tree_multimap  # Element-wise manipulation of collections of numpy arrays

In [3]:
# Load data
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
rng = random.PRNGKey(1)
print(mnist.train.images.shape, mnist.train.labels.shape, mnist.test.images.shape, mnist.test.labels.shape)

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz
(55000, 784) (55000, 10) (10000, 784) (10000, 10)


In [4]:
#Define model
net_init, net_apply = stax.serial(
    Conv(32, (5,5), padding='SAME'),
    Relu, 
    MaxPool((2,2), strides=(2,2), padding='SAME'),
    Conv(64, (5,5), strides=(1,1), padding='SAME'), 
    Relu, 
    MaxPool((2,2), strides=(2,2), padding='SAME'), Flatten,
    Dense(1024), #Dropout(0.5),
    Dense(10)
)

in_shape = (-1, 28, 28,1)
out_shape = (-1, 10)
_, net_params = net_init(rng, in_shape)
#print(net_params[0][0].shape,net_params[0][1].shape,net_params[1][0].shape, net_params[1][1].shape, mnist.test.images[0].reshape(-1, 28, 28,1).shape)
# print(net_params)
dummy_inputs=mnist.test.images[0:5].reshape(-1, 28, 28,1)
out=net_apply(net_params,dummy_inputs)
print(dummy_inputs.shape, out.shape)

(5, 28, 28, 1) (5, 10)


In [0]:
#Define losses and optimizers
def loss(params, inputs, targets):
    # Computes average loss for the batch
    # print('Hello loss')
    # print(inputs.reshape(in_shape).shape, targets.reshape(out_shape).shape)
    predictions = softmax(net_apply(params, inputs.reshape(in_shape)), axis=-1)
    # print('Predicted')
    # print(inputs.shape, targets.shape, predictions.shape)
    return -np.sum(targets*np.log(predictions))

def batch_loss(params, inputs, targets):
    losses=vmap(partial(loss, params))(inputs, targets)
    # losses=loss(params, inputs, targets)
    # print(losses.shape)
    # print('Before batch mean')
    return np.mean(losses, axis=0)

def accuracy(params, inputs, targets):
    predictions = softmax(net_apply(params, inputs, mode='test'), axis=-1)
    correct_prediction = np.equal(np.argmax(targets,-1), np.argmax(predictions,-1))
    return np.mean(correct_prediction)

opt_init, opt_update, get_params = optimizers.adam(step_size=1e-4)  # this LR seems to be better than 1e-2 and 1e-4
out_shape, net_params = net_init(rng, in_shape)
opt_state = opt_init(net_params)

@jit
def step(i, opt_state, x, y):
    p = get_params(opt_state)
    g = grad(batch_loss)(p, x, y)
    l = batch_loss(p, x, y)
    # print('After batch mean')
    return opt_update(i, g, opt_state), l

In [6]:
#Training
losses=[]
for i in range(10000):
    batch_xs, batch_ys = mnist.train.next_batch(100)
    # print(batch_xs.reshape(-1,28,28,1).shape, batch_ys.shape)
    #print(batch_xs.reshape(-1,28,28,1,1).shape)
    if i % 100 == 0:
        # print(i, mnist.test.images.shape, mnist.test.labels.shape)
        current_params = get_params(opt_state)
        print("Accuracy at step %s: %g" % (i, accuracy(current_params, mnist.test.images.reshape(-1,28,28,1), mnist.test.labels)))
    opt_state, l = step(i, opt_state, batch_xs.reshape(-1,28,28,1), batch_ys)
    losses.append(l)
    #print(l)
net_params=get_params(opt_state)

Accuracy at step 0: 0.062
Accuracy at step 100: 0.9039
Accuracy at step 200: 0.9372
Accuracy at step 300: 0.9552
Accuracy at step 400: 0.9604
Accuracy at step 500: 0.9698
Accuracy at step 600: 0.9716
Accuracy at step 700: 0.9746
Accuracy at step 800: 0.9764
Accuracy at step 900: 0.9806
Accuracy at step 1000: 0.9803
Accuracy at step 1100: 0.9799
Accuracy at step 1200: 0.9823
Accuracy at step 1300: 0.9779
Accuracy at step 1400: 0.983
Accuracy at step 1500: 0.9785
Accuracy at step 1600: 0.9821
Accuracy at step 1700: 0.9831
Accuracy at step 1800: 0.9832
Accuracy at step 1900: 0.985
Accuracy at step 2000: 0.9853
Accuracy at step 2100: 0.9859
Accuracy at step 2200: 0.9849
Accuracy at step 2300: 0.985
Accuracy at step 2400: 0.9868
Accuracy at step 2500: 0.9861
Accuracy at step 2600: 0.9871
Accuracy at step 2700: 0.989
Accuracy at step 2800: 0.9894
Accuracy at step 2900: 0.9878
Accuracy at step 3000: 0.9889
Accuracy at step 3100: 0.9887
Accuracy at step 3200: 0.9863
Accuracy at step 3300: 0.98