In [1]:
import torch
import torchvision.models as models
import torch.nn as nn
import onnx
from onnx import shape_inference
import sys
from tabulate import tabulate
from onnx import onnx_ml_pb2 as xpb2
import onnx.helper as helper
from onnx import numpy_helper
import numpy as np
from onnx import TensorProto
import onnxruntime as ort

### Build a GEMM

In [2]:
M, N, K = 128, 128, 128
graph_def = helper.make_graph(
    nodes=[],
    name='Gemm',
    inputs=[
        helper.make_tensor_value_info('A', onnx.TensorProto.FLOAT, [M, N]),
        helper.make_tensor_value_info('B', onnx.TensorProto.FLOAT, [N, K]),
        helper.make_tensor_value_info('C', onnx.TensorProto.FLOAT, [K])
    ],
    outputs=[
        helper.make_tensor_value_info('OUT', onnx.TensorProto.FLOAT, [M, K])
    ],
)

gemm = helper.make_node(
    op_type='Gemm',
    inputs=['A', 'B', 'C'],
    outputs=['OUT'],
)
graph_def.node.extend([gemm])

# Create the ONNX model
model_def = helper.make_model(graph_def, producer_name='simple_gemm')

# Save the ONNX model to a file
onnx.save(model_def, './models/simple_gemm.onnx')

### Build 64x64x64 Limit GEMM Layer with Bias

In [3]:
def Build_GEMM(M, N, K, Block_Size, model_name, input_A_name, input_B_name, input_C_name, output_name):
    M_blocks = (M + Block_Size - 1) // Block_Size
    N_blocks = (N + Block_Size - 1) // Block_Size
    K_blocks = (K + Block_Size - 1) // Block_Size
    
    graph_def = helper.make_graph(
        nodes=[],
        name='my_model',
        inputs=[
            helper.make_tensor_value_info(input_A_name, onnx.TensorProto.FLOAT, [M, N]),
            helper.make_tensor_value_info(input_B_name, onnx.TensorProto.FLOAT, [N, K]),
            helper.make_tensor_value_info(input_C_name, onnx.TensorProto.FLOAT, [K])
        ],
        outputs=[
            helper.make_tensor_value_info(output_name, onnx.TensorProto.FLOAT, [M, K])
        ],
    )

    # Create the Split nodes
    split_A_names1 = [f'A{i}' for i in range(M_blocks)]
    split_A_row = helper.make_node(
        op_type='Split',
        inputs=[input_A_name],
        outputs=split_A_names1,
        axis=0,
        num_outputs=M_blocks
    )
    graph_def.node.extend([split_A_row])
    
    # Split along axis 2
    for i in range(M_blocks):
        split_A_names2 = [f'A{i}/{j}' for j in range(N_blocks)]
        split_A_col = helper.make_node(
            op_type='Split',
            inputs=[split_A_names1[i]],
            outputs=split_A_names2,
            axis=1,
            num_outputs=N_blocks
        )
        graph_def.node.extend([split_A_col])
    
    # Create the Split nodes
    split_B_names1 = [f'B{i}' for i in range(N_blocks)]
    split_B_row = helper.make_node(
        op_type='Split',
        inputs=[input_B_name],
        outputs=split_B_names1,
        axis=0,
        num_outputs=N_blocks
    )
    graph_def.node.extend([split_B_row])
    
    # Split along axis 2
    for i in range(N_blocks):
        split_B_names2 = [f'B{i}/{j}' for j in range(K_blocks)]
        split_B_col = helper.make_node(
            op_type='Split',
            inputs=[split_B_names1[i]],
            outputs=split_B_names2,
            axis=1,
            num_outputs=K_blocks
        )
        graph_def.node.extend([split_B_col])
    
    # Loop through matrix multiplication
    for i in range(M_blocks):
        for j in range(K_blocks):
            for k in range(N_blocks):
                mul_node = helper.make_node('MatMul', [f'A{i}/{k}', f'B{k}/{j}'], [f'C{i}/{j}/{k}'])
                graph_def.node.extend([mul_node])

    # Add
    for i in range(M_blocks):
        for j in range(K_blocks):
            if (N_blocks == 1):
                copy_node = helper.make_node('Identity', inputs=[f'C{i}/{j}/{0}'], outputs=[f'C{i}/{j}'])
                graph_def.node.extend([copy_node])
            else:
                name_list = [f'C{i}/{j}/{k}' for k in range(N_blocks)]
                count = 0
                while (len(name_list) > 2):
                    add_node = helper.make_node('Add', inputs=[name_list[0], name_list[1]], outputs=[f'temp{i}/{j}/{count}'])
                    graph_def.node.extend([add_node])
                    name_list.pop(0)
                    name_list.pop(0)
                    name_list.append(f'temp{i}/{j}/{count}')
                    count += 1
                add_node = helper.make_node('Add', inputs=[name_list[0], name_list[1]], outputs=[f'C{i}/{j}'])
                graph_def.node.extend([add_node])

    for i in range(M_blocks):
        concat_node = helper.make_node('Concat', inputs=[f'C{i}/{j}' for j in range(K_blocks)], outputs=[f'C{i}'], axis=1)
        graph_def.node.extend([concat_node])
    
    concat_node = helper.make_node('Concat', inputs=[f'C{i}' for i in range(M_blocks)], outputs=['output_wt_bias'], axis=0)
    graph_def.node.extend([concat_node])

    add_node = helper.make_node('Add', inputs=['output_wt_bias', input_C_name], outputs=[output_name])
    graph_def.node.extend([add_node])
    # Create the ONNX model
    model_def = helper.make_model(graph_def, producer_name=model_name)
    
    # Save the ONNX model to a file
    onnx.save(model_def, './models/' + model_name + '.onnx')
    pass

