Define the DigitClassificationInterface. This interface will ensure that all models have a predict method that takes a 28x28x1 image as input and returns a single integer value

In [15]:
from abc import ABC, abstractmethod
import numpy as np

import tensorflow as tf
from tensorflow.keras import layers, models

from abc import ABC, abstractmethod
import numpy as np
from sklearn.ensemble import RandomForestClassifier

# Suppress TensorFlow warnings
tf.compat.v1.logging.set_verbosity(tf.compat.v1.logging.ERROR)





In [10]:
class DigitClassificationInterface(ABC):
    @abstractmethod
    def predict(self, image: np.ndarray) -> int:
        pass

Implement the CNNClassifier. Note that this is a mock implementation, as we are not focusing on the actual model training and architecture details.

In [11]:
class CNNClassifier(DigitClassificationInterface):
    def __init__(self):
        self.model = self._build_model()
        
    def _build_model(self):
        model = models.Sequential([
            layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(64, (3, 3), activation='relu'),
            layers.MaxPooling2D((2, 2)),
            layers.Conv2D(64, (3, 3), activation='relu'),
            layers.Flatten(),
            layers.Dense(64, activation='relu'),
            layers.Dense(10, activation='softmax')
        ])
        return model
    
    def predict(self, image: np.ndarray) -> int:
        # Mock prediction for demonstration
        return np.argmax(np.random.rand(10))

In [16]:
class DigitClassificationInterface(ABC):
    @abstractmethod
    def predict(self, image: np.ndarray) -> int:
        pass

class CNNClassifier(DigitClassificationInterface):
    def __init__(self):
        self.model = self._build_model()
        
    def _build_model(self):
        model = tf.keras.Sequential([
            tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=(28, 28, 1)),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.MaxPooling2D((2, 2)),
            tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
            tf.keras.layers.Flatten(),
            tf.keras.layers.Dense(64, activation='relu'),
            tf.keras.layers.Dense(10, activation='softmax')
        ])
        return model
    
    def predict(self, image: np.ndarray) -> int:
        # Mock prediction for demonstration
        return np.argmax(np.random.rand(10))

class RandomForestClassifierModel(DigitClassificationInterface):
    def __init__(self):
        self.model = RandomForestClassifier()
    
    def predict(self, image: np.ndarray) -> int:
        image_flat = image.flatten().reshape(1, -1)
        return np.random.randint(0, 10)

class RandomClassifier(DigitClassificationInterface):
    def predict(self, image: np.ndarray) -> int:
        center_cropped_image = image[9:19, 9:19]
        return np.random.randint(0, 10)

class DigitClassifier:
    def __init__(self, algorithm: str):
        self.algorithm = algorithm
        self.model = self._get_model(algorithm)
        
    def _get_model(self, algorithm: str) -> DigitClassificationInterface:
        if algorithm == 'cnn':
            return CNNClassifier()
        elif algorithm == 'rf':
            return RandomForestClassifierModel()
        elif algorithm == 'rand':
            return RandomClassifier()
        else:
            raise ValueError("Unsupported algorithm. Choose from 'cnn', 'rf', 'rand'.")
    
    def predict(self, image: np.ndarray) -> int:
        if image.shape != (28, 28, 1):
            raise ValueError("Input image must have shape (28, 28, 1).")
        return self.model.predict(image)



In [17]:
# Example Usage
if __name__ == "__main__":
    dummy_image = np.random.rand(28, 28, 1)
    
    classifier = DigitClassifier(algorithm='cnn')
    print("CNN Prediction:", classifier.predict(dummy_image))
    
    classifier = DigitClassifier(algorithm='rf')
    print("RF Prediction:", classifier.predict(dummy_image))
    
    classifier = DigitClassifier(algorithm='rand')
    print("Random Prediction:", classifier.predict(dummy_image))

CNN Prediction: 8
RF Prediction: 8
Random Prediction: 8
