In [None]:
model_name = "imagenet21k_ViT-B_32"  # @param ["ViT-B_32", "Mixer-B_16"]

import os

assert os.path.exists(f"{model_name}.npz")

model_name = "ViT-B_32"


from absl import logging
import flax
import jax

# from matplotlib import pyplot as plt
import numpy as np
import optax
import tqdm
import tensorflow as tf
import pandas as pd
from matplotlib import pyplot as plt

logging.set_verbosity(logging.INFO)

# Shows the number of available devices.
# In a CPU/GPU runtime this will be a single device.
# In a TPU runtime this will be 8 cores.
jax.local_devices()

In [None]:
import sys

if "./vision_transformer" not in sys.path:
    sys.path.append("./vision_transformer")


# Helper functions for images.

labelnames = dict(
    skin=("benign", "malignant"),
)


def make_label_getter(dataset):
    """Returns a function converting label indices to names."""

    def getter(label):
        if dataset in labelnames:
            return labelnames[dataset][label]
        return f"label={label}"

    return getter


def show_img(img, ax=None, title=None):
    """Shows a single image."""
    if ax is None:
        ax = plt.gca()
    # ax.imshow(img[...])
    print(img.shape)
    ax.set_xticks([])
    ax.set_yticks([])
    if title:
        ax.set_title(title)


def show_img_grid(imgs, titles):
    """Shows a grid of images."""
    n = int(np.ceil(len(imgs) ** 0.5))
    _, axs = plt.subplots(n, n, figsize=(3 * n, 3 * n))
    for i, (img, title) in enumerate(zip(imgs, titles)):
        img = (img + 1) / 2  # Denormalize
        show_img(img, axs[i // n][i % n], title)


"""### Load dataset"""


def preprocess_image(image_path, label, skin, img_size):
    # Read the image file
    base_dir = "Fitzpatric_subset"
    image_path = tf.strings.join([base_dir, image_path], separator="/")
    image_path = tf.strings.join([image_path, ".jpg"])

    # Read and decode the image
    image = tf.io.read_file(image_path)
    image = tf.image.decode_jpeg(image, channels=3)

    # Resize the image
    image = tf.image.resize(image, (img_size, img_size))

    # Normalize the image to [-1, 1] Double check range
    image = (image / 127.5) - 1.0

    return image, label, int(skin)


def load_dataset_from_csv(csv_file, img_size, batch_size, mode):
    # Load the CSV file
    df = pd.read_csv(csv_file)

    # Convert labels to one-hot encoding
    num_classes = df["lesion"].nunique()
    df["lesion"] = df["lesion"].apply(lambda x: tf.one_hot(x, num_classes))

    # Create a dataset of (image_path, label) pairs
    dataset = tf.data.Dataset.from_tensor_slices(
        (df["image_path"].values, df["lesion"].tolist(), df["skin_color"].tolist())
    )

    # Apply the preprocessing function
    dataset = dataset.map(
        lambda img_path, label, skin: preprocess_image(img_path, label, skin, img_size),
        num_parallel_calls=tf.data.AUTOTUNE,
    )

    # Convert each element to a dictionary
    dataset = dataset.map(
        lambda image, label, skin: {
            "image": image,
            "label": label,
            "skin": skin,
        }
    )

    # Batch the dataset
    dataset = dataset.batch(batch_size)

    # Add an extra dimension to the entire batch to match model input shape
    dataset = dataset.map(
        lambda batch: {
            "image": tf.expand_dims(
                batch["image"], axis=0
            ),  # Add batch dimension to images
            "label": tf.expand_dims(
                batch["label"], axis=0
            ),  # Add batch dimension to labels
            "skin": tf.expand_dims(
                batch["skin"], axis=0
            ),  # Add batch dimension to labels
        }
    )

    # Shuffle and repeat for training
    if mode == "train":
        dataset = dataset.shuffle(buffer_size=1000).repeat()

    # Prefetch for performance
    dataset = dataset.prefetch(buffer_size=tf.data.AUTOTUNE)
    return dataset

In [None]:
from vit_jax import checkpoint
from vit_jax import input_pipeline
from vit_jax import utils
from vit_jax import models
from vit_jax import train
from vit_jax.configs import common as common_config
from vit_jax.configs import models as models_config

dataset = "skin"
batch_size = 512  # 16 32
config = common_config.with_dataset(common_config.get_config(), dataset)
# config.batch = batch_size
# config.pp.crop = 224  ## 300 resize check


# Load training and test datasets
train_data = load_dataset_from_csv(
    "train_data.csv", img_size=224, batch_size=batch_size, mode="train"
)
test_data = load_dataset_from_csv(
    "test_data.csv", img_size=224, batch_size=batch_size, mode="test"
)

# For details about setting up datasets, see input_pipeline.py on the right.
ds_train = input_pipeline.get_data_from_tfds(config=config, mode="train")
ds_test = input_pipeline.get_data_from_tfds(config=config, mode="test")
ds_train = train_data
ds_test = test_data
# num_classes = input_pipeline.get_dataset_info(dataset, "train")["num_classes"]
num_classes = 2
# print(f"num classes = {num_classes}")
# for images, labels in train_data.take(1):  # Take one batch for inspection
#     print("Images shape:", images.shape)
#     print("Labels shape:", labels.shape)
# print("Images batch:", images)
# print("Labels batch:", labels)
# del config  # Only needed to instantiate datasets.

# Fetch a batch of test images for illustration purposes.
batch = next(iter(ds_test.as_numpy_iterator()))
# Note the shape : [num_local_devices, local_batch_size, h, w, c]

In [None]:
# Show some images with their labels.
images, labels = batch["image"][0][:9], batch["label"][0][:9]
titles = map(make_label_getter(dataset), labels.argmax(axis=1))
show_img_grid(images, titles)

In [None]:
# Same as above, but with train images.
# Note how images are cropped/scaled differently.
# Check out input_pipeline.get_data() in the editor at your right to see how the
# images are preprocessed differently.
images, labels = batch["image"][0][:9], batch["label"][0][:9]
titles = map(make_label_getter(dataset), labels.argmax(axis=1))
show_img_grid(images, titles)