M, N, K, BLOCK_SIZE = 128, 128, 128, 64
Build_GEMM(M, N, K, BLOCK_SIZE, 'my_model', 'A', 'B', 'C', 'OUT')

### Load My Model

In [4]:
onnx_model = onnx.load("./models/my_model.onnx", load_external_data=False)
onnx.checker.check_model(onnx_model)
simple = onnx.load("./models/simple_gemm.onnx", load_external_data=False)
onnx.checker.check_model(simple)

### Inference Test

In [5]:
onnx_session = ort.InferenceSession("./models/my_model.onnx")
onnx_session2 = ort.InferenceSession("./models/simple_gemm.onnx")
A_name = onnx_session.get_inputs()[0].name
B_name = onnx_session.get_inputs()[1].name
C_name = onnx_session.get_inputs()[2].name

A = np.random.rand(M, N).astype(np.float32)
B = np.random.rand(N, K).astype(np.float32)
C = np.random.rand(K).astype(np.float32)

onnx_output = onnx_session.run(None, {A_name: A, B_name: B, C_name: C})
onnx_output2 = onnx_session2.run(None, {A_name: A, B_name: B, C_name: C})

print("Inference Test for My GEMM")
print(np.allclose(np.array(onnx_output), np.array(onnx_output2), atol = 0))
print(np.array(onnx_output).shape)

Inference Test for My GEMM
True
(1, 128, 128)


### Load Alexnet

In [6]:
alexnet_model = onnx.load("./models/alexnet.onnx", load_external_data=False)
onnx.checker.check_model(alexnet_model)

inferred_model = shape_inference.infer_shapes(alexnet_model)
print('shape inference complete ...')

shape inference complete ...


### Get GEMM Layer names

In [7]:
GEMM_Layer_names = []
GEMM_Layer_input_names = []
GEMM_Layer_output_names = []
for node in alexnet_model.graph.node:
    if node.op_type == 'Gemm':
        GEMM_Layer_names.append(node.name)
        GEMM_Layer_input_names.append(node.input)
        GEMM_Layer_output_names.append(node.output)
        
print(GEMM_Layer_names)
print(GEMM_Layer_input_names)
print(GEMM_Layer_output_names)

['/classifier/classifier.1/Gemm', '/classifier/classifier.4/Gemm', '/classifier/classifier.6/Gemm']
[['/Flatten_output_0', 'learned_10', 'learned_11'], ['/classifier/classifier.2/Relu_output_0', 'learned_12', 'learned_13'], ['/classifier/classifier.5/Relu_output_0', 'learned_14', 'learned_15']]
[['/classifier/classifier.1/Gemm_output_0'], ['/classifier/classifier.4/Gemm_output_0'], ['output1']]


### Build My GEMM Layer  

