In [24]:
import popart
import numpy as np
from typing import *
import time
import onnx
from itertools import chain
from onnx.helper import make_attribute, make_node

np.random.seed(1)

In [2]:
def make_an_anchor(onnx_model_builder: popart.Builder, anchor_return_type: str = "ALL"):
    return {output_name: popart.AnchorReturnType(anchor_return_type) 
                for output_name in onnx_model_builder.getOutputTensorIds()}

def set_batch_size(shape: List, batch_size = 1):
    """ shape is like [unknow_batch_size, n1, n2, ...] """
    shape[0] = batch_size if shape[0] == 0 else shape[0]
    return shape

def convert_popart_dtype(dtype: str):
    dtype_conv_dic = {
        "int64": "INT32",
        "int32": "INT32",
        "float32": "FLOAT",
        "float16": "FLOAT16",
        "float64": "FLOAT",
    }

    return dtype_conv_dic[dtype]

def convert_numpy_dtype(dtype: str):
    dtype_conv_dic = {
        "int64": np.int32, 
        "int32": np.int32, 
        "float32": np.float32,
        "float": np.float32,
        "float64": np.float32,
    }
    return dtype_conv_dic[dtype.lower()]


def add_shapeinfo_from_onnx(onnx_model_builder: popart.Builder, batch_size = 1, batch_per_step = 1):
    inputs_tensor_id = onnx_model_builder.getInputTensorIds()
    outputs_tensor_id = onnx_model_builder.getOutputTensorIds()

    print(inputs_tensor_id)
    inputs_shapes = [set_batch_size(onnx_model_builder.getTensorShape(i), 
                                batch_size=batch_size * batch_per_step) for i in inputs_tensor_id]
    print(inputs_shapes)
    inputs_dtypes = [convert_popart_dtype(onnx_model_builder.getTensorDtypeString(i)) for i in inputs_tensor_id]
    print(inputs_dtypes)

    inputs_tensor_shapes = [set_batch_size(onnx_model_builder.getTensorShape(i), 
                                batch_size=batch_size) for i in inputs_tensor_id + outputs_tensor_id]
    inputs_tensor_dtypes = [convert_popart_dtype(onnx_model_builder.getTensorDtypeString(i)) for i in inputs_tensor_id + outputs_tensor_id]

    inputShapeInfo = popart.InputShapeInfo()

    for tid, tshape, tdype in zip(inputs_tensor_id + outputs_tensor_id, inputs_tensor_shapes, inputs_tensor_dtypes):
        inputShapeInfo.add(tid, popart.TensorInfo(tdype, tshape))

    return inputs_tensor_id, inputShapeInfo, inputs_shapes, inputs_dtypes


def fake_dataset(inputs_tensor_id, inputs_shapes, inputs_dtypes, num_samples = 100):
    for _ in range(num_samples):
        yield { i: np.random.randint(12, size=s).astype(convert_numpy_dtype(d)) for i, s, d in zip(inputs_tensor_id, inputs_shapes, inputs_dtypes) }


In [3]:
def run(builder, opts, batch_size = 1, batch_per_step = 1, n_sample = None):

    global_batch_size = batch_per_step * batch_size
    n_sample = n_sample or global_batch_size

    # builder = popart.Builder("qtc35/model.onnx")
    # builder = popart.Builder("subqtc-manually.onnx")
    # builder = popart.Builder("qtc211-tf-ipu/qtcv211-int32-manual.onnx")
    # builder = popart.Builder("reproducer/sub_qtcv211.onnx")
    anchors = make_an_anchor(builder)
    inputs_tensor_id, inputShapeInfo, inputs_shapes, inputs_dtypes = add_shapeinfo_from_onnx(builder, batch_size=batch_size, batch_per_step = batch_per_step)
    # anchors = {output_name: popart.AnchorReturnType("All") for output_name in builder.getOutputTensorIds() }
    dataflow = popart.DataFlow(batch_per_step, anchors)
    device = popart.DeviceManager().acquireAvailableDevice(2)
    # device = popart.DeviceManager().createCpuDevice()

    # opts = popart.SessionOptions()
    # opts.virtualGraphMode = popart.VirtualGraphMode.Manual
    # opts.enablePipelining = True
    # partials_type = "half"
    # opts.partialsTypeMatMuls = partials_type
    # opts.convolutionOptions = {'partialsType': partials_type}
    # opts.groupHostSync = True

    # builder.virtualGraph("Reshape_1:0", 0)
    # builder.virtualGraph("Reshape_2:0", 1)  
    # builder.virtualGraph("Reshape:0", 2)
    # builder.virtualGraph("prob", 3)
    # builder.pipelineStage("Reshape_1:0", 0)
    # builder.pipelineStage("Reshape_2:0", 1)
    # builder.pipelineStage("Reshape:0", 2)
    # builder.pipelineStage("prob", 3)

    # session = popart.InferenceSession(builder.getModelProto(), dataflow, device, inputShapeInfo)
    session = popart.InferenceSession(builder.getModelProto(), dataflow, device, 
                                    inputShapeInfo=inputShapeInfo,
                                    userOptions=opts)

    session.prepareDevice()
    anchors = session.initAnchorArrays()

    durations = []

    for feed_dict in fake_dataset(inputs_tensor_id, inputs_shapes, inputs_dtypes, num_samples=n_sample):

        print("[qtc-inference] Starting batch inference")
        stepio = popart.PyStepIO(feed_dict, anchors)
        # start = time.perf_counter()
        session.run(stepio)
        for k, v in anchors.items():
            print(v)
        # duration = time.perf_counter() - start

        # durations.append(duration / global_batch_size)

    # np_dur = np.array(durations[10:]).mean()
    # print(f"Latency: {np_dur} s/sample(mean)")

    # for k, v in anchors.items():
    #     print(v)


In [4]:
builder = popart.Builder("qtc35/model.onnx")

In [5]:
opts = popart.SessionOptions()
opts.virtualGraphMode = popart.VirtualGraphMode.Auto
# opts.enablePipelining = True
partials_type = "half"
opts.partialsTypeMatMuls = partials_type
opts.convolutionOptions = {'partialsType': partials_type}

In [6]:
m = onnx.load("qtc35/model.onnx")

In [7]:
reload_qt_ns = [n.output for n in m.graph.node if n.name.startswith("reload_qt/")]

In [8]:
reload_ns = [n.output for n in m.graph.node if n.name.startswith("reload/")]

In [9]:
rest_ns = [n.output for n in m.graph.node if not n.name.startswith("reload")]

In [21]:
def ipuconfig_in_attributes(node):
    for a in node.attribute:
        if a.name == "__ipu_number":
            return True
    return False

In [25]:
ipuconfig_in_attributes(make_node("Cast", inputs=["a"], outputs=["c"], to = 7, __ipu_number = 1))

