In [None]:
import os
import numpy as np
import tensorflow as tf
import random

from utils.data_loader import get_train_test, get_train_test_fashion
from utils.openmax import create_model, get_activations, compute_openmax
from utils.openmax_utils import image_show, get_openmax_predict

In [None]:
print('Num GPUs Available: ', len(tf.config.list_physical_devices('GPU')))

In [None]:
gpus = tf.config.list_physical_devices('GPU')
if gpus:
    try:
        tf.config.set_visible_devices(gpus[0], 'GPU')
        tf.config.experimental.set_memory_growth(gpus[0], True)
        logical_gpus = tf.config.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
    except RuntimeError as e:
        print(e)

In [None]:
SEED = 0
IMG_DIM = 28

In [None]:
random.seed(SEED)
tf.random.set_seed(SEED)
np.random.seed(SEED)

In [None]:
model = tf.keras.models.load_model('models/mnist.h5')

In [None]:
model.summary()

In [None]:
train_ds, _ = get_train_test(training=True)

In [None]:
create_model(model, train_ds) # Only need to run this once

In [None]:
test_ds = get_train_test(training=False)
x_test, y_test = test_ds.get_all()

for i in range(5):
    random_char = np.random.randint(0, len(x_test))

    test_x1 = x_test[random_char]
    test_y1 = y_test[random_char]

    logits, softmax = get_activations(
        test_x1.reshape(-1, IMG_DIM, IMG_DIM, 1), model)

    openmax, _ = compute_openmax(logits)
    print(f'SoftMax Sum: {np.sum(softmax)}')
    print(f'OpenMax Sum: {np.sum(openmax)}')
    print(f'True Label: {test_y1}')
    print(f'SoftMax Label: {np.argmax(softmax)}')
    print(f'OpenMax Label: {get_openmax_predict(openmax)}')
    image_show(test_x1)

In [None]:
test_u_ds = get_train_test_fashion(training=False)
x_test, _ = test_u_ds.get_all()

for i in range(5):
    random_char = np.random.randint(0, len(x_test))

    test_x1 = x_test[random_char]

    logits, softmax = get_activations(
        test_x1.reshape(-1, IMG_DIM, IMG_DIM, 1), model)

    openmax, _ = compute_openmax(logits)
    print(f'SoftMax Sum: {np.sum(softmax)}')
    print(f'OpenMax Sum: {np.sum(openmax)}')
    print(f'True Label: 10')
    print(f'SoftMax Label: {np.argmax(softmax)}')
    print(f'OpenMax Label: {get_openmax_predict(openmax)}')
    image_show(test_x1)