# Tutorial 0: Functional Programming and State

## Functional Programming

TensorNEAT uses functional programming (because it is based on the JAX framework, and JAX is designed for it).

Functional Programming is a programming paradigm that treats computation as the evaluation of mathematical functions and avoids changing state and mutable data. Its main features include:

1. **Pure Functions**: The same input always produces the same output, with no side effects.
2. **Immutable Data**: Once data is created, it cannot be changed. All operations return new data.
3. **Higher-order Functions**: Functions can be passed as arguments to other functions or returned as values.

## State

In TensorNEAT, we use `State` to manage the input and output of functions. `State` can be seem as a python dictionary with additional functions.

Here are some usages about `State`.

In [34]:
# import State
from tensorneat.common import State

# create a new state
state_a = State()  # no arguments
state_b = State(a=1, b=2)  # kwargs

print(f"{state_a=}")
print(f"{state_b=}")

# get items from state, use dot notation
print(f"{state_b.a=}")
print(f"{state_b.b=}")


state_a=State ({})
state_b=State ({'a': 1, 'b': 2})
state_b.a=1
state_b.b=2


In [35]:
# add new items to the state, use register
state_a = state_a.register(a=1, b=2)
print(f"{state_a=}")

# We CANNOT register the existing item
# state_a = state_a.register(a=1)
# will raise ValueError(f"Key {key} already exists in state")

state_a=State ({'a': 1, 'b': 2})


In [36]:
# update the value of an item, use update
state_a = state_a.update(a=3, b=4)
print(f"{state_a=}")

# We CANNOT update the non-existing item
# state_a = state_a.update(c=3)
# will raise ValueError(f"Key {key} does not exist in state")

state_a=State ({'a': 3, 'b': 4})


In [37]:
# State is immutable! We always create a new state, rather than modifying the existing one.

origin_state = State(a=1, b=2)
new_state = origin_state.update(a=3)
print(f"{origin_state=}")  # origin_state is not changed
print(f"{new_state=}")

# We can not modify the state directly
# origin_state.a = 3
# will raise AttributeError: AttributeError("State is immutable")

origin_state=State ({'a': 1, 'b': 2})
new_state=State ({'a': 3, 'b': 2})


In [38]:
# State can be used in JAX functions
import jax


@jax.jit
def func(state):
    c = state.a + state.b  # fetch items from state
    state = state.update(a=c, b=c)  # update items in state
    state = state.register(c=c)  # add new item to state
    return state  # return state


new_state = func(state_a)
print(f"{new_state=}")

new_state=State ({'a': Array(7, dtype=int32, weak_type=True), 'b': Array(7, dtype=int32, weak_type=True), 'c': Array(7, dtype=int32, weak_type=True)})


In [39]:
# Save the state (use pickle) as file and load it.
state = State(a=1, b=2, c=3, d=4)
state.save("tutorial_0_state.pkl")
loaded_state = State.load("tutorial_0_state.pkl")
print(f"{loaded_state=}")

loaded_state=State ({'a': 1, 'b': 2, 'c': 3, 'd': 4})


## Objects in TensorNEAT

In the object-oriented programming (OOP) paradigm, both data and functions are stored in objects. 

In the functional programming used by TensorNEAT, data is stored in the form of JAX Tensors, while functions are stored in objects.

For example, when we create an object `genome`, we are not create a genome instance in the NEAT algorithm. We are actually define some functions!

In [40]:
from tensorneat.genome import DefaultGenome

genome = DefaultGenome(
    num_inputs=3,
    num_outputs=1,
    max_nodes=5,
    max_conns=5,
)

`genome` only stores functions that define the operation of the genome in the NEAT algorithm. 

To create a genome that can participate in calculation, we need to do following things.

In [41]:
# setup the genome, let the genome class store some useful information in State
state = genome.setup()

# create a new genome
randkey = jax.random.key(0)
nodes, conns = genome.initialize(state, randkey)
print(f"{nodes=}")
print(f"{conns=}")
      

nodes=Array([[ 0.        ,  0.5097862 ,  1.        ,  0.        ,  0.        ],
       [ 1.        ,  0.9807121 ,  1.        ,  0.        ,  0.        ],
       [ 2.        , -0.8425486 ,  1.        ,  0.        ,  0.        ],
       [ 3.        , -0.53765106,  1.        ,  0.        ,  0.        ],
       [        nan,         nan,         nan,         nan,         nan]],      dtype=float32, weak_type=True)
conns=Array([[0.        , 3.        , 0.785558  ],
       [1.        , 3.        , 2.3734226 ],
       [2.        , 3.        , 0.07902155],
       [       nan,        nan,        nan],
       [       nan,        nan,        nan]],      dtype=float32, weak_type=True)


In [42]:
# calculate
inputs = jax.numpy.array([1, 2, 3])

transformed = genome.transform(state, nodes, conns)
outputs = genome.forward(state, transformed, inputs)

print(f"{outputs=}")

outputs=Array([5.231817], dtype=float32, weak_type=True)


## Conclusion
1. TensorNEAT use functional programming paradiam.
2. TensorNEAT provides `State` to manage data.
3. In TensorNEAT, objects are responsible for controlling functions, rather than storing data.