In [189]:
import functools
import operator
import tensorflow as tf
from tensorflow.keras import layers
import tensorflow.sparse as sparse

import os
import sys
import inspect
import importlib
module_path = os.path.abspath(os.path.join('..'))
if module_path not in sys.path:
    sys.path.append(module_path)

from network import vlayers
importlib.reload(vlayers)
pass

In [524]:
data = tf.reshape(tf.range(24), [4, 3, 2])
print(data)
x = tf.constant(data)
result = tf.gather_nd(x, [[[0, 0], [0, 1]], [[1, 0], [1, 1]]], batch_dims=0)
result

tf.Tensor(
[[[ 0  1]
  [ 2  3]
  [ 4  5]]

 [[ 6  7]
  [ 8  9]
  [10 11]]

 [[12 13]
  [14 15]
  [16 17]]

 [[18 19]
  [20 21]
  [22 23]]], shape=(4, 3, 2), dtype=int32)


<tf.Tensor: shape=(2, 2, 2), dtype=int32, numpy=
array([[[0, 1],
        [2, 3]],

       [[6, 7],
        [8, 9]]])>

In [548]:
def flatten(inputs, dims_to_flatten):
    """Flatten given dimensions of tensor"""
    input_shape = inputs.shape
    rank = input_shape.rank
    batch_dims = input_shape[:rank-dims_to_flatten]
    non_batch_dims = input_shape[-dims_to_flatten:]
    
    if tf.executing_eagerly():
        # Full static shape is guaranteed to be available.
        # Performance: Using `constant_op` is much faster than passing a list.
        flattened_shape = tf.concat([batch_dims, [-1]], 0)
        return tf.reshape(inputs, flattened_shape)
    else:
        last_dim = int(functools.reduce(operator.mul, non_batch_dims))
        flattened_shape = tf.concat([[-1], batch_dims[1:], [last_dim]])
        return tf.reshape(inputs, flattened_shape)
    
def get_input_shape(input_shape, padding):
    """Get shape of input feature tensor"""
    input_shape = tf.constant(input_shape)
    return tf.concat([input_shape[-3:-1] + 2 * padding, input_shape[-1:]], axis=0)