In [8]:
def Build_GEMM2(M, N, K, Block_Size, index, input_A_name, input_B_name, input_C_name, output_name, graph_def):
    M_blocks = (M + Block_Size - 1) // Block_Size
    N_blocks = (N + Block_Size - 1) // Block_Size
    K_blocks = (K + Block_Size - 1) // Block_Size
    
    # Create the Split nodes
    split_A_names1 = [f'{index} A{i}' for i in range(M_blocks)]
    split_A_row = helper.make_node(
        op_type='Split',
        inputs=[input_A_name],
        outputs=split_A_names1,
        axis=0,
        num_outputs=M_blocks
    )
    graph_def.node.extend([split_A_row])

    # Split along axis 2
    for i in range(M_blocks):
        split_A_names2 = [f'{index} A{i}/{j}' for j in range(N_blocks)]
        split_A_col = helper.make_node(
            op_type='Split',
            inputs=[split_A_names1[i]],
            outputs=split_A_names2,
            axis=1,
            num_outputs=N_blocks
        )
        graph_def.node.extend([split_A_col])
    
    # Create the Split nodes
    Trans = helper.make_node(
        'Transpose',
        inputs=[input_B_name],
        outputs=[f'{index} BT'],
    )
    graph_def.node.extend([Trans])
    
    split_B_names1 = [f'{index} B{i}' for i in range(N_blocks)]
    split_B_row = helper.make_node(
        op_type='Split',
        inputs=[f'{index} BT'],
        outputs=split_B_names1,
        axis=0,
        num_outputs=N_blocks
    )
    graph_def.node.extend([split_B_row])
    
    # Split along axis 2
    for i in range(N_blocks):
        split_B_names2 = [f'{index} B{i}/{j}' for j in range(K_blocks)]
        split_B_col = helper.make_node(
            op_type='Split',
            inputs=[split_B_names1[i]],
            outputs=split_B_names2,
            axis=1,
            num_outputs=K_blocks
        )
        graph_def.node.extend([split_B_col])
    
    # Loop through matrix multiplication
    for i in range(M_blocks):
        for j in range(K_blocks):
            for k in range(N_blocks):
                mul_node = helper.make_node('MatMul', [f'{index} A{i}/{k}', f'{index} B{k}/{j}'], [f'{index} C{i}/{j}/{k}'])
                graph_def.node.extend([mul_node])
    # Add
    for i in range(M_blocks):
        for j in range(K_blocks):
            if (N_blocks == 1):
                copy_node = helper.make_node('Identity', inputs=[f'{index} C{i}/{j}/{0}'], outputs=[f'{index} C{i}/{j}'])
                graph_def.node.extend([copy_node])
            else:
                name_list = [f'{index} C{i}/{j}/{k}' for k in range(N_blocks)]
                count = 0
                while (len(name_list) > 2):
                    add_node = helper.make_node('Add', inputs=[name_list[0], name_list[1]], outputs=[f'{index} temp{i}/{j}/{count}'])
                    graph_def.node.extend([add_node])
                    name_list.pop(0)
                    name_list.pop(0)
                    name_list.append(f'{index} temp{i}/{j}/{count}')
                    count += 1
                add_node = helper.make_node('Add', inputs=[name_list[0], name_list[1]], outputs=[f'{index} C{i}/{j}'])
                graph_def.node.extend([add_node])
                
    for i in range(M_blocks):
        concat_node = helper.make_node('Concat', inputs=[f'{index} C{i}/{j}' for j in range(K_blocks)], outputs=[f'{index} C{i}'], axis=1)
        graph_def.node.extend([concat_node])
    
    concat_node = helper.make_node('Concat', inputs=[f'{index} C{i}' for i in range(M_blocks)], outputs=[f'{index} output_wt_bias'], axis=0)
    graph_def.node.extend([concat_node])

    add_node = helper.make_node('Add', inputs=[f'{index} output_wt_bias', input_C_name], outputs=[output_name])
    graph_def.node.extend([add_node])
    pass

### Build MY Alexnet

In [10]:
my_alexnet = onnx.helper.make_graph(
    nodes=[],
    name="my_alexnet",
    inputs=alexnet_model.graph.input,
    outputs=alexnet_model.graph.output,
    initializer=alexnet_model.graph.initializer,
)

Block_Size = 64

