In [1]:
"""Implemented with Tensorflow v1.5"""

import numpy as np
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data

In [2]:
mnist = input_data.read_data_sets('mnist', one_hot=False)

Successfully downloaded train-images-idx3-ubyte.gz 9912422 bytes.
Extracting mnist/train-images-idx3-ubyte.gz
Successfully downloaded train-labels-idx1-ubyte.gz 28881 bytes.
Extracting mnist/train-labels-idx1-ubyte.gz
Successfully downloaded t10k-images-idx3-ubyte.gz 1648877 bytes.
Extracting mnist/t10k-images-idx3-ubyte.gz
Successfully downloaded t10k-labels-idx1-ubyte.gz 4542 bytes.
Extracting mnist/t10k-labels-idx1-ubyte.gz


In [3]:
batch_size = 32
learning_rate = 1e-4
max_train_itr = int(1e5)

In [5]:
def sparse_fully_connected(inpt, n_out, n_params=None, density=1.):
    """Multiplies `inpt` with a sparse matrix.

    :param inpt: tf.Tensor
    :param n_out: int, size of the output vector
    :param n_params: int, number of learnable parameters in the matrix. If None,
        the number of parameters is equal to `density` times the number of parameters
        that a dense matrix would have.
    :param density: float in (0, 1.]; number of parameters relative to the number of
    parameters in a dense matrix.
    
    :return: tf.Tensor
    """
    n_in = int(inpt.shape[-1])
    shape = [n_in, n_out]
    dense_n_params = int(np.prod(shape))
    
    if n_params is None:
        assert 0 < density <= 1.
        n_params = int(dense_n_params * density)
        
    assert 0. < n_params <= dense_n_params
    
    params = tf.get_variable('weights', shape=n_params, dtype=tf.float32, trainable=True)
    
    if n_params == dense_n_params:
        w = tf.reshape(params, shape)
    else:
        linear_idx = np.random.choice(dense_n_params, size=n_params, replace=False)
        idx = np.unravel_index(linear_idx, shape)
        idx = np.stack(idx, 1)
        w = tf.scatter_nd(idx, params, shape)
    
    w_is_sparse = (n_params < 0.5 * dense_n_params)

    return tf.matmul(inpt, w, b_is_sparse=w_is_sparse)

In [13]:
tf.reset_default_graph()
x = tf.placeholder(tf.float32, [batch_size, 28**2], 'img')
y = tf.placeholder(tf.int32, [batch_size], 'label')

# logits = sparse_fully_connected(x, 10, density=.1)
logits = sparse_fully_connected(x, 10, n_params=1000)

In [14]:
loss = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=y)
loss = tf.reduce_mean(loss)

In [15]:
opt = tf.train.AdamOptimizer(learning_rate)

In [16]:
global_step = tf.train.get_or_create_global_step()
train_step = opt.minimize(loss, global_step=global_step)

In [17]:
sess = tf.Session()

In [18]:
sess.run(tf.global_variables_initializer())

In [19]:
train_itr = sess.run(global_step)
while train_itr < max_train_itr:
    xx, yy = mnist.train.next_batch(batch_size)
    fd = {x: xx, y:yy}
    train_itr, l, _ = sess.run([global_step, loss, train_step], fd)
    
    if train_itr % 1000 == 0:
        print train_itr, l

1000 1.9707838
2000 1.7411225
3000 1.3305104
4000 1.3550026
5000 1.0721931
6000 1.1515383
7000 0.9379637
8000 0.9047595
9000 0.9809159
10000 0.714977
11000 0.9139421
12000 0.6471793
13000 0.86941516
14000 0.78030103


KeyboardInterrupt: 

In [None]:
total_n_params = 0
for v in tf.trainable_variables():
    shape = v.shape.as_list()
    n_params = int(np.prod(shape))
    total_n_params += n_params
    print '\t', v.name, shape, n_params
print 'Total number of trainable parameters:', t
    