True

In [31]:
make_attribute("__ipu_number", 0)

name: "__ipu_number"
i: 0
type: INT

In [32]:
for node in m.graph.node:
    if node.name.startswith("reload_qt/"):
        if not ipuconfig_in_attributes(node):
            node.attribute.append(make_attribute("__ipu_number", 0))

In [33]:
for node in m.graph.node:
    if node.name.startswith("reload/"):
        if not ipuconfig_in_attributes(node):
            node.attribute.append(make_attribute("__ipu_number", 1))

In [34]:
for node in m.graph.node:
    if not node.name.startswith("reload"):
        if not ipuconfig_in_attributes(node):
            node.attribute.append(make_attribute("__ipu_number", 1))

In [37]:
[(i, n) for i, n in enumerate(m.graph.node) if "Concat__453" in n.name ]

[(72,
  input: "const_fold_opt__1614"
  input: "const_fold_opt__1618"
  output: "Concat__453:0"
  name: "Concat__453"
  op_type: "Concat"
  attribute {
    name: "axis"
    i: 0
    type: INT
  }
  attribute {
    name: "__ipu_number"
    i: 1
    type: INT
  }
  domain: "")]

In [41]:
m.graph.node[72].attribute[1]

name: "__ipu_number"
i: 0
type: INT

In [13]:
m.graph.node[6].HasExtension("attribute")

KeyError: 'attribute'

In [20]:
m.graph.node[6].attribute.append

<function RepeatedCompositeContainer.append>

In [20]:
for Ore in chain(*reload_ns):
    try:
        builder.virtualGraph(Oqt, 1)
    except:
        pass

In [18]:
for Oqt in chain(*reload_qt_ns):
    builder.virtualGraph(Oqt, 0)

In [30]:
mk = onnx.load_from_string(builder.getModelProto())

In [23]:
mk.graph.node[64]

input: "reload_qt/SequenceMask/Cast:0"
output: "reload_qt/SequenceMask/Less__1199:0"
name: "reload_qt/SequenceMask/Less__1199"
op_type: "Cast"
attribute {
  name: "to"
  i: 1
  type: INT
}
attribute {
  name: "__ipu_number"
  i: 0
  type: INT
}

In [25]:
for Oqt in chain(*rest_ns[1:]):
    builder.virtualGraph(Oqt, 1)

In [26]:
rest_ns[0]

['Concat__453:0']

In [103]:
builder.virtualGraph('Concat__453:0', 0)

popart_exception: Node already has attribute __ipu_number.

In [28]:
opts = popart.SessionOptions()
opts.virtualGraphMode = popart.VirtualGraphMode.Manual
# opts.enablePipelining = True
partials_type = "half"
opts.partialsTypeMatMuls = partials_type
opts.convolutionOptions = {'partialsType': partials_type}

In [29]:
run(builder, opts)

['TensorDict/StandardKvParser_4:0', 'TensorDict/StandardKvParser_1:0', 'TensorDict/StandardKvParser_6:0', 'TensorDict/StandardKvParser_8:0']
[[1, 512], [1, 64], [1, 512], [1, 64]]
['INT32', 'INT32', 'INT32', 'INT32']


popart_exception: Either all Ops in the main graph must have their virtual graph ids set, or none must. Op count per virtual graph id
  -1 : 790
  0 : 790
  1 : 42
