# 👆 Filtering PyTrees

`PyTreeClass` offers four means of filtering through the `.at[]` property:

- Filter by value
- Filter by field name
- Filter by field type
- Filter by field metadata.

To enable filterning/masking, `PyTreeClass` uses `tree_map` to broadcast standard mathematical operations to be applied on `PyTrees`. 
For example:

In [1]:
import jax.numpy as jnp 
import pytreeclass as pytc 

@pytc.treeclass
class Test:
    a: int = 1
    b: jnp.ndarray = jnp.array([1, 2, 3])

t = Test()

print(t+1)  # Test(a=2,b=[2 3 4])
print(t+t)  # Test(a=2,b=[2 4 6])
print(t>1)  # Test(a=False,b=[False  True  True])

Test(a=2,b=[2 3 4])
Test(a=2,b=[2 4 6])
Test(a=False,b=[False  True  True])


Full operations implemented are :
```python
__abs__ 
__add__ 
__radd__ 
__and__ 
__rand__ 
__eq__ 
__floordiv__ 
__ge__
__gt__
__inv__ 
__invert__ 
__le__
__lshift__ 
__lt__
__matmul__ 
__mod__ 
__mul__ 
__rmul__ 
__ne__ 
__neg__ 
__not__
__or__ 
__pos__ 
__pow__ 
__rshift__ 
__sub__ 
__rsub__ 
__truediv__ 
__xor__ 
```

Continuing with an example that demonstrates the usage the filtering. Suppose you have the following (Multilayer perceptron) MLP class. 

Let's use `dataclasses.field` metadata property, to give some meaning to each field. In our case we simply describe `l1` as the initial layer, and `l2` position as head layer. It's totally up to you to define a `dict` to give a meaning to your field.

In [2]:
import jax
from jax import numpy as jnp
import pytreeclass as pytc
import matplotlib.pyplot as plt
from dataclasses import  field 

@pytc.treeclass
class Linear :
   weight : jnp.ndarray
   bias   : jnp.ndarray

   def __init__(self,key,in_dim,out_dim):
       self.weight = jax.random.normal(key,shape=(in_dim, out_dim)) * jnp.sqrt(2/in_dim)
       self.bias = jnp.ones((1,out_dim))

   def __call__(self,x):
       return x @ self.weight + self.bias

@pytc.treeclass
class StackedLinear:
    l1 : Linear = field(metadata={"description": "initial"})
    l2 : Linear = field(metadata={"position": "head"})

    def __init__(self,key,in_dim,out_dim,hidden_dim):
        keys= jax.random.split(key,3)

        self.l1 = Linear(key=keys[0],in_dim=in_dim,out_dim=hidden_dim)
        self.l2 = Linear(key=keys[2],in_dim=hidden_dim,out_dim=out_dim)

    def __call__(self,x):
        x = self.l1(x)
        x = jax.nn.tanh(x)
        x = self.l2(x)

        return x
        
model = StackedLinear(in_dim=1,out_dim=1,hidden_dim=5,key=jax.random.PRNGKey(0))

Let's see the model raw values before doing any filtering. For this let's use the pretty printing `__str__` of the model.

In [3]:
print(model)

StackedLinear(
  l1=Linear(
    weight=[[-1.6248673  -2.8383057   1.3969219   1.3169124  -0.40784812]],
    bias=[[1. 1. 1. 1. 1.]]
  ),
  l2=Linear(
    weight=
      [[ 0.98507565]
       [ 0.99815285]
       [-1.0687716 ]
       [-0.19255024]
       [-1.2108876 ]],
    bias=[[1.]]
  )
)


## Boolean filtering

### Filter by Value
Now, suppose we need to filter out all negative values, in `numpy` we can do something like this `array[array<0]`. `PyTreeClass` builds on this familiar idea and does exactly the same with the functional `.at` property (i.e. returns a new filtered copy of the model)

* Get all negative values

In [4]:
print(model.at[model<0].get())

StackedLinear(
  l1=Linear(
    weight=[-1.6248673  -2.8383057  -0.40784812],
    bias=[]
  ),
  l2=Linear(
    weight=[-1.0687716  -0.19255024 -1.2108876 ],
    bias=[]
  )
)


* Set negative values to 0

In [5]:
print(model.at[model<0].set(0))

StackedLinear(
  l1=Linear(
    weight=[[0.        0.        1.3969219 1.3169124 0.       ]],
    bias=[[1. 1. 1. 1. 1.]]
  ),
  l2=Linear(
    weight=
      [[0.98507565]
       [0.99815285]
       [0.        ]
       [0.        ]
       [0.        ]],
    bias=[[1.]]
  )
)


* Apply f(x)=x^2 to negative values

In [6]:
print(model.at[model<0].apply(lambda x:x**2))

StackedLinear(
  l1=Linear(
    weight=[[2.6401937  8.05598    1.3969219  1.3169124  0.16634008]],
    bias=[[1. 1. 1. 1. 1.]]
  ),
  l2=Linear(
    weight=
      [[0.98507565]
       [0.99815285]
       [1.1422727 ]
       [0.03707559]
       [1.4662486 ]],
    bias=[[1.]]
  )
)


* Sum all negative values

Here, we use `reduce` function with function arguments corresponds to accumulated,current value. 

In [7]:
print(model.at[model<0].reduce(lambda acc,cur: acc+jnp.sum(cur)))

-7.3432307


### Filter by field name

In scenarior where nested layers with same name exists, it might be helpful to select a layer by it's name. However, filterning by name become more powerful when combined with other masks.

Let's first see how can we use `.at[].{get,set,apply,reduce}` with string filtering

* Get all fields named l1

In [8]:
print(model.at[model == "l1"].get())

StackedLinear(
  l1=Linear(
    weight=[-1.6248673  -2.8383057   1.3969219   1.3169124  -0.40784812],
    bias=[1. 1. 1. 1. 1.]
  ),
  l2=Linear(weight=[],bias=[])
)


Its noted that If entire array does not satisfy a condition, an empty array is returned.
Similarly we can filter by field type and field metadata as the following.

### Filter by field type

* Get all fields of Linear type

In [9]:
print(model.at[model == Linear].get())

StackedLinear(
  l1=Linear(
    weight=[-1.6248673  -2.8383057   1.3969219   1.3169124  -0.40784812],
    bias=[1. 1. 1. 1. 1.]
  ),
  l2=Linear(
    weight=[ 0.98507565  0.99815285 -1.0687716  -0.19255024 -1.2108876 ],
    bias=[1.]
  )
)


### Filter by field metadata

* Get all fields of with {"description": "First layer"} in their metadata

In [10]:
print(model.at[model == {"description": "First layer"}].get())

StackedLinear(l1=Linear(weight=[],bias=[]),l2=Linear(weight=[],bias=[]))
