## Mixed-precision Training

<a href="https://colab.research.google.com/github/ntt123/pax/blob/main/examples/notebooks/mixed_precision.ipynb" target="_top"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" style="vertical-align:text-bottom"></a>

### Introduction

This tutorial shows how to train a U-Net image segmentation network with mixed precision. You will learn to:

- load an image dataset with `tensorflow_datasets` and data augumentation with `tf.image`,
- define a modified U-Net in Pax,
- define a mixed-precision policy for our training.


*Note*: this tutorial partially follows https://www.tensorflow.org/tutorials/images/segmentation.

### Oxford-IIIT Pet Dataset

We will use `tensorflow_datasets` to load the Pet dataset. Each data record includes an color image and label image.

Test images are resized to `(128, 128)`. 

Train images are resized to `(150, 150)` and then randomly cropped to size `(128, 128)` using the `tf.image.random_crop` function.
Moreover, we use `tf.image.random_flip_{left_right,up_down}`  functions for additional data augmentations.

In [None]:
# uncomment to install Pax
# !pip3 install -q git+https://github.com/ntt123/pax#egg=pax[test]

In [None]:
from collections import OrderedDict

import jax
import jax.numpy as jnp
import jmp
import matplotlib.pyplot as plt
import opax
import pax
import tensorflow as tf
import tensorflow_datasets as tfds

pax.seed_rng_key(42)

BATCH_SIZE = 64
BUFFER_SIZE = 1000
TRAINING_STEPS = 2_000

In [None]:
def normalize(input_image, input_mask):
    input_image = tf.cast(input_image, tf.float32) / 255.0
    input_image = input_image * 2.0 - 1.0
    input_mask -= 1
    return input_image, input_mask


def load_image(datapoint):
    input_image = tf.image.resize(datapoint["image"], (128, 128))
    input_mask = tf.image.resize(
        datapoint["segmentation_mask"],
        (128, 128),
        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
    )

    input_image, input_mask = normalize(input_image, input_mask)

    return input_image, input_mask


def train_load_image(datapoint):
    input_image = tf.image.resize(
        datapoint["image"], (150, 150), method=tf.image.ResizeMethod.NEAREST_NEIGHBOR
    )
    input_mask = tf.image.resize(
        datapoint["segmentation_mask"],
        (150, 150),
        method=tf.image.ResizeMethod.NEAREST_NEIGHBOR,
    )
    return input_image, input_mask


def train_data_aug(image, label):
    image_and_label = tf.concat((image, label), axis=-1)
    image_and_label = tf.image.random_flip_left_right(image_and_label)
    image_and_label = tf.image.random_flip_up_down(image_and_label)
    image_and_label = tf.image.random_crop(image_and_label, size=(128, 128, 4))
    image = image_and_label[..., :3]
    label = image_and_label[..., 3:]
    image, label = normalize(image, label)
    return image, label


dataset, info = tfds.load("oxford_iiit_pet:3.*.*", with_info=True)

train_images = dataset["train"].map(
    train_load_image, num_parallel_calls=tf.data.AUTOTUNE
)
test_images = dataset["test"].map(load_image, num_parallel_calls=tf.data.AUTOTUNE)


train_batches = (
    train_images.cache()
    .repeat()
    .shuffle(BUFFER_SIZE, seed=42)
    .map(train_data_aug, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(1)
)

test_batches = test_images.batch(BATCH_SIZE, drop_remainder=True)


def display(display_list):
    plt.figure(figsize=(5, 5))

    title = ["Input Image", "True Mask", "Predicted Mask"]

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i + 1)
        plt.title(title[i])
        plt.imshow(tf.keras.preprocessing.image.array_to_img(display_list[i]))
        plt.axis("off")
    plt.show()


for images, masks in train_batches.take(2):
    sample_image, sample_mask = images[0], masks[0]
    display([sample_image * 0.5 + 0.5, sample_mask])


### U-Net Image Segmentation

We will implement U-Net, a famous image segmentation network, described in https://arxiv.org/abs/1505.04597. 

One major difference between our version and the original U-Net is that we use `SAME` padding instead of `VALID` padding. 

As a result, our network is not a fully convolutional network. However, this is not important in our case because we use a fixed image size of `(128, 128)`.