Ops with no virtual graph id :  
  101 (ai.onnx.Reshape:5)
  106 (ai.onnx.Expand:8)
  109 (ai.onnx.Reshape:5)
  110 (ai.onnx.Gather:11)
  115 (ai.onnx.Reshape:5)
  120 (ai.onnx.Equal:11)
  121 (ai.onnx.Not:1)
  122 (ai.onnx.Cast:9)
  123 (ai.onnx.ReduceSum:11)
  124 (ai.onnx.Reshape:5)
  125 (ai.onnx.Cast:9)
  126 (ai.onnx.Cast:9)
  127 (ai.onnx.Less:9)
  128 (ai.onnx.Cast:9)
  129 (ai.onnx.Reshape:5)
  130 (ai.onnx.Mul:7)
  131 (ai.onnx.Reshape:5)
  132 (ai.onnx.Sub:7)
  133 (ai.onnx.Mul:7)
  963 (ai.onnx.OneHot:11)
  964 (ai.onnx.MatMul:9)
  965 (ai.onnx.Reshape:5)
  966 (ai.onnx.Add:7)
  967 (ai.onnx.Add:7)
  968 (ai.onnx.ReduceMean:11)
  969 (ai.onnx.Cast:9)
  970 (ai.onnx.Cast:9)
  971 (ai.onnx.Sub:7)
  972 (ai.onnx.Mul:7)
  973 (ai.onnx.GlobalAveragePool:1)
  974 (ai.onnx.Cast:9)
  975 (ai.onnx.Add:7)
  976 (ai.onnx.Sqrt:6)
  977 (ai.onnx.Reciprocal:6)
  978 (ai.onnx.Mul:7)
  979 (ai.onnx.Mul:7)
  980 (ai.onnx.Sub:7)
  981 (ai.onnx.Mul:7)
  982 (ai.onnx.Add:7)
  995 (ai.onnx.Reshape:5)
  996 (ai.onnx.MatMul:9)
  997 (ai.onnx.Add:7)
  998 (ai.onnx.Reshape:5)
  999 (ai.onnx.Transpose:1)
  1000 (ai.onnx.MatMul:9)
  1001 (ai.onnx.Add:7)
  1002 (ai.onnx.Reshape:5)
  1003 (ai.onnx.Transpose:1)
  1004 (ai.onnx.MatMul:9)
  1005 (ai.onnx.Add:7)
  1006 (ai.onnx.Reshape:5)
  1007 (ai.onnx.Transpose:1)
  1008 (ai.onnx.MatMul:9)
  1009 (ai.onnx.Mul:7)
  1010 (ai.onnx.Add:7)
  1011 (ai.onnx.Softmax:11)
  1012 (ai.onnx.MatMul:9)
  1013 (ai.onnx.Transpose:1)
  1014 (ai.onnx.Reshape:5)
  1015 (ai.onnx.MatMul:9)
  1016 (ai.onnx.Add:7)
  1017 (ai.onnx.Add:7)
  1018 (ai.onnx.ReduceMean:11)
  1019 (ai.onnx.Cast:9)
  1020 (ai.onnx.Cast:9)
  1021 (ai.onnx.Sub:7)
  1022 (ai.onnx.Mul:7)
  1023 (ai.onnx.ReduceMean:11)
  1024 (ai.onnx.Cast:9)
  1025 (ai.onnx.Add:7)
  1026 (ai.onnx.Sqrt:6)
  1027 (ai.onnx.Reciprocal:6)
  1028 (ai.onnx.Mul:7)
  1029 (ai.onnx.Mul:7)
  1030 (ai.onnx.Sub:7)
  1031 (ai.onnx.Mul:7)
  1032 (ai.onnx.Add:7)
  1033 (ai.onnx.MatMul:9)
  1034 (ai.onnx.Add:7)
  1035 (ai.onnx.Div:7)
  1036 (ai.onnx.Erf:9)
  1037 (ai.onnx.Add:7)
  1038 (ai.onnx.Mul:7)
  1039 (ai.onnx.Mul:7)
  1040 (ai.onnx.MatMul:9)
  1041 (ai.onnx.Add:7)
  1042 (ai.onnx.Add:7)
  1043 (ai.onnx.ReduceMean:11)
  1044 (ai.onnx.Cast:9)
  1045 (ai.onnx.Cast:9)
  1046 (ai.onnx.Sub:7)
  1047 (ai.onnx.Mul:7)
  1048 (ai.onnx.ReduceMean:11)
  1049 (ai.onnx.Cast:9)
  1050 (ai.onnx.Add:7)
  1051 (ai.onnx.Sqrt:6)
  1052 (ai.onnx.Reciprocal:6)
  1053 (ai.onnx.Mul:7)
  1054 (ai.onnx.Mul:7)
  1055 (ai.onnx.Sub:7)
  1056 (ai.onnx.Mul:7)
  1057 (ai.onnx.Add:7)
  1058 (ai.onnx.MatMul:9)
  1059 (ai.onnx.Add:7)
  1060 (ai.onnx.Reshape:5)
  1061 (ai.onnx.Transpose:1)
  1062 (ai.onnx.MatMul:9)
  1063 (ai.onnx.Add:7)
  1064 (ai.onnx.Reshape:5)
  1065 (ai.onnx.Transpose:1)
  1066 (ai.onnx.MatMul:9)
  1067 (ai.onnx.Add:7)
  1068 (ai.onnx.Reshape:5)
  1069 (ai.onnx.Transpose:1)
  1070 (ai.onnx.MatMul:9)
  1071 (ai.onnx.Mul:7)
  1072 (ai.onnx.Add:7)
  1073 (ai.onnx.Softmax:11)
  1074 (ai.onnx.MatMul:9)
  1075 (ai.onnx.Transpose:1)
  1076 (ai.onnx.Reshape:5)
  1077 (ai.onnx.MatMul:9)
  1078 (ai.onnx.Add:7)
  1079 (ai.onnx.Add:7)
  1080 (ai.onnx.ReduceMean:11)
  1081 (ai.onnx.Cast:9)
  1082 (ai.onnx.Cast:9)
  1083 (ai.onnx.Sub:7)
  1084 (ai.onnx.Mul:7)
  1085 (ai.onnx.ReduceMean:11)
  1086 (ai.onnx.Cast:9)
  1087 (ai.onnx.Add:7)
  1088 (ai.onnx.Sqrt:6)
  1089 (ai.onnx.Reciprocal:6)
  1090 (ai.onnx.Mul:7)
  1091 (ai.onnx.Mul:7)
  1092 (ai.onnx.Sub:7)
  1093 (ai.onnx.Mul:7)
  1094 (ai.onnx.Add:7)
  1095 (ai.onnx.MatMul:9)
  1096 (ai.onnx.Add:7)
  1097 (ai.onnx.Div:7)
  1098 (ai.onnx.Erf:9)
  1099 (ai.onnx.Add:7)
  1100 (ai.onnx.Mul:7)
  1101 (ai.onnx.Mul:7)
  1102 (ai.onnx.MatMul:9)
  1103 (ai.onnx.Add:7)
  1104 (ai.onnx.Add:7)
  1105 (ai.onnx.ReduceMean:11)
  1106 (ai.onnx.Cast:9)
  1107 (ai.onnx.Cast:9)
  1108 (ai.onnx.Sub:7)
  1109 (ai.onnx.Mul:7)
  1110 (ai.onnx.ReduceMean:11)
  1111 (ai.onnx.Cast:9)
  1112 (ai.onnx.Add:7)
  1113 (ai.onnx.Sqrt:6)
  1114 (ai.onnx.Reciprocal:6)
  1115 (ai.onnx.Mul:7)
  1116 (ai.onnx.Mul:7)
  1117 (ai.onnx.Sub:7)
  1118 (ai.onnx.Mul:7)
  1119 (ai.onnx.Add:7)
  1120 (ai.onnx.MatMul:9)
  1121 (ai.onnx.Add:7)
  1122 (ai.onnx.Reshape:5)
  1123 (ai.onnx.Transpose:1)
  1124 (ai.onnx.MatMul:9)
  1125 (ai.onnx.Add:7)
  1126 (ai.onnx.Reshape:5)
  1127 (ai.onnx.Transpose:1)
  1128 (ai.onnx.MatMul:9)
  1129 (ai.onnx.Add:7)
  1130 (ai.onnx.Reshape:5)
  1131 (ai.onnx.Transpose:1)
  1132 (ai.onnx.MatMul:9)
  1133 (ai.onnx.Mul:7)
  1134 (ai.onnx.Add:7)
  1135 (ai.onnx.Softmax:11)
  1136 (ai.onnx.MatMul:9)
  1137 (ai.onnx.Transpose:1)
  1138 (ai.onnx.Reshape:5)
  1139 (ai.onnx.MatMul:9)
  1140 (ai.onnx.Add:7)
  1141 (ai.onnx.Add:7)
  1142 (ai.onnx.ReduceMean:11)
  1143 (ai.onnx.Cast:9)
  1144 (ai.onnx.Cast:9)
  1145 (ai.onnx.Sub:7)
  1146 (ai.onnx.Mul:7)
  1147 (ai.onnx.ReduceMean:11)
  1148 (ai.onnx.Cast:9)
  1149 (ai.onnx.Add:7)
  1150 (ai.onnx.Sqrt:6)
  1151 (ai.onnx.Reciprocal:6)
  1152 (ai.onnx.Mul:7)
  1153 (ai.onnx.Mul:7)
  1154 (ai.onnx.Sub:7)
  1155 (ai.onnx.Mul:7)
  1156 (ai.onnx.Add:7)
  1157 (ai.onnx.MatMul:9)
  1158 (ai.onnx.Add:7)
  1159 (ai.onnx.Div:7)
  1160 (ai.onnx.Erf:9)
  1161 (ai.onnx.Add:7)
  1162 (ai.onnx.Mul:7)
  1163 (ai.onnx.Mul:7)
  1164 (ai.onnx.MatMul:9)
  1165 (ai.onnx.Add:7)
  1166 (ai.onnx.Add:7)
  1167 (ai.onnx.ReduceMean:11)
  1168 (ai.onnx.Cast:9)
  1169 (ai.onnx.Cast:9)
  1170 (ai.onnx.Sub:7)
  1171 (ai.onnx.Mul:7)
  1172 (ai.onnx.ReduceMean:11)
  1173 (ai.onnx.Cast:9)
  1174 (ai.onnx.Add:7)
  1175 (ai.onnx.Sqrt:6)
  1176 (ai.onnx.Reciprocal:6)
  1177 (ai.onnx.Mul:7)
  1178 (ai.onnx.Mul:7)
  1179 (ai.onnx.Sub:7)
  1180 (ai.onnx.Mul:7)
  1181 (ai.onnx.Add:7)
  1182 (ai.onnx.MatMul:9)
  1183 (ai.onnx.Add:7)
  1184 (ai.onnx.Reshape:5)
  1185 (ai.onnx.Transpose:1)
  1186 (ai.onnx.MatMul:9)
  1187 (ai.onnx.Add:7)
  1188 (ai.onnx.Reshape:5)
  1189 (ai.onnx.Transpose:1)
  1190 (ai.onnx.MatMul:9)
  1191 (ai.onnx.Add:7)
  1192 (ai.onnx.Reshape:5)
  1193 (ai.onnx.Transpose:1)
  1194 (ai.onnx.MatMul:9)
  1195 (ai.onnx.Mul:7)
  1196 (ai.onnx.Add:7)
  1197 (ai.onnx.Softmax:11)
  1198 (ai.onnx.MatMul:9)
  1199 (ai.onnx.Transpose:1)
  1200 (ai.onnx.Reshape:5)
  1201 (ai.onnx.MatMul:9)
  1202 (ai.onnx.Add:7)
  1203 (ai.onnx.Add:7)
  1204 (ai.onnx.ReduceMean:11)
  1205 (ai.onnx.Cast:9)
  1206 (ai.onnx.Cast:9)
  1207 (ai.onnx.Sub:7)
  1208 (ai.onnx.Mul:7)
  1209 (ai.onnx.ReduceMean:11)
  1210 (ai.onnx.Cast:9)
  1211 (ai.onnx.Add:7)
  1212 (ai.onnx.Sqrt:6)
  1213 (ai.onnx.Reciprocal:6)
  1214 (ai.onnx.Mul:7)
  1215 (ai.onnx.Mul:7)
  1216 (ai.onnx.Sub:7)
  1217 (ai.onnx.Mul:7)
  1218 (ai.onnx.Add:7)
  1219 (ai.onnx.MatMul:9)
  1220 (ai.onnx.Add:7)
  1221 (ai.onnx.Div:7)
  1222 (ai.onnx.Erf:9)
  1223 (ai.onnx.Add:7)
  1224 (ai.onnx.Mul:7)
  1225 (ai.onnx.Mul:7)
  1226 (ai.onnx.MatMul:9)
  1227 (ai.onnx.Add:7)
  1228 (ai.onnx.Add:7)
  1229 (ai.onnx.ReduceMean:11)
  1230 (ai.onnx.Cast:9)
  1231 (ai.onnx.Cast:9)
  1232 (ai.onnx.Sub:7)
  1233 (ai.onnx.Mul:7)
  1234 (ai.onnx.ReduceMean:11)
  1235 (ai.onnx.Cast:9)
  1236 (ai.onnx.Add:7)
  1237 (ai.onnx.Sqrt:6)
  1238 (ai.onnx.Reciprocal:6)
  1239 (ai.onnx.Mul:7)
  1240 (ai.onnx.Mul:7)
  1241 (ai.onnx.Sub:7)
  1242 (ai.onnx.Mul:7)
  1243 (ai.onnx.Add:7)
  1244 (ai.onnx.MatMul:9)
  1245 (ai.onnx.Add:7)
  1246 (ai.onnx.Reshape:5)
  1247 (ai.onnx.Transpose:1)
  1248 (ai.onnx.MatMul:9)
  1249 (ai.onnx.Add:7)
  1250 (ai.onnx.Reshape:5)
  1251 (ai.onnx.Transpose:1)
  1252 (ai.onnx.MatMul:9)
  1253 (ai.onnx.Add:7)
  1254 (ai.onnx.Reshape:5)
  1255 (ai.onnx.Transpose:1)
  1256 (ai.onnx.MatMul:9)
  1257 (ai.onnx.Mul:7)
  1258 (ai.onnx.Add:7)
  1259 (ai.onnx.Softmax:11)
  1260 (ai.onnx.MatMul:9)
  1261 (ai.onnx.Transpose:1)
  1262 (ai.onnx.Reshape:5)
  1263 (ai.onnx.MatMul:9)
  1264 (ai.onnx.Add:7)
  1265 (ai.onnx.Add:7)
  1266 (ai.onnx.ReduceMean:11)
  1267 (ai.onnx.Cast:9)
  1268 (ai.onnx.Cast:9)
  1269 (ai.onnx.Sub:7)
  1270 (ai.onnx.Mul:7)
  1271 (ai.onnx.ReduceMean:11)
  1272 (ai.onnx.Cast:9)
  1273 (ai.onnx.Add:7)
  1274 (ai.onnx.Sqrt:6)
  1275 (ai.onnx.Reciprocal:6)
  1276 (ai.onnx.Mul:7)
  1277 (ai.onnx.Mul:7)
  1278 (ai.onnx.Sub:7)
  1279 (ai.onnx.Mul:7)
  1280 (ai.onnx.Add:7)
  1281 (ai.onnx.MatMul:9)
  1282 (ai.onnx.Add:7)
  1283 (ai.onnx.Div:7)
  1284 (ai.onnx.Erf:9)
  1285 (ai.onnx.Add:7)
  1286 (ai.onnx.Mul:7)
  1287 (ai.onnx.Mul:7)
  1288 (ai.onnx.MatMul:9)
  1289 (ai.onnx.Add:7)
  1290 (ai.onnx.Add:7)
  1291 (ai.onnx.ReduceMean:11)
  1292 (ai.onnx.Cast:9)
  1293 (ai.onnx.Cast:9)
  1294 (ai.onnx.Sub:7)
  1295 (ai.onnx.Mul:7)
  1296 (ai.onnx.ReduceMean:11)
  1297 (ai.onnx.Cast:9)
  1298 (ai.onnx.Add:7)
  1299 (ai.onnx.Sqrt:6)
  1300 (ai.onnx.Reciprocal:6)
  1301 (ai.onnx.Mul:7)
  1302 (ai.onnx.Mul:7)
  1303 (ai.onnx.Sub:7)
  1304 (ai.onnx.Mul:7)
  1305 (ai.onnx.Add:7)
  1306 (ai.onnx.MatMul:9)
  1307 (ai.onnx.Add:7)
  1308 (ai.onnx.Reshape:5)
  1309 (ai.onnx.Transpose:1)
  1310 (ai.onnx.MatMul:9)
  1311 (ai.onnx.Add:7)
  1312 (ai.onnx.Reshape:5)
  1313 (ai.onnx.Transpose:1)
  1314 (ai.onnx.MatMul:9)
  1315 (ai.onnx.Add:7)
  1316 (ai.onnx.Reshape:5)
  1317 (ai.onnx.Transpose:1)
  1318 (ai.onnx.MatMul:9)
  1319 (ai.onnx.Mul:7)
  1320 (ai.onnx.Add:7)
  1321 (ai.onnx.Softmax:11)
  1322 (ai.onnx.MatMul:9)
  1323 (ai.onnx.Transpose:1)
  1324 (ai.onnx.Reshape:5)
  1325 (ai.onnx.MatMul:9)
  1326 (ai.onnx.Add:7)
  1327 (ai.onnx.Add:7)
  1328 (ai.onnx.ReduceMean:11)
  1329 (ai.onnx.Cast:9)
  1330 (ai.onnx.Cast:9)
  1331 (ai.onnx.Sub:7)
  1332 (ai.onnx.Mul:7)
  1333 (ai.onnx.ReduceMean:11)
  1334 (ai.onnx.Cast:9)
  1335 (ai.onnx.Add:7)
  1336 (ai.onnx.Sqrt:6)
  1337 (ai.onnx.Reciprocal:6)
  1338 (ai.onnx.Mul:7)
  1339 (ai.onnx.Mul:7)
  1340 (ai.onnx.Sub:7)
  1341 (ai.onnx.Mul:7)
  1342 (ai.onnx.Add:7)
  1343 (ai.onnx.MatMul:9)
  1344 (ai.onnx.Add:7)
  1345 (ai.onnx.Div:7)
  1346 (ai.onnx.Erf:9)
  1347 (ai.onnx.Add:7)
  1348 (ai.onnx.Mul:7)
  1349 (ai.onnx.Mul:7)
  1350 (ai.onnx.MatMul:9)
  1351 (ai.onnx.Add:7)
  1352 (ai.onnx.Add:7)
  1353 (ai.onnx.ReduceMean:11)
  1354 (ai.onnx.Cast:9)
  1355 (ai.onnx.Cast:9)
  1356 (ai.onnx.Sub:7)
  1357 (ai.onnx.Mul:7)
  1358 (ai.onnx.ReduceMean:11)
  1359 (ai.onnx.Cast:9)
  1360 (ai.onnx.Add:7)
  1361 (ai.onnx.Sqrt:6)
  1362 (ai.onnx.Reciprocal:6)
  1363 (ai.onnx.Mul:7)
  1364 (ai.onnx.Mul:7)
  1365 (ai.onnx.Sub:7)
  1366 (ai.onnx.Mul:7)
  1367 (ai.onnx.Add:7)
  1368 (ai.onnx.MatMul:9)
  1369 (ai.onnx.Add:7)
  1370 (ai.onnx.Reshape:5)
  1371 (ai.onnx.Transpose:1)
  1372 (ai.onnx.MatMul:9)
  1373 (ai.onnx.Add:7)
  1374 (ai.onnx.Reshape:5)
  1375 (ai.onnx.Transpose:1)
  1376 (ai.onnx.MatMul:9)
  1377 (ai.onnx.Add:7)
  1378 (ai.onnx.Reshape:5)
  1379 (ai.onnx.Transpose:1)
  1380 (ai.onnx.MatMul:9)
  1381 (ai.onnx.Mul:7)
  1382 (ai.onnx.Add:7)
  1383 (ai.onnx.Softmax:11)
  1384 (ai.onnx.MatMul:9)
  1385 (ai.onnx.Transpose:1)
  1386 (ai.onnx.Reshape:5)
  1387 (ai.onnx.MatMul:9)
  1388 (ai.onnx.Add:7)
  1389 (ai.onnx.Add:7)
  1390 (ai.onnx.ReduceMean:11)
  1391 (ai.onnx.Cast:9)
  1392 (ai.onnx.Cast:9)
  1393 (ai.onnx.Sub:7)
  1394 (ai.onnx.Mul:7)
  1395 (ai.onnx.ReduceMean:11)
  1396 (ai.onnx.Cast:9)
  1397 (ai.onnx.Add:7)
  1398 (ai.onnx.Sqrt:6)
  1399 (ai.onnx.Reciprocal:6)
  1400 (ai.onnx.Mul:7)
  1401 (ai.onnx.Mul:7)
  1402 (ai.onnx.Sub:7)
  1403 (ai.onnx.Mul:7)
  1404 (ai.onnx.Add:7)
  1405 (ai.onnx.MatMul:9)
  1406 (ai.onnx.Add:7)
  1407 (ai.onnx.Div:7)
  1408 (ai.onnx.Erf:9)
  1409 (ai.onnx.Add:7)
  1410 (ai.onnx.Mul:7)
  1411 (ai.onnx.Mul:7)
  1412 (ai.onnx.MatMul:9)
  1413 (ai.onnx.Add:7)
  1414 (ai.onnx.Add:7)
  1415 (ai.onnx.ReduceMean:11)
  1416 (ai.onnx.Cast:9)
  1417 (ai.onnx.Cast:9)
  1418 (ai.onnx.Sub:7)
  1419 (ai.onnx.Mul:7)
  1420 (ai.onnx.ReduceMean:11)
  1421 (ai.onnx.Cast:9)
  1422 (ai.onnx.Add:7)
  1423 (ai.onnx.Sqrt:6)
  1424 (ai.onnx.Reciprocal:6)
  1425 (ai.onnx.Mul:7)
  1426 (ai.onnx.Mul:7)
  1427 (ai.onnx.Sub:7)
  1428 (ai.onnx.Mul:7)
  1429 (ai.onnx.Add:7)
  1430 (ai.onnx.MatMul:9)
  1431 (ai.onnx.Add:7)
  1432 (ai.onnx.Reshape:5)
  1433 (ai.onnx.Transpose:1)
  1434 (ai.onnx.MatMul:9)
  1435 (ai.onnx.Add:7)
  1436 (ai.onnx.Reshape:5)
  1437 (ai.onnx.Transpose:1)
  1438 (ai.onnx.MatMul:9)
  1439 (ai.onnx.Add:7)
  1440 (ai.onnx.Reshape:5)
  1441 (ai.onnx.Transpose:1)
  1442 (ai.onnx.MatMul:9)
  1443 (ai.onnx.Mul:7)
  1444 (ai.onnx.Add:7)
  1445 (ai.onnx.Softmax:11)
  1446 (ai.onnx.MatMul:9)
  1447 (ai.onnx.Transpose:1)
  1448 (ai.onnx.Reshape:5)
  1449 (ai.onnx.MatMul:9)
  1450 (ai.onnx.Add:7)
  1451 (ai.onnx.Add:7)
  1452 (ai.onnx.ReduceMean:11)
  1453 (ai.onnx.Cast:9)
  1454 (ai.onnx.Cast:9)
  1455 (ai.onnx.Sub:7)
  1456 (ai.onnx.Mul:7)
  1457 (ai.onnx.ReduceMean:11)
  1458 (ai.onnx.Cast:9)
  1459 (ai.onnx.Add:7)
  1460 (ai.onnx.Sqrt:6)
  1461 (ai.onnx.Reciprocal:6)
  1462 (ai.onnx.Mul:7)
  1463 (ai.onnx.Mul:7)
  1464 (ai.onnx.Sub:7)
  1465 (ai.onnx.Mul:7)
  1466 (ai.onnx.Add:7)
  1467 (ai.onnx.MatMul:9)
  1468 (ai.onnx.Add:7)
  1469 (ai.onnx.Div:7)
  1470 (ai.onnx.Erf:9)
  1471 (ai.onnx.Add:7)
  1472 (ai.onnx.Mul:7)
  1473 (ai.onnx.Mul:7)
  1474 (ai.onnx.MatMul:9)
  1475 (ai.onnx.Add:7)
  1476 (ai.onnx.Add:7)
  1477 (ai.onnx.ReduceMean:11)
  1478 (ai.onnx.Cast:9)
  1479 (ai.onnx.Cast:9)
  1480 (ai.onnx.Sub:7)
  1481 (ai.onnx.Mul:7)
  1482 (ai.onnx.ReduceMean:11)
  1483 (ai.onnx.Cast:9)
  1484 (ai.onnx.Add:7)
  1485 (ai.onnx.Sqrt:6)
  1486 (ai.onnx.Reciprocal:6)
  1487 (ai.onnx.Mul:7)
  1488 (ai.onnx.Mul:7)
  1489 (ai.onnx.Sub:7)
  1490 (ai.onnx.Mul:7)
  1491 (ai.onnx.Add:7)
  1492 (ai.onnx.MatMul:9)
  1493 (ai.onnx.Add:7)
  1494 (ai.onnx.Reshape:5)
  1495 (ai.onnx.Transpose:1)
  1496 (ai.onnx.MatMul:9)
  1497 (ai.onnx.Add:7)
  1498 (ai.onnx.Reshape:5)
  1499 (ai.onnx.Transpose:1)
  1500 (ai.onnx.MatMul:9)
  1501 (ai.onnx.Add:7)
  1502 (ai.onnx.Reshape:5)
  1503 (ai.onnx.Transpose:1)
  1504 (ai.onnx.MatMul:9)
  1505 (ai.onnx.Mul:7)
  1506 (ai.onnx.Add:7)
  1507 (ai.onnx.Softmax:11)
  1508 (ai.onnx.MatMul:9)
  1509 (ai.onnx.Transpose:1)
  1510 (ai.onnx.Reshape:5)
  1511 (ai.onnx.MatMul:9)
  1512 (ai.onnx.Add:7)
  1513 (ai.onnx.Add:7)
  1514 (ai.onnx.ReduceMean:11)
  1515 (ai.onnx.Cast:9)
  1516 (ai.onnx.Cast:9)
  1517 (ai.onnx.Sub:7)
  1518 (ai.onnx.Mul:7)
  1519 (ai.onnx.ReduceMean:11)
  1520 (ai.onnx.Cast:9)
  1521 (ai.onnx.Add:7)
  1522 (ai.onnx.Sqrt:6)
  1523 (ai.onnx.Reciprocal:6)
  1524 (ai.onnx.Mul:7)
  1525 (ai.onnx.Mul:7)
  1526 (ai.onnx.Sub:7)
  1527 (ai.onnx.Mul:7)
  1528 (ai.onnx.Add:7)
  1529 (ai.onnx.MatMul:9)
  1530 (ai.onnx.Add:7)
  1531 (ai.onnx.Div:7)
  1532 (ai.onnx.Erf:9)
  1533 (ai.onnx.Add:7)
  1534 (ai.onnx.Mul:7)
  1535 (ai.onnx.Mul:7)
  1536 (ai.onnx.MatMul:9)
  1537 (ai.onnx.Add:7)
  1538 (ai.onnx.Add:7)
  1539 (ai.onnx.ReduceMean:11)
  1540 (ai.onnx.Cast:9)
  1541 (ai.onnx.Cast:9)
  1542 (ai.onnx.Sub:7)
  1543 (ai.onnx.Mul:7)
  1544 (ai.onnx.ReduceMean:11)
  1545 (ai.onnx.Cast:9)
  1546 (ai.onnx.Add:7)
  1547 (ai.onnx.Sqrt:6)
  1548 (ai.onnx.Reciprocal:6)
  1549 (ai.onnx.Mul:7)
  1550 (ai.onnx.Mul:7)
  1551 (ai.onnx.Sub:7)
  1552 (ai.onnx.Mul:7)
  1553 (ai.onnx.Add:7)
  1554 (ai.onnx.MatMul:9)
  1555 (ai.onnx.Add:7)
  1556 (ai.onnx.Reshape:5)
  1557 (ai.onnx.Transpose:1)
  1558 (ai.onnx.MatMul:9)
  1559 (ai.onnx.Add:7)
  1560 (ai.onnx.Reshape:5)
  1561 (ai.onnx.Transpose:1)
  1562 (ai.onnx.MatMul:9)
  1563 (ai.onnx.Add:7)
  1564 (ai.onnx.Reshape:5)
  1565 (ai.onnx.Transpose:1)
  1566 (ai.onnx.MatMul:9)
  1567 (ai.onnx.Mul:7)
  1568 (ai.onnx.Add:7)
  1569 (ai.onnx.Softmax:11)
  1570 (ai.onnx.MatMul:9)
  1571 (ai.onnx.Transpose:1)
  1572 (ai.onnx.Reshape:5)
  1573 (ai.onnx.MatMul:9)
  1574 (ai.onnx.Add:7)
  1575 (ai.onnx.Add:7)
  1576 (ai.onnx.ReduceMean:11)
  1577 (ai.onnx.Cast:9)
  1578 (ai.onnx.Cast:9)
  1579 (ai.onnx.Sub:7)
  1580 (ai.onnx.Mul:7)
  1581 (ai.onnx.ReduceMean:11)
  1582 (ai.onnx.Cast:9)
  1583 (ai.onnx.Add:7)
  1584 (ai.onnx.Sqrt:6)
  1585 (ai.onnx.Reciprocal:6)
  1586 (ai.onnx.Mul:7)
  1587 (ai.onnx.Mul:7)
  1588 (ai.onnx.Sub:7)
  1589 (ai.onnx.Mul:7)
  1590 (ai.onnx.Add:7)
  1591 (ai.onnx.MatMul:9)
  1592 (ai.onnx.Add:7)
  1593 (ai.onnx.Div:7)
  1594 (ai.onnx.Erf:9)
  1595 (ai.onnx.Add:7)
  1596 (ai.onnx.Mul:7)
  1597 (ai.onnx.Mul:7)
  1598 (ai.onnx.MatMul:9)
  1599 (ai.onnx.Add:7)
  1600 (ai.onnx.Add:7)
  1601 (ai.onnx.ReduceMean:11)
  1602 (ai.onnx.Cast:9)
  1603 (ai.onnx.Cast:9)
  1604 (ai.onnx.Sub:7)
  1605 (ai.onnx.Mul:7)
  1606 (ai.onnx.ReduceMean:11)
  1607 (ai.onnx.Cast:9)
  1608 (ai.onnx.Add:7)
  1609 (ai.onnx.Sqrt:6)
  1610 (ai.onnx.Reciprocal:6)
  1611 (ai.onnx.Mul:7)
  1612 (ai.onnx.Mul:7)
  1613 (ai.onnx.Sub:7)
  1614 (ai.onnx.Mul:7)
  1615 (ai.onnx.Add:7)
  1616 (ai.onnx.MatMul:9)
  1617 (ai.onnx.Add:7)
  1618 (ai.onnx.Reshape:5)
  1619 (ai.onnx.Transpose:1)
  1620 (ai.onnx.MatMul:9)
  1621 (ai.onnx.Add:7)
  1622 (ai.onnx.Reshape:5)
  1623 (ai.onnx.Transpose:1)
  1624 (ai.onnx.MatMul:9)
  1625 (ai.onnx.Add:7)
  1626 (ai.onnx.Reshape:5)
  1627 (ai.onnx.Transpose:1)
  1628 (ai.onnx.MatMul:9)
  1629 (ai.onnx.Mul:7)
  1630 (ai.onnx.Add:7)
  1631 (ai.onnx.Softmax:11)
  1632 (ai.onnx.MatMul:9)
  1633 (ai.onnx.Transpose:1)
  1634 (ai.onnx.Reshape:5)
  1635 (ai.onnx.MatMul:9)
  1636 (ai.onnx.Add:7)
  1637 (ai.onnx.Add:7)
  1638 (ai.onnx.ReduceMean:11)
  1639 (ai.onnx.Cast:9)
  1640 (ai.onnx.Cast:9)
  1641 (ai.onnx.Sub:7)
  1642 (ai.onnx.Mul:7)
  1643 (ai.onnx.ReduceMean:11)
  1644 (ai.onnx.Cast:9)
  1645 (ai.onnx.Add:7)
  1646 (ai.onnx.Sqrt:6)
  1647 (ai.onnx.Reciprocal:6)
  1648 (ai.onnx.Mul:7)
  1649 (ai.onnx.Mul:7)
  1650 (ai.onnx.Sub:7)
  1651 (ai.onnx.Mul:7)
  1652 (ai.onnx.Add:7)
  1653 (ai.onnx.MatMul:9)
  1654 (ai.onnx.Add:7)
  1655 (ai.onnx.Div:7)
  1656 (ai.onnx.Erf:9)
  1657 (ai.onnx.Add:7)
  1658 (ai.onnx.Mul:7)
  1659 (ai.onnx.Mul:7)
  1660 (ai.onnx.MatMul:9)
  1661 (ai.onnx.Add:7)
  1662 (ai.onnx.Add:7)
  1663 (ai.onnx.ReduceMean:11)
  1664 (ai.onnx.Cast:9)
  1665 (ai.onnx.Cast:9)
  1666 (ai.onnx.Sub:7)
  1667 (ai.onnx.Mul:7)
  1668 (ai.onnx.ReduceMean:11)
  1669 (ai.onnx.Cast:9)
  1670 (ai.onnx.Add:7)
  1671 (ai.onnx.Sqrt:6)
  1672 (ai.onnx.Reciprocal:6)
  1673 (ai.onnx.Mul:7)
  1674 (ai.onnx.Mul:7)
  1675 (ai.onnx.Sub:7)
  1676 (ai.onnx.Mul:7)
  1677 (ai.onnx.Add:7)
  1678 (ai.onnx.MatMul:9)
  1679 (ai.onnx.Add:7)
  1680 (ai.onnx.Reshape:5)
  1681 (ai.onnx.Transpose:1)
  1682 (ai.onnx.MatMul:9)
  1683 (ai.onnx.Add:7)
  1684 (ai.onnx.Reshape:5)
  1685 (ai.onnx.Transpose:1)
  1686 (ai.onnx.MatMul:9)
  1687 (ai.onnx.Add:7)
  1688 (ai.onnx.Reshape:5)
  1689 (ai.onnx.Transpose:1)
  1690 (ai.onnx.MatMul:9)
  1691 (ai.onnx.Mul:7)
  1692 (ai.onnx.Add:7)
  1693 (ai.onnx.Softmax:11)
  1694 (ai.onnx.MatMul:9)
  1695 (ai.onnx.Transpose:1)
  1696 (ai.onnx.Reshape:5)
  1697 (ai.onnx.MatMul:9)
  1698 (ai.onnx.Add:7)
  1699 (ai.onnx.Add:7)
  1700 (ai.onnx.ReduceMean:11)
  1701 (ai.onnx.Cast:9)
  1702 (ai.onnx.Cast:9)
  1703 (ai.onnx.Sub:7)
  1704 (ai.onnx.Mul:7)
  1705 (ai.onnx.ReduceMean:11)
  1706 (ai.onnx.Cast:9)
  1707 (ai.onnx.Add:7)
  1708 (ai.onnx.Sqrt:6)
  1709 (ai.onnx.Reciprocal:6)
  1710 (ai.onnx.Mul:7)
  1711 (ai.onnx.Mul:7)
  1712 (ai.onnx.Sub:7)
  1713 (ai.onnx.Mul:7)
  1714 (ai.onnx.Add:7)
  1715 (ai.onnx.MatMul:9)
  1716 (ai.onnx.Add:7)
  1717 (ai.onnx.Div:7)
  1718 (ai.onnx.Erf:9)
  1719 (ai.onnx.Add:7)
  1720 (ai.onnx.Mul:7)
  1721 (ai.onnx.Mul:7)
  1722 (ai.onnx.MatMul:9)
  1723 (ai.onnx.Add:7)
  1724 (ai.onnx.Add:7)
  1725 (ai.onnx.ReduceMean:11)
  1726 (ai.onnx.Cast:9)
  1727 (ai.onnx.Cast:9)
  1728 (ai.onnx.Sub:7)
  1729 (ai.onnx.Mul:7)
  1730 (ai.onnx.ReduceMean:11)
  1731 (ai.onnx.Cast:9)
  1732 (ai.onnx.Add:7)
  1733 (ai.onnx.Sqrt:6)
  1734 (ai.onnx.Reciprocal:6)
  1735 (ai.onnx.Mul:7)
  1736 (ai.onnx.Mul:7)
  1737 (ai.onnx.Sub:7)
  1738 (ai.onnx.Mul:7)
  1739 (ai.onnx.Add:7)
  1740 (ai.onnx.Reshape:5)
  1741 (ai.onnx.Slice:11)
  1742 (ai.onnx.Reshape:5)
  1743 (ai.onnx.MatMul:9)
  1744 (ai.onnx.Add:7)
  1745 (ai.onnx.Tanh:6)


