In [1]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

from tensorflow.keras import datasets
from groupy.gconv.tensorflow_gconv.splitgconv2d import gconv2d, gconv2d_util

In [63]:
(X, Y), (X_t, Y_t) = datasets.mnist.load_data()

# Normalize pixel values to be between 0 and 1
X, X_t = X / 255.0, X_t / 255.0
X = X.reshape(-1, 28, 28, 1)
X_t = X_t.reshape(-1, 28, 28, 1)

In [164]:
ksize = 3
batch_normalization = True
activation = 'relu'
dropout = .3
batch_size = 128
k = None

In [268]:
graph = tf.Graph()
with graph.as_default():
    '''Input Data'''
    x = tf.placeholder(tf.float32, shape=(batch_size, 28, 28, 1))
    x_labels = tf.placeholder(tf.float32, shape=(batch_size))
    #tf_valid_dataset = tf.constant(valid_dataset)
    tf_test_dataset = tf.constant(X_t, shape=X_t.shape)
    tf_test_labels = tf.constant(Y_t, shape=Y_t.shape)
    
    '''Training Computation'''
    # L1: GCONV[None,28,28,1] ==> [NONE,28,28,10] 
    gconv_indices, gconv_shape_info, w_shape = \
    gconv2d_util(h_input='Z2', h_output='C4', in_channels=1, out_channels=10, ksize=ksize)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME',
                gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    # BN
    if batch_normalization:
        y = tf.keras.layers.BatchNormalization(axis=-1)(y)
    # Activation
    if activation:
        y = tf.nn.relu(y)
        
    # L2: GCONV[None,28,28,10] ==> [NONE,28,28,10] 
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
                gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    # BN
    if batch_normalization:
        y = tf.keras.layers.BatchNormalization(axis=-1)(y)
    # Activation
    if activation:
        y = tf.nn.relu(y)
        
    ### Here should be some kind of max pool?
    
    #y = tf.reshape(y, (ys[0], ys[1] * ys[2], ys[3]))
    #y = tf.keras.layers.MaxPool2D(pool_size=(2,2), strides=2, padding='VALID')(y)
    #y = tf.reshape(y, (-1))
    
    # L3: Same as L2
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
                gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    # BN
    if batch_normalization:
        y = tf.keras.layers.BatchNormalization(axis=-1)(y)
    # Activation
    if activation:
        y = tf.nn.relu(y)
    
    # L4: Same as L2
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
                gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    # BN
    if batch_normalization:
        y = tf.keras.layers.BatchNormalization(axis=-1)(y)
    # Activation
    if activation:
        y = tf.nn.relu(y)
    
    # L5: Same as L2
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
                gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    # BN
    if batch_normalization:
        y = tf.keras.layers.BatchNormalization(axis=-1)(y)
    # Activation
    if activation:
        y = tf.nn.relu(y)
    
    # L6: Same as L2
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
                gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    # BN
    if batch_normalization:
        y = tf.keras.layers.BatchNormalization(axis=-1)(y)
    # Activation
    if activation:
        y = tf.nn.relu(y)

        
    # Top Layer: Same as L2 except for norm, and activation
    gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
        h_input='C4', h_output='C4', in_channels=10, out_channels=10, ksize=3)
    w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
    y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
                gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)
    
    # Max Poolings 
    y = tf.math.reduce_max(y, axis=-3, keepdims=False)
    y = tf.math.reduce_max(y, axis=-1, keepdims=False)
    y = tf.math.reduce_max(y, axis=-1, keepdims=False)
    # Loss 
    loss = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits_v2(logits=y, labels=x_labels))
    '''Optimizer'''
    global_step = tf.Variable(0)  # count the number of steps taken.
    start_learning_rate = 0.5
   # learning_rate = tf.train.exponential_decay(start_learning_rate, global_step, 100000, 0.96, staircase=True)
    optimizer = tf.train.AdamOptimizer().minimize(loss, global_step=global_step)
    
    

In [236]:
from tqdm import tqdm

In [280]:
NUM_EPOCHS = 3
pbar = tqdm(range(NUM_EPOCHS))

with tf.Session(graph=graph) as session:
    tf.initialize_all_variables().run()
    print("Initialized")
    for epoch in pbar:
        l_hist = []
        for inputs, labels in get_minibatch(X, Y, batch_size=batch_size, shuffle=True, drop_last=True):
            feed_dict = {x: inputs, x_labels: labels}
            _, l = session.run([optimizer, loss], feed_dict=feed_dict)
            
            l_hist.append(l)
            
            
            l_hist_m = np.mean(l_hist[-50:])
            
            pbar.set_description('Loss ' + str(l_hist_m))
    
    

  0%|          | 0/3 [00:00<?, ?it/s]

Initialized


Loss 43377010000000.0:   0%|          | 0/3 [01:47<?, ?it/s]


KeyboardInterrupt: 

In [235]:
def get_minibatch(x, labels, batch_size=64, shuffle=True, drop_last=True):
    idx = np.arange(len(x))
    if shuffle:
        np.random.shuffle(idx)
    if drop_last:
        n_batches = len(idx) // batch_size
    else:
        n_batches = int(np.ceil(len(idx) / batch_size))
    for b in range(n_batches):
        left_idx  = b*batch_size
        right_idx = min((b+1)*batch_size, len(idx))
        batch_idx = idx[left_idx:right_idx]
        yield x[batch_idx], labels[batch_idx]

In [2]:
# Construct graph
x = tf.placeholder(tf.float32, [None, 9, 9, 3])

gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='Z2', h_output='C4', in_channels=3, out_channels=64, ksize=3)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=x, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

Instructions for updating:
Colocations handled automatically by placer.


In [3]:
y.shape

TensorShape([Dimension(None), Dimension(9), Dimension(9), Dimension(256)])

In [5]:
gconv_indices, gconv_shape_info, w_shape = gconv2d_util(
    h_input='C4', h_output='C4', in_channels=64, out_channels=64, ksize=3)
w = tf.Variable(tf.truncated_normal(w_shape, stddev=1.))
y = gconv2d(input=y, filter=w, strides=[1, 1, 1, 1], padding='SAME',
            gconv_indices=gconv_indices, gconv_shape_info=gconv_shape_info)

In [6]:
y.shape

TensorShape([Dimension(None), Dimension(9), Dimension(9), Dimension(256)])