<a href="https://colab.research.google.com/github/Hashhhhhhhh/JAX-Playground/blob/main/Pure_and_Impure_functions.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Pure and Impure Functions

In [2]:
# Understanding Pure and Impure Functions with JAX

# ## What is a Pure Function?
# A pure function always returns the same output for the same input and does not cause any side effects like modifying external variables or printing.

# ## What is an Impure Function?
# An impure function may cause side effects such as modifying global variables, printing to the console, or writing to files, and its output might depend on or change external state.

import jax
import jax.numpy as jnp

# Pure function example
def pure_multiply(x, y):
    return x * y

print("Pure function output:")
print(pure_multiply(jnp.array([1, 2, 3]), 2))
print(pure_multiply(jnp.array([1, 2, 3]), 2))

# Impure function example: modifying a global variable
counter = 0
def impure_increment():
    global counter
    counter += 1
    return counter

print("\nImpure function output:")
print(impure_increment())
print(impure_increment())

# Impure function example: printing (side effect)
def impure_print(x):
    print("Printing from function:", x)
    return x * 2

print("\nCalling impure_print:")
result = impure_print(3)
print("Result:", result)

# ## Why JAX enforces purity and immutability?
# - It allows JAX to perform optimizations like automatic differentiation and parallelization reliably.
# - Immutable arrays mean every array operation returns a new array without changing the original.

# Trying to mutate a JAX array directly will cause errors or unexpected behavior.

# Create a JAX array and attempt to change it (commented out to avoid runtime error)
# x = jnp.array([1, 2, 3])
# x[0] = 10  # This will raise an error because JAX arrays are immutable

# Instead, create a new array with a modified value
x = jnp.array([1, 2, 3])
x_new = x.at[0].set(10)
print("\nOriginal array:", x)
print("New array after set operation:", x_new)


Pure function output:
[2 4 6]
[2 4 6]

Impure function output:
1
2

Calling impure_print:
Printing from function: 3
Result: 6

Original array: [1 2 3]
New array after set operation: [10  2  3]
