In [3]:
import jax
import jax.numpy as jnp
import numpy as np

print("using jax", jax.__version__)

using jax 0.6.2


In [4]:
# creating arrays

a = jnp.zeros((2,5), dtype= jnp.float32)
print(a)

b = jnp.arange(6) #creates an array of 6 elements (indexing standard at 0)
print(b)

[[0. 0. 0. 0. 0.]
 [0. 0. 0. 0. 0.]]
[0 1 2 3 4 5]


In [5]:
#Immutable tensors

#cannot change b[0] = 1 have to set like below (Pure function type)
b_new = b.at[0].set(1)
print('Original Array:', b)
print('Changed Array:', b_new)

Original Array: [0 1 2 3 4 5]
Changed Array: [1 1 2 3 4 5]


In [6]:
#Want to write main code of JAX in functions that only affect output

def simple_graph(x):
    x = x+2
    x = x**2
    x = x+3
    y = x.mean()
    return y

inp = jnp.arange(3, dtype=jnp.float32)
print('Input', inp)
print('Output', simple_graph(inp))



Input [0. 1. 2.]
Output 12.666667


In [7]:
jax.make_jaxpr(simple_graph)(inp)

#jaxpr representation
#jaxpr ::= {
  #  lambda Var*; Var+.
    #let Eqn*
   # in [Expr+]
#} 


{ [34;1mlambda [39;22m; a[35m:f32[3][39m. [34;1mlet
    [39;22mb[35m:f32[3][39m = add a 2.0:f32[]
    c[35m:f32[3][39m = integer_pow[y=2] b
    d[35m:f32[3][39m = add c 3.0:f32[]
    e[35m:f32[][39m = reduce_sum[axes=(0,)] d
    f[35m:f32[][39m = div e 3.0:f32[]
  [34;1min [39;22m(f,) }

In [8]:
#Automatic Differentiation

grad_function = jax.grad(simple_graph)
gradients = grad_function(inp)
print('Gradient:', gradients)

Gradient: [1.3333334 2.        2.6666667]


In [11]:
print('grad function', grad_function)

grad function <function simple_graph at 0x1179e5ea0>


In [18]:
#Return actual output along with gradients
val_grad_function = jax.value_and_grad(simple_graph)
val_grad_function(inp)

(Array(12.666667, dtype=float32),
 Array([1.3333334, 2.       , 2.6666667], dtype=float32))

In [None]:
#Building XOR Classifier
%pip install flax

Collecting flax
  Downloading flax-0.10.7-py3-none-any.whl.metadata (11 kB)
Collecting msgpack (from flax)
  Downloading msgpack-1.1.1-cp310-cp310-macosx_11_0_arm64.whl.metadata (8.4 kB)
Collecting optax (from flax)
  Downloading optax-0.2.5-py3-none-any.whl.metadata (7.5 kB)
Collecting orbax-checkpoint (from flax)
  Downloading orbax_checkpoint-0.11.20-py3-none-any.whl.metadata (2.3 kB)
Collecting tensorstore (from flax)
  Downloading tensorstore-0.1.76-cp310-cp310-macosx_11_0_arm64.whl.metadata (21 kB)
Collecting rich>=11.1 (from flax)
  Downloading rich-14.1.0-py3-none-any.whl.metadata (18 kB)
Collecting treescope>=0.1.7 (from flax)
  Downloading treescope-0.1.9-py3-none-any.whl.metadata (6.6 kB)
Collecting markdown-it-py>=2.2.0 (from rich>=11.1->flax)
  Downloading markdown_it_py-3.0.0-py3-none-any.whl.metadata (6.9 kB)
Collecting absl-py>=0.7.1 (from optax->flax)
  Downloading absl_py-2.3.1-py3-none-any.whl.metadata (3.3 kB)
Collecting chex>=0.1.87 (from optax->flax)
  Downloading

In [22]:
import flax
from flax import linen as nn

#Flax uses lazy initialization

class MyModule(nn.Module):
    # Some dataclass attributes, like hidden dimension, number of layers, etc. of the form:
    # varname : vartype

    def setup(self):
        # Flax uses "lazy" initialization. This function is called once before you
        # call the model, or try to access attributes. In here, define your submodules etc.
        pass

    def __call__(self, x):
        # Function for performing the calculation of the module.
        pass

#Parameters are kept inside the pytree
#No need to define like Pytorch __init__()



In [23]:
#Simple classifier 
class SimpleClassifier(nn.Module):
    num_hidden : int
    num_output : int

    def setup(self):
        self.linear1 = nn.Dense(features=self.num_hidden)
        self.linear2 = nn.Dense(features=self.num_output)

    def __call__(self,x):
        x = self.linear1(x)
        x = nn.tanh(x)
        x = self.linear2(x)

        return x


In [24]:
#Instead of explicitly layers in the setup function
#Use nn.compact to call in the __call__ func

class SimpleClassifier(nn.Module):
    num_hiddens : int
    num_outputs : int

    @nn.compact #Tells flax to look for defined submodules
    def __call__(self,x):
        #Perform calc while defining necessary layers
        x = nn.Dense(features=self.num_hiddens)
        x = nn.tanh(x)
        x = nn.Dense(features=self.num_outputs)
        return x
    