<img src="https://raw.githubusercontent.com/google/jax/main/images/jax_logo_250px.png" width="300" height="300" align="center"/><br>

I hope you all enjoyed the first JAX tutorial where we discussed **DeviceArray** and some other fundamental concepts in detail. This is the fifth tutorial in this series, and today we will discuss another important concept specific to JAX. If you haven't looked at the previous tutorials, I highly suggest going through them once. Here are the links:

1. [TF_JAX_Tutorials - Part 1](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part1)
2. [TF_JAX_Tutorials - Part 2](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part2)
3. [TF_JAX_Tutorials - Part 3](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part3)
4. [TF_JAX_Tutorials - Part 4 (JAX and DeviceArray)](https://www.kaggle.com/aakashnain/tf-jax-tutorials-part-4-jax-and-devicearray)


Without any further delay, let's jump in and talk about **pure functions** along with code examples

# Pure Functions

According to [Wikipedia](https://en.wikipedia.org/wiki/Pure_function), a function is pure if:
1. The function returns the same values when invoked with the same inputs
2. There are no side effects observed on a function call

Although the definition looks pretty simple, without examples it can be hard to comprehend and it can sound very vague (especially to the beginners). The first point is clear, but what does a **`side-effect`** mean? What constitutes or is marked as a side effect? What can you do to avoid side effects?

Though I can state all the things here and you can try to "fit" them in your head to make sure that you aren't writing anything that has a side effect, I prefer taking examples so that everyone can understand the "why" part in an easier way. So, let's take a few examples and see some common mistakes that can create side effects

In [1]:
import numpy as np

import jax
import jax.numpy as jnp

from jax import grad
from jax import jit
from jax import lax
from jax import random

%config IPCompleter.use_jedi = False

# Case 1 : Globals

In [2]:
# A global variable
counter = 5

def add_global_value(x):
    """
    A function that relies on the global variable `counter` for
    doing some computation.
    """
    return x + counter

In [3]:
x = 2

# We will `JIT` the function so that it runs as a JAX transformed
# function and not like a normal python function
y = jit(add_global_value)(x)
print("Global variable value: ", counter)
print(f"First call to the function with input {x} with global variable value {counter} returned {y}")

# Someone updated the global value later in the code
counter = 10

# Call the function again
y = jit(add_global_value)(x)
print("\nGlobal variable changed value: ", counter)
print(f"Second call to the function with input {x} with global variable value {counter} returned {y}")

Global variable value:  5
First call to the function with input 2 with global variable value 5 returned 7

Global variable changed value:  10
Second call to the function with input 2 with global variable value 10 returned 7


Wait...What??? What just happened?

When you `jit` your function, JAX tracing kicks in. On the first call, the results would be as expected, but on the subsequent function calls you will get the **`cached`** results unless:
1. The type of the argument has changed or
2. The shape of the argument has changed

Let's see it in action

In [4]:
# Change the type of the argument passed to the function
# In this case we will change int to float (2 -> 2.0)
x = 2.0
y = jit(add_global_value)(x)
print(f"Third call to the function with input {x} with global variable value {counter} returned {y}")

Third call to the function with input 2.0 with global variable value 10 returned 12.0


In [5]:
# Change the shape of the argument
x = jnp.array([2])

# Changing global variable value again
counter = 15

# Call the function again
y = jit(add_global_value)(x)
print(f"Third call to the function with input {x} with global variable value {counter} returned {y}")

Third call to the function with input [2] with global variable value 15 returned [17]


What if I don't `jit` my function in the first place?  ¯\_(ツ)_/¯ <br>
Let's take an example of that as well. We are in no hurry!

In [6]:
def apply_sin_to_global():
    return jnp.sin(jnp.pi / counter)

y = apply_sin_to_global()
print("Global variable value: ", counter)
print(f"First call to the function with global variable value {counter} returned {y}")


# Change the global value again
counter = 90
y = apply_sin_to_global()
print("\nGlobal variable value: ", counter)
print(f"Second call to the function with global variable value {counter} returned {y}")

Global variable value:  15
First call to the function with global variable value 15 returned 0.20791170001029968

Global variable value:  90
Second call to the function with global variable value 90 returned 0.03489949554204941


*`Hooraaayy! Problem solved! You can use JIT, I won't!`*  If you are thinking in this direction, then it's time to remember two things:

1. We are using JAX so that we can transform our native Python code to make it run **faster**
2. We can achieve 1) if we compile (using it loosely here) the code so that it can run on **XLA**, the compiler used by JAX

Hence, avoid using `globals` in your computation because globals introduce **impurity**

# Case 2: Iterators

We will take a very simple example to see the side effect. We will add numbers from `0 to 5` but in two different ways:
1. Passing an actual array of numbers to a function
2. Passing an **`iterator`** object to the same function

In [7]:
# A function that takes an actual array object
# and add all the elements present in it
def add_elements(array, start, end, initial_value=0):
    res = 0
    def loop_fn(i, val):
        return val + array[i]
    return lax.fori_loop(start, end, loop_fn, initial_value)


# Define an array object
array = jnp.arange(5)
print("Array: ", array)
print("Adding all the array elements gives: ", add_elements(array, 0, len(array), 0))


# Redefining the same function but this time it takes an 
# iterator object as an input
def add_elements(iterator, start, end, initial_value=0):
    res = 0
    def loop_fn(i, val):
        return val + next(iterator)
    return lax.fori_loop(start, end, loop_fn, initial_value)
    
    
# Define an iterator
iterator = iter(np.arange(5))
print("\n\nIterator: ", iterator)
print("Adding all the elements gives: ", add_elements(iterator, 0, 5, 0))

Array:  [0 1 2 3 4]
Adding all the array elements gives:  10


Iterator:  <iterator object at 0x7f373c3c5b50>
Adding all the elements gives:  0


Why the result turned out to be zero in the second case?<br>
This is because an `iterator` introduces an **external state** to retrieve the next value.

# Case 3: IO

Let's take one more example, a very **unusual** one that can turn your functions impure.

In [8]:
def return_as_it_is(x):
    """Returns the same element doing nothing. A function that isn't
    using `globals` or any `iterator`
    """
    print(f"I have received the value")
    return x


# First call to the function
print(f"Value returned on first call: {jit(return_as_it_is)(2)}\n")

# Second call to the fucntion with different value
print(f"Value returned on second call: {jit(return_as_it_is)(4)}")

I have received the value
Value returned on first call: 2

Value returned on second call: 4


Did you notice that? The statement **`I have received the value`** didn't get printed on the subsequent call. <br>
At this point, most people would literally say `Well, this is insane! I am not using globals, no iterators, nothing at all and there is still a side effect? How is that even possible?`

The thing is that your function is still **dependent** on an external state. The **print** statement! It is using the standard output stream to print. What if the stream isn't available on the subsequent calls for whatsoever reason? That will violate the first principle of "returning the same thing" when called with the same inputs.


In a nutshell, to keep function pure, don't use anything that depends on an **external state**. The word **external** is important because you can use stateful objects internally and still keep the functions pure. Let's take an example of this as well

# Pure functions with stateful objects

In [9]:
# Function that uses stateful objects but internally and is still pure
def pure_function_with_stateful_obejcts(array):
    array_dict = {}
    for i in range(len(array)):
        array_dict[i] = array[i] + 10
    return array_dict


array = jnp.arange(5)

# First call to the function
print(f"Value returned on first call: {jit(pure_function_with_stateful_obejcts)(array)}")

# Second call to the fucntion with different value
print(f"\nValue returned on second call: {jit(pure_function_with_stateful_obejcts)(array)}")

Value returned on first call: {0: DeviceArray(10, dtype=int32), 1: DeviceArray(11, dtype=int32), 2: DeviceArray(12, dtype=int32), 3: DeviceArray(13, dtype=int32), 4: DeviceArray(14, dtype=int32)}

Value returned on second call: {0: DeviceArray(10, dtype=int32), 1: DeviceArray(11, dtype=int32), 2: DeviceArray(12, dtype=int32), 3: DeviceArray(13, dtype=int32), 4: DeviceArray(14, dtype=int32)}


So, to keep things **pure**, remember not to use anything inside a function that depends on any **external state**, including the IO as well. If you do that, transforming the function would give you unexpected results, and you would end up wasting a lot of time debugging your code when the transformed function returns a cached result, which is ironical because pure functions are easy to debug

# Why pure functions?

A natural question that comes to mind is that why JAX uses pure functions in the first place? No other framework like TensorFlow, PyTorch, mxnet, etc uses it. <br>
Another thing that you must be thinking right is probably this: Using pure functions is such a headache, I never have to deal with these nuances in TF/Torch.

Well, if you are thinking that, you aren't alone but before jumping to any conclusion, consider the advantages of relying on pure functions.

### 1. Easy to debug

The fact that a function is pure implies that you don't need to look beyond the scope of the pure function. All you need to focus on is the arguments, the logic inside the function, and the returned value. That's it! Same inputs => Same outputs


### 2. Easy to parallelize

Let's say you have three functions A, B, and C and there is a computation involved like this one:<br>
 <div style="font-style: italic; text-align: center;">
 `res = A(x) + B(y) + C(z)` <br>
 </div>
 
Because all the functions are pure, you don't have to worry about the dependency on an external state or a shared state. There is no dependency between A, B, and C in terms of how are they executed. Each function receives some argument and returns the same output. Hence you can easily offload the computation to many threads, cores, devices, etc. The only thing that the compiler has to ensure that the results of all the functions (A, b, and C in this case) are available before item assignment


### 3. Caching or Memoization

We saw in the above examples that once we compile a pure function, the function will return a cached result on the subsequent calls. We can cache the results of the transformed functions to make the whole program a lot faster


### 4. Functional Composition

When functions are pure, you can `chain` them to solve complex things in a much easier way. For example, in JAX you will see these patterns very often:
<div style="font-style: italic; text-align: center;">
jit(vmap(grad(..)))
</div>

### 5. Referential transparency

An expression is called referentially transparent if it can be replaced with its corresponding value (and vice-versa) without changing the program's behavior. This can only be achieved when the function is pure. It is especially helpful when doing algebra (which is all we do in ML). For example, consider the expression<br>
 <div style="font-style: italic; text-align: center;">
  x = 5 <br>
  y = 5 <br> 
 z = x + y <br>
 </div>
 
 Now you can replace `x + y` with `z` anywhere in your code, considering the value of `z` is coming from a pure function

That's it for Part-5! We will look into other building blocks in the next few chapters, and then we will dive into building neural networks in JAX! 

**References:**<br>
1. https://jax.readthedocs.io/en/latest/
2. https://alvinalexander.com/scala/fp-book/benefits-of-pure-functions/
3. https://www.sitepoint.com/what-is-referential-transparency/#referentialtransparencyinmaths