In [3]:
import cv2
import numpy as np
import tensorflow as tf

class DIGITS_CLASSIFIER():
    def __init__(self, model_path) -> None:
        interpreter = tf.lite.Interpreter(model_path)
        interpreter.allocate_tensors()
        self.digits_classifier = interpreter.get_signature_runner('serving_default')

    def preprocess_image(self, images):
        gray_images = [cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) for image in images]
        resized_images = np.array([cv2.resize(gray_image, (32, 32)) for gray_image in gray_images])
        normalized_images = resized_images.astype('float32') / 255.0
        return np.expand_dims(normalized_images, axis=-1)
    
    def postprocess_predictions(self, predictions):
        predictions = predictions['tf.stack']
        predicted_labels = [''.join(map(str, row)) for row in np.argmax(predictions, axis=2)]
        predicted_confs = np.min(np.max(predictions, axis=2),axis=1)
        return predicted_labels, predicted_confs

    def predict(self, images):
        predictions = self.digits_classifier(input_1=self.preprocess_image(images))
        return self.postprocess_predictions(predictions)

In [4]:
digits_classifier = DIGITS_CLASSIFIER('models/svhn_2digits_model.tflite')

In [7]:
images = [cv2.imread('images4test/2.png'),cv2.imread('custom_images/6.png')]
predicted_labels, predicted_confs = digits_classifier.predict(images)
print(predicted_labels)
print(predicted_confs)

['38', '29']
[0.9425901 0.9984584]
