In [29]:
import logging
import absl.logging
absl.logging.get_absl_logger().removeHandler(absl.logging._absl_handler)
absl.logging.get_absl_logger().addHandler(logging.NullHandler())
absl.logging.set_verbosity('info')  # Set your desired verbosity level

import os
os.environ['JAX_LOG_COMPILES'] = '1'

In [30]:
import numpy as np

In [31]:
class Context:
    def __init__(self, model, params):
        self.model = model
        self.params = params
        self.old_params = None
        
    def __enter__(self):
        self.old_params = self.model.params
        self.model.params = self.params
    
    def __exit__(self, type, value, traceback):
        self.model.params = self.old_params

In [32]:
class ElementMult:
    def __init__(self):
        self.params = dict()
    
    def _context(self, params):
        return Context(self, params)
    
    def __call__(self, x):
        return x * self.params['weights']

    def apply(self, params, x):
        with self._context(params):
            return self.__call__(x)

In [33]:
model = ElementMult()
model.params['weights'] = np.array([[5, 6, 7], [0, 5, 0]])
model.params['weights'].shape

(2, 3)

In [34]:
sample_input = np.array([1, 1, 2])
model(sample_input)

array([[ 5,  6, 14],
       [ 0,  5,  0]])

In [35]:
new_params = np.array([[1, 1, 1], [0, 5, 0]])
new_params_dict = {'weights': new_params}

model.apply(new_params_dict, sample_input)

array([[1, 1, 2],
       [0, 5, 0]])

# Dense

In [36]:
import sys
sys.path.append('/home/jaxmao/jaxmao_branches/JaxMao/')
sys.path.append('/home/jaxmao/jaxmao_branches/JaxMao/jaxmao')

import jax
import jax.numpy as jnp
import jaxmao
from jaxmao.initializers import *
from jaxmao.layers import Layer, Activation

In [37]:
class SimpleDense:    
    def __init__(
        self, key,
        in_channels, 
        out_channels,
        activation='relu',
        weights_initializer=HeNormal(),
        bias_initializer=zeros_initializer,
        use_bias=True,
    ):
        self.shapes = dict()
        self.initializers = dict()
        self.params = dict()
        
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_bias = use_bias
        self.activation = Activation(activation)

        self.shapes.update({'weights' : (in_channels, out_channels)})
        self.initializers.update({'weights' : weights_initializer,})
        if use_bias:
            self.shapes.update({'biases'  : (out_channels, )})
            self.initializers.update({'biases': bias_initializer})
            self.params['biases'] = self.initializers['biases'](key, self.shapes['biases'])

        self.params['weights'] = self.initializers['weights'](key, self.shapes['weights'])
    
    def __call__(self, x):
        return jnp.dot(x, self.params['weights'])

In [38]:
key = jax.random.key(42)

simple_dense = SimpleDense(key, 5, 4, use_bias=False)
print('params:\n', simple_dense.params)

sample_input = np.ones((16, 5))
simple_dense.__call__ = jax.jit(simple_dense.__call__)
print('output:\n', simple_dense(sample_input))

params:
 {'weights': Array([[-0.09176069,  1.1012427 , -0.8198022 ,  0.98706996],
       [-0.5258742 , -0.90167457, -0.2776343 , -0.1228469 ],
       [ 0.3729634 , -0.19016472, -0.36814284,  0.10331149],
       [ 0.1593394 ,  0.18784064, -0.3971092 , -0.01390252],
       [ 0.36769372, -0.01763257,  1.3538171 , -0.07520011]],      dtype=float32)}
