# ✂️ Surgery

This notebook provides tree editing (surgery) recipes using `at`. `at` encapsules a pytree and provides a way to edit it in out-of-place manner. Under the hood, `at` uses `jax.tree_util` or `optree` to traverse the tree and apply the provided function to the selected nodes.

```python
import sepes as sp
import re
tree = dict(key_1=[1, -2, 3], key_2=[4, 5, 6], other=[7, 8, 9])
where = re.compile("key_.*")  # select all keys starting with "key_"
func = lambda node: sum(map(abs, node))  # sum of absolute values
sp.at(tree)[where].apply(func)
# {'key_1': 6, 'key_2': 15, 'other': [7, 8, 9]}
```


In [1]:
!pip install sepes

## Out-of-place editing

Out-of-place means that the original tree is not modified. Instead, a new tree is created with the changes. This is the default behavior of `at`. The following example demonstrates this behavior:

In [2]:
import sepes as sp

pytree1 = [1, [2, 3], 4]
pytree2 = sp.at(pytree1)[...].get()  # get the whole pytree using ...
print(f"{pytree1=}, {pytree2=}")
# even though pytree1 and pytree2 are the same, they are not the same object
# because pytree2 is a copy of pytree1
print(f"pytree1 is pytree2 = {pytree1 is pytree2}")

pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]
pytree1 is pytree2 = False


## Matching keys

### Match all

Use `...` to match all keys. 

The following example applies `plus_one` function to all tree nodes. This is equivalent to `tree = tree_map(plus_one, tree)`.

In [3]:
import sepes as sp

pytree1 = [1, [2, 3], 4]
plus_one = lambda x: x + 1
pytree2 = sp.at(pytree1)[...].apply(plus_one)
pytree2

[2, [3, 4], 5]

### Integer indexing

`at` can edit pytrees by integer paths.

In [4]:
import sepes as sp

pytree1 = [1, [2, 3], 4]
pytree2 = sp.at(pytree1)[1][0].set(100)  # equivalent to pytree1[1][0] = 100
pytree2

[1, [100, 3], 4]

### Named path indexing
`at` can edit pytrees by named paths.

In [5]:
import sepes as sp

pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4, "f": {"g": 7, "h": 8}}
pytree2 = sp.at(pytree1)["b"].set(100)  # equivalent to pytree1["b"] = 100
pytree2

{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}

### Regular expressions indexing
`at` can edit pytrees by regular expressions applied to keys.

In [6]:
import sepes as sp
import re

pytree1 = {
    "key_1": 1,
    "key_2": {"key_3": 3, "key_4": 4},
    "key_5": 5,
    "key_6": {"key_7": 7, "key_8": 8},
}
# from 1 - 5, set the value to 100
pattern = re.compile(r"key_[1-5]")
pytree2 = sp.at(pytree1)[pattern].set(100)
pytree2

{'key_1': 100, 'key_2': 100, 'key_5': 100, 'key_6': {'key_7': 7, 'key_8': 8}}

### Mask indexing

`at` can edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. 


The following example set all negative tree nodes to zero.

In [7]:
import sepes as sp
import jax

pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4}
# mask defines all desired entries to apply the function
mask = jax.tree_util.tree_map(lambda x: x < 0, pytree1)
pytree2 = sp.at(pytree1)[mask].set(0)
pytree2

{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}

## Composition

`at` can compose multiple keys, integer paths, named paths, regular expressions, and masks to edit the tree.

The following example demonstrates how to apply a function to all nodes with:
- Name starting with "key_"
- Positive values

In [8]:
import sepes as sp
import re
import jax

pytree1 = {"key_1": 1, "key_2": -2, "value_1": 1, "value_2": 2}
pattern = re.compile(r"key_.*")
mask = jax.tree_util.tree_map(lambda x: x > 0, pytree1)
pytree2 = sp.at(pytree1)[pattern][mask].set(100)
pytree2

{'key_1': 100, 'key_2': -2, 'value_1': 1, 'value_2': 2}