In [23]:
import equinox as eqx
import jax
def print_per_layer(model):
    is_conv = lambda x: isinstance(x, (eqx.nn.Conv1d, eqx.nn.ConvTranspose1d))
    get_weights = lambda m: [x.weight
                            for x in jax.tree_util.tree_leaves(m, is_leaf=is_conv)
                            if is_conv(x)]
    
    avg_weights = [x[0] for x in get_weights(model)]
    print(avg_weights)

# @eqx.filter_jit
@eqx.filter_value_and_grad(has_aux=True)
def calculate_loss(model, x):
    y = jax.vmap(model)(x)

    # MSE
    # loss = jax.numpy.linalg.norm((x - y), ord=1, axis=-1)
    loss = jax.numpy.mean((x - y) ** 2, axis=-1)
    print(loss.shape)
    loss = jax.numpy.sum(loss)
    return loss, y

# @eqx.filter_jit
def make_step(model, optimizer, opt_state, x):
    (losses, y), grads = calculate_loss(model, x)
    print_per_layer(model)

    updates, opt_state = optimizer.update(grads, opt_state, model)
    print_per_layer(updates)
    model = eqx.apply_updates(model, updates)
    print_per_layer(model)

    return losses, y, opt_state, model

In [24]:
import equinox as eqx
import jax
import optax



batch_size = 8
epochs = 5
learning_rate = 1e-3

class EncodecModel(eqx.Module):
    enc: eqx.nn.Conv1d
    dec: eqx.nn.ConvTranspose1d

    def __init__(self, key=None):
        key1, key2 = jax.random.split(key, 2)
        self.enc = eqx.nn.Conv1d(1, 2, kernel_size=3, stride=2, padding="SAME", key=key1)
        self.dec = eqx.nn.ConvTranspose1d(2, 1, kernel_size=3, stride=2, padding="SAME", key=key2)

    def __call__(self, x):
        x = self.enc(x)
        x = jax.nn.relu(x)
        x = self.dec(x)
        x = jax.nn.relu(x)

        return x

grab1, grab2 = jax.random.split(jax.random.PRNGKey(1), 2)

model = EncodecModel(key=grab1)

optimizer = optax.adam(learning_rate)
opt_state = optimizer.init(model)


for i in range(0, 100):
    x = jax.numpy.ones((batch_size, 1, 100))

    losses, y, opt_state, model = make_step(model, optimizer, opt_state, x)

    print(f"Losses : {losses}")



(8, 1)
[Array([[-0.32698175,  0.39774594,  0.46568155]], dtype=float32), Array([[ 0.20344223, -0.28761053, -0.10750052],
       [ 0.04928516,  0.2901747 , -0.07114627]], dtype=float32)]
[Array([[-0.00099999, -0.00099999, -0.00099999]], dtype=float32), Array([[0.00099999, 0.00099999, 0.00099999],
       [0.00099999, 0.00099999, 0.00099999]], dtype=float32)]
[Array([[-0.32798174,  0.39674595,  0.46468157]], dtype=float32), Array([[ 0.20444222, -0.28661054, -0.10650052],
       [ 0.05028515,  0.29117468, -0.07014628]], dtype=float32)]
Losses : 3.7925984859466553
(8, 1)
[Array([[-0.32798174,  0.39674595,  0.46468157]], dtype=float32), Array([[ 0.20444222, -0.28661054, -0.10650052],
       [ 0.05028515,  0.29117468, -0.07014628]], dtype=float32)]
[Array([[-0.0009993, -0.0009993, -0.0009993]], dtype=float32), Array([[0.00099959, 0.0009995 , 0.00099959],
       [0.00100012, 0.00100004, 0.00100012]], dtype=float32)]
[Array([[-0.32898104,  0.39574665,  0.46368226]], dtype=float32), Array([[ 0.2