## Import relevant libraries

In [1]:
import jax
import jax.numpy as jnp
import time

## The lambda function

In [2]:
sq_fn = lambda x:jnp.square(x)

In [3]:
sq_fn

<function __main__.<lambda>(x)>

In [4]:
sq_fn(3)

Array(9, dtype=int32, weak_type=True)

In [5]:
def reg_sq_fn(x):
    return jnp.square(x)

In [6]:
reg_sq_fn

<function __main__.reg_sq_fn(x)>

In [7]:
reg_sq_fn(3)

Array(9, dtype=int32, weak_type=True)

1. Lambda functions are "anonymous" functions in Python and are written as one-liners
2. We can see that these functions are "anonymous" or wihtout name by printing out "sq_fn" and "reg_sq_fn" and comparing them
3. The syntax of these lambda functions is of the form "lambda argument : expression"
4. The form of the syntax of these lambda functions is very convenient for automatic differentiation purposes as it amounts to a statement of the form "lambda x : f(x)" and then one is taking the derivative of "f" with respect to "x"

## Memory location

In [8]:
id(sq_fn)

1510304247584

In [9]:
id(reg_sq_fn)

1510304247904

1. The memory location tells us where our functions live or have been stored
2. Retrieving memory location will be required to demonstrate a subtle point in regards to our pinn

## List comprehension

In [10]:
y_lc = [jnp.square(x) for x in range(5)]

In [11]:
y_lc

[Array(0, dtype=int32, weak_type=True),
 Array(1, dtype=int32, weak_type=True),
 Array(4, dtype=int32, weak_type=True),
 Array(9, dtype=int32, weak_type=True),
 Array(16, dtype=int32, weak_type=True)]

In [12]:
y_reg = []
for x in range(5):
    y_reg.append(jnp.square(x))

In [13]:
y_reg

[Array(0, dtype=int32, weak_type=True),
 Array(1, dtype=int32, weak_type=True),
 Array(4, dtype=int32, weak_type=True),
 Array(9, dtype=int32, weak_type=True),
 Array(16, dtype=int32, weak_type=True)]

1. List comprehensions offer compact syntax to create lists that utilize a one-liner for loop
2. In this example we create a list containing squares of integers from 0 to 5
3. We create two lists - one with list comprehension and on with a for loop to constrast both approaches

## The enumerate operation

In [14]:
for idx, item in enumerate(y_lc):
    print(idx, 2*idx, item)

0 0 0
1 2 1
2 4 4
3 6 9
4 8 16


1. The enumerate operation is used when we want to loop through a collection of items and simultaneously access its location
2. Here, we use enurate to loop through a list to extract the index which is the same as location
3. We then print the index, the index doubled and the original item which is the index squared

## zip objects

In [15]:
list_one = [1, 2, 3]
list_two = [99, 98, 97]
zip_obj = zip(list_one, list_two)

In [16]:
type(zip_obj)

zip

In [17]:
zip_obj

<zip at 0x15fa694fcc0>

In [18]:
for tuple_obj in zip_obj:
    (num_1, num_2) = tuple_obj
    print(num_1, num_2, num_1 + num_2)

1 99 100
2 98 100
3 97 100


1. zip objects are iterators of tuples
2. By an interator, we mean that we can loop through zip objects
3. zip objects are a collection of tuples, meaning that we can loop through a collection of tuples
4. zip objects are used when we want to simultaneously iteratre or loop through mutliple collections of items
5. Here, we simultaneously loop through two lists which when done via "zip" results in looping through a collection of tuples where each tuple is made out of the items of the list corresponding to the same index
6. In this example, we loop through the zip object and extract the entries of the tuple which are numbers and print their sum

## Decorators and just-in-time (jit) compilers in JAX

In [19]:
def decorator(func):
    def wrapper():
        print("Something is happening before the function is called.")
        func()
        print("Something is happening after the function is called.")
    return wrapper

In [20]:
def print_hello_world():
    print("hello world!")

In [21]:
@decorator
def print_hi_world():
    print("hi world!")

In [22]:
print_hello_world()

hello world!


In [23]:
print_hi_world()

Something is happening before the function is called.
hi world!
Something is happening after the function is called.


1. Decorators can be thought of as "enhancers" of functions as they extend the capabilities of functions!
3. Decorators are used to invoke the just-in-time (jit) compiler in JAX
4. The jit compiler leads to massive speed-ups in processing
5. The jit compiler along with vmap are two key features which make JAX very attractive for pinns!

## Array slicing

