In [1]:

import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random

In [25]:
step_size = 0.01
num_epochs = 10

shape = [100,100]
key = random.PRNGKey(0)
inputs = random.uniform(key=key,shape=shape,minval=-1,maxval=1)
targets = random.randint(key=key,shape=[shape[0],1],minval=0,maxval=1)
params = [(random.uniform(key=key,shape=[shape[1],1],minval=-1,maxval=1),random.uniform(key=key,shape=[1,shape[1]],minval=-1,maxval=1)) for i in range(10)]

In [2]:

# 定义模型
def predict(params, inputs):
    for W, b in params:
        outputs = jnp.dot(inputs, W) + b
        inputs = jnp.tanh(outputs)
    return outputs

In [3]:

# 定义损失函数
def loss(params, inputs, targets):
    preds = predict(params, inputs)
    return jnp.mean((preds - targets)**2)

In [6]:
# 定义优化器
def update(params, inputs, targets):
    grads = grad(loss)(params, inputs, targets)
    return [(W - step_size * dW, b - step_size * db)
            for (W, b), (dW, db) in zip(params, grads)]

In [26]:

# 训练循环
for epoch in range(num_epochs):
    params = update(params, inputs, targets)

In [27]:
print(params)

[(Array([[-0.95241666],
       [ 0.70544076],
       [ 0.62643695],
       [ 0.02805257],
       [-0.65654397],
       [ 0.60537314],
       [ 0.02492619],
       [-0.30323124],
       [ 0.0105381 ],
       [-0.3258958 ],
       [-0.7826352 ],
       [-0.78958726],
       [ 0.6765473 ],
       [ 0.5797305 ],
       [-0.31880307],
       [ 0.66985464],
       [-0.50848746],
       [-0.5722525 ],
       [-0.95153546],
       [ 0.12348461],
       [-0.43866467],
       [ 0.8873291 ],
       [ 0.22429991],
       [ 0.47667766],
       [ 0.04839611],
       [ 0.30933452],
       [-0.17974472],
       [-0.51942706],
       [ 0.48886132],
       [-0.92910147],
       [ 0.70202804],
       [-0.95130944],
       [-0.05520535],
       [ 0.45413613],
       [-0.2988913 ],
       [ 0.25483418],
       [ 0.2215507 ],
       [-0.8694854 ],
       [ 0.6183858 ],
       [-0.57384324],
       [ 0.2930646 ],
       [-0.35099697],
       [ 0.10777664],
       [ 0.76996136],
       [ 0.9182422 ],
       [