# ☝️ Out-of-place indexing

Pytreeclass implements out-of-place indexing similar to jax.numpy array `.at` methods. 

In [None]:
!pip install pytreeclass

In [2]:
import jax 
import jax.numpy as jnp
from pytreeclass import treeclass , tree_viz

### `.at[].get()`

Similar to `jax.numpy.array.at[].get()` , the return value is a new instance of the pytree.
For array values , the operation is equivalent to `array[condition]`. For non-array values , the operation returns None if the get condition does not match the argument.

**`.at[].get()`** on array leaves



In [3]:
@treeclass
class array_leaves:
    a : jnp.ndarray = jnp.array([1,2,3,4,5])
    b : jnp.ndarray = jnp.array([10,20,30,40,50])

x = array_leaves()
print(x)



array_leaves(
  a=
    [1 2 3 4 5],
  b=
    [10 20 30 40 50])


In [4]:
# get less than 20
print(x.at[x<20].get())

array_leaves(
  a=
    [1 2 3 4 5],
  b=
    [10])


In [5]:
# get more than 20
print(x.at[x>20].get())

array_leaves(
  a=
    [],
  b=
    [30 40 50])


**`.at[].get()`** on non-array leaves


In [6]:
@treeclass
class nonarray_leaves:
    a : int = 10 
    b : float = 1. 

x = nonarray_leaves()
print(x)

nonarray_leaves(
  a=
    10,
  b=
    1.0)


In [7]:
# get all values less than 10
print(x.at[x<10].get())

nonarray_leaves(
  a=
    None,
  b=
    1.0)


**`.at[].get()`** On array and non-array leaves

In [8]:
@treeclass
class general_leaves:
    a : int = 10 
    b : float = 1. 
    c : jnp.ndarray = jnp.array([1,2,3,4,5])

x = general_leaves()
print(x)

general_leaves(
  a=
    10,
  b=
    1.0,
  c=
    [1 2 3 4 5])


In [9]:
# `a` is set to None as its larger than 3.
# equivalent of c[c<3] is done on  array like `c`
print(x[x<3])

general_leaves(
  a=
    None,
  b=
    1.0,
  c=
    [1 2])


### `.at[].set()`

Similar to `jax.numpy.array.at[].get()` , the return value is a new instance of the pytree.
For array values , the operation is equivalent to `array[condition]`. For non-array values , the operation returns None if the get condition does not match the argument.

**`.at[].set()`** on array leaves



In [10]:
@treeclass
class array_leaves:
    a : jnp.ndarray = jnp.array([1,2,3,4,5])
    b : jnp.ndarray = jnp.array([10,20,30,40,50])

x = array_leaves()
print(x)

array_leaves(
  a=
    [1 2 3 4 5],
  b=
    [10 20 30 40 50])


In [11]:
# set less than 20 to 0
print(x.at[x<20].set(0))

array_leaves(
  a=
    [0 0 0 0 0],
  b=
    [ 0 20 30 40 50])


In [12]:
# set more than 20 to 0
print(x.at[x>20].set(0))

array_leaves(
  a=
    [1 2 3 4 5],
  b=
    [10 20  0  0  0])


**`.at[].set()`** on non-array leaves


In [13]:
@treeclass
class nonarray_leaves:
    a : int = 10 
    b : float = 1. 

x = nonarray_leaves()
print(x)

nonarray_leaves(
  a=
    10,
  b=
    1.0)


In [14]:
# set all values less than 10 to 0
print(x.at[x<10].set(0))

nonarray_leaves(
  a=
    10,
  b=
    0)


**`.at[].set()`** On array and non-array leaves

In [15]:
@treeclass
class general_leaves:
    a : int = 10 
    b : float = 1. 
    c : jnp.ndarray = jnp.array([1,2,3,4,5])

x = general_leaves()
print(x)

general_leaves(
  a=
    10,
  b=
    1.0,
  c=
    [1 2 3 4 5])


In [16]:
# `a` is untouched as its larger than 3.
# equivalent of c.at[c<3].set(0) is done on  array like `c`
print(x.at[x<3].set(0))

general_leaves(
  a=
    10,
  b=
    0,
  c=
    [0 0 3 4 5])
