Load trained model

In [58]:
import tensorflow as tf

model = tf.keras.models.load_model('models/svhn_2digits_model.keras')

batch data generator to generate test data

In [59]:
import h5py
import numpy as np
import random

def data_generator(h5_dataset_path, batch_size):
    with h5py.File(h5_dataset_path, 'r') as f:
        num_samples = f['images'].shape[0]
        num_batches = int(np.ceil(num_samples / batch_size))
        
        while True:  # 生成器需要无限循环
            index_list = list(range(num_batches))
            random.shuffle(index_list)
            for i in index_list:
                start_index = i * batch_size
                end_index = min((i + 1) * batch_size, num_samples)

                batch_images = f['images'][start_index:end_index]
                batch_labels = f['labels'][start_index:end_index]
                
                yield batch_images, batch_labels

Predict and show

In [60]:
import matplotlib.pyplot as plt

def checkout_predict_result(model, batch_images, batch_labels):

    predictions = model.predict(batch_images)

    original_labels = [''.join(map(str, row)) for row in np.argmax(batch_labels, axis=2)] 
    predicted_labels = [''.join(map(str, row)) for row in np.argmax(predictions, axis=2)]
    predicted_confs = np.min(np.max(predictions, axis=2),axis=1)
    
    plt.figure(figsize=(15, 7))
    for i in range(10):
        plt.subplot(2, 5, i + 1)
        plt.imshow(batch_images[i], cmap='gray')
        plt.axis('off')
        plt.title(f'Original: {original_labels[i]}\nPredicted: {predicted_labels[i]}\nConf: {predicted_confs[i]:.2f}')
    plt.show()

In [63]:
for batch_images, batch_labels in data_generator('svhn_dataset/svhn_2digits_test.h5', batch_size=1000):
    checkout_predict_result(model, batch_images, batch_labels)
    break

8356
