In [2]:
import jax
from jax import lax,random,numpy as jnp

import flax
from flax.core import freeze, unfreeze
from flax import linen as nn
from flax.training import train_state

# import haiku as hk

import optax


from torchvision.datasets import MNIST
from torch.utils.data import DataLoader

import functools
from typing import Any,Callable,Sequence,Optional

import numpy as np
import matplotlib.pyplot as plt

In [None]:
model = nn.Dense(features=5)

seed = 23

key1, key2 = jax.random.split(jax.random.PRNGKey(seed))

x = jax.random.normal(key1, (10,))

y, params = model.init_with_output(key2,x)

print(y)
print(params)
print(jax.tree.map(lambda x: x.shape, params))

y = jnp.dot(params['params']['kernel'].transpose(),x)
print(y)
y_test = model.apply(params, x)
print(y_test)

[-1.2137768  -0.36633098  0.06677867 -2.014398   -0.32107347]
{'params': {'kernel': Array([[-0.27781105,  0.1656284 , -0.06729937, -0.19215076,  0.321123  ],
       [ 0.26235792,  0.27014345, -0.28854737, -0.1615871 ,  0.31653988],
       [-0.40874493,  0.6885968 ,  0.46351194, -0.22216477, -0.04983266],
       [ 0.09898845,  0.02624258, -0.32574084, -0.71759474,  0.45787135],
       [-0.49562743, -0.0781584 ,  0.07065256, -0.5920491 , -0.41872957],
       [-0.2380879 ,  0.27913874, -0.68323755,  0.13979784,  0.04806783],
       [ 0.24915032, -0.13843682, -0.0688496 , -0.03453913,  0.5834356 ],
       [ 0.02020426,  0.19953905,  0.42730477,  0.10579848, -0.04245503],
       [ 0.21156143, -0.5489475 ,  0.69323635, -0.24123473,  0.02384394],
       [-0.1420826 ,  0.5460323 ,  0.07612853,  0.49402887, -0.25343806]],      dtype=float32), 'bias': Array([0., 0., 0., 0., 0.], dtype=float32)}}
{'params': {'bias': (5,), 'kernel': (10, 5)}}
[-1.213777   -0.36633098  0.06677864 -2.014398   -0.321

In [40]:
n_samples = 150
x_dim = 2
y_dim = 1
key, w_key,b_key = random.split(random.PRNGKey(seed),num=3)
noise_weight = 0.01

W = random.normal(w_key, (x_dim,y_dim))
b = random.normal(b_key, (y_dim,))

true_params = freeze({'params': {'bias': b, 'kernel': W}})

key, x_key, noise_key = random.split(key, num=3)
xs = random.normal(x_key, (n_samples, x_dim))
ys = jnp.dot(xs, W)+b + noise_weight*random.normal(noise_key, (n_samples, y_dim))

# print(jnp.dot(xs, W),b, jnp.dot(xs,W)+b)
# print(b)
print(xs.shape,ys.shape)


(150, 2) (150, 1)


In [44]:
def make_mse_loss(xs, ys):
    def mse_loss(params):
        def square_loss(x,y):
            pred = model.apply(params,x)
            return jnp.inner(y-pred,y-pred)
        
        return jnp.mean(jax.vmap(square_loss)(xs,ys),axis=0)
    return jax.jit(mse_loss)

mse_loss = make_mse_loss(xs,ys)
value_and_grad_fn = jax.value_and_grad(mse_loss)

In [48]:
model = nn.Dense(features=y_dim)
params = model.init(key,xs)

epochs = 50
lr = 0.1
log_epoch = 10

for epoch in range(epochs):
    loss, grads = value_and_grad_fn(params)

    params = jax.tree.map(lambda p, g: p - lr*g, params, grads)

    if epoch % log_epoch == 0:
        print(f"epoch: {epoch}, loss: {loss}")

epoch: 0, loss: 1.2882237434387207
epoch: 10, loss: 0.015024389140307903
epoch: 20, loss: 0.00026953377528116107
epoch: 30, loss: 9.36854921746999e-05
epoch: 40, loss: 9.150459663942456e-05


In [50]:
opt_sgd = optax.sgd(learning_rate=lr)
opt_state = opt_sgd.init(params)

In [53]:
for epoch in range(epochs):
    loss, grads = value_and_grad_fn(params)
    updates, opt_state = opt_sgd.update(grads, opt_state)
    params = optax.apply_updates(params,updates)

    if epoch % log_epoch == 0:
        print(f"epoch: {epoch}, loss: {loss}")

epoch: 0, loss: 9.147590026259422e-05
epoch: 10, loss: 9.147550008492544e-05
epoch: 20, loss: 9.147547098109499e-05
epoch: 30, loss: 9.147546370513737e-05
epoch: 40, loss: 9.147546370513737e-05


In [57]:
class MLP(nn.Module):
    features: Sequence[int]

    def setup(self):
        self.layers = [nn.Dense(n) for n in self.features]
    
    def __call__(self, x):
        activation = x
        for i, layer in enumerate(self.layers):
            activation = layer(activation)
            if i != len(self.features)-1:
                activation = nn.relu(activation)

        return activation
    
x_key, init_key = random.split(random.PRNGKey(seed))

model = MLP(features=[16,8,2])
x = random.uniform(x_key, (4,4))
params = model.init(init_key, x)
y = model.apply(params, x)
print(x, y)

[[0.3073057  0.04725921 0.33156574 0.95880425]
 [0.00381255 0.16904569 0.6073719  0.01276541]
 [0.66728354 0.32034397 0.2885779  0.7781857 ]
 [0.6977296  0.9792644  0.44513798 0.21770954]] [[ 0.33665472 -0.37799144]
 [ 0.21020135 -0.13624829]
 [ 0.4062944  -0.30691484]
 [ 0.5687312  -0.42791107]]