In [24]:
array_example = [1, 2, 3, 4, 5]

In [25]:
array_example[:3]

[1, 2, 3]

In [26]:
array_example[2:4]

[3, 4]

In [27]:
array_example[:-1]

[1, 2, 3, 4]

1. Array slicing is used to literally "slice" an array
2. The syntax is array[first_index:last_index] which leads to an array with elements starting at array[first_index] and going upto array[last_index] but excluding it
3. The syntax array[:last_index] means do not slice the array on the left
4. The syntax array[-1] is used to access the last element of the array and therefore array[:-1] means do not slice the array on the left and keep everything upto the last element but not the last element

## Type hints

In [28]:
def add_two_integers(first_num:int, second_num:int) -> int:
    return first_num + second_num

In [29]:
add_two_integers(1, 2)

3

In [30]:
add_two_integers(1.1, 2.2)

3.3000000000000003

1. Type hints are used to specify the desired Python type of parameters of interest
2. Here, the parameters of interest are arguments to a function
3. The type hints are not strictly enforced as can be seen from the example
4. The type hints mainly serve for clarity and improved readability of code

## Random keys in JAX

In [31]:
key = jax.random.PRNGKey(0)

In [32]:
key, new_key = jax.random.split(key)

1. The random keys are typically required for reproducibility of code and for initializing neural networks
2. The "jax.random.PRNGKey" line of code provides the initial key required for initialization
3. Based on the neural network architecture, this primary key might have to be modified in addition to generating a new key
4. The "jax.random.split(key)" line of code provides the modified key in addition to generating a new key and we think of this operation as "splitting" the initial key

## Vectorized mapping or vmap in JAX

In [33]:
def vec_norm(vec):
    x = vec[0]
    y = vec[1]
    n = jnp.square(x) + jnp.square(y)
    return n

In [34]:
theta = jnp.linspace(0, jnp.pi, 100000)
x = jnp.cos(theta).reshape(-1,1)
y = jnp.sin(theta).reshape(-1,1)
vecs = jnp.concatenate((x, y), axis=1)

In [35]:
start_time_reg = time.time()
list_norm_reg = [vec_norm(vec) for vec in vecs]
end_time_reg = time.time()
exec_time_reg = end_time_reg - start_time_reg

In [36]:
start_time_vm = time.time()
list_norm_vm = jax.vmap(vec_norm)(vecs)
end_time_vm = time.time()
exec_time_vm = end_time_vm - start_time_vm

In [37]:
print(f"execution time (regular): {round(exec_time_reg,2)} seconds")
print(f"execution time (vmap): {round(exec_time_vm,2)} seconds")

execution time (regular): 35.98 seconds
execution time (vmap): 0.14 seconds


1. vmap vectorizes functions by taking a function that acts on a object like a vector and extending its functionality so that it acts on an entire batch with a collection of such objects
2. In case of custom functions that are not vectorized "out-of-the-box", the most rudimentary way of vectorizing such functions might be via for loops like in the example shown here
3. The use of rudimentary iterators like loops creates speed bottlenecks for vectorizing the function
4. vmap provides an extremely speedy implementation that handles function vectorization!
5. The example of the custom function shown here is "standard" enough to vectorize based on "out-of-the-box" functions and it has been used merely to illustrate a point!
6. In real problems, we will be faced with custom functions that are hard to vectorize and that is where the power of vmap really comes in handy!
7. When we go through the main code, we will be vectorizing a custom function that converts orientation in the form of quaternion directly to angular representations and vmap really helps there!
8. vmap also helps in terms of conceptually thinking and formulating a function for a scalar input and then vectorizing it - for example, we might be interested in computing the velocity vector by defining a function that automatically differentiates the position vector at a given time "t" - then we can apply to this an entire array of temporal observations via vmap - example below!

## Automatic differentiation using jacrev in JAX

In [38]:
dr_dt = jax.jacrev(lambda t:jnp.array([t, jnp.square(t)]))

In [39]:
t_val = jnp.linspace(0, 2, 10)

In [40]:
t_val[0]

Array(0., dtype=float32)

In [41]:
dr_dt(t_val[0])

Array([1., 0.], dtype=float32)

In [42]:
dr_dt_vec = jax.vmap(dr_dt)

In [43]:
dr_dt_vec(t_val)

Array([[1.        , 0.        ],
       [1.        , 0.44444445],
       [1.        , 0.8888889 ],
       [1.        , 1.3333334 ],
       [1.        , 1.7777778 ],
       [1.        , 2.2222223 ],
       [1.        , 2.6666667 ],
       [1.        , 3.1111112 ],
       [1.        , 3.5555556 ],
       [1.        , 4.        ]], dtype=float32)