In [35]:
len(reload_ns) + len(reload_qt_ns) + len(rest_ns)

1683

In [36]:
len(mk.graph.initializer)

444

In [37]:
mk.graph.node[34]

input: "reload/bert/encoder/mul:0"
output: "reload/bert/encoder/layer_0/attention/self/ExpandDims:0"
name: "reload/bert/encoder/layer_0/attention/self/ExpandDims"
op_type: "Unsqueeze"
attribute {
  name: "axes"
  ints: 1
  type: INTS
}

In [104]:
builder = popart.Builder(m.SerializeToString())

In [71]:
opts = popart.SessionOptions()
opts.virtualGraphMode = popart.VirtualGraphMode.Manual
opts.enablePipelining = True
partials_type = "half"
opts.partialsTypeMatMuls = partials_type
opts.convolutionOptions = {'partialsType': partials_type}

In [105]:
run(builder=builder, opts=opts)

['TensorDict/StandardKvParser_4:0', 'TensorDict/StandardKvParser_1:0', 'TensorDict/StandardKvParser_6:0', 'TensorDict/StandardKvParser_8:0']
[[1, 512], [1, 64], [1, 512], [1, 64]]
['INT32', 'INT32', 'INT32', 'INT32']


popart_exception: For pipelining, depth (batchesPerStep) must equal at least the number of pipeline stages (2)

