In [1]:
import sys

sys.path.append('/home/rustambaku13/Documents/Warwick/flink-streaming-gnn/python')

In [2]:
import jax
from flax.linen import relu, softmax, sigmoid
from nn.jax.multi_layer_dense import MultiLayerDense

In [3]:
predict_fn = MultiLayerDense(features=[16, 32, 7], activations=[relu, relu, softmax ])

In [4]:
predict_fn_params = predict_fn.init(jax.random.PRNGKey(0), jax.random.uniform(jax.random.PRNGKey(0), (7,)))



In [5]:
def loss(parameters, embedding, label):
    prediction = predict_fn.apply(parameters, embedding)
    label_probs = prediction * label + (1 - prediction) * (1 - label)
    return -jax.numpy.sum(jax.numpy.log(label_probs))
def batch_loss(parameters, batch_embeddings, batch_labels):
    losses = jax.vmap(loss, (None, 0, 0))(parameters, batch_embeddings, batch_labels)
    return jax.numpy.sum(losses)

In [6]:
labels = jax.numpy.eye(7)

In [7]:
logits = jax.random.uniform(jax.random.PRNGKey(0), (7,7))

In [54]:
for i in range(100):
    

    primals, f_vjp = jax.vjp(batch_loss, predict_fn_params, logits, labels)

    print(primals)

    delta = jax.grad(batch_loss)(predict_fn_params, logits, labels)

    predict_fn_params = jax.tree_multimap(lambda x, y: jax.numpy.asarray(x - (y * 0.03)), predict_fn_params, delta)

23.153784
23.079899
23.007263
22.935833
22.86558
22.796467
22.72847
22.66156
22.595722
22.530933
22.467182
22.404453
22.342735
22.28202
22.222298
22.163568
22.105824
22.049057
21.993273
21.938465
21.884636
21.83178
21.7799
21.728992
21.679066
21.630112
21.58213
21.535128
21.489098
21.444038
21.399954
21.356838
21.314692
21.273506
21.233284
21.19402
21.155706
21.11834
21.081911
21.04642
21.011852
20.978203
20.945461
20.91362
20.882668
20.852596
20.823393
20.795044
20.76754
20.740868
20.715014
20.689968
20.66571
20.642227
20.619513
20.597542
20.576305
20.55579
20.535976
20.51685
20.498398
20.4806
20.463446
20.446915
20.431
20.415676
20.400932
20.386753
20.373127
20.360031
20.347454
20.33538
20.3238
20.312689
20.30204
20.291838
20.282066
20.272713
20.263767
20.255207
20.24703
20.239216
20.231754
20.224632
20.217834
20.211355
20.20518
20.199295
20.193691
20.188358
20.183285
20.178461
20.173874
20.169518
20.165379
20.161453
20.157726
20.154192
20.15084
20.147665


In [8]:
batch_loss(predict_fn_params, logits, labels)

DeviceArray(20.200226, dtype=float32)

In [12]:
predict_fn_params

FrozenDict({
    params: {
        layers_0: {
            kernel: DeviceArray([[-0.40598792, -0.20192443,  0.1822813 , -0.6701383 ,
                           0.7234791 ,  0.48052135,  0.11202915,  0.7503722 ,
                          -0.3069905 ,  0.17787433,  0.05712227, -0.00219361,
                          -0.43479165,  0.48816282,  0.3167736 , -0.56967115],
                         [-0.51099217, -0.1250201 , -0.50434786, -0.31360218,
                          -0.42523575,  0.07905342, -0.36755604,  0.5149487 ,
                          -0.0250216 ,  0.0767649 , -0.02069273,  0.16237336,
                           0.05648867,  0.20248288, -0.4029955 , -0.27893478],
                         [ 0.39262852,  0.27463427,  0.41443932,  0.2553303 ,
                           0.29250273, -0.31036663,  0.46665376, -0.0289441 ,
                          -0.5784792 ,  0.10486496,  0.07086333, -0.46257764,
                           0.16107841, -0.59642285,  0.0359646 ,  0.73806477],
      

In [14]:
param_grads, logit_grads = jax.grad(batch_loss, argnums=[0, 1])(predict_fn_params, logits, labels)


In [16]:
logit_grads

DeviceArray([[-0.1721328 , -0.12458567,  0.24019186,  0.15012556,
              -0.1415945 ,  0.13709623,  0.03365368],
             [-0.21368742, -0.0111424 ,  0.10205607, -0.04278085,
              -0.09789167, -0.20932147, -0.00832144],
             [ 0.25060448,  0.1436313 ,  0.06111101,  0.08825464,
              -0.01146407,  0.07025305,  0.0851654 ],
             [ 0.3798127 ,  0.00584622, -0.20414168, -0.09796152,
               0.18164098,  0.19643015,  0.1549829 ],
             [ 0.27007744,  0.06186572, -0.11354411,  0.05876795,
               0.1672815 ,  0.15384679, -0.01412305],
             [ 0.02611538,  0.08763807, -0.10241678,  0.04101484,
               0.05331423,  0.0128478 , -0.16085328],
             [-0.04424264, -0.09154289, -0.00784902, -0.13900608,
              -0.09918598, -0.2248016 , -0.04589351]], dtype=float32)

In [18]:
a = jax.numpy.ones((3,3))

In [20]:
b = jax.numpy.ones(3)

In [21]:
b,a

(DeviceArray([1., 1., 1.], dtype=float32),
 DeviceArray([[1., 1., 1.],
              [1., 1., 1.],
              [1., 1., 1.]], dtype=float32))

In [23]:
jax.numpy.vstack((a,b))

DeviceArray([[1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.],
             [1., 1., 1.]], dtype=float32)

In [25]:
a[1, None]

DeviceArray([[1., 1., 1.]], dtype=float32)