In [None]:
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from vision_transformer.constants import CIFAR10_CLASSES
from vision_transformer.model import custom_layers
from vision_transformer.data_pipeline import input_dataset

In [None]:
vit = tf.keras.models.load_model('../experiments/run008/checkpoints/ckpt-033.h5', 
                                 custom_objects={'AddPositionEmbs': custom_layers.AddPositionEmbs, 'RandomNormal': tf.keras.initializers.RandomNormal})

In [None]:
_, _, x_test, y_test = input_dataset.get_cifar10_raw_data()
_, _, ds_test = input_dataset.get_cifar10_data_splits()
print(x_test.shape, y_test.shape)

In [None]:
predictions = vit.predict(ds_test, verbose=1)
print(predictions.shape)

In [None]:
# accuracy
accuracy = tf.keras.metrics.SparseCategoricalAccuracy()
accuracy(y_test, predictions)

In [None]:
# top 5 accuracy
top5_acc = tf.keras.metrics.SparseTopKCategoricalAccuracy(k=5)
top5_acc(y_test, predictions)

In [None]:
n_rows, n_cols = 3, 3

fig, axes = plt.subplots(n_rows, n_cols, figsize=(4*n_cols, 4*n_rows))

axes = np.ravel(axes)
for i in range(n_rows * n_cols):
    ax = axes[i]
    sample_idx = np.random.randint(len(y_test))
    ax.imshow(x_test[sample_idx])
    label = CIFAR10_CLASSES[y_test[sample_idx]]
    pred = CIFAR10_CLASSES[np.argmax(predictions[sample_idx])]
    ax.set_title(f"label '{label}' - pred '{pred}'")