In [None]:
def double_conv(in_channels, out_channels, name=None):
    return pax.nn.Sequential(
        pax.nn.Conv2D(in_channels, out_channels, (3, 3), padding="SAME", name="conv_1"),
        jax.nn.relu,
        pax.nn.Conv2D(
            out_channels, out_channels, (3, 3), padding="SAME", name="conv_2"
        ),
        jax.nn.relu,
        name=name,
    )


class UNet(pax.Module):
    """A U-Net implementation."""

    def __init__(self, in_channels: int, out_channels: int):
        super().__init__()
        self.register_module_subtree(
            "down_path", [
                OrderedDict({
                        "double_conv": double_conv(
                            in_channels if i == 0 else 32 * (2**i),
                            64 * (2**i),
                            name=f"down_{i}",
                        ),
                        "downsampling": None if i == 4 else pax.utils.Lambda(
                            lambda x: pax.nn.max_pool(
                                x, (2, 2, 1), (2, 2, 1), padding="VALID"
                            ),
                            name="down_sample",
                        ),
                    }
                ) for i in range(5) ],
        )

        self.register_module_subtree(
            "up_path", [
                OrderedDict({
                        "upsampling": pax.nn.Conv2DTranspose(
                            128*(2**i), 64*(2**i), 2, 2, padding="VALID", name="up_sample"
                        ),
                        "double_conv": double_conv(128*(2**i), 64*(2**i), name=f"up_{i}"),
                    }
                ) for i in reversed(range(4)) ],
        )
        
        self.output = pax.nn.Conv2D(
            2 ** 6, out_channels, 1, padding="VALID", name="output"
        )

    def __call__(self, x):
        mid = []
        for i, down in enumerate(self.down_path):
            x = down["double_conv"](x)
            if i < len(self.down_path) - 1:
                mid.insert(0, x)
                x = down["downsampling"](x)
        for up, ft in zip(self.up_path, mid):
            x = up["upsampling"](x)
            x = jnp.concatenate((ft, x), axis=-1)
            x = up["double_conv"](x)

        x = self.output(x)
        return x

### Mixed-precision Policy

The idea of mixed-precision training is to compute different parts of the network with different floating-point precisions. Usually, we can compute linear functions with `float16` precision and non-linear activation functions with `float32` precision.

We will compute `Conv2D` and `Conv2DTranspose` modules with `float16` precision while keeping all the parameters and activations in `float32` precision. We also use the optimizer with `float32` precision.


Because of using `float16` precision, we usually don't have enough resolution around `0` to represent small activation gradients. As a result, these small gradients are approximated by `O` which are unwanted. 
To fix this problem, we scale the loss value by a big scaling factor to shift activation gradients to `float16` representable range. 

We use the `jmp.DynamicLossScale` by default, which will automatically select and adjust the scaling factor.

**Note**: 

* We apply the `mp_policy_fn` for all submodules of `unet` recursively with the `apply` method. 
* We apply `optimizer_policy` for the `optimizer` using `mixed_precision` method with `method_name="step"`.

In [None]:
half = jmp.half_dtype()
full = jnp.float32

linear_policy = jmp.Policy(compute_dtype=half, param_dtype=full, output_dtype=full)
output_policy = jmp.Policy(
    compute_dtype=half, param_dtype=full, output_dtype=full
)  # log_softmax need full precision
optimizer_policy = jmp.Policy(
    compute_dtype=full, param_dtype=full, output_dtype=full
)  # optimizer needs full precision


def mp_policy_fn(mod):
    if mod.name == "output":
        return mod.mixed_precision(output_policy)
    elif isinstance(mod, pax.nn.Conv2D) or isinstance(mod, pax.nn.Conv2DTranspose):
        return mod.mixed_precision(linear_policy)
    else:
        # unchanged
        return mod


# loss_scale = jmp.StaticLossScale(2 ** 10)
loss_scale = jmp.DynamicLossScale(jmp.half_dtype()(2 ** 15), period=2000)
# loss_scale = jmp.NoOpLossScale()


if isinstance(loss_scale, jmp.NoOpLossScale):
    unet = UNet(3, 3)
    optimizer = opax.adam(1e-3)(unet.parameters())
else:
    unet = UNet(3, 3).apply(mp_policy_fn)
    optimizer = opax.adam(1e-3)(unet.parameters())
    optimizer = optimizer.mixed_precision(optimizer_policy, method_name="step")

