# Supplementary Material to [pytorch-friendly introduction to JAX+TPU](https://www.kaggle.com/roguekk007/pytorch-friendly-comprehensive-jax-tpu-intro/edit)
Let's dive right in!

####  Set up Kaggle / JAX Runtime (GPU will suffice for this kernel)

In [None]:
import os
import warnings
import tensorflow as tf

!pip install --upgrade jax jaxlib git+https://github.com/deepmind/optax.git flax -q
    
if 'TPU_NAME' in os.environ and 'KAGGLE_DATA_PROXY_TOKEN' in os.environ:
    use_tpu = True
    !pip install --upgrade jax jaxlib git+https://github.com/deepmind/optax.git flax -q
    
    import requests 
    from jax.config import config
    if 'TPU_DRIVER_MODE' not in globals():
        url = 'http:' + os.environ['TPU_NAME'].split(':')[1] + ':8475/requestversion/tpu_driver_nightly'
        resp = requests.post(url)
        TPU_DRIVER_MODE = 1
    config.FLAGS.jax_xla_backend = "tpu_driver"
    config.FLAGS.jax_backend_target = os.environ['TPU_NAME']
    print('Registered (Kaggle) TPU:', config.FLAGS.jax_backend_target)
else:
    use_tpu = False
!pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git -q

In [None]:
import functools
from tqdm.notebook import tqdm

import jax
from jax import lax 

import flax
import flax.linen as nn
from flax.core.frozen_dict import FrozenDict
from flax.training.common_utils import shard, shard_prng_key
from flax.serialization import to_state_dict, from_state_dict,\
                        msgpack_serialize, msgpack_restore, from_bytes

import optax
import msgpack

from jax_resnet import pretrained_resnet
from jax_resnet.common import slice_variables

# Apologies for shamelessly oppressing the warnings here
import warnings
warnings.filterwarnings("ignore")

Here are some typical model manipulations, and we exemplify their usage in this notebook
1. Creating a backbone
2. Remove head from a pretrained backbone & append own head

### To think in the way of JAX / Flax, we must first **decouple a model's architecture (application function) and parameters**. 

Let that sink in...

From now on, a better (and compiler-tolerant) way to think about `model.forward(x)` is `model_forward_fn(x, params)` where `model_forward_fn` only depends on the model architecture and `params` denotes the model parameters.

In [None]:
# Create a dummy input
# always use the subkey and propagate the original key!
random_key = jax.random.PRNGKey(0)
random_key, subkey = jax.random.split(random_key)
x = jax.random.normal(subkey, (2, 128, 128, 3))

# 1. Create a backbone: 
### Simple example
We begin with a simple CNN. 

Note that for `nn.Dense` we only specified the `out_channels` in the definition - how would flax know what is the `in_channel`? 

In [None]:
import flax
import flax.linen as nn
from typing import Sequence
from flax.core import freeze, unfreeze

# Straightforward definition
class simpleCNN(nn.Module):
    # We specify constructing elements here
    channels : Sequence[int]
    num_classes : int
        
    def setup(self):
        self.conv_layers = [nn.Conv(channel, kernel_size=(3, 3))
                            for channel in self.channels]
        # Note that for linear layer, 
        self.linear_layer = nn.Dense(self.num_classes)
    
    def __call__(self, x):
        for conv_layer in self.conv_layers:
            x = conv_layer(x)
            x = nn.relu(x)
        # Manual average pooling over last 2 dimensions:
        x = x.mean(-1).mean(-1)
        x = self.linear_layer(x)
        return x
    
model = simpleCNN(channels=[4, 8, 16], num_classes=12)

## Remarks on Construction
* Put constructing parameters (`channels`, `num_classes`) with their typing in the class field (in Pytorch this is done in `__init__`)
* Use `setup` to register submodules, variables, and parameters 
* `__call__` defines the model's behavior, but **to apply the model to data call `model.apply`**


## On using the model:
Since model weights and application are decoupled, we need to:
* Call `model.init` *with a random key and sample input* to instantiate model parameters
* Use `apply_fn=model.apply` to do a forward pass on *any* models with the same architecture - we only need to vary `params`. 
* Wait - what about `model.eval(), model.train()` and batch statistics (which should vary with each forward pass?) - see below "compact" section

In [None]:
random_key, subkey = jax.random.split(random_key)
params = model.init(subkey, x)
print("Parameter:", params.keys(), type(params))

apply_fn = model.apply
model_output = apply_fn(params, x)
print("Model output:", model_output.shape)

## Compact module definition
This interface integrates module declaration and application. Replace the **application of a module**  with its **definition**

e.g. Instead of
```
def setup(...
    self.conv = nn.Conv(16, (3, 3))
...
def __call__(self, x): ...
    x = self.conv(x)
``` 
tantamountly do this:
```
@nn.compact
def __call__(self, x): ...
    x = nn.Conv(16, (3, 3))(x)
```

Here is an example:

In [None]:
class simpleCNN_compact(nn.Module):
    # We specify constructing elements here
    channels : Sequence[int]
    num_classes : int
    
    @nn.compact
    def __call__(self, x):
        for i in range(len(self.channels)):
            x = nn.Conv(self.channels[i], kernel_size=(3, 3))(x)
            x = nn.relu(x)
        # Manual average pooling over last 2 dimensions:
        x = x.mean(-1).mean(-1)
        x = nn.Dense(self.num_classes)(x)
        return x
    
model = simpleCNN_compact(channels=[4, 8, 16], num_classes=12)
params = model.init(subkey, x)
print("Parameter:", params.keys(), type(params))

apply_fn = model.apply
model_output = apply_fn(params, x)
print("Model output:", model_output.shape)

