# ✂️ Model surgery

This tutorial provides a basic review of model surgery techniques. Because models are basically [pytrees](https://jax.readthedocs.io/en/latest/pytrees.html) (nested datastructures like `tuple` or `dict`), this approach applies to manipulating any other pytrees, not just neural network layers.

In [None]:
!pip install git+https://github.com/ASEM000/serket --quiet

## `AtIndexer` basics

`serket.AtIndexer` wraps any pytree to manipulate its content in out-of-place fashion. This means that any change will be applied on a _new_ instance of the pytree. The following example demonstrate this point:

In [1]:
import serket as sk
pytree1 = [1, [2, 3], 4]
pytree2 = sk.AtIndexer(pytree1)[1][0].set(100)  # equivalent to pytree[1][0] = 100
print(pytree2)
# [1, [100, 3], 4]
pytree1 is pytree2  # test out-of-place update

[1, [100, 3], 4]


False

`serket.AtIndexer` can also edit pytree entries by a boolean mask, meaning that given a mask of the same structure of the pytree, then nodes marked `True` will be edited, otherwise will not be touched. The following example set all negative entries to 0:

In [2]:
import serket as sk
import jax

pytree1 = {"a": -1, "b": {"c": 2, "d": 3}, "e": -4}
mask = jax.tree_map(lambda x: x < 0, pytree1)
pytree2 = sk.AtIndexer(pytree1)[mask].set(0)
print(pytree2)
# {'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}
pytree1 is pytree2  # test out-of-place update
# False

{'a': 0, 'b': {'c': 2, 'd': 3}, 'e': 0}


False

## `serket` layers surgery

Similarly, `serket` layers are pytrees as above. Howver, `AtIndexer` is embeded in `TreeClass` under `.at` property, this design enables powerful composition of both name/index based and boolean based updates. The next example demonstrates this point.


In [3]:
import serket as sk
import jax
import jax.numpy as jnp


# basic convnet with two convolutional layers
class ConvNet(sk.TreeClass):
    def __init__(self, indim, outdim, key):
        k1, k2 = jax.random.split(key)
        self.conv1 = sk.nn.Conv2D(indim, outdim, 3, key=k1)
        self.conv2 = sk.nn.Conv2D(outdim, 1, 1, key=k2)

    def __call__(self, x):
        x = self.conv1(x)
        x = jax.nn.relu(x)
        x = self.conv2(x)
        return x


cnn1 = ConvNet(3, 10, jax.random.PRNGKey(0))

# note that `ConvNet` is composed of two branches
print(sk.tree_diagram(cnn1, depth=2))

ConvNet
├── .conv1:Conv2D
│   ├── .in_features=3
│   ├── .out_features=10
│   ├── .kernel_size=(...)
│   ├── .strides=(...)
│   ├── .padding=same
│   ├── .dilation=(...)
│   ├── .weight_init=glorot_uniform
│   ├── .bias_init=zeros
│   ├── .groups=1
│   ├── .weight=f32[10,3,3,3](μ=-0.00, σ=0.11, ∈[-0.18,0.18])
│   └── .bias=f32[10,1,1](μ=0.00, σ=0.00, ∈[0.00,0.00])
└── .conv2:Conv2D
    ├── .in_features=10
    ├── .out_features=1
    ├── .kernel_size=(...)
    ├── .strides=(...)
    ├── .padding=same
    ├── .dilation=(...)
    ├── .weight_init=glorot_uniform
    ├── .bias_init=zeros
    ├── .groups=1
    ├── .weight=f32[1,10,1,1](μ=-0.18, σ=0.29, ∈[-0.53,0.31])
    └── .bias=f32[1,1,1](μ=0.00, σ=0.00, ∈[0.00,0.00])


Now, suppose we want to set the range of 'weight' in both layers to `[-0.2, 0.2]` by setting out-of-range values to zero. Combining the name-based indexing - i.e. `conv1.weight` and `conv2.weight` - with boolean masking - i.e. a mask that is true if `x<-0.2` or `x>0.2` - suffices to achieve this. The following example show how can achieve this by _composition_.

In [4]:
def set_to_zero(x):
    # set all values of x to zero if they are not in the range [-0.2, 0.2]
    return jnp.where(x < -0.2, 0, jnp.where(x > 0.2, 0, x))


# note that ['conv1', 'conv2'] is basically selecting both 'conv1' and 'conv2'
cnn2 = cnn1.at["conv1", "conv2"]["weight"].apply(set_to_zero)

# note that weight of both 'conv1' and 'conv2' range is changed
print(sk.tree_diagram(cnn2, depth=2))

ConvNet
├── .conv1:Conv2D
│   ├── .in_features=3
│   ├── .out_features=10
│   ├── .kernel_size=(...)
│   ├── .strides=(...)
│   ├── .padding=same
│   ├── .dilation=(...)
│   ├── .weight_init=glorot_uniform
│   ├── .bias_init=zeros
│   ├── .groups=1
│   ├── .weight=f32[10,3,3,3](μ=-0.00, σ=0.11, ∈[-0.18,0.18])
│   └── .bias=f32[10,1,1](μ=0.00, σ=0.00, ∈[0.00,0.00])
└── .conv2:Conv2D
    ├── .in_features=10
    ├── .out_features=1
    ├── .kernel_size=(...)
    ├── .strides=(...)
    ├── .padding=same
    ├── .dilation=(...)
    ├── .weight_init=glorot_uniform
    ├── .bias_init=zeros
    ├── .groups=1
    ├── .weight=f32[1,10,1,1](μ=-0.02, σ=0.08, ∈[-0.17,0.14])
    └── .bias=f32[1,1,1](μ=0.00, σ=0.00, ∈[0.00,0.00])
