# Introduction

<a href="https://colab.research.google.com/github/ntt123/pax/blob/main/examples/notebooks/fine_tuning_resnet18.ipynb" target="_top"><img alt="Open In Colab" src="https://colab.research.google.com/assets/colab-badge.svg" style="vertical-align:text-bottom"></a>

This example shows how easy it is to fine-tune a `resnet18` model with `Pax`.

We will use the pretrained weights from `resnet18` in ``torchvision`` library. See ``pretrained_resnet18.py`` for more details on how we convert a `resnet18` model from pytorch to pax.

**Note**: to run this example, you will need to install `pytorch` and `torchvision`: 
```bash
    $ pip3 install torch torchvision
```

OK, how do we fine-tune `resnet18`?


```python
resnet18 = load_pretrained_resnet18()
resnet18 = resnet18.freeze()
resnet18.logits = pax.nn.Linear(resnet18.logits.in_dim, 2)
resnet18 = resnet18.train()
```

Freeze the model, replace the ``logits`` module by a new linear module. That is all we need.

# Dataloader

We use `tensorflow_datasets` to load `dogs vs cats` dataset. 

**Note**: The ``data_processing`` function is executed on ``GPU`` to make the dataloader faster (even though this dataloader is still NOT fast enough).

In [None]:
# uncomment to install Pax
# !pip3 install -q git+https://github.com/ntt123/pax#egg=pax[test]

In [None]:
import math

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import opax
import pax
import tensorflow as tf
import tensorflow_datasets as tfds
from pax.nets import ResNet18
from tqdm.auto import tqdm

from pretrained_resnet18 import (IMAGENET_MEAN, IMAGENET_STD,
                                 load_pretrained_resnet18)


In [None]:
if jax.default_backend() == "cpu":
    batch_size = 16
else:
    batch_size = 128
test_data_size = 2048
num_training_steps = 1_000
pax.seed_rng_key(42)
learning_rate = 1e-3
weight_decay = 1e-4
classes = ["cat", "dog"]

In [None]:
@tf.function
def image_resize(batch):
    image = batch["image"]
    label = batch["label"]
    image = tf.image.resize(image, (224, 224))
    return {"image": image, "label": label}


cat_dog_data = tfds.load("cats_vs_dogs")["train"].map(image_resize)


@tf.function
def data_processing(batch):
    with tf.device('/device:gpu:0'):
        image = batch["image"]
        label = batch["label"]
        image = tf.cast(image, tf.float32) / 255.0
        mean = tf.constant(IMAGENET_MEAN, dtype=tf.float32)
        std = tf.constant(IMAGENET_STD, dtype=tf.float32)
        image = (image - mean[None, None, None, :]) / std[None, None, None, :]
        return {"image": image, "label": label}


test_data = (
    cat_dog_data
    .take(test_data_size)
    .batch(batch_size, drop_remainder=True)
    .map(data_processing)
)


train_data = (
    cat_dog_data
    .skip(test_data_size)
    .repeat()
    .shuffle(20 * batch_size)
    .batch(batch_size)
    .map(data_processing)
    .take(num_training_steps)
    .prefetch(1)
)

In [None]:
def plot_images_with_label(images, labels):
    L = len(images)
    N = int(math.sqrt(L))
    L = N*N
    plt.figure(figsize=(N, N))
    for i in range(L):
        plt.subplot(N, N, i + 1)
        plt.imshow(images[i])
        plt.text(0, -2, labels[i])

        plt.axis("off")
    plt.show()


def show_test_images():
    test_batch = next(iter(test_data))
    images = jnp.clip(
        test_batch["image"].numpy() * IMAGENET_STD + IMAGENET_MEAN, a_min=0.0, a_max=1.0
    )
    labels = [classes[i] for i in test_batch["label"].numpy().tolist()]
    plot_images_with_label(images, labels)


show_test_images()

# Loss function and accuracy metric

In [None]:
def loss_fn(params: ResNet18, model: ResNet18, inputs) -> pax.utils.LossFnOutput:
    model = pax.update_params(model, params=params)
    image, label = inputs["image"], inputs["label"]
    image = jnp.transpose(image, (0, 3, 1, 2))
    logits = model(image)

    log_pr = jax.nn.log_softmax(logits, axis=-1)
    label = jax.nn.one_hot(label, num_classes=log_pr.shape[-1])
    log_pr = jnp.sum(label * log_pr, axis=-1)
    loss = -jnp.mean(log_pr)
    return loss, (loss, model)

def test_loss_fn(model: ResNet18, inputs):
    model = model.eval()
    loss = loss_fn(model.parameters(), model, inputs)[0]
    return loss

fast_test_loss_fn = pax.jit(test_loss_fn)

update_fn = pax.utils.build_update_fn(loss_fn)
fast_update_fn = pax.jit(update_fn)

@pax.jit
def predict(model, images):
    model = model.eval()
    images = jnp.transpose(images, (0, 3, 1, 2))
    logits = model(images)
    return jnp.argmax(logits, axis=-1)

def test_accuracy(model, test_data):
    num_correct_predictions, total = 0, 0
    for batch in test_data:
        batch = jax.tree_map(lambda x: x.numpy(), batch)
        predicted_label = predict(model, batch['image'])
        num_correct_predictions = num_correct_predictions + jnp.sum(predicted_label == batch['label'])
        total = total + batch['image'].shape[0]
    
    return num_correct_predictions.item() / total

# Fine-tuning model

In [None]:
def train():
    resnet18 = load_pretrained_resnet18()
    resnet18 = resnet18.freeze()
    resnet18.logits = pax.nn.Linear(resnet18.logits.in_dim, 2)
    resnet18 = resnet18.train()

    opt = opax.adamw(
        learning_rate=learning_rate,
        weight_decay=weight_decay,
    )(resnet18.parameters())

        
    losses = 0.0

    for step, batch in enumerate(tqdm(train_data), 1):
        batch = jax.tree_map(lambda x: x.numpy(), batch)
        (resnet18, opt), loss = fast_update_fn((resnet18, opt), batch)
        losses = losses + loss

        if step % 100 == 0:
            train_loss = losses / 100
            losses = 0.0

            total_test_loss = 0.
            for batch in test_data:
                batch = jax.tree_map(lambda x: x.numpy(), batch)
                loss = fast_test_loss_fn(resnet18, batch)
                total_test_loss = total_test_loss + loss
            test_loss = total_test_loss / len(test_data)
            print(f"[step {step}]  train loss {train_loss:.3f}  test loss {test_loss:.3f}")


    print("Final test accuracy:", test_accuracy(resnet18, test_data))
    return resnet18

resnet18 = train()

# Evaluation

In [None]:
def plot_model_prediction(model):
    test_batch = next(iter(test_data))
    test_batch = jax.tree_map(lambda x: x.numpy(), test_batch)
    test_image, test_label = test_batch["image"], test_batch["label"]
    test_image = jnp.transpose(test_image, axes=(0, 3, 1, 2))

    logit = model.eval()(test_image)
    predicted_label = jnp.argmax(logit, axis=-1)

    predicted_class = [classes[i] for i in predicted_label.tolist()]

    to_plot_test_image = test_batch["image"]
    to_plot_test_image = jnp.clip(
        to_plot_test_image * IMAGENET_STD + IMAGENET_MEAN, a_min=0.0, a_max=1.0
    )
    plot_images_with_label(to_plot_test_image, predicted_class)

plot_model_prediction(resnet18)