In [27]:
import onnx
import onnx_graphsurgeon as gs
from pathlib import Path
import numpy as np
import copy

In [28]:
def newPath(path):

    OldPath = Path(path)

    return OldPath.parent / f"{OldPath.stem}_modified{OldPath.suffix}"

In [29]:
def updateDicts(graph):

    outputsDict = {}
    inputsDict = {}
    namesDict = {}

    for index in range(len(graph.nodes)):
        
        namesDict[graph.nodes[index].name] = index

        for output in graph.nodes[index].outputs:
            outputsDict[output.name] = index
        
        for inpt in graph.nodes[index].inputs:
            inputsDict[inpt.name] = index
    
    return outputsDict, inputsDict, namesDict

In [30]:
def outToIdx(outputs):
    global outputsDict
    
    outputsIndex = []
    for output in outputs:
        outputsIndex.append(outputsDict[output])
    return outputsIndex

In [31]:
def searchElem(graph, start, opType, maxDepth=7):
    global outputsDict

    if not maxDepth:
        return []
    
    node = graph.nodes[start]
    if node.op == opType:
        outputs =[]
        for output in node.outputs:
            outputs.append(output.name)
        return outputs
    
    inputs = []
    for inPut in node.inputs:
        if inPut.name in outputsDict:
            inputs.extend(searchElem(graph, outputsDict[inPut.name], opType, maxDepth - 1))
    
    return set(inputs)

In [32]:
def getOutput(graph, convIdxs):
    
    outputs = []

    for i in convIdxs:
        outputs.extend(graph.nodes[i].outputs)
    
    return outputs

In [33]:
def delNodes(graph, nodesToDel):
    
    nodesToDel = list(nodesToDel)
    nodesToDel.sort(reverse=True)

    for i in nodesToDel:
        del graph.nodes[i]

    graph.cleanup()

In [34]:
def getNewName(name='', preffix=''):
    global counts

    if name not in counts:
        counts[name] = 0

    counts[name] += 1

    if preffix:
        return f"/{preffix}/{name}_{counts[name]}"
    
    return f"{name}_{counts[name]}"

In [35]:
def getClipInputs():

    arr = []
    
    for i in [0, 1]:
        arr.append(gs.Constant(
                name=getNewName("Constant"),
                values=np.array(i, dtype=np.float32)
            ))
    
    return arr

In [36]:
def createOutput(opType='', graphOutput=False, name='', dtype="float32", shape=None):
    outputBase = {
        "dtype" : dtype, 
        "shape" : shape,
    }
    global preffix

    if opType:
        outputBase["name"] = getNewName(opType, preffix)
        return gs.Variable(**outputBase)

    if name:
        outputBase["name"] = getNewName(name)
        return gs.Variable(**outputBase)
    
    outputBase["name"] = getNewName()
    
    return gs.Variable(**outputBase)

In [37]:
def createNode(graph, newNodesInfo, inputs, shapeEnd=None):

    if newNodesInfo:
        nodeInfo = newNodesInfo.pop(0)
        outputs = copy.deepcopy(nodeInfo["outputs"][0])

        if nodeInfo["inputsConst"]:
            inputs.extend(getClipInputs())
        
        if outputs["shape"]:
            outputs["shape"].extend([shapeEnd] * 2)
        
        output = [createOutput(**outputs)]
        print(output)

        graph.outputs.extend(output * outputs["graphOutput"])

        nodeBase = {
            "op" : nodeInfo["opType"],
            "name" : getNewName(nodeInfo["opType"]),
            "inputs" : inputs,
            "outputs" : output,
        }

        if nodeInfo["attrs"]:
            nodeBase['attrs'] = nodeInfo["attrs"]
        
        node = gs.Node(**nodeBase)
            
        graph.nodes.append(node)

        createNode(graph, newNodesInfo, output, shapeEnd)
    

In [38]:
newElements = [
    {
        "opType" : 'Sigmoid',
        "attrs" : None,
        "inputsConst": False,
        "outputs" : [
            {
                "name" : "onnx::ReduceSum",
                "graphOutput" : True,
                "shape" : [1, 80],
                "dtype" : "float32",
            }
        ]
    },
    {
        "opType" : 'ReduceSum',
        "attrs" : {
            "axes": [1],
            "keepdims": 1,
        },
        "inputsConst": False,
        "outputs" : [
            {
                "opType" : 'ReduceSum',
                "graphOutput" : False,
                "shape" : None,
                "dtype" : "float32",
            }
        ]
    },
    {
        "opType" : 'Clip',
        "attrs" : None,
        "inputsConst": True,
        "outputs" : [
            {
                "graphOutput" : True,
                "shape" : [1, 1],
                "dtype" : "float32",
            }
        ]
    }
]

counts = {}


In [39]:
modelPath = './models/yolov8n.onnx'

newModelPath = newPath(modelPath)

In [40]:
model = onnx.load(modelPath)
graph = gs.import_onnx(model)

In [41]:
outputsDict, inputsDict, namesDict = updateDicts(graph)

In [42]:
convOuts = searchElem(graph, -1, 'Conv')
convOutIdxs = outToIdx(convOuts)
print(convOuts)
print(convOutIdxs)


{'/model.22/cv3.0/cv3.0.2/Conv_output_0'}
[137]


In [43]:
nodesToDel = set()
for i in convOuts:
    nodesToDel.add(inputsDict[i])
print(nodesToDel)

{138}


In [44]:
newOutputs = getOutput(graph, convOutIdxs)
graph.outputs = []

for output in newOutputs:
    preffix = output.name.strip("/").split("/")[0]
    shape = output.shape
    if shape[1] == 64:
        graph.outputs.append(output)
    else:
        createNode(graph, newElements.copy(), [output], shapeEnd=shape[-1])

delNodes(graph, nodesToDel)

[Variable (onnx::ReduceSum_1): (shape=[1, 80, 80, 80], dtype=float32)]
[Variable (/model.22/ReduceSum_1): (shape=None, dtype=float32)]
[Variable (_1): (shape=[1, 1, 80, 80], dtype=float32)]
[W] colored module is not installed, will not use colors when logging. To enable colors, please install the colored module: python3 -m pip install colored
[W] Found distinct tensors that share the same name:
[id: 2210264036816] Variable (onnx::ReduceSum_1): (shape=[1, 80, 80, 80], dtype=float32)
[id: 2210275778896] Variable (onnx::ReduceSum_1): (shape=[1, 80, 80, 80], dtype=float32)
Note: Producer node(s) of first tensor:
[Sigmoid_1 (Sigmoid)
	Inputs: [
		Variable (/model.22/cv3.0/cv3.0.2/Conv_output_0): (shape=[1, 80, 80, 80], dtype=float32)
	]
	Outputs: [
		Variable (onnx::ReduceSum_1): (shape=[1, 80, 80, 80], dtype=float32)
	]]
Producer node(s) of second tensor:
[Sigmoid_1 (Sigmoid)
	Inputs: [
		Variable (/model.22/cv3.0/cv3.0.2/Conv_output_0): (shape=[1, 80, 80, 80], dtype=float32)
	]
	Outputs: 

In [45]:
onnx.save(gs.export_onnx(graph), newModelPath)