## Manage randomness (dropout) and mutable parameters (batchnorm)
Let's add `BatchNorm` and `Dropout` to our model: 

In [None]:
# Straightforward definition
class complexCNN(nn.Module):
    # We specify constructing elements here
    channels : Sequence[int]
    num_classes : int
    train : bool = True
    
    @nn.compact
    def __call__(self, x):
        for i in range(len(self.channels)):
            x = nn.Conv(self.channels[i], kernel_size=(3, 3))(x)
            x = nn.Dropout(rate=0.5, deterministic=not self.train)(x)
            x = nn.BatchNorm(use_running_average=not self.train)(x)
            x = nn.relu(x)
        # Manual average pooling over last 2 dimensions:
        x = x.mean(-1).mean(-1)
        x = nn.Dense(self.num_classes)(x)
        return x
    
model = complexCNN(channels=[4, 8, 16], num_classes=12)

### 🔪Why `BatchNorm` and `Dropout` are problematic in JAX🔪
* `Dropout` and `Batchnorm` **behave differently during training & evaluation**
* `Dropout` introduces stochastic behavior which we wish to control
* `Batchnorm` introduces running statistics: they change with each forward pass during training. 

In [None]:
# Generate a random key for the forward pass
random_key, key1, key2, key3 = jax.random.split(random_key, 4)

params = model.init({'params': key1, 'dropout': key2}, x)
print("Parameter:", params.keys(), type(params))
params = unfreeze(params)

x_out, params['batch_stats'] = model.apply(params, x, rngs={'dropout': key3},
                    mutable='batch_stats')
params = freeze(params)
print("Model output:", x_out.shape)

Wow! That's quite a mouthful, let's break this down. 

#### Now our `params` contain a second key field `batch_stats`! 
### Remarks:
1. For dropout, provide a random key dictionary `{'params': key1, 'dropout': key2}` during `model.init`
2. Provide a random key dictionary `{'params': key1, 'dropout': key2}` (don't forget to split your keys each time!)
3. Specify `mutable=batch_stats` during the forward pass to modify the running stats, the `apply` function returns the altered version of `batch_stats`
4. **Introduce a state `train` which alters the behavior of training / evaluation forward pass** Alternatively, we can assume that we're in `eval` state whenever nothing is mutable (see the comprehensive intro notebook)
5. Flax manages parameters in `FrozenDicts`, we need to `unfreeze` before forward pass and `freeze` afterwards

Also refer to [this excellent source](https://colab.research.google.com/github/google/flax/blob/main/docs/notebooks/linen_intro.ipynb#scrollTo=BBrbcEdCnQ4o) for more information.
Let us test that the forward function behaves as expected: we use the same input across four runs with different rng keys

In [None]:
import jax.numpy as jnp

random_key, key1, key2, key3, key4 = jax.random.split(random_key, 5)
# training mode
model.train = True
params = unfreeze(params)
x_out1, params1 = model.apply(params, x, rngs={'dropout':key1}, mutable='batch_stats')
x_out2, params2 = model.apply(params, x, rngs={'dropout':key2}, mutable='batch_stats')
params = freeze(params)
# Compare the two outputs
print('Difference between outputs for training forward pass:', jnp.abs(x_out1 - x_out2).sum().item())

# Evaluation mode:
model.train = False
params = unfreeze(params)
x_out1, params1 = model.apply(params, x, rngs={'dropout':key1}, mutable='batch_stats')
x_out2, params2 = model.apply(params, x, rngs={'dropout':key2}, mutable='batch_stats')
params = freeze(params)
# Compare the two outputs
print('Difference between outputs for evaluation forward pass:', jnp.abs(x_out1 - x_out2).sum().item())

In [None]:
!pip install --upgrade git+https://github.com/n2cholas/jax-resnet.git -q

# Head Surgery
We have access to imagenet-pretrained backbone with 3-channel input and 1000-class output, but we want a 6-channel model with 10-class output, what should we do?

The problem here, again, is that weights are decoupled from application, so we simply declare a module within another as in Pytorch

In [None]:
# Let us grab a backbone and truncate it
from jax_resnet import pretrained_resnet, pretrained_resnest

resnet_template, backbone_params = pretrained_resnest(50)
backbone = resnet_template()
backbone = nn.Sequential(backbone.layers[:-2])

# The number of in_channels and num_classes we need
in_channels = 6
num_classes = 24

random_key, k1, k2 = jax.random.split(random_key, 3)
x_3 = jax.random.normal(k1, (1, 128, 128, 3))
x_6 = jax.random.normal(k2, (1, 128, 128, 6))

### Let's start with head surgery

We begin by creating a model wrapper which "wraps around the backbone"

In [None]:
class AddHeadtoBackbone(nn.Module):
    backbone : nn.Sequential
    num_classes : int
        
    def setup(self):
        self.head = nn.Dense(num_classes)
    
    def __call__(self, x):
        x = self.backbone(x)
        # Avg pooling
        x = x.mean(-1).mean(-1)
        return self.head(x)
        
model = AddHeadtoBackbone(backbone=backbone, num_classes=num_classes)

The next step in creating a model in Flax is always to initialize it.

Note: by providing a sample input, we do not even need to know how many 
channels the backbone output

In [None]:
random_key, subkey = jax.random.split(random_key)
params = unfreeze(model.init(subkey, x_3))
params['params'].keys(), params['batch_stats'].keys()

We next load the pretrained weights into `params`, to do this call `unfreeze` first

In [None]:
params['params']['backbone'] = backbone_params['params']
params['batch_stats']['backbone'] = backbone_params['batch_stats']
params = freeze(params)

In [None]:
# Let's try a forward pass
x_out = model.apply(params, x_3, mutable=False)
# Great! we have what we want
x_out.shape, x_out.dtype