In [65]:
selected_names = ["reload/bert/embeddings", "reload/bert/encoder/layer_0/", "reload/bert/encoder/layer_1/", "reload/bert/encoder/layer_2/"]

In [66]:
for node in m.graph.node:
    if any([node.name.startswith(s) for s in selected_names]):
        print(node.name)
        # node.attribute[-1].i = 0

reload/bert/embeddings/Reshape_1
reload/bert/encoder/Shape
reload/bert/encoder/Shape__459
reload/bert/encoder/strided_slice
reload/bert/encoder/ones/packed_Concat__467
reload/bert/encoder/ones__468
reload/bert/encoder/ones
reload/bert/encoder/Reshape/shape_Concat__472
reload/bert/encoder/Reshape__734
reload/bert/embeddings/ExpandDims
reload/bert/embeddings/embedding_lookup
reload/bert/embeddings/Shape
reload/bert/embeddings/Shape__473
reload/bert/embeddings/strided_slice
reload/bert/embeddings/Reshape/shape_Concat__481
reload/bert/embeddings/Reshape__482
reload/bert/embeddings/Reshape
reload/bert/embeddings/Shape_1
reload/bert/embeddings/Shape_1__483
reload/bert/embeddings/strided_slice_1
reload/bert/embeddings/Reshape_2/shape_Concat__491
reload/bert/embeddings/Reshape_2__492
reload/bert/encoder/Reshape
reload/bert/encoder/mul
reload/bert/encoder/layer_0/attention/self/ExpandDims
reload/bert/encoder/layer_9/attention/self/sub
reload/bert/encoder/layer_9/attention/self/mul_1
reload/bert

