# ✂️ Model surgery

This tutorial provides a basic review of model surgery techniques. Because models are basically [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) (nested datastructures like `tuple` or `dict`), this approach applies to manipulating any other pytrees, not just neural network layers.

In [1]:
# !pip install git+https://github.com/ASEM000/serket --quiet

In [2]:
import serket as sk
import jax
import jax.numpy as jnp
import re

## `AtIndexer` basics

`serket.AtIndexer` wraps any pytree (nested container) to manipulate its content in out-of-place fashion. This means that any change will be applied to a _new_ instance of the pytree.

 The following example demonstrate this point:

In [3]:
pytree1 = [1, [2, 3], 4]
indexer: sk.AtIndexer = sk.AtIndexer(pytree1)
pytree2 = indexer[...].get()  # get the whole pytree using ...
print(f"{pytree1=}, {pytree2=}")
# even though pytree1 and pytree2 are the same, they are not the same object
# because pytree2 is a copy of pytree1
print(f"pytree1 is pytree2 = {pytree1 is pytree2}")

pytree1=[1, [2, 3], 4], pytree2=[1, [2, 3], 4]
pytree1 is pytree2 = False


Note that each `[ ]` is selecting at a certain depth, meaning that `[a][b]` is selecting
`a` at depth=1 and `b` at depth=2.

### Integer indexing

`serket.AtIndexer` can edit pytrees by integer paths.

In [4]:
pytree1 = [1, [2, 3], 4]
indexer: sk.AtIndexer = sk.AtIndexer(pytree1)
pytree2 = indexer[1][0].set(100)  # equivalent to pytree1[1][0] = 100

print(f"{pytree1=}, {pytree2=}")

pytree1=[1, [2, 3], 4], pytree2=[1, [100, 3], 4]


### Named path indexing
`serket.AtIndexer` can edit pytrees by named paths.

In [5]:
pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4, "f": {"g": 7, "h": 8}}
indexer: sk.AtIndexer = sk.AtIndexer(pytree1)

In [6]:
# exmaple 1: set the value of pytree1["b"]["c"] to 200
pytree2 = indexer["b"]["c"].set(200)
pytree2

{'a': -1, 'b': {'c': 200, 'd': 3}, 'e': -4, 'f': {'g': 7, 'h': 8}}

In [7]:
# example 2: set the value of pytree1["b"] to 100
pytree3 = indexer["b"].set(100)
pytree3

{'a': -1, 'b': 100, 'e': -4, 'f': {'g': 7, 'h': 8}}

In [8]:
# example 3: set _all leaves_ of  "b" subtree to 100
pytree4 = indexer["b"][...].set(100)
pytree4

{'a': -1, 'b': {'c': 100, 'd': 100}, 'e': -4, 'f': {'g': 7, 'h': 8}}

In [9]:
# example 4: set _all leaves_ of  pytree1["b"] _and_ pytree1["f"] to 100
pytree5 = indexer["b", "f"][...].set(100)
pytree5

{'a': -1, 'b': {'c': 100, 'd': 100}, 'e': -4, 'f': {'g': 100, 'h': 100}}

### Masked indexing
`serket.AtIndexer` can also edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. The following example set all negative entries to 0:

In [10]:
pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4}
mask = jax.tree_map(lambda x: x < 0, pytree1)
indexer: sk.AtIndexer = sk.AtIndexer(pytree1)
pytree2 = indexer[mask].set(0)
pytree2

{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}

Other features include `get`,`apply`,`scan`,`reduce`, and `pluck`. Check the documentation for more examples

## `serket` layers surgery

Similarly, `serket` layers are pytrees as above with `AtIndexer` embeded in `TreeClass` under `.at` property. This design enables powerful composition of both name/index based and boolean based updates. The next example demonstrates this point.


In [11]:
# basic convnet with two convolutional layers
class Net(sk.TreeClass):
    def __init__(self, in_features: int, out_features: int, *, key: jax.Array):
        k1, k2 = jax.random.split(key)
        W1 = jax.random.normal(k1, (out_features, in_features))
        W2 = jax.random.normal(k2, (out_features, out_features))

        self.encoder = {"weight": W1, "bias": jnp.zeros((out_features,))}
        self.decoder = {"weight": W2, "bias": jnp.zeros((in_features,))}

    def __call__(self, x):
        x = x @ self.encoder["weight"] + self.encoder["bias"]
        x = x @ self.decoder["weight"] + self.decoder["bias"]
        return


net1 = Net(3, 5, key=jax.random.PRNGKey(0))
print(f"{net1=}")

net1=Net(
  encoder={
    weight:f32[5,3](μ=0.30, σ=0.90, ∈[-1.44,1.84]), 
    bias:f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00])
  }, 
  decoder={
    weight:f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20]), 
    bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])
  }
)


Now, suppose we want to set the range of `weight` in both layers `conv1` and `conv2` to `[-0.2, 0.2]` by clipping out of bound values.

In [12]:
# example 1: clip the `weights` of `encoder` and `decoder` to [-0.2, 0.2]
net2 = net1.at["encoder", "decoder"]["weight"].apply(lambda x: jnp.clip(x, -0.2, 0.2))
net2

Net(
  encoder={
    bias:f32[5](μ=0.00, σ=0.00, ∈[0.00,0.00]), 
    weight:f32[5,3](μ=0.04, σ=0.18, ∈[-0.20,0.20])
  }, 
  decoder={
    bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00]), 
    weight:f32[5,5](μ=-0.02, σ=0.18, ∈[-0.20,0.20])
  }
)

In [13]:
# example 2: load pretrained weights for `encoder`
pretrained = {"weight": jnp.ones((5, 3)) * 100.0, "bias": jnp.ones((5,)) * 100.0}
net3 = net1.at["encoder"].set(pretrained)
net3

Net(
  encoder={
    weight:f32[5,3](μ=100.00, σ=0.00, ∈[100.00,100.00]), 
    bias:f32[5](μ=100.00, σ=0.00, ∈[100.00,100.00])
  }, 
  decoder={
    weight:f32[5,5](μ=-0.16, σ=0.75, ∈[-1.78,1.20]), 
    bias:f32[3](μ=0.00, σ=0.00, ∈[0.00,0.00])
  }
)