In [None]:
import tensorflow as tf
import tensorflow_datasets as tfds
from flax import jax_utils
import jax
import ml_collections

import input_pipeline
import train
from configs import resnet_v1 as config_lib

In [None]:
%matplotlib inline
import matplotlib
import matplotlib.pyplot as plt

import numpy as np

In [None]:
tf.config.experimental.set_visible_devices([], "GPU")

In [None]:
# Load config that was used to train checkpoint.
import importlib

config_name = "resnet_v1"
config = importlib.import_module(f'configs.{config_name}').get_config()

In [None]:
config

In [None]:
local_batch_size = config.batch_size // jax.process_count()
input_dtype = train.get_input_dtype(config.half_precision)

dataset_builder = tfds.builder(config.dataset, data_dir=config.dataset_dir)
dataset_builder.download_and_prepare()
train_iter = train.create_input_iter(
    dataset_builder, local_batch_size, input_dtype, train=True, config=config
)
eval_iter = train.create_input_iter(
    dataset_builder, local_batch_size, input_dtype, train=False, config=config
)

In [None]:
def display(display_list):
    plt.figure(figsize=(10, 10))

    title = ['Input Image']

    for i in range(len(display_list)):
        plt.subplot(1, len(display_list), i + 1)
        plt.imshow(tf.keras.utils.array_to_img(display_list[0][i]))
        plt.axis('off')
    plt.show()

In [None]:
for _ in range(5):
    train_batch = next(train_iter)
    image = train_batch["image"]
    label = train_batch["label"]
    display([np.array(image[0]), np.array(label)])

In [None]:
for _ in range(5):
    eval_batch = next(eval_iter)
    image = eval_batch["image"]
    label = eval_batch["label"]
    display([np.array(image[0]), np.array(label[0])])