In [73]:
for node in m.graph.node:
    if node.name.startswith("reload/bert/encoder/") and "layer_" not in node.name:
        # print(node.name)
        node.attribute[-1].i = 0

In [84]:
for node in m.graph.node:
    if node.name.startswith("reload/") and "bert" not in node.name:
        # print(node.name)
        node.attribute[-1].i = 0

In [79]:
[(i, n) for i, n in enumerate(m.graph.node) if n.name == "reload_qt/bert/encoder/Reshape_13"]

[(851,
  input: "reload_qt/bert/encoder/layer_11/output/LayerNorm/batchnorm/add_1:0"
  input: "reload_qt/bert/encoder/Reshape_13__1389:0"
  output: "reload_qt/bert/encoder/Reshape_13:0"
  name: "reload_qt/bert/encoder/Reshape_13"
  op_type: "Reshape"
  attribute {
    name: "__ipu_number"
    i: 1
    type: INT
  })]

In [77]:
m.graph.node[1648].attribute[-1].i = 1

In [88]:
[(i, n) for i, n in enumerate(m.graph.node) if n.name == "reload/bert/encoder/layer_0/attention/self/add" ]

[(918,
  input: "reload/bert/encoder/layer_0/attention/self/Mul:0"
  input: "reload/bert/encoder/layer_9/attention/self/mul_1:0"
  output: "reload/bert/encoder/layer_0/attention/self/add:0"
  name: "reload/bert/encoder/layer_0/attention/self/add"
  op_type: "Add"
  attribute {
    name: "__ipu_number"
    i: 0
    type: INT
  })]

