In [1]:
import tensorflow as tf

In [2]:
def generate_selection_tensor(shape, range_start, range_stop, dtype="float32"):
    result = [[0 for _ in range(shape[1])] for _ in range(shape[0])]
    for i in range(range_start, range_stop + 1):
        result[i][i] = 1
    return tf.constant(result, dtype=dtype)

In [3]:
class Model:
    def __init__(
        self,
        input_shape,
        output_shape,
        weights,
        iterations,
        hidden_activation_fn,
        output_activation_fn,
        dtype="float32"
    ):
        self.input_shape = input_shape
        self.output_shape = output_shape
        self.weights = weights
        self.iterations = iterations
        self.hidden_activation_fn = hidden_activation_fn
        self.output_activation_fn = output_activation_fn
        self.dtype = dtype
        self._input_selection_tensor = generate_selection_tensor(self.weights.shape, 0, self.input_shape - 1)
        self._output_selection_tensor = generate_selection_tensor(
            self.weights.shape,
            self.input_shape,
            self.input_shape + self.output_shape - 1,
        )
        self._hidden_selection_tensor = generate_selection_tensor(
            self.weights.shape,
            self.input_shape + self.output_shape,
            self.weights.shape[0] - 1,
        )
        
    def fit(self, x):
        node_tensor = tf.zeros((x.shape[0], self.weights.shape[0]), dtype=self.dtype)
        node_tensor = tf.tensor_scatter_nd_update(
            node_tensor,
            self._generate_node_tensor_indicies_for_inputs(x.shape[0]),
            x,
        )
        initial_tensor = tf.identity(node_tensor)
        for i in tf.range(self.iterations):
            neuron_outputs = tf.tensordot(node_tensor, self.weights, axes=[[1], [0]])
            node_tensor = (
                initial_tensor
                + self.output_activation_fn(tf.tensordot(neuron_outputs, self._output_selection_tensor, axes=[[1], [0]]))
                + self.hidden_activation_fn(tf.tensordot(neuron_outputs, self._hidden_selection_tensor, axes=[[1], [0]]))
            )
        return node_tensor[:, self.input_shape:self.input_shape + self.output_shape]
    
    def _generate_node_tensor_indicies_for_inputs(self, batch_size):
        result = []
        # only rank 2 inputs are allowed right now, meaning that tensor is in the form (batch_size, data)
        for i in range(batch_size):
            tmp = []
            for j in range(self.input_shape):
                tmp.append([i, j])
            result.append(tmp)
        return result

In [4]:
model = Model(
    3,
    2,
    tf.Variable([
        [0, 0, 0, 0, 0, 2, 0],
        [0, 0, 0, 0, 0, 2, 0],
        [0, 0, 0, 0, 2, 0, 2],
        [0, 0, 0, 0, 0, 2, 0],
        [0, 0, 0, 0, 0, 0, 0],
        [0, 0, 0, 2, 0, 0, 0],
        [0, 0, 0, 2, 0, 0, 2],
    ],dtype="float32"),
    2,
    tf.identity,
    tf.identity,
)

In [5]:
x = tf.reshape(tf.constant([[1, 2, 3], [1, 2, 3]], dtype="float32"), shape=(2, 3))
x

<tf.Tensor: shape=(2, 3), dtype=float32, numpy=
array([[1., 2., 3.],
       [1., 2., 3.]], dtype=float32)>

In [6]:
model.fit(x)

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[24.,  6.],
       [24.,  6.]], dtype=float32)>

In [7]:
@tf.function
def fit_model(x):
    return model.fit(x)

In [8]:
fit_model(x)

<tf.Tensor: shape=(2, 2), dtype=float32, numpy=
array([[24.,  6.],
       [24.,  6.]], dtype=float32)>