for node in alexnet_model.graph.node:
    if node.name == GEMM_Layer_names[0]:
        M, N, K, BLOCK_SIZE = 1, 9216, 4096, Block_Size
        input1, input2, input3 = GEMM_Layer_input_names[0]
        output = GEMM_Layer_output_names[0][0]
        Build_GEMM2(M, N, K, BLOCK_SIZE, 0, input1, input2, input3, output, my_alexnet)
    
    elif node.name == GEMM_Layer_names[1]:
        M, N, K, BLOCK_SIZE = 1, 4096, 4096, Block_Size
        input1, input2, input3 = GEMM_Layer_input_names[1]
        output = GEMM_Layer_output_names[1][0]
        Build_GEMM2(M, N, K, BLOCK_SIZE, 1, input1, input2, input3, output, my_alexnet)
       
    elif node.name == GEMM_Layer_names[2]:
        M, N, K, BLOCK_SIZE = 1, 4096, 1000, Block_Size
        input1, input2, input3 = GEMM_Layer_input_names[2]
        output = GEMM_Layer_output_names[2][0]
        Build_GEMM2(M, N, K, BLOCK_SIZE, 2, input1, input2, input3, output, my_alexnet)
   
    else:
        my_alexnet.node.extend([node])

# Create a new model with the modified graph
modified_model = onnx.helper.make_model(my_alexnet)
modified_model_path = "./models/modified_alexnet.onnx"
onnx.save(modified_model, modified_model_path)

### Correctness Verification

In [23]:
modified_alexnet = onnx.load("./models/modified_alexnet.onnx", load_external_data=False)
onnx.checker.check_model(modified_alexnet)

In [14]:
onnx_session = ort.InferenceSession("./models/alexnet.onnx")
mod_onnx_session = ort.InferenceSession("./models/modified_alexnet.onnx")

In [20]:
input_name = onnx_session.get_inputs()[0].name
input = np.random.rand(1, 3, 224, 224).astype(np.float32)

onnx_output = onnx_session.run(None, {input_name: input})
mod_onnx_output = mod_onnx_session.run(None, {input_name: input})
print(np.allclose(np.array(mod_onnx_output), np.array(onnx_output), atol = 1e-6))
print(np.array(mod_onnx_output).shape)

True
(1, 1, 1000)


### Misunderstand

In [None]:
class MyGEMMLayer(nn.Module):
    def __init__(self, input_size, output_size, block_size=64):
        super(MyGEMMLayer, self).__init__()
        self.input_size = input_size
        self.output_size = output_size
        self.block_size = block_size

        #linear_layer = nn.Linear(input_size, output_size, bias=True)
        #self.weight = nn.Parameter(linear_layer.weight.t())  # Transpose here
        #self.bias = nn.Parameter(linear_layer.bias)
        self.weight = torch.transpose(weight, 0, 1)
        self.bias = bias
    
    def forward(self, x):
        M, _ = x.size()
        N = self.input_size
        K = self.output_size
        block_size = self.block_size
        
        # Reshape input and weight matrices into blocks
        x_blocks = x.view(M // block_size, block_size, N // block_size, block_size)
        weight_blocks = self.weight.view(N // block_size, block_size, K // block_size, block_size)

        output = torch.zeros(M, K)

        # Perform the 64x64x64 tensor multiplication iteratively
        for i in range(M // block_size):
            for k in range(K // block_size):
                for j in range(N // block_size):
                    output[i * block_size: (i + 1) * block_size, k * block_size: (k + 1) * block_size] += torch.matmul(
                        x_blocks[i, :, j, :], weight_blocks[j, :, k, :]
                    )
        output += self.bias
        
        return output


# Example usage:
x_size = 256
input_size = 192
output_size = 128

# Generate random input tensor
input_tensor = torch.randn(x_size, input_size)

# test my layer with origin linear layer 
linear = nn.Linear(input_size, output_size)
# extract weight and bias in linear layer
weight = linear.weight
bias = linear.bias
print(weight.size())
print(bias.size())
weight = linear.weight
bias = linear.bias

output_tensor = linear(input_tensor)

gemm_layer = MyGEMMLayer(input_size, output_size)
output_tensor2 = gemm_layer(input_tensor)

tolerance = 1e-5
print(torch.allclose(torch.transpose(linear.weight, 0, 1), gemm_layer.weight, atol=tolerance))
print(torch.allclose(linear.bias, gemm_layer.bias, atol=tolerance))
print(torch.allclose(output_tensor, output_tensor2, atol=tolerance))