In [92]:
m.graph.node[918].attribute[-1].i = 0

In [95]:
[(i, n) for i, n in enumerate(m.graph.node) if n.name == "reload/bert/encoder/layer_9/attention/self/sub" ]

[(35,
  input: "reload/bert/encoder/layer_7/attention/self/sub/x:0"
  input: "reload/bert/encoder/layer_0/attention/self/ExpandDims:0"
  output: "reload/bert/encoder/layer_9/attention/self/sub:0"
  name: "reload/bert/encoder/layer_9/attention/self/sub"
  op_type: "Sub"
  attribute {
    name: "__ipu_number"
    i: 1
    type: INT
  })]

In [96]:
m.graph.node[35].attribute[-1].i = 0

In [97]:
[(i, n) for i, n in enumerate(m.graph.node) if n.name == "reload/bert/encoder/layer_9/attention/self/mul_1" ]

[(36,
  input: "reload/bert/encoder/layer_9/attention/self/sub:0"
  input: "reload/bert/encoder/layer_1/attention/self/mul_1/y:0"
  output: "reload/bert/encoder/layer_9/attention/self/mul_1:0"
  name: "reload/bert/encoder/layer_9/attention/self/mul_1"
  op_type: "Mul"
  attribute {
    name: "__ipu_number"
    i: 1
    type: INT
  })]

In [102]:
for node in m.graph.node:
    if node.name.startswith("reload_qt/bert/pooler/"):
        node.attribute[-1].i = 1

In [98]:
m.graph.node[36].attribute[-1].i = 0

In [85]:
onnx.checker.check_model(m,1)

In [107]:
onnx.save(m, "qtc35-onnx-pipeline/model-on-half.onnx")