def get_full_output_shape(input_shape, kernel_shape, strides, use_bias):
    """Get shape of output tensor"""
    vector_dim = tf.reduce_prod(kernel_shape[:-1])
    if use_bias:
        vector_dim += 1
    input_shape = tf.constant(input_shape[-3:])
    kernel_shape = tf.constant(kernel_shape[-4:-1])
    strides = tf.concat([tf.constant(strides), [1]], axis=0)
    # Convolution layer output shape formula
    output_shape = ((input_shape - kernel_shape) // strides) + 1
    # Add filters 
    output_shape *= tf.concat([1, 1, kernel_shape[-1]], axis=0)
    # Add vector dimension
    output_shape = tf.concat([[-1, vector_dim], output_shape], axis=0)
    return output_shape

def get_output_shape(input_shape, kernel_shape, strides):
    """Get shape of output feature tensor"""
    input_map_shape = tf.constant(input_shape[-3:-1])
    input_depth_shape = tf.constant(input_shape[-1:])
    kernel_map_shape = tf.constant(kernel_shape[-4:-2])
    kernel_depth_shape = tf.constant(kernel_shape[-2:-1])
    # Convolution layer result formula
    output_map_shape = ((input_map_shape - kernel_map_shape) // strides) + 1
    # Set output depth
    output_depth_shape = ((input_depth_shape - kernel_depth_shape) + 1) * kernel_shape[-1]
    output_shape = tf.concat([output_map_shape, output_depth_shape], axis=0)
    return output_shape

def iterate_sparsed_weight_indices(weight_shape, input_shape, output_shape, kernel_shape, strides, use_bias):
    """Iterate over indices of non-zero elements in sparsed weight matrix"""
    for j in range(weight_shape[-2]):
        # Compute position in output tensor
        chan_num = j % output_shape[-1]
        col_num = (j // output_shape[-1]) % output_shape[-2]
        row_num = (j // output_shape[-1]) // output_shape[-2]
        # Compute row in weight tensor to start with
        offset = (row_num * input_shape[-2] * strides[-2] + col_num \
                * strides[-1]) * input_shape[-1] + chan_num
        for i in range(tf.cast(weight_shape[-3], tf.int32) - offset):
            # Compute position in input tensor
            chan_num = i % input_shape[-1]
            col_num = (i // input_shape[-1]) % input_shape[-2]
            row_num = (i // input_shape[-1]) // input_shape[-2]
            
            if chan_num < kernel_shape[-2] and col_num < kernel_shape[-3] and row_num < kernel_shape[-4]:
                for f in range(weight_shape[-1]):
                    yield (i + int(offset), j, f)
    
    # Append bias
    bias_i = int(weight_shape[-3] - 1)
    for j in range(weight_shape[-2]):
        for f in range(weight_shape[-1]):
            yield (bias_i, j, f)


def iterate_input_gather_indices(weight_shape, input_shape, output_shape, kernel_shape, strides, use_bias, vector_input=False):
    position_nums = dict()
    kernel_shape
    
    for j in range(weight_shape[-2]):
        # Compute position in output tensor
        chan_num = j % output_shape[-1]
        col_num = (j // output_shape[-1]) % output_shape[-2]
        row_num = (j // output_shape[-1]) // output_shape[-2]
        # Compute row in weight tensor to start with
        offset = (row_num * input_shape[-2] * strides[-2] + col_num \
                * strides[-1]) * input_shape[-1] + chan_num
        
        for i in range(tf.cast(weight_shape[-3], tf.int32) - offset):
            # Compute position in input tensor
            chan_num = i % input_shape[-1]
            col_num = (i // input_shape[-1]) % input_shape[-2]
            row_num = (i // input_shape[-1]) // input_shape[-2]
            
            if (chan_num < kernel_shape[-2] and col_num < kernel_shape[-3] and row_num < kernel_shape[-4]):
                position = (i + int(offset),)
                if vector_input:
                    if position not in position_nums:
                        position_nums[position] = 0
                    yield position + (position_nums[position],)
                    position_nums[position] += 1
                else:
                    yield position
        # Set bias
        if use_bias:
            if vector_input:
                yield (int(weight_shape[-3]-1), 0)
            else:
                yield (int(weight_shape[-3]-1),)
            
            
            
            
def get_sparsed_weight_params(input_shape, output_shape, kernel_shape, strides, use_bias, vector_input=False):
    """Get shape and indices of non-zero elements for sparsed weight matrix"""
    output_shape = output_shape // tf.concat([[1, 1], kernel_shape[-1:]], axis=0)
    # Compute weight shape
    input_flat_len = tf.reduce_prod(input_shape) + (1 if use_bias else 0)
    output_flat_len = tf.reduce_prod(output_shape)
    weight_shape = tf.concat([input_flat_len, output_flat_len, tf.reduce_prod(kernel_shape[-1:])], axis=0)
    weight_shape = tf.cast(weight_shape, tf.int64)
    
    # Get indices of non-zero elements in sparsed weight matrix
    sparsed_indices = tf.constant(
        list(iterate_sparsed_weight_indices(weight_shape, input_shape, output_shape, kernel_shape, strides, use_bias)),
        dtype=tf.int64
    )
    
    gather_indices = tf.constant(
        list(iterate_input_gather_indices(weight_shape, input_shape, output_shape, kernel_shape, strides, use_bias, vector_input)),
        dtype=tf.int64
    )
    vector_dim = tf.reduce_prod(kernel_shape[:-1])
    if use_bias:
        # Bias must be taken into account
        vector_dim += 1
    gather_shape = tf.concat([tf.reduce_prod(output_shape), vector_dim, tf.constant(2 if vector_input else 1)], axis=0)
    gather_indices = tf.reshape(gather_indices, gather_shape)
    gather_indices = tf.transpose(gather_indices, perm=[1,0,2])
    gather_shape = gather_indices.shape
    
    return weight_shape, sparsed_indices, gather_indices


def get_dense_output_params(weight_shape, output_shape, kernel_shape, use_bias):
    """Get shape and indices of elements for dense output matrix"""
    vector_dim = tf.reduce_prod(kernel_shape[:-1])
    if use_bias:
        # Bias must be taken into account
        vector_dim += 1
    shape = tf.concat([[vector_dim], output_shape], axis=0)
    shape = tf.cast(shape, tf.int64)
    
    indices = [[v, i, j, f] for v in range(shape[-4]) for i in range(shape[-3]) for j in range(shape[-2]) for f in range(shape[-1])]
    indices = tf.constant(
        indices,
        dtype=tf.int64
    )
    
    return shape, indices


def concat_biases_fun(inputs_rank, axis=-1):
    """Add bias to each input vector"""
    # Inputs shape can be partially known, so
    # Get inputs slice with current dimension equals one
    slice_begin = tf.zeros(inputs_rank, dtype=tf.int32)
    slice_size = tf.concat([tf.fill([inputs_rank + axis], -1), tf.constant([1]), tf.fill([-axis - 1], -1)], 0)
    
    def concat_biases(inputs):
        nonlocal slice_begin
        nonlocal slice_size
        nonlocal axis
        inputs_slice = tf.slice(inputs, slice_begin, slice_size)
        # Create biases shaped like inputs slice
        biases = tf.ones_like(inputs_slice, dtype=inputs.dtype)
        # Concatenate inputs with biases
        return tf.concat([inputs, biases], axis)
    
    return concat_biases


def multiply_sparsed(x_flat, weight, dense_indices, dense_shape):
    """Multiply separate (non-batched) input with sparsed weight tensor"""
    output = x_flat * weight
    # Rearrange and convet to dense
    return tf.reshape(output.values, dense_shape)
#     dense = sparse.SparseTensor(dense_indices, output.values, dense_shape)
#     return sparse.to_dense(dense)


def get_sparsed_weight(kernel, weight_shape, sparsed_indices, bias=None):
    use_bias = bias is not None
    
    # Get kernel values
    sparsed_values = tf.reshape(kernel, [-1])
    sparsed_values = tf.tile(sparsed_values, weight_shape[-2:-1])
    if use_bias:
        # Get bias values
        sparsed_bias = tf.reshape(bias, [-1])
        sparsed_bias = tf.tile(sparsed_bias, weight_shape[-2:-1])
        sparsed_values = tf.concat([sparsed_values, sparsed_bias], axis=0)
    
    # Initialize sparsed weight tensor
    weight = sparse.SparseTensor(sparsed_indices, sparsed_values, weight_shape)
    weight = sparse.reorder(weight)
    
    return weight

In [546]:
class VInputConv(layers.Layer):
    """Input vector layer for convolutional networks"""
    def __init__(self, filter_shape, num_filters=1, kernel_type="convolution", strides=(1,1), padding=0, weight_initializer="random_normal"):
        super().__init__()
        self.filter_shape = filter_shape
        self.num_filters = num_filters
        self.kernel_type = kernel_type
        self.strides = strides
        self.padding = padding
        self.weight_initializer = weight_initializer
    
    def build(self, input_shape):
        if self.kernel_type == "convolution":
            kernel_shape = tf.concat([self.filter_shape, input_shape[-1:], [self.num_filters]], axis=0)
            self.use_bias = True
            bias_shape = kernel_shape[-1:]
        else: # if kernel_type == "pooling"
            kernel_shape = tf.concat([self.filter_shape, [1, 1]], axis=0)
            self.use_bias = False
        
        self.kernel = self.add_weight(
            shape=kernel_shape,
            initializer=self.weight_initializer
        )
        if self.use_bias:
            self.bias = self.add_weight(
                shape=bias_shape,
                initializer=self.weight_initializer
            )
        
        if self.padding > 0:
            input_rank = input_shape.rank
            self.paddings = tf.concat([tf.zeros([input_rank - 3], dtype=tf.int32), [self.padding, self.padding], [0]], axis=0)
            self.paddings = tf.stack([self.paddings, self.paddings], axis=1)
        
        self.full_output_shape = get_full_output_shape(input_shape, kernel_shape, self.strides, self.use_bias)
        
        padded_input_shape = get_input_shape(input_shape, self.padding)
        output_shape = get_output_shape(padded_input_shape, kernel_shape, self.strides)
        
        self.weight_shape, self.sparsed_indices, self.gather_indices = get_sparsed_weight_params(
            padded_input_shape, output_shape, kernel_shape, self.strides, self.use_bias
        )
        self.dense_shape, self.dense_indices = get_dense_output_params(
            self.weight_shape, output_shape, kernel_shape, self.use_bias
        )
        
        
        
        self.flattened_weight = tf.reshape(self.kernel, tf.concat([tf.reduce_prod(self.kernel.shape[:-1]), 1, self.kernel.shape[-1]], axis=0))
        if self.use_bias:
            # Get bias values
            bias = tf.reshape(self.bias, tf.concat([1, 1, self.bias.shape[-1]], axis=0))
            self.flattened_weight = tf.concat([self.flattened_weight, bias], axis=0)
            
        self.weight_shape = tf.concat([[1], self.gather_indices.shape[-2:-1], [1]], axis=0)
        
        self.concat_biases = concat_biases_fun(len(input_shape) - 2, axis=-1)
    
    def call(self, inputs):
        x = inputs
        if self.padding > 0:
            x = tf.pad(x, self.paddings)
        
        
#         x_flat = tf.expand_dims(tf.expand_dims(flatten(x, tf.constant(3)), -1), -1)
#         if self.use_bias:
#             x_flat = concat_biases(x_flat, axis=-3)
        
        x_flat = flatten(x, tf.constant(3))
        if self.use_bias:
            x_flat = self.concat_biases(x_flat)
        
#         x = tf.map_fn(
#             lambda x: tf.gather_nd(x, self.gather_indices), 
#             x_flat
#         )
        x = tf.transpose(x_flat, perm=[1,0])
        x = tf.gather_nd(x, self.gather_indices) # TODO: depends on rank
        x = tf.transpose(x, perm=[2,0,1])
        
        x = tf.expand_dims(x, -1)
    
        weight = tf.tile(self.flattened_weight, self.weight_shape)
        
        y = x * weight
        print(y.shape)
        print(self.full_output_shape)
        return tf.reshape(y, self.full_output_shape)
        #gather_indices
        
#         weight = get_sparsed_weight(
#             self.kernel, 
#             self.weight_shape, 
#             self.sparsed_indices, 
#             bias=self.bias if self.use_bias else None
#         )
        
#         output = tf.map_fn(
#             lambda x: multiply_sparsed(x, weight, self.dense_indices, self.dense_shape), 
#             x_flat
#         )
        
#         return output

    
class VOutputConv(layers.Layer):
    """Output vector layer for convolutional networks"""
    

In [549]:
# import timeit
batch_size = 64
x_dim = 9
x = tf.reshape(tf.range([batch_size*x_dim*x_dim*2], dtype=tf.float32), shape=(batch_size,x_dim,x_dim,2))
strides = (1,1)
padding = 0
num_filters=2
filter_dim=2

tries = 100

tf.config.run_functions_eagerly(True)

layer = VInputConv((filter_dim,filter_dim), num_filters=num_filters, kernel_type="convolution", strides=strides, padding=padding)
print(layer(x).shape)
# new_time = timeit.timeit(lambda: layer(x), number=tries)

# layer = layers.Conv2D(num_filters, filter_dim, activation='relu', strides=strides, padding="valid")
# print(layer(x).shape)
# old_time = timeit.timeit(lambda: layer(x), number=tries)

# print(new_time / old_time)

(64, 9, 64, 2)
tf.Tensor([-1  9  8  8  2], shape=(5,), dtype=int32)
(64, 9, 8, 8, 2)


In [364]:
layer = layers.Conv2D(num_filters, filter_dim, activation='relu', strides=strides, padding="valid")
print(layer(x).shape)
timeit.timeit(lambda: layer(x), number=100)

(50, 8, 8, 2)


0.03154229999927338

In [173]:
x = tf.reshape(tf.range([2*3*3*2], dtype=tf.float32), shape=(2,3,3,2))
# kernel = tf.constant([[[[1]], [[2]]], [[[3]], [[4]]]], dtype=tf.float32)
kernel = tf.Variable(tf.reshape(tf.range([2*2*2*2], dtype=tf.float32), shape=(2,2,2,2)))
strides = (1,1)
padding = 0

input_shape = get_input_shape(x.shape, padding)
output_shape = get_output_shape(input_shape, kernel.shape, strides, padding)

weight_shape, sparsed_indices = get_sparsed_weight_params(input_shape, output_shape, kernel.shape, strides)

dense_shape, dense_indices = get_dense_output_params(weight_shape, output_shape, kernel.shape, False)

def multiply_sparsed(x_flat):
    """Bad code, but sparse-to-dence broadcasting is not working"""
    output = x_flat * weight
    dense = sparse.SparseTensor(dense_indices, output.values, dense_shape)
    return sparse.to_dense(dense)

with tf.GradientTape() as tape:
    sparsed_values = tf.reshape(kernel, [-1])
    sparsed_values = tf.tile(sparsed_values, weight_shape[-2:-1])
    weight = sparse.SparseTensor(sparsed_indices, sparsed_values, weight_shape)
    weight = sparse.reorder(weight)
    x_flat = tf.expand_dims(tf.expand_dims(flatten(x, tf.constant(3)), -1), -1)
    result = tf.map_fn(multiply_sparsed, x_flat)
    loss = tf.reduce_mean(result)

print(tape.gradient(loss, kernel))

tf.Tensor(
[[[[0.8125 0.8125]
   [0.875  0.875 ]]

  [[0.9375 0.9375]
   [1.     1.    ]]]


 [[[1.1875 1.1875]
   [1.25   1.25  ]]

  [[1.3125 1.3125]
   [1.375  1.375 ]]]], shape=(2, 2, 2, 2), dtype=float32)


In [6]:
conv = tf.keras.layers.Conv2D(2,3,activation="relu", input_shape=(28,28,3))

In [12]:
conv(tf.reshape(tf.range(28*28*3, dtype=tf.float32) * 0.1, (1, 28,28,3)))
conv.get_weights()

[array([[[[ 0.14504266, -0.27120164],
          [-0.28834227,  0.0261265 ],
          [ 0.07041037, -0.2354877 ]],
 
         [[-0.19154891,  0.3179202 ],
          [ 0.01622459,  0.1081526 ],
          [-0.3642441 , -0.32064518]],
 
         [[-0.20420814,  0.02284172],
          [-0.20973232, -0.03993881],
          [-0.32509318,  0.01819271]]],
 
 
        [[[ 0.16470277,  0.2110464 ],
          [ 0.22549063, -0.0138002 ],
          [-0.10054016, -0.2771492 ]],
 
         [[-0.12416469, -0.32693505],
          [ 0.23607647, -0.23215224],
          [ 0.02976117, -0.17566273]],
 
         [[-0.0430896 ,  0.20499647],
          [-0.10320622, -0.16313514],
          [-0.2966457 , -0.3647445 ]]],
 
 
        [[[-0.31194293,  0.28944844],
          [ 0.2014187 , -0.3381625 ],
          [ 0.10150501, -0.28951746]],
 
         [[-0.10838193, -0.25806308],
          [ 0.26325548, -0.02701598],
          [ 0.00530359,  0.21974504]],
 
         [[ 0.14539486,  0.32093298],
          [-0.050225