1. The reverse mode automatic differentiation in JAX allows us to differentiate vector-valued functions
2. For example - position vectors are typically vector-value functions of time whose differentiation with respect to time yields the velocity vector
3. The grad functionality within JAX is designed to only handle scalar-valued functions which is why it is not applicable here
4. The "jac" in "jacrev" stands for "jacobian" which is nothing but the gradient of vector-valued functions
5. The "rev" in "jacrev" stands for "reverse-mode" automatic differentiation
6. The "reverse-mode" automatic differentiation is typically used when the number of inputs are much larger than the number of outputs
7. We are interested in representing physical quantities of interest with neural networks which do have a lot of input parameters. However, we plan on carrying out explicit automatic differentiation to compute physical derivatives where the "active" input is time and thus, we are dealing with a scalar input for the purposes of automatic differentiation for obtaining physical derivatives
8. Given that the "true" number of inputs to the neural network is a large number but the "active" input is just one single number, the choice between forward and reverse mode was not clear to me - and since both approaches provide accurate answers and given that speed was not a major concern - I decided to use reverse mode automatic differentiation!
9. Here, we define the "dr_dt" function which differentiates the vector-valued position vector with respect to time - which is a scalar-valued input
10. The vmap functionality within JAX allows us to apply the "dr_dt" function to the entire temporal array thus leading to the computation of the velocity vector from the vector-valued position vector function for each discrete time instant in the temporal array
11. We will now discuss the same projectile motion example in the light of Object-oriented Programming (OOP)

## OOP basics

In [44]:
class TimeStep:
    def __init__(self, position, velocity):
        self.position = position  # Assume position is a tuple (x, y)
        self.velocity = velocity  # Assume velocity is a tuple (vx, vy)
        self.dt = 0.01

    def move(self):
        new_x = self.position[0] + self.velocity[0] * self.dt
        new_y = self.position[1] + self.velocity[1] * self.dt
        self.position = (new_x, new_y)

    def __call__(self):
        self.move()
        return self.position

class Motion(TimeStep):
    def __init__(self, position, velocity):
        super().__init__(position, velocity)

    def compute_trajectory(self, tf):
        vx = 1.0
        t = 0.0
        while t < tf:
            vy = 2 * t
            self.velocity = (vx, vy)
            self.move()
            t += self.dt

    def __call__(self, tf):
        self.compute_trajectory(tf)
        return self.position

In [45]:
projectile = Motion(position=(0.0, 0.0), velocity=(0.0, 0.0))
projectile_position_after_2s = projectile(tf=2.0)
print(f"Projectile position after 2 seconds: {round(projectile_position_after_2s[0], 2), round(projectile_position_after_2s[1], 2)}")

Projectile position after 2 seconds: (2.0, 3.98)


1. Classes can be thought of as a blueprint for creating reproducible and structured data containers
3. Objects are embodiments of such classes
4. The properties associated with such objects are called attributes
5. Constructors are used to create objects belonging to a certain class and the "init" method is an example of a constructor. Everytime we create an object of the class with the syntax "example_object = ClassName(parameters)", the code inside the "init" method runs and initializes the attributes by making assignments based on the parameters used by the "init" method
6. We have used the term "method" which can simply be thought of as a function
8. The "self" keyword refers to the instance of a class and it is standard practice to use self to assign attributes. For example, our TimeStep and Motion classes will have instances that have the properties of position and velocity - the "self" keyword is used to attribute the position and velocity information to these instances of the class
9. Once we create an object, we might want to repetitively perform a set of instructions on this object. In the example here, once we create an object of the class Motion by providing it with an initial position and velocity, we might want to perform the kinematic motion sequence multiple times to get trajectories of different temporal lengths
10. We have also used the concept of "Parent" and "Child" class in this example. The "Child" class inherits properties from the "Parent" class. Here, the Motion class has all the properties of the TimeStep class. The Motion class computes a trajectory which requires stepping through time. The TimeStep class has a method called "move" which carries out this function. By inheriting properties from the TimeStep class, the Motion class has access to the "move" method which it can use for stepping through time as part of computing the trajectory. This inheritance is carried out via the syntax "class Child(Parent)"
11. The "super" keyword is used to initialize attributes of the parent class and ensures that the initialization logic of the parent class is executed before the child class allowing us to build on and extend the functionality of the parent class