In [32]:
import jax
import jax.numpy as jnp
from jax import random, grad, value_and_grad
from jax.nn import relu, softmax
import einops
from showmethetypes import SMTT
from optax import sgd, apply_updates

In [8]:
tt = SMTT()

In [None]:
"""
What am I doing?
I need a loss function that takes in the weights of an MLP embedding (or I guess the whole MLP) and gives you the loss (maybe across the whole dataset

How can I start out?
1. by making a linear regression
2. making a logistic regression (adding a nonlinearity)
3. Making the rest of the MLP

1.
y = ax + b
"""

In [2]:
class regressor:
    def __init__(self):
        initializer = jax.nn.initializers.he_normal()
        self.params = initializer(jax.random.PRNGKey(42), (1, 2), jnp.float32)

    def forward(self, data):
        return self.params[0] * data + self.params[1]

In [3]:
initializer = jax.nn.initializers.he_normal()
params = initializer(jax.random.PRNGKey(43), (1, 2), jnp.float32)

In [4]:
params

Array([[-1.940299  ,  0.13722903]], dtype=float32)

In [5]:
uniform_init = jax.nn.initializers.uniform(100)
true_params = uniform_init(jax.random.PRNGKey(43), (1, 2), jnp.float32)
true_a = true_params[1]
true_b = true_params[0]

In [16]:
xs = uniform_init(jax.random.PRNGKey(43), (1000, 1), jnp.float32)
ys = true_a * xs + true_b

In [18]:
def mse(pred, target):
    return jnp.mean((pred - target) ** 2)

In [19]:
mse(jnp.array([2, 1]), jnp.array([1]))

Array(0.5, dtype=float32)

In [26]:
def loss_fn(params, loss_fn, xs, ys):
    preds = params[0] * xs + params[1]
    return loss_fn(preds, ys)

In [20]:
grad(mse)(jnp.array([2, 1], dtype=jnp.float32), jnp.array([1, 1], dtype=jnp.float32))

Array([1., 0.], dtype=float32)

In [22]:
model.forward(xs)

Array([[ -78.47896 ,   19.433558],
       [-236.83218 ,   58.646194],
       [-162.63159 ,   40.27208 ],
       ...,
       [ -71.13399 ,   17.61474 ],
       [ -38.326168,    9.490619],
       [-221.83112 ,   54.93152 ]], dtype=float32)

In [39]:
model = regressor()
params = model.params
solver = sgd(learning_rate=3e-9)
optimizer = solver.init(params)
epochs = 10

In [40]:
for epoch in range(epochs):
    loss, grad = value_and_grad(loss_fn)(params, mse, xs, ys)
    print(f"Loss: {loss}")
    updates, opt_state = solver.update(grad, optimizer, params)
    params = apply_updates(params, updates)

Loss: 5147373.0
Loss: 5383.748046875
Loss: 5.630852699279785
Loss: 0.005902229808270931
Loss: 6.154059519758448e-06
Loss: 3.2639984937077315e-08
Loss: 0.0
Loss: 0.0
Loss: 0.0
Loss: 0.0


In [41]:
"""
Hooray! We've got a linear regressor
"""

"\nHooray! We've got a linear regressor\n"