In [1]:
import functools
import operator
import tensorflow as tf
import tensorflow.sparse as sparse

In [73]:
def flatten(inputs, dims_to_flatten):
    input_shape = inputs.shape
    rank = input_shape.rank
    batch_dims = input_shape[:rank-dims_to_flatten]
    non_batch_dims = input_shape[-dims_to_flatten:]
    
    if tf.executing_eagerly():
        # Full static shape is guaranteed to be available.
        # Performance: Using `constant_op` is much faster than passing a list.
        flattened_shape = tf.concat([batch_dims, [-1]], 0)
        return tf.reshape(inputs, flattened_shape)
    else:
        last_dim = int(functools.reduce(operator.mul, non_batch_dims))
        flattened_shape = tf.concat([[-1], batch_dims[1:], [last_dim]])
        return tf.reshape(inputs, flattened_shape)

def get_conv_fun(input_shape, kernel_shape, strides):
    # Get spatial shape
    def get_result_spat_shape(input_spat_shape, kernel_spat_shape, strides):
        
        return ((input_spat_shape - kernel_spat_shape) // strides) + 1
    # Get sparced indices generator
    def iterate_sparsed_indices(sparsed_shape, input_shape, kernel_shape, result_shape, strides):
        for i in range(sparsed_shape[-2]):
            for j in range(sparsed_shape[-1]):
                channel_num = i % result_shape[-1]
                col_num = (i // result_shape[-1]) % result_shape[-2]
                row_num = (i // result_shape[-1]) // result_shape[-2]
                offset = (row_num * input_shape[-2] + col_num) * input_shape[-1] * strides + channel_num
                if j >= offset and \
                    ((j - offset) % input_shape[-1]) < kernel_shape[-1] and \
                    (((j - offset) // input_shape[-1]) % input_shape[-2]) < kernel_shape[-2] and \
                    (((j - offset) // input_shape[-1]) // input_shape[-2]) < kernel_shape[-3]:
                    yield [i, j]
    
    # Get height and width of result tensor
    result_spat_shape = ((tf.constant(input_shape[-3:-1]) - tf.constant(kernel_shape[-3:-1])) // strides) + 1
    # Get depth of result tensor (for pooling filter, strides=1, kernel depth (supposely) =1)
    result_depth_shape = (tf.constant(input_shape[-1:]) - tf.constant(kernel_shape[-1:])) + 1
    result_shape = tf.concat([
        input_shape[:-3],
        result_spat_shape,
        result_depth_shape
    ], 0)
    
    input_flat_len = tf.reduce_prod(tf.constant(input_shape[-3:]))
    result_flat_len = tf.reduce_prod(tf.constant(result_shape[-3:]))
    
    sparsed_shape = tf.concat([result_flat_len, input_flat_len], axis=0)
    sparsed_shape = tf.cast(sparsed_shape, tf.int64)
    
    sparsed_indices = tf.constant(
        list(iterate_sparsed_indices(sparsed_shape, input_shape, kernel_shape, result_shape, strides)),
        dtype=tf.int64
    )
    
    def conv_fun(inputs, kernel, bias):
        nonlocal sparsed_shape
        nonlocal sparsed_indices
        
        sparsed_values = tf.reshape(kernel, [-1])
        sparsed_values = tf.tile(sparsed_values, sparsed_shape[:1])
        
        sparsed_kernel = sparse.SparseTensor(sparsed_indices, sparsed_values, sparsed_shape)
        
        input_flat = tf.expand_dims(flatten(inputs, tf.constant(3)), -1)
        
        return tf.sparse.sparse_dense_matmul(sparsed_kernel, input_flat) + bias
    
    return conv_fun
        

In [74]:
x = tf.reshape(tf.range([3*3*2], dtype=tf.float32), shape=(3,3,2))
# kernel = tf.constant([[[1,2], [3,4]], [[5,6], [7,8]]], dtype=tf.float32)
# kernel = tf.constant([[[1,2],[3,4],[5,6]], [[7,8],[9,10],[11,12]]], dtype=tf.float32)
kernel = tf.constant([[[1], [2]], [[3], [4]]], dtype=tf.float32)
strides = 1
bias = 1
conv_fun = get_conv_fun(x.shape, kernel.shape, strides)
conv_fun(x, kernel, bias)

<tf.Tensor: shape=(8, 1), dtype=float32, numpy=
array([[ 55.],
       [ 65.],
       [ 75.],
       [ 85.],
       [115.],
       [125.],
       [135.],
       [145.]], dtype=float32)>