output:
 [[ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.5088715   0.8784319 ]
 [ 0.2823616   0.17961144 -0.50

In [39]:
simple_dense.__call__

<PjitFunction of <bound method SimpleDense.__call__ of <__main__.SimpleDense object at 0x7f1433f9e590>>>

In [40]:
simple_dense.params = {'weights': jax.random.normal(jax.random.key(0), (5, 4))}

print(simple_dense.__call__)

<PjitFunction of <bound method SimpleDense.__call__ of <__main__.SimpleDense object at 0x7f1433f9e590>>>


In [41]:
simple_dense(sample_input)

Array([[ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ],
       [ 3.2032537 , -0.81059504,  0.75257564,  1.5382664 ]],      dtype=float32)

In [42]:
simple_dense.__call__

<PjitFunction of <bound method SimpleDense.__call__ of <__main__.SimpleDense object at 0x7f1433f9e590>>>

### limitation of jit

different shapes

In [43]:
sample_input2 = np.ones((32, 5))
simple_dense(sample_input2)

Array([[ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.810595  ,  0.75257564,  1.5382662 ],
       [ 3.203254  , -0.

# context manager

In [44]:
class Context:
    def __init__(self, model, params):
        self.model = model
        self.params = params
        self.old_params = None
        
    def __enter__(self):
        self.old_params = self.model.params
        self.model.params = self.params
    
    def __exit__(self, type, value, traceback):
        self.model.params = self.old_params

class Module:
    def __init__(self):
        pass
    
    def _context(self, params):
        return Context(self, params)
    
    def apply(self, params, x):
        with self._context(params):
            return self.__call__(x)

class ContextDense(Module):    
    """Dense with Context Mangaer"""
    def __init__(
        self, key,
        in_channels, 
        out_channels,
        activation='relu',
        weights_initializer=HeNormal(),
        bias_initializer=zeros_initializer,
        use_bias=True,
    ):
        self.shapes = dict()
        self.initializers = dict()
        self.params = dict()
        
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.use_bias = use_bias
        self.activation = Activation(activation)

        self.shapes.update({'weights' : (in_channels, out_channels)})
        self.initializers.update({'weights' : weights_initializer,})
        if use_bias:
            self.shapes.update({'biases'  : (out_channels, )})
            self.initializers.update({'biases': bias_initializer})
            self.params['biases'] = self.initializers['biases'](key, self.shapes['biases'])

        self.params['weights'] = self.initializers['weights'](key, self.shapes['weights'])
    
    def __call__(self, x):
        return jnp.dot(x, self.params['weights'])

In [45]:
dense = ContextDense(jax.random.key(42), 5, 4, use_bias=False)
print('params:\n', dense.params)

dense(sample_input)

params:
 {'weights': Array([[-0.09176069,  1.1012427 , -0.8198022 ,  0.98706996],
       [-0.5258742 , -0.90167457, -0.2776343 , -0.1228469 ],
       [ 0.3729634 , -0.19016472, -0.36814284,  0.10331149],
       [ 0.1593394 ,  0.18784064, -0.3971092 , -0.01390252],
       [ 0.36769372, -0.01763257,  1.3538171 , -0.07520011]],      dtype=float32)}


Array([[ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ],
       [ 0.2823616 ,  0.17961144, -0.5088715 ,  0.8784319 ]],      dtype=float32)

In [46]:
new_params_dict = {'weights': jax.random.normal(jax.random.key(22), (5, 4))}
dense.apply(new_params_dict, sample_input)

Array([[-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481],
       [-1.8329227, -0.6777953, -1.0361731, -1.9786481]], dtype=float32)

In [49]:
new_params_dict = {'weights': jax.random.normal(jax.random.key(22), (5, 16))}
dense.apply(new_params_dict, sample_input)

Finished tracing + transforming <lambda> for pjit in 0.0010955333709716797 sec
Finished tracing + transforming fn for pjit in 0.0015070438385009766 sec
Finished tracing + transforming fn for pjit in 0.001615285873413086 sec
Finished tracing + transforming _uniform for pjit in 0.013531923294067383 sec
Finished tracing + transforming _normal_real for pjit in 0.016244173049926758 sec
Finished tracing + transforming _normal for pjit in 0.018434524536132812 sec
Compiling _normal for with global shapes and types [ShapedArray(key<fry>[])]. Argument mapping: (GSPMDSharding({replicated}),).
Finished tracing + transforming ravel for pjit in 0.001171112060546875 sec
Finished tracing + transforming threefry_2x32 for pjit in 0.004179477691650391 sec
Finished tracing + transforming _threefry_random_bits_original for pjit in 0.006880044937133789 sec
Finished jaxpr to MLIR module conversion jit(_normal) in 0.020771503448486328 sec
Finished XLA compilation of jit(_normal) in 0.13342642784118652 sec
Fin

Finished jaxpr to MLIR module conversion jit(dot) in 0.002245187759399414 sec
Finished XLA compilation of jit(dot) in 0.17280793190002441 sec


Array([[-0.7190635 ,  2.7677035 , -1.1061106 ,  3.614097  ,  3.0745487 ,
         2.2379713 ,  2.447216  ,  0.38021234,  5.288913  , -2.5442004 ,
        -0.650793  ,  1.577282  ,  2.1357284 , -1.0383855 , -0.44585156,
        -0.99454516],
       [-0.7190635 ,  2.7677035 , -1.1061106 ,  3.614097  ,  3.0745487 ,
         2.2379713 ,  2.447216  ,  0.38021234,  5.288913  , -2.5442004 ,
        -0.650793  ,  1.577282  ,  2.1357284 , -1.0383855 , -0.44585156,
        -0.99454516],
       [-0.7190635 ,  2.7677035 , -1.1061106 ,  3.614097  ,  3.0745487 ,
         2.2379713 ,  2.447216  ,  0.38021234,  5.288913  , -2.5442004 ,
        -0.650793  ,  1.577282  ,  2.1357284 , -1.0383855 , -0.44585156,
        -0.99454516],
       [-0.7190635 ,  2.7677035 , -1.1061106 ,  3.614097  ,  3.0745487 ,
         2.2379713 ,  2.447216  ,  0.38021234,  5.288913  , -2.5442004 ,
        -0.650793  ,  1.577282  ,  2.1357284 , -1.0383855 , -0.44585156,
        -0.99454516],
       [-0.7190635 ,  2.7677035 , -1