In [289]:
import onnx
import onnx_graphsurgeon as gs
from pathlib import Path
import numpy as np

In [290]:
def newPath(path):

    OldPath = Path(path)

    return OldPath.parent / f"{OldPath.stem}_modified{OldPath.suffix}"

In [291]:
def updateDicts(graph):

    outputsDict = {}
    inputsDict = {}
    namesDict = {}

    for index in range(len(graph.nodes)):
        
        namesDict[graph.nodes[index].name] = index

        for output in graph.nodes[index].outputs:
            outputsDict[output.name] = index
        
        for inpt in graph.nodes[index].inputs:
            inputsDict[inpt.name] = index
    
    return outputsDict, inputsDict, namesDict

In [292]:
def outToIdx(outputs):
    global outputsDict
    
    outputsIndex = []
    for output in outputs:
        outputsIndex.append(outputsDict[output])
    return outputsIndex

In [293]:
def searchElem(graph, start, opType, maxDepth=7):
    global outputsDict

    if not maxDepth:
        return []
    
    node = graph.nodes[start]
    if node.op == opType:
        outputs =[]
        for output in node.outputs:
            outputs.append(output.name)
        return outputs
    
    inputs = []
    for inPut in node.inputs:
        if inPut.name in outputsDict:
            inputs.extend(searchElem(graph, outputsDict[inPut.name], opType, maxDepth - 1))
    
    return set(inputs)

In [294]:
def getOutput(graph, convIdxs):
    
    outputs = []

    for i in convIdxs:
        outputs.extend(graph.nodes[i].outputs)
    
    return outputs

In [295]:
def delNodes(graph, nodesToDel):
    
    nodesToDel = list(nodesToDel)
    nodesToDel.sort(reverse=True)

    for i in nodesToDel:
        del graph.nodes[i]

    graph.cleanup()

In [None]:
def createNode(graph, newNodesInfo, inputs, shapeEnd=None):
    global suffix

    if newNodesInfo:
        nodeInfo = newNodesInfo.pop(0)
        outputs = nodeInfo["outputs"][0]

        if nodeInfo["inputsConst"]:
            minNode = gs.Node(
                op="Constant",
                name=f"/model.22/Constant_{suffix}",
                attrs={"value": np.array([0], dtype=np.float32)},
                outputs=[gs.Variable(name=f"/model.22/Constant_{suffix}", dtype=np.float32)]
            )
            graph.nodes.append(minNode)
            suffix += 1

            maxNode = gs.Node(
                op="Constant",
                name=f"/model.22/Constant_{suffix}",
                attrs={"value": np.array([1], dtype=np.float32)},
                outputs=[gs.Variable(name=f"/model.22/Constant_{suffix}", dtype=np.float32)]
            )
            graph.nodes.append(maxNode)
            suffix += 1

            inputs.extend([minNode.outputs, maxNode.outputs])
        
        shape = outputs["shape"]
        if shape:
            shape.extend([shapeEnd] * 2)
        print(shape)
        output = [gs.Variable(
            name=outputs["name"] + str(suffix), 
            dtype=outputs["dtype"], 
            shape=shape
        )]
        suffix += 1

        graph.outputs.extend(output * outputs["graphOutput"])

        if nodeInfo["attrs"]:
            graph.nodes.append(gs.Node(
                op=nodeInfo["opType"],
                name=f"/model.22/{nodeInfo["opType"]}_" + str(suffix),
                attrs=nodeInfo["attrs"],
                inputs=inputs,
                outputs=output,
            ))
            
        else:
            graph.nodes.append(gs.Node(
                op=nodeInfo["opType"],
                name=f"/model.22/{nodeInfo["opType"]}",
                inputs=inputs,
                outputs=output,
            ))
        suffix += 1

        createNode(graph, newNodesInfo, output, shapeEnd)
    

In [297]:
newElements = [
    {
        "opType" : 'Sigmoid',
        "attrs" : None,
        "inputsConst": False,
        "outputs" : [
            {
                "name" : "onnx::ReduceSum_36",
                "graphOutput" : True,
                "shape" : [1, 80],
                "dtype" : "float32",
            }
        ]
    },
    {
        "opType" : 'ReduceSum',
        "attrs" : {
            "axes": [1],
            "keepdims": 1,
        },
        "inputsConst": False,
        "outputs" : [
            {
                "name" : "/model.22/ReduceSum_1_output_0",
                "graphOutput" : False,
                "shape" : None,
                "dtype" : "float32",
            }
        ]
    },
    {
        "opType" : 'Clip',
        "attrs" : None,
        "inputsConst": True,
        "outputs" : [
            {
                "name" : "369",
                "graphOutput" : True,
                "shape" : [1, 1],
                "dtype" : "float32",
            }
        ]
    }
]

suffix = 300

In [298]:
minIn = gs.Constant(
    name="/model.22/Constant_4_output_0",
    values=np.array([0], dtype=np.float32)
)

In [299]:
modelPath = './models/yolo11s.onnx'

newModelPath = newPath(modelPath)

In [300]:
model = onnx.load(modelPath)
graph = gs.import_onnx(model)

In [301]:
outputsDict, inputsDict, namesDict = updateDicts(graph)

In [302]:
convOuts = searchElem(graph, -1, 'Conv')
convOutIdxs = outToIdx(convOuts)
print(convOuts)
print(convOutIdxs)


{'/model.23/cv3.2/cv3.2.2/Conv_output_0', '/model.23/cv2.0/cv2.0.2/Conv_output_0', '/model.23/cv2.1/cv2.1.2/Conv_output_0', '/model.23/cv3.1/cv3.1.2/Conv_output_0', '/model.23/cv2.2/cv2.2.2/Conv_output_0', '/model.23/cv3.0/cv3.0.2/Conv_output_0'}
[299, 200, 241, 257, 292, 213]


In [303]:
nodesToDel = set()
for i in convOuts:
    nodesToDel.add(inputsDict[i])
print(nodesToDel)

{259, 300, 215}


In [304]:
# print("Узлов в графе:", len(graph.nodes))
# print("Входы:", graph.inputs)
# print("Выходы:", graph.nodes)

In [305]:
newOutputs = getOutput(graph, convOutIdxs)
graph.outputs = []
for output in newOutputs:
    print(output)
    createNode(graph, newElements, [output], shapeEnd=20)

delNodes(graph, nodesToDel)

Variable (/model.23/cv3.2/cv3.2.2/Conv_output_0): (shape=[1, 80, 20, 20], dtype=float32)
[1, 80, 20, 20]
None
[1, 1, 20, 20]
Variable (/model.23/cv2.0/cv2.0.2/Conv_output_0): (shape=[1, 64, 80, 80], dtype=float32)
Variable (/model.23/cv2.1/cv2.1.2/Conv_output_0): (shape=[1, 64, 40, 40], dtype=float32)
Variable (/model.23/cv3.1/cv3.1.2/Conv_output_0): (shape=[1, 80, 40, 40], dtype=float32)
Variable (/model.23/cv2.2/cv2.2.2/Conv_output_0): (shape=[1, 64, 20, 20], dtype=float32)
Variable (/model.23/cv3.0/cv3.0.2/Conv_output_0): (shape=[1, 80, 80, 80], dtype=float32)


In [306]:
onnx.save(gs.export_onnx(graph), newModelPath)