In [5]:
# Import necessary libraries
import tensorflow as tf
from tensorflow.keras import layers, models
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical
import numpy as np

# Precomputed Lookup Table for log2 values
LOG2_LOOKUP_TABLE = {i: np.log2(i + 1e-8) for i in range(1, 256)}  # For values 1-255

def lookup_log2(x):
    """
    Approximate log2 using a precomputed lookup table for values in [1, 255].
    """
    x_clipped = tf.clip_by_value(x, 1, 255)  # Clip values to [1, 255]
    x_int = tf.cast(x_clipped, tf.int32)
    lookup_tensor = tf.constant([LOG2_LOOKUP_TABLE[i] for i in range(1, 256)], dtype=tf.float32)
    log2_approx = tf.gather(lookup_tensor, x_int - 1)  # Subtract 1 since index starts at 0
    return log2_approx

# Taylor series expansion for log(1 + x)
def taylor_log1p(x, terms=5):
    """
    Compute log(1 + x) using Taylor series expansion.
    """
    result = tf.zeros_like(x)
    for n in range(1, terms + 1):
        term = tf.pow(-1.0, n + 1) * tf.pow(x, n) / n
        result += term
    return result

# Quantization function using Taylor approximation and lookup
def log2_quantize(x, method="floor", fractional_bits=3):
    """
    Quantize log2(x) using lookup tables and Taylor series.
    """
    log2_x = lookup_log2(x)
    small_x = x - tf.ones_like(x)  # Adjust for values close to 1
    refined_log = taylor_log1p(small_x, terms=5)
    log2_x = tf.where(x < 2.0, refined_log, log2_x)

    if method == "floor":
        return tf.floor(log2_x)
    elif method == "round":
        integer_part = tf.floor(log2_x)
        fractional_part = log2_x - integer_part
        threshold = (2**fractional_bits - 1) / (2**fractional_bits)
        return integer_part + tf.cast(fractional_part >= threshold, tf.float32)
    else:
        raise ValueError("Invalid quantization method. Choose 'floor' or 'round'.")

# Custom CNN with Your Desired Architecture and Log Quantization
class LogCNN(tf.keras.Model):
    def __init__(self, input_shape, num_classes):
        super(LogCNN, self).__init__()
        self.conv1a = layers.Conv2D(64, (3, 3), padding='same', activation='relu', input_shape=input_shape)
        self.bn1a = layers.BatchNormalization()
        self.conv1b = layers.Conv2D(64, (3, 3), padding='same', activation='relu')
        self.pool1 = layers.MaxPooling2D((2, 2))
        self.dropout1 = layers.Dropout(0.25)
        
        self.conv2a = layers.Conv2D(128, (3, 3), padding='same', activation='relu')
        self.bn2a = layers.BatchNormalization()
        self.conv2b = layers.Conv2D(128, (3, 3), padding='same', activation='relu')
        self.pool2 = layers.MaxPooling2D((2, 2))
        self.dropout2 = layers.Dropout(0.25)

        self.conv3a = layers.Conv2D(256, (3, 3), padding='same', activation='relu')
        self.bn3a = layers.BatchNormalization()
        self.conv3b = layers.Conv2D(256, (3, 3), padding='same', activation='relu')
        self.pool3 = layers.MaxPooling2D((2, 2))
        self.dropout3 = layers.Dropout(0.25)

        self.flatten = layers.Flatten()
        self.fc1 = layers.Dense(512, activation='relu')
        self.bn4 = layers.BatchNormalization()
        self.dropout4 = layers.Dropout(0.5)
        self.fc2 = layers.Dense(num_classes, activation='softmax')

    def call(self, inputs):
        # Quantize inputs
        x = log2_quantize(inputs, method="floor")
        
        # First Convolutional Block
        x = self.conv1a(x)
        x = self.bn1a(x)
        x = log2_quantize(x)
        x = self.conv1b(x)
        x = self.pool1(x)
        x = self.dropout1(x)

        # Second Convolutional Block
        x = self.conv2a(x)
        x = self.bn2a(x)
        x = log2_quantize(x)
        x = self.conv2b(x)
        x = self.pool2(x)
        x = self.dropout2(x)

        # Third Convolutional Block
        x = self.conv3a(x)
        x = self.bn3a(x)
        x = log2_quantize(x)
        x = self.conv3b(x)
        x = self.pool3(x)
        x = self.dropout3(x)

        # Fully Connected Layers
        x = self.flatten(x)
        x = self.fc1(x)
        x = self.bn4(x)
        x = self.dropout4(x)
        return self.fc2(x)

# Main Execution
if __name__ == "__main__":
    # Load MNIST dataset
    (x_train, y_train), (x_test, y_test) = mnist.load_data()

    # Preprocess the data
    x_train = x_train.reshape((-1, 28, 28, 1)).astype("float32") / 255.0
    x_test = x_test.reshape((-1, 28, 28, 1)).astype("float32") / 255.0

    # One-hot encode the labels
    y_train = to_categorical(y_train, 10)
    y_test = to_categorical(y_test, 10)

    # Define input shape and number of classes
    input_shape = (28, 28, 1)
    num_classes = 10

    # Instantiate and compile the model
    model = LogCNN(input_shape, num_classes)
    model.compile(optimizer="adam", loss="categorical_crossentropy", metrics=["accuracy"])

    # Train the model
    model.fit(x_train, y_train, batch_size=64, epochs=10, validation_split=0.1)

    # Evaluate the model
    test_loss, test_acc = model.evaluate(x_test, y_test, verbose=2)
    print(f"Test Accuracy: {test_acc:.4f}")


Epoch 1/10




[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m99s[0m 115ms/step - accuracy: 0.3616 - loss: 2.5565 - val_accuracy: 0.6940 - val_loss: 0.9257
Epoch 2/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 121ms/step - accuracy: 0.6657 - loss: 0.9665 - val_accuracy: 0.7250 - val_loss: 0.8181
Epoch 3/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m102s[0m 121ms/step - accuracy: 0.7109 - loss: 0.8375 - val_accuracy: 0.7657 - val_loss: 0.7090
Epoch 4/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m101s[0m 120ms/step - accuracy: 0.7340 - loss: 0.7785 - val_accuracy: 0.7617 - val_loss: 0.7184
Epoch 5/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m101s[0m 119ms/step - accuracy: 0.7450 - loss: 0.7487 - val_accuracy: 0.7913 - val_loss: 0.6349
Epoch 6/10
[1m844/844[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m103s[0m 122ms/step - accuracy: 0.7555 - loss: 0.7149 - val_accuracy: 0.7822 - val_loss: 0.6478
Epoch 7/10
[1m8