In [None]:
import math

import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np
import pax
import tensorflow as tf
import tensorflow_datasets as tfds
from pax.nets import ResNet18
from tqdm.auto import tqdm

from pretrained_resnet18 import (IMAGENET_MEAN, IMAGENET_STD,
                                 load_pretrained_resnet18)


In [None]:
batch_size = 32
test_data_size = 2048
num_training_steps = 10_000
pax.seed_rng_key(42)
learning_rate = 1e-4
weight_decay = 1e-4
classes = ["cat", "dog"]


In [None]:
@tf.function
def image_resize(batch):
    image = batch["image"]
    label = batch["label"]
    image = tf.cast(image, tf.float32) / 255.0
    image = tf.image.resize(image, (255, 255))
    return {"image": image, "label": label}


cat_dog_data = tfds.load("cats_vs_dogs")["train"].map(
    image_resize, num_parallel_calls=tf.data.AUTOTUNE
)


@tf.function
def train_data_processing(batch):
    with tf.device("/GPU:0"):
        image = batch["image"]
        label = batch["label"]
        mean = tf.constant(IMAGENET_MEAN, dtype=tf.float32)
        std = tf.constant(IMAGENET_STD, dtype=tf.float32)
        image = (image - mean[None, None, None, :]) / std[None, None, None, :]
        image = tf.image.random_flip_left_right(image)
        image = tf.image.random_crop(image, (batch_size, 224, 224, 3))
        return {"image": image, "label": label}


@tf.function
def test_data_processing(batch):
    image = batch["image"]
    label = batch["label"]
    mean = tf.constant(IMAGENET_MEAN, dtype=tf.float32)
    std = tf.constant(IMAGENET_STD, dtype=tf.float32)
    image = (image - mean[None, None, None, :]) / std[None, None, None, :]
    # image = tf.image.central_crop(image, 0.9)
    return {"image": image, "label": label}


test_data = (
    cat_dog_data
    .take(test_data_size)
    .batch(batch_size)
    .map(test_data_processing)
)


train_data = (
    cat_dog_data
    .skip(test_data_size)
    .repeat()
    .shuffle(20 * batch_size)
    .batch(batch_size)
    .map(train_data_processing)
    .take(num_training_steps)
    .prefetch(tf.data.AUTOTUNE)
)


In [None]:
def plot_images_with_label(images, labels):
    N = int(math.sqrt(len(images)))
    L = N*N
    plt.figure(figsize=(N, N))
    for i in range(L):
        plt.subplot(N, N, i + 1)
        plt.imshow(images[i])
        plt.text(0, -2, labels[i])

        plt.axis("off")
    plt.show()


def show_test_images():
    test_batch = next(iter(test_data))
    images = jnp.clip(
        test_batch["image"].numpy() * IMAGENET_STD + IMAGENET_MEAN, a_min=0.0, a_max=1.0
    )
    labels = [classes[i] for i in test_batch["label"].numpy().tolist()]
    plot_images_with_label(images, labels)


show_test_images()


In [None]:
resnet18 = load_pretrained_resnet18()
resnet18 = resnet18.freeze()
resnet18.logits = pax.nn.Linear(resnet18.logits.in_dim, 2)
resnet18 = resnet18.train()

opt = pax.optim.adamw(
    resnet18.parameters(),
    learning_rate=learning_rate,
    weight_decay=weight_decay,
)


In [None]:
def loss_fn(params: ResNet18, model: ResNet18, inputs) -> pax.utils.LossFnOutput:
    model = model.update(params)
    image, label = inputs["image"], inputs["label"]
    image = jnp.transpose(image, (0, 3, 1, 2))
    logits = model(image)

    log_pr = jax.nn.log_softmax(logits, axis=-1)
    label = jax.nn.one_hot(label, num_classes=log_pr.shape[-1])
    log_pr = jnp.sum(label * log_pr, axis=-1)
    loss = -jnp.mean(log_pr)
    return loss, (loss, model)


update_fn = pax.utils.build_update_fn(loss_fn)
fast_update_fn = jax.jit(update_fn)


In [None]:
losses = 0.0

for step, batch in enumerate(tqdm(train_data), 1):
    batch = jax.tree_map(lambda x: x.numpy(), batch)
    loss, resnet18, opt = fast_update_fn(resnet18, opt, batch)
    losses = losses + loss

    if step % 1000 == 0:
        loss = losses / 1000
        losses = 0.0
        print(f"[step {step}]  loss {loss:.3f}")


In [None]:
test_batch = next(iter(test_data))
test_batch = jax.tree_map(lambda x: x.numpy(), test_batch)
test_image, test_label = test_batch["image"], test_batch["label"]
test_image = jnp.transpose(test_image, axes=(0, 3, 1, 2))


classes = ["cat", "dog"]
logit = resnet18.eval()(test_image)
predicted_label = jnp.argmax(logit, axis=-1)

predicted_class = [classes[i] for i in predicted_label.tolist()]
print(jnp.stack([test_label, predicted_label], axis=-1).T)

to_plot_test_image = test_batch["image"]
to_plot_test_image = jnp.clip(
    to_plot_test_image * IMAGENET_STD + IMAGENET_MEAN, a_min=0.0, a_max=1.0
)
# plt.imshow(make_image_grid(to_plot_test_image))
plot_images_with_label(to_plot_test_image, predicted_class)