In [11]:
import jax
import jax.numpy as jnp
from jax import random, grad, value_and_grad
from jax.nn import relu, softmax
import einops
from showmethetypes import SMTT
from optax import sgd, apply_updates

In [12]:
class regressor:
    def __init__(self):
        initializer = jax.nn.initializers.he_normal()
        self.W_in = initializer(jax.random.PRNGKey(42), (113, 64), jnp.float32)
        self.W_out = initializer(jax.random.PRNGKey(42), (128, 113), jnp.float32)
        self.nonlinearity = relu
        self.parameters = [self.W_in, self.W_out]

    def forward(self, data):
        l = self.W_in[data[0]]
        r = self.W_in[data[1]]

        emb = jnp.concatenate([l, r])
        return relu(emb) @ self.W_out

In [13]:
def mse(pred, target):
    return jnp.mean((pred - target) ** 2)

In [14]:
def loss_fn(params, loss_fn, xs, ys):
    W_in, W_out = params[0], params[1]
    l = W_in[xs[0]]
    r = W_in[xs[1]]

    emb = jnp.concatenate([l, r], axis=1)

    pred = relu(emb) @ W_out
    logits = jnp.mean(pred, axis=0)
    print(logits.shape)
    return -loss_fn(logits, ys)

In [15]:
oot = jnp.array([i for i in range(113)])
row = einops.repeat(oot, "r -> r c", c=113)
col = row.T
row = einops.rearrange(row, "r c -> (r c)")
col = einops.rearrange(col, "r c -> (r c)")
xs = jnp.stack([row, col])

ys = jax.nn.one_hot(row * col, 113)

In [16]:
print(ys.shape)

(12769, 113)


In [17]:
model = regressor()
params = model.parameters
solver = sgd(learning_rate=30, momentum=0.2)
optimizer = solver.init(params)
epochs = 10

In [20]:
ys[1]

Array([1., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.,
       0., 0., 0., 0., 0., 0., 0., 0., 0., 0., 0.], dtype=float32)

In [22]:
model.forward(jnp.array([0, 1]))

Array([ 0.04267131, -0.03805638, -0.17450848, -0.01736136, -0.00368068,
        0.16894473, -0.13843809,  0.09179612,  0.15945284, -0.19543816,
       -0.00710924, -0.07168756, -0.20479256,  0.04670107,  0.21704043,
       -0.04665428, -0.2114431 ,  0.22236507, -0.13043669,  0.11875629,
       -0.04315519, -0.04645056, -0.22517869, -0.329479  , -0.29496247,
       -0.23088296,  0.32569015,  0.08745132,  0.09596004, -0.06016586,
        0.14054558, -0.17426753, -0.11412124, -0.07116443, -0.20031852,
        0.06302016,  0.02788424, -0.00630307, -0.30521145,  0.00481466,
       -0.05119161, -0.07668882, -0.21699896, -0.15560207, -0.05269596,
       -0.21378863,  0.34839845,  0.15465759,  0.00795331,  0.17334522,
        0.01560128,  0.04274034, -0.05091236,  0.1054142 ,  0.24133027,
       -0.19889037, -0.03527186, -0.01436651, -0.04284835,  0.10262948,
       -0.11394368, -0.00486074, -0.06071401, -0.0544565 ,  0.23919182,
        0.00043497,  0.00769247, -0.14868787,  0.14559108,  0.24

In [24]:
loss_fn(model.parameters, mse, jnp.array([[0, 1]]), ys[1])

(113,)


Array(0.02196129, dtype=float32)

In [19]:
for epoch in range(epochs):
    loss, grad = value_and_grad(loss_fn)(params, mse, xs, ys)
    print(f"Loss: {loss}")
    updates, opt_state = solver.update(grad, optimizer, params)
    params = apply_updates(params, updates)

(113,)
Loss: 0.00833863951265812
(113,)
Loss: 0.0051587349735200405
(113,)
Loss: 0.0033300260547548532
(113,)
Loss: 0.0022443023044615984
(113,)
Loss: 0.0015877812402322888
(113,)
Loss: 0.00118648714851588
(113,)
Loss: 0.000939269841182977
(113,)
Loss: 0.0007863908540457487
(113,)
Loss: 0.000691707362420857
(113,)
Loss: 0.0006326906732283533