print(unet.summary())

Our loss function and update function have an additional `loss_scale` parameter to scale the loss and unscale the gradient when needed.

**Note** we pass `all_finite=grads_finite` to the `optimizer.step` method to make sure that our optimizer will ignore the update whenever the gradients contain `nan` values.

In [None]:
def loss_fn(params, model, loss_scale: jmp.LossScale, inputs) -> pax.utils.LossFnOutput:
    model = model.update(params)
    image, label = inputs
    logits = model(image)
    log_pr = jax.nn.log_softmax(logits, axis=-1)
    label = jax.nn.one_hot(label[..., 0], num_classes=logits.shape[-1])
    log_pr = jnp.sum(label * log_pr, axis=-1)
    loss = -jnp.mean(log_pr)
    loss = loss_scale.scale(loss)
    return loss, (loss, model)


@pax.jit
def test_loss_fn(model, loss_scale: jmp.LossScale, inputs):
    model = model.eval()
    loss = loss_fn(model.parameters(), model, loss_scale, inputs)[0]
    loss = loss_scale.unscale(loss)
    return loss


@pax.jit
def update_fn(model, optimizer, loss_scale, inputs, total_loss):
    grads, (loss, model) = jax.grad(loss_fn, has_aux=True)(
        model.parameters(), model, loss_scale, inputs
    )
    grads = loss_scale.unscale(grads)
    # don't need to do this...
    # grads = linear_policy.cast_to_param(grads)

    skip_nonfinite_updates = isinstance(loss_scale, jmp.DynamicLossScale)

    if skip_nonfinite_updates:
        grads_finite = jmp.all_finite(grads)
        # Adjust our loss scale depending on whether gradients were finite. The
        # loss scale will be periodically increased if gradients remain finite and
        # will be decreased if not.
        loss_scale = loss_scale.adjust(grads_finite)
        # Only apply our optimizer if grads are finite, if any element of any
        # gradient is non-finite the whole update is discarded
        # including the optimizer's internal states.
        params = optimizer.step(grads, model.parameters(), all_finite=grads_finite)
    else:
        # With static or no loss scaling just apply our optimizer.
        params = optimizer.step(grads, model.parameters())

    model = model.update(params)
    loss = loss_scale.unscale(loss)
    total_loss = total_loss + loss
    return total_loss, model, optimizer, loss_scale

Let us start training for `2000` steps.

In [None]:
total_loss = 0.0
from tqdm.auto import tqdm

tr = tqdm(train_batches.take(TRAINING_STEPS).enumerate(1), desc="training")
total_loss = jnp.array(0.0, dtype=jnp.float32)

for step, batch in tr:
    batch = jax.tree_map(lambda x: x.numpy(), batch)
    total_loss, unet, optimizer, loss_scale = update_fn(
        unet, optimizer, loss_scale, batch, total_loss
    )
    if step % 200 == 0:
        loss = total_loss / 200
        total_loss = jnp.array(0.0, dtype=jnp.float32)

        test_losses = 0.0
        for test_batch in tqdm(test_batches, desc="evaluating", leave=False):
            test_batch = jax.tree_map(lambda x: x.numpy(), test_batch)
            test_losses = test_losses + test_loss_fn(unet, loss_scale, test_batch)
        test_loss = test_losses / len(test_batches)

        tr.write(
            f"[step {step}]  train loss {loss:.3f}  test loss {test_loss:.3f}  loss_scale  {loss_scale.loss_scale}"
        )

### Performance

Using mixed precision offers a `1.5x` speed-up when training our U-Net on a Tesla T4.

In [None]:
# On a Tesla T4
# - jmp.DynamicLossScale  .... 1.82 it/s  155% 
# - jmp.NoOpLossScale     .... 1.16 it/s  100%

### Evaluation

Let us plot the predicted label for images in the test set.

In [None]:
test_iter = iter(test_batches)

In [None]:
test_batch = jax.tree_map(lambda x: x.numpy(), next(test_iter))
logits = unet.eval()(test_batch[0])
label = jnp.argmax(logits, axis=-1)
for i in range(3):
    display([test_batch[0][i] * 0.5 + 0.5, test_batch[1][i], label[i, :, :, None]])