In [None]:
from tensorflow.keras.datasets import mnist
from tensorflow.keras.utils import to_categorical

# Load the MNIST dataset
(x_train, y_train), (x_test, y_test) = mnist.load_data()

# Preprocess the data
x_train = x_train.reshape(-1, 28, 28, 1).astype('float32') / 255
x_test = x_test.reshape(-1, 28, 28, 1).astype('float32') / 255
y_train = to_categorical(y_train, num_classes)
y_test = to_categorical(y_test, num_classes)

# Train the model
model.fit(x_train, y_train, epochs=1, batch_size=32, validation_split=0.2)

# Evaluate the model
test_loss, test_acc = model.evaluate(x_test, y_test)
print(f'Test accuracy: {test_acc}')

In [None]:
import tf2onnx
import onnx

# Convert the Keras model to ONNX format
onnx_model, _ = tf2onnx.convert.from_keras(model, opset=13)

# Save the ONNX model to a file
onnx.save_model(onnx_model, "simple_cnn_model.onnx")

In [None]:
import numpy as np
import onnx
import onnx.numpy_helper as numpy_helper
import matplotlib.pyplot as plt
from tensorflow.keras.datasets import mnist

def relu(x):
    return np.maximum(0, x)

def add(a, b):
    return np.add(a, b)

def matmul(a, b):
    return np.dot(a, b)

def reshape(tensor, shape):
    return np.reshape(tensor, shape)

def conv(x, w, b=None, strides=(1, 1), pads=(0, 0, 0, 0)):
    batch_size, in_channels, in_height, in_width = x.shape
    out_channels, _, kernel_height, kernel_width = w.shape
    stride_height, stride_width = strides
    pad_height_begin, pad_width_begin, pad_height_end, pad_width_end = pads

    out_height = (in_height + pad_height_begin + pad_height_end - kernel_height) // stride_height + 1
    out_width = (in_width + pad_width_begin + pad_width_end - kernel_width) // stride_width + 1

    y = np.zeros((batch_size, out_channels, out_height, out_width), dtype=np.float32)
    
    x_padded = np.pad(x, ((0, 0), (0, 0), (pad_height_begin, pad_height_end), (pad_width_begin, pad_width_end)), mode='constant')
    
    for i in range(out_height):
        for j in range(out_width):
            h_start = i * stride_height
            h_end = h_start + kernel_height
            w_start = j * stride_width
            w_end = w_start + kernel_width
            x_slice = x_padded[:, :, h_start:h_end, w_start:w_end]
            for k in range(out_channels):
                y[:, k, i, j] = np.sum(x_slice * w[k, :, :, :], axis=(1, 2, 3))
                
    if b is not None:
        y += b.reshape(1, -1, 1, 1)
        
    return y

def maxpool(x, kernel_shape, strides=(1, 1), pads=(0, 0, 0, 0)):
    batch_size, in_channels, in_height, in_width = x.shape
    kernel_height, kernel_width = kernel_shape
    stride_height, stride_width = strides
    pad_height_begin, pad_width_begin, pad_height_end, pad_width_end = pads

    out_height = (in_height + pad_height_begin + pad_height_end - kernel_height) // stride_height + 1
    out_width = (in_width + pad_width_begin + pad_width_end - kernel_width) // stride_width + 1

    y = np.zeros((batch_size, in_channels, out_height, out_width), dtype=np.float32)

    x_padded = np.pad(x, ((0, 0), (0, 0), (pad_height_begin, pad_height_end), (pad_width_begin, pad_width_end)), mode='constant')

    for i in range(out_height):
        for j in range(out_width):
            h_start = i * stride_height
            h_end = h_start + kernel_height
            w_start = j * stride_width
            w_end = w_start + kernel_width
            x_slice = x_padded[:, :, h_start:h_end, w_start:w_end]
            y[:, :, i, j] = np.max(x_slice, axis=(2, 3))

    return y

def transpose(x, perm):
    return np.transpose(x, perm)

def softmax(x, axis=-1):
    e_x = np.exp(x - np.max(x, axis=axis, keepdims=True))
    return e_x / e_x.sum(axis=axis, keepdims=True)

def shape(x):
    return np.array(x.shape)

def gather(x, indices, axis=0):
    return np.take(x, indices, axis=axis)

def cast(x, to):
    onnx_to_numpy_dtype = {
        1: np.float32,
        2: np.uint8,
        3: np.int8,
        4: np.uint16,
        5: np.int16,
        6: np.int32,
        7: np.int64,
        8: np.str_,
        9: np.bool_,
        10: np.float16,
        11: np.double,
        12: np.uint32,
        13: np.uint64,
        14: np.complex64,
        15: np.complex128
    }
    return x.astype(onnx_to_numpy_dtype[to])

def reduceprod(x, axis=None, keepdims=False):
    if axis is not None:
        if isinstance(axis, (list, tuple)):
            for ax in axis:
                x = np.prod(x, axis=ax, keepdims=keepdims)
            return x
        else:
            return np.prod(x, axis=axis, keepdims=keepdims)
    else:
        return np.prod(x, keepdims=keepdims)

def unsqueeze(x, axes):
    for axis in sorted(axes):
        x = np.expand_dims(x, axis)
    return x

def concat(*tensors, axis=0):
    tensors = [np.atleast_1d(tensor) for tensor in tensors]  # Ensure all tensors are at least 1-dimensional
    return np.concatenate(tensors, axis=axis)

def sigmoid(x):
    return 1 / (1 + np.exp(-x))

def gemm(a, b, c=None, alpha=1.0, beta=1.0, trans_a=False, trans_b=False):
    if trans_a:
        a = a.T
    if trans_b:
        b = b.T
    y = alpha * np.dot(a, b)
    if c is not None:
        y += beta * c
    return y

