In [2]:
import os
os.environ["CUDA_VISIBLE_DEVICES"] = "-1"

import numpy as np
import tensorflow as tf

Let's start by including the code that gets the data and model:

In [3]:
!wget -Nq https://raw.githubusercontent.com/MicrosoftDocs/tensorflow-learning-path/main/intro-keras/kintro.py
from kintro import *

In order to make a prediction, we need to pass some data to the model, and do a single forward pass through the network to get the prediction. If the code below is unclear to you, make sure you go back to module 1, where we explain it in detail. Remember that, unlike during testing, we don't need to call the loss function because we're no longer interested in evaluating how well the model is doing. Instead, we call `softmax` to convert the values of the output vector into values between 0 and 1, and then get the `argmax` of that vector to get the predicted label index.

Similarly to the training and testing sections, once we're done with debugging, we can add a `@tf.function` decorator to get the performance benefits of graph execution.

In [4]:
@tf.function
def predict(model: tf.keras.Model, X: np.ndarray) -> tf.Tensor:
  y_prime = model(X, training=False)
  probabilities = tf.nn.softmax(y_prime, axis=1)
  predicted_indices = tf.math.argmax(input=probabilities, axis=1)
  return predicted_indices

Typically, during inference, we give the model input data that it hasn't seen before. In this sample, for simplicity, we input the first 3 images of the test dataset instead. 

In [5]:
batch_size = 64

model = get_model()
model.load_weights('outputs/weights')

(_, test_dataset) = get_data(batch_size)
(X_batch, actual_index_batch) = next(test_dataset.as_numpy_iterator())
X = X_batch[0:3, :, :]
actual_indices = actual_index_batch[0:3]

predicted_indices = predict(model, X)

print('\nPredicting:')
for (actual_index, predicted_index) in zip(actual_indices, predicted_indices):
  actual_name = labels_map[actual_index]
  predicted_name = labels_map[predicted_index.numpy()]
  print(f'Actual: {actual_name}, Predicted: {predicted_name}')


Predicting:
Actual: Bag, Predicted: Bag
Actual: Sandal, Predicted: Sandal
Actual: Sandal, Predicted: Sandal
