In [1]:
from tensorwrap import nn
import tensorwrap as tf

In [2]:
class Linear(nn.layers.Layer):
    def __init__(self, units) -> None:
        super().__init__() # Needed for making it JIT compatible.
        self.units = units # Defining the output shape
  
    def build(self, input_shape: tuple) -> None:
        input_shape = tf.shape(input_shape) # Getting appropriate input shape
        self.kernel = self.add_weights([input_shape, self.units],
                                       initializer = 'glorot_uniform')
        self.bias = self.add_weights([self.units],
                                     initializer = 'zeros')
        super().build(self.kernel, self.bias) # Needed to add the kernel to model.
    
    # Use call not __call__ to define the flow. No tf.function needed either.
    def call(self, inputs):
        return inputs @ self.kernel + self.bias

In [3]:
model = tf.nn.Sequential([
    Linear(10),
    tf.nn.layers.Dense(1)
])

In [4]:
x = tf.range(1, int(1e5))
x = tf.expand_dims(x, axis=1)
x /= 1e5
y = x + 10

# Custom Training Loop

In [5]:
loss = tf.nn.losses.mse
optimizer = tf.nn.optimizers.gradient_descent()

In [6]:
@tf.function
def training_loop():
    global x, y, model
    epochs = 100
    for epoch in tf.range(epochs):
        y_pred = model(x)
        training(y_pred, y)

# @tf.function
def training(y_pred, y):
    global model
    grads = tf.grad(loss)(y, y_pred)
    model.layers = optimizer.apply_gradients(grads, model.layers)

In [7]:
%timeit training_loop()

TypeError: Shapes must be 1D sequences of concrete values of integer type, got (Traced<ShapedArray(int32[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>, 10).
If using `jit`, try using `static_argnums` or applying `jit` to smaller subfunctions.

# Prebuilt training loop:

In [8]:
model.compile(
    loss=tf.nn.losses.mse,
    optimizer=tf.nn.optimizers.gradient_descent(),
    metrics=tf.nn.losses.mae
)
%timeit model.fit(x, x + 10, epochs=100, verbose=0)

143 ms ± 1.74 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)