# Map ONNX operator names to functions
operator_map = {
    'Relu': relu,
    'Add': add,
    'MatMul': matmul,
    'Reshape': reshape,
    'Conv': conv,
    'MaxPool': maxpool,
    'Transpose': transpose,
    'Softmax': softmax,
    'Shape': shape,
    'Gather': gather,
    'Cast': cast,
    'ReduceProd': reduceprod,
    'Unsqueeze': unsqueeze,
    'Concat': concat,
    'Sigmoid': sigmoid,
    'Gemm': gemm,
    # Add more operators
}

def execute_node(node, inputs):
    input_tensors = [inputs[input_name] for input_name in node.input]
    print(f"Executing {node.op_type} with inputs: {[t.shape for t in input_tensors]}")
    
    if node.op_type == 'Reshape':
        shape_tensor = input_tensors[1]
        if isinstance(shape_tensor, list):
            shape_tensor = np.array(shape_tensor)
        output_tensors = operator_map[node.op_type](input_tensors[0], shape_tensor)
    elif node.op_type == 'Conv':
        attrs = {attr.name: attr.ints for attr in node.attribute}
        strides = tuple(attrs.get('strides', [1, 1]))
        pads = tuple(attrs.get('pads', [0, 0, 0, 0]))
        b = input_tensors[2] if len(input_tensors) > 2 else None
        output_tensors = operator_map[node.op_type](input_tensors[0], input_tensors[1], b, strides, pads)
    elif node.op_type == 'MaxPool':
        attrs = {attr.name: attr.ints for attr in node.attribute}
        kernel_shape = attrs.get('kernel_shape', [2, 2])
        strides = tuple(attrs.get('strides', [1, 1]))
        pads = tuple(attrs.get('pads', [0, 0, 0, 0]))
        output_tensors = operator_map[node.op_type](input_tensors[0], kernel_shape, strides, pads)
    elif node.op_type == 'Transpose':
        perm = node.attribute[0].ints if node.attribute else []
        output_tensors = operator_map[node.op_type](input_tensors[0], perm)
    elif node.op_type == 'Gather':
        indices = input_tensors[1]
        axis = node.attribute[0].i if node.attribute else 0
        output_tensors = operator_map[node.op_type](input_tensors[0], indices, axis)
    elif node.op_type == 'Cast':
        to = node.attribute[0].i  # The data type to cast to (ONNX data type enum)
        output_tensors = operator_map[node.op_type](input_tensors[0], to)
    elif node.op_type == 'ReduceProd':
        axis = node.attribute[0].ints if node.attribute else None
        if axis is not None:
            axis = list(axis)  # Convert to list if it's a RepeatedScalarContainer
        keepdims = node.attribute[1].i if len(node.attribute) > 1 else False
        output_tensors = operator_map[node.op_type](input_tensors[0], axis, keepdims)
    elif node.op_type == 'Unsqueeze':
        axes = node.attribute[0].ints if node.attribute else []
        output_tensors = operator_map[node.op_type](input_tensors[0], axes)
    elif node.op_type == 'Concat':
        axis = node.attribute[0].i if node.attribute else 0
        output_tensors = operator_map[node.op_type](*input_tensors, axis=axis)
    elif node.op_type == 'Sigmoid':
        output_tensors = operator_map[node.op_type](input_tensors[0])
    elif node.op_type == 'Gemm':
        attrs = {attr.name: attr for attr in node.attribute}
        alpha = attrs['alpha'].f if 'alpha' in attrs else 1.0
        beta = attrs['beta'].f if 'beta' in attrs else 1.0
        trans_a = attrs['transA'].i if 'transA' in attrs else 0
        trans_b = attrs['transB'].i if 'transB' in attrs else 0
        c = input_tensors[2] if len(input_tensors) > 2 else None
        output_tensors = operator_map[node.op_type](input_tensors[0], input_tensors[1], c, alpha, beta, trans_a, trans_b)
    else:
        output_tensors = operator_map[node.op_type](*input_tensors)

    if not isinstance(output_tensors, tuple):
        output_tensors = (output_tensors,)

    for output_name, output_tensor in zip(node.output, output_tensors):
        inputs[output_name] = output_tensor
    print(f"Produced output for {node.op_type}: {output_tensors[0].shape}")
    

def execute_graph(graph, input_data):
    inputs = {}

    # Use provided input data
    for input_tensor in graph.input:
        input_name = input_tensor.name
        inputs[input_name] = input_data
        print(f"Loaded input {input_name} with shape {input_data.shape}")

    # Initialize tensors for constants (initializers)
    for initializer in graph.initializer:
        tensor = numpy_helper.to_array(initializer)
        inputs[initializer.name] = tensor
        print(f"Initialized tensor {initializer.name} with shape {tensor.shape}")

    # Execute nodes
    for node in graph.node:
        execute_node(node, inputs)

    # Extract outputs
    outputs = {output.name: inputs[output.name] for output in graph.output}
    return outputs

# Load the ONNX model
model_path = 'simple_cnn_model.onnx'
model = onnx.load(model_path)

# Get the graph from the model
graph = model.graph

# Load the MNIST dataset
(_, _), (x_test, y_test) = mnist.load_data()

# Preprocess the data
x_test = x_test.astype('float32') / 255
x_test = np.expand_dims(x_test, axis=-1)  # Add channel dimension

# Select a test sample
input_data = x_test[0:1]  # Selecting the first sample and keeping it as a batch

# Execute the graph with the input data
outputs = execute_graph(graph, input_data)
print(outputs)