# Split algo

In [None]:
import numpy as np
from numpy.testing import assert_allclose 


In [None]:
M =1
K=8
N=10

In [None]:

# Define the matrices
A = np.random.rand(M, K).astype(np.float32)  
B = np.random.rand(N, K).astype(np.float32)   #  transposed 
C = np.random.rand(M, N).astype(np.float32)  # 10x8 matrix

# Parameters for GEMM
alpha = 1.0
beta = 1.0

# Define the split point
p = C.shape[1] // 2   # Split column at position 512

# Split matrices B^T and C
B1 = B[:p, :]  # First 4 rows of transposed B
B2 = B[p:, :]  # Remaining rows of transposed B
C1 = C[:, :p]
C2 = C[:, p:]

# Perform sub-GEMMs
Y1 = alpha * np.dot(A, B1.T) + beta * C1
Y2 = alpha * np.dot(A, B2.T) + beta * C2

# Concatenate the results
Y = np.hstack((Y1, Y2))

#print("Matrix Y after combining sub-GEMMs:\n", Y)
Y0= alpha * np.dot(A, B.T) + beta * C
assert_allclose(Y0,Y,rtol=1e-0)

# Print the shapes of Y1 and Y2
print("Shape of Y1:", Y1.shape)
print("Shape of Y2:", Y2.shape)
print("Shape of Y:", Y.shape)
print("Shape of Y0:", Y0.shape)


In [None]:
import onnx
from onnx import helper, TensorProto, numpy_helper
import numpy as np

# Define the inputs and outputs for the Gemm node
input_A = helper.make_tensor_value_info('A', TensorProto.FLOAT, A.shape)
output_Y = helper.make_tensor_value_info('Y', TensorProto.FLOAT, Y0.shape)


# Create initializers for B and C
initializer_B = numpy_helper.from_array(B, name='B')
initializer_C = numpy_helper.from_array(C, name='C')

# Create the Gemm node
gemm_node = helper.make_node(
    'Gemm',
    inputs=['A', 'B', 'C'],
    outputs=['Y'],
    alpha=1.0,
    beta=1.0,
    transA=0,
    transB=1
)

# Create the graph (GraphProto)
graph_def = helper.make_graph(
    [gemm_node],
    'gemm_test',
    [input_A],
    [output_Y],
    [initializer_B, initializer_C]
)

# Create the model (ModelProto)
model_def = helper.make_model(graph_def, producer_name='gemm_example')

# Save the model to a file
onnx.save(model_def, 'onnx/gemm_model.onnx')

# Show model

In [None]:
import os
import pathlib
import onnx
import netron
import numpy as np

In [None]:

def show_model(model_file_name,itf='10.217.184.110',port=8098):
    netron.start(file=model_file_name,address=(itf,port))
    return port

input_path=pathlib.Path("onnx/gemm_model.onnx")
#input_path=pathlib.Path("onnx/submodel_46_48.onnx")
onnx_model = onnx.load(input_path)
onnx.checker.check_model(onnx_model)

def get_opset(model):
    fields =model.opset_import
    field=  fields[0]
    return field.version

In [None]:
f"{get_opset(onnx_model)}"

In [None]:
port=show_model("./"+str(input_path))

# Inference  

In [None]:
import onnx
from onnx import helper, TensorProto, numpy_helper
import onnxruntime as ort
import numpy as np

session = ort.InferenceSession('onnx/gemm_model.onnx')

# Prepare the input dictionary
input_dict = {'A': A}

# Run the inference
outputs = session.run(None, input_dict)

assert_allclose(Y0,outputs[0],rtol=1e-0)