diff --git a/libs/onnx-mlir b/libs/onnx-mlir index f6cff23..95edc72 160000 --- a/libs/onnx-mlir +++ b/libs/onnx-mlir @@ -1 +1 @@ -Subproject commit f6cff2380f78e67397c5a262cf0b2034975596fb +Subproject commit 95edc72aa98e9eebe0f48a8d42d92a71fe615b19 diff --git a/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp b/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp index 3b126f8..0332d4d 100644 --- a/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp +++ b/mlir-assigner/include/mlir-assigner/parser/evaluator.hpp @@ -265,6 +265,10 @@ namespace zk_ml_toolchain { auto val = var_value(assignmnt, toConvert).data; nil::blueprint::components::FixedPoint out(val, 16); return out.to_double(); + // auto Lhs = frames.back().locals[mlir::hash_value(operation.getLhs())]; + // auto Rhs = frames.back().locals[mlir::hash_value(operation.getRhs())]; + // auto Result = frames.back().locals[mlir::hash_value(operation.getResult())]; + // std::cout << toFixpoint(Lhs) << " * " << toFixpoint(Rhs) << " = " << toFixpoint(Result) << "\n"; } void handleArithOperation(Operation *op) { @@ -472,14 +476,14 @@ namespace zk_ml_toolchain { typename BlueprintFieldType::value_type field_constant = value; auto val = put_into_assignment(field_constant); - frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val)); + frames.back().locals[mlir::hash_value(operation.getResult())] = val; } else if (constantValue.isa()) { double d = llvm::dyn_cast(constantValue).getValueAsDouble(); nil::blueprint::components::FixedPoint fixed(d); auto value = put_into_assignment(fixed.get_value()); // this insert is ok, since this should never change, so we // don't override it if it is already there - frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), value)); + frames.back().locals[mlir::hash_value(operation.getResult())] = value; } else { logger << constantValue; UNREACHABLE("unhandled constant"); @@ -489,6 +493,7 @@ namespace zk_ml_toolchain { auto opHash = mlir::hash_value(operation->getOperand(0)); Type casteeType = operation->getOperand(0).getType(); if (casteeType.isa()) { + UNREACHABLE("I SHOULD NOT WORK"); auto i = frames.back().locals.find(opHash); assert(i != frames.back().locals.end()); frames.back().constant_values[mlir::hash_value(operation.getResult())] = resolve_number(i->second); @@ -497,7 +502,7 @@ namespace zk_ml_toolchain { assert(index != frames.back().constant_values.end()); typename BlueprintFieldType::value_type field_constant = index->second; auto val = put_into_assignment(field_constant); - frames.back().locals.insert(std::make_pair(mlir::hash_value(operation.getResult()), val)); + frames.back().locals[mlir::hash_value(operation.getResult())] = val; } else { UNREACHABLE("unsupported Index Cast"); } @@ -524,8 +529,6 @@ namespace zk_ml_toolchain { std::uint32_t start_row = assignmnt.allocated_rows(); if (math::ExpOp operation = llvm::dyn_cast(op)) { handle_fixedpoint_exp_component(operation, frames.back(), bp, assignmnt, start_row); - } else if (math::Exp2Op operation = llvm::dyn_cast(op)) { - UNREACHABLE("TODO: component for exp2 not ready"); } else if (math::LogOp operation = llvm::dyn_cast(op)) { handle_fixedpoint_log_component(operation, frames.back(), bp, assignmnt, start_row); } else if (math::PowFOp operation = llvm::dyn_cast(op)) { @@ -874,8 +877,8 @@ namespace zk_ml_toolchain { auto m = nil::blueprint::memref(dims, type.getElementType()); auto hash = mlir::hash_value(operation.getMemref()); auto insert_res = frames.back().memrefs.insert({hash, m}); - assert(insert_res.second); // Reallocating over an existing memref - // should not happen ATM + // assert(insert_res.second); // Reallocating over an existing memref + // should not happen ATM logger.debug("inserting memref with hash %x", size_t(hash)); } else if (memref::AllocaOp operation = llvm::dyn_cast(op)) { // TACEO_TODO: handle cleanup of these stack memrefs diff --git a/mlir-assigner/tests/Ops/Onnx/Split/.ignore b/mlir-assigner/tests/Models/NanoGPT/.ignore similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/Split/.ignore rename to mlir-assigner/tests/Models/NanoGPT/.ignore diff --git a/mlir-assigner/tests/Models/NanoGPT/nanoGPT-new.onnx b/mlir-assigner/tests/Models/NanoGPT/nanoGPT-new.onnx new file mode 100644 index 0000000..9b8fa0a Binary files /dev/null and b/mlir-assigner/tests/Models/NanoGPT/nanoGPT-new.onnx differ diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.mlir b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.mlir deleted file mode 100644 index df0a25f..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.mlir +++ /dev/null @@ -1,37 +0,0 @@ -#map = affine_map<(d0) -> (0, d0 - 1)> -#map1 = affine_map<(d0) -> (5, d0 + 2)> -module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "lrnsimple.mlir"} { - func.func @main_graph(%arg0: memref<5x5x5x5xf32>) -> memref<5x5x5x5xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} { - %cst = arith.constant 0.000000e+00 : f32 - %cst_0 = arith.constant 1.000000e+00 : f32 - %cst_1 = arith.constant 3.33333337E-5 : f32 - %cst_2 = arith.constant 7.500000e-01 : f32 - %alloc = memref.alloc() {alignment = 16 : i64} : memref<5x5x5x5xf32> - affine.for %arg1 = 0 to 5 { - affine.for %arg2 = 0 to 5 { - affine.for %arg3 = 0 to 5 { - affine.for %arg4 = 0 to 5 { - %alloc_3 = memref.alloc() : memref - affine.store %cst, %alloc_3[] : memref - affine.for %arg5 = max #map(%arg2) to min #map1(%arg2) { - %6 = affine.load %arg0[%arg1, %arg5, %arg3, %arg4] : memref<5x5x5x5xf32> - %7 = arith.mulf %6, %6 : f32 - %8 = affine.load %alloc_3[] : memref - %9 = arith.addf %8, %7 : f32 - affine.store %9, %alloc_3[] : memref - } - %0 = affine.load %arg0[%arg1, %arg2, %arg3, %arg4] : memref<5x5x5x5xf32> - %1 = affine.load %alloc_3[] : memref - %2 = arith.mulf %1, %cst_1 : f32 - %3 = arith.addf %2, %cst_0 : f32 - %4 = math.powf %3, %cst_2 : f32 - %5 = arith.divf %0, %4 : f32 - affine.store %5, %alloc[%arg1, %arg2, %arg3, %arg4] : memref<5x5x5x5xf32> - } - } - } - } - return %alloc : memref<5x5x5x5xf32> - } - "krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [5 , 5 , 5 , 5] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [5 , 5 , 5 , 5] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> () -} diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.json b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.json deleted file mode 100644 index 2459e46..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.json +++ /dev/null @@ -1 +0,0 @@ -[{"memref": {"data": [0.25299072265625, 0.5795745849609375, 0.419647216796875, 0.5677490234375, 0.5795135498046875, 0.6259613037109375, 0.6096954345703125, 0.8977203369140625, 0.34307861328125, 0.5006256103515625], "dims": [1, 10], "type": "f32"}}] diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.mlir b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.mlir deleted file mode 100644 index 7f0c664..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.mlir +++ /dev/null @@ -1,17 +0,0 @@ -module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "powpublicbase.mlir"} { - func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_b"], llvm.emit_c_interface, output_names = ["out_a"]} { - %c0 = arith.constant 0 : index - %0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<2.000000e+00> : tensor<1x10xf32>} : () -> memref<1x10xf32> - %alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32> - affine.for %arg1 = 0 to 1 { - affine.for %arg2 = 0 to 10 { - %1 = affine.load %0[%c0, %arg2] : memref<1x10xf32> - %2 = affine.load %arg0[%c0, %arg2] : memref<1x10xf32> - %3 = math.powf %1, %2 : f32 - affine.store %3, %alloc[%arg1, %arg2] : memref<1x10xf32> - } - } - return %alloc : memref<1x10xf32> - } - "krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_b\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> () -} diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.onnx b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.onnx deleted file mode 100644 index dc46445..0000000 Binary files a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.onnx and /dev/null differ diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.res b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.res deleted file mode 100644 index 50583d1..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicBase.res +++ /dev/null @@ -1,3 +0,0 @@ -Result: -memref<1x10xf32>[4, 4, 4, 4, 4, 4, 4, 4, 4, 4] -ADD THE ROWS HERE \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.json b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.json deleted file mode 100644 index c7abcfe..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.json +++ /dev/null @@ -1 +0,0 @@ -[{"memref": {"data": [0.919586181640625, 0.2015838623046875, 0.2564697265625, 0.3241424560546875, 0.3890228271484375, 0.4170989990234375, 0.6596832275390625, 0.7839508056640625, 0.1458587646484375, 0.4071502685546875], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [3, 3, 3, 3, 3, 3, 3, 3, 3, 3], "dims": [1, 10], "type": "f32"}}] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.mlir b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.mlir deleted file mode 100644 index 4c2d8cc..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.mlir +++ /dev/null @@ -1,17 +0,0 @@ -module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "powpublicexponent.mlir"} { - func.func @main_graph(%arg0: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} { - %c0 = arith.constant 0 : index - %0 = "krnl.global"() {name = "constant_0", shape = [1, 10], value = dense<3.000000e+00> : tensor<1x10xf32>} : () -> memref<1x10xf32> - %alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32> - affine.for %arg1 = 0 to 1 { - affine.for %arg2 = 0 to 10 { - %1 = affine.load %arg0[%c0, %arg2] : memref<1x10xf32> - %2 = affine.load %0[%c0, %arg2] : memref<1x10xf32> - %3 = math.powf %1, %2 : f32 - affine.store %3, %alloc[%arg1, %arg2] : memref<1x10xf32> - } - } - return %alloc : memref<1x10xf32> - } - "krnl.entry_point"() {func = @main_graph, numInputs = 1 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> () -} diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.onnx b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.onnx deleted file mode 100644 index 8baeb09..0000000 Binary files a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.onnx and /dev/null differ diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.res b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.res deleted file mode 100644 index e4cf932..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowPublicExponent.res +++ /dev/null @@ -1,3 +0,0 @@ -Result: -memref<1x10xf32>[0.7776377201080322, 0.00819157250225544, 0.016869736835360527, 0.03405710682272911, 0.058874230831861496, 0.07256337255239487, 0.2870822548866272, 0.4817996025085449, 0.0031031130347400904, 0.06749384850263596] -ADD THE ROWS HERE \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.json b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.json deleted file mode 100644 index f1aa3c6..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.json +++ /dev/null @@ -1 +0,0 @@ -[{"memref": {"data": [0.6160430908203125, 0.445709228515625, 0.0120391845703125, 0.0461273193359375, 0.1510009765625, 0.1910400390625, 0.03076171875, 0.043548583984375, 0.2318572998046875, 0.4149627685546875], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [0.05859375, 0.8663330078125, 0.753448486328125, 0.1318359375, 0.8713531494140625, 0.0277557373046875, 0.5159149169921875, 0.480560302734375, 0.127532958984375, 0.01123046875], "dims": [1, 10], "type": "f32"}}] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.mlir b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.mlir deleted file mode 100644 index 9899787..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.mlir +++ /dev/null @@ -1,16 +0,0 @@ -module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "powsimple.mlir"} { - func.func @main_graph(%arg0: memref<1x10xf32>, %arg1: memref<1x10xf32>) -> memref<1x10xf32> attributes {input_names = ["in_a", "in_b"], llvm.emit_c_interface, output_names = ["out_a"]} { - %c0 = arith.constant 0 : index - %alloc = memref.alloc() {alignment = 16 : i64} : memref<1x10xf32> - affine.for %arg2 = 0 to 1 { - affine.for %arg3 = 0 to 10 { - %0 = affine.load %arg0[%c0, %arg3] : memref<1x10xf32> - %1 = affine.load %arg1[%c0, %arg3] : memref<1x10xf32> - %2 = math.powf %0, %1 : f32 - affine.store %2, %alloc[%arg2, %arg3] : memref<1x10xf32> - } - } - return %alloc : memref<1x10xf32> - } - "krnl.entry_point"() {func = @main_graph, numInputs = 2 : i32, numOutputs = 1 : i32, signature = "[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_a\22 }\0A , { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22in_b\22 }\0A\0A]\00@[ { \22type\22 : \22f32\22 , \22dims\22 : [1 , 10] , \22name\22 : \22out_a\22 }\0A\0A]\00"} : () -> () -} diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.res b/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.res deleted file mode 100644 index 7b23894..0000000 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.res +++ /dev/null @@ -1,3 +0,0 @@ -Result: -memref<1x10xf32>[0.9720140099525452, 0.49654868245124817, 0.03579552844166756, 0.6665944457054138, 0.19257567822933197, 0.9550961256027222, 0.16593657433986664, 0.22179152071475983, 0.8299362659454346, 0.9901706576347351] -ADD THE ROWS HERE \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.json b/mlir-assigner/tests/Ops/Onnx/LRN/LRNSimple.json similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.json rename to mlir-assigner/tests/Ops/Onnx/LRN/LRNSimple.json diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.onnx b/mlir-assigner/tests/Ops/Onnx/LRN/LRNSimple.onnx similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.onnx rename to mlir-assigner/tests/Ops/Onnx/LRN/LRNSimple.onnx diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.res b/mlir-assigner/tests/Ops/Onnx/LRN/LRNSimple.res similarity index 99% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.res rename to mlir-assigner/tests/Ops/Onnx/LRN/LRNSimple.res index 6ec8fec..4e07953 100644 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/LRN/LRNSimple.res +++ b/mlir-assigner/tests/Ops/Onnx/LRN/LRNSimple.res @@ -1,3 +1,3 @@ Result: memref<5x5x5x5xf32>[0.7278932332992554, 0.6214843392372131, 0.04486075043678284, 0.5573806762695312, 0.32497990131378174, 0.16202950477600098, 0.10778684914112091, 0.6184303164482117, 0.8605485558509827, 0.5976501703262329, 0.6005876064300537, 0.9977467656135559, 0.6964892148971558, 0.0393524132668972, 0.18226537108421326, 0.9929409623146057, 0.7867917418479919, 0.7318156957626343, 0.7927106022834778, 0.4756576418876648, 0.7881949543952942, 0.873757541179657, 0.9118134379386902, 0.5997382402420044, 0.4006440043449402, 0.3631528615951538, 0.10258357971906662, 0.28429877758026123, 0.49082186818122864, 0.33846741914749146, 0.9258153438568115, 0.6670033931732178, 0.40259668231010437, 0.8632484078407288, 0.22752146422863007, 0.7391752004623413, 0.1604723483324051, 0.534450113773346, 0.07579036802053452, 0.3999868929386139, 0.11204248666763306, 0.12461519986391068, 0.2646084129810333, 0.2809653580188751, 0.4014509320259094, 0.21358677744865417, 0.4826231896877289, 0.32524630427360535, 0.45301175117492676, 0.6632726788520813, 0.18619997799396515, 0.31135398149490356, 0.5822421312332153, 0.2676512598991394, 0.3687865138053894, 0.5280604362487793, 0.8388807773590088, 0.3169954717159271, 0.18903528153896332, 0.05653373897075653, 0.5130502581596375, 0.24725285172462463, 0.8300331830978394, 0.10533127188682556, 0.7131791114807129, 0.05421419441699982, 0.6597151756286621, 0.9105149507522583, 0.8467022776603699, 0.6256607174873352, 0.6421872973442078, 0.15716129541397095, 0.5465914607048035, 0.07407940179109573, 0.31948354840278625, 0.6954917907714844, 0.31794580817222595, 0.9019707441329956, 0.4001818597316742, 0.3037501573562622, 0.2384621948003769, 0.6208935976028442, 0.6906333565711975, 0.6783787608146667, 0.0026855398900806904, 0.26035836338996887, 0.0573573112487793, 0.6743569374084473, 0.1989128440618515, 0.5162548422813416, 0.4428943395614624, 0.6231961250305176, 0.3191256523132324, 0.5480440855026245, 0.35910630226135254, 0.3525032103061676, 0.907708466053009, 0.49446865916252136, 0.9379674792289734, 0.29094821214675903, 0.7614390850067139, 0.05587754026055336, 0.7101358771324158, 0.898904025554657, 0.7967995405197144, 0.15937772393226624, 0.08523476123809814, 0.023345669731497765, 0.443168967962265, 0.31846538186073303, 0.17848160862922668, 0.5192072987556458, 0.2700159549713135, 0.31288042664527893, 0.2847875654697418, 0.6781657934188843, 0.6072272658348083, 0.7508115172386169, 0.7231602072715759, 0.7193944454193115, 0.27032339572906494, 0.325172483921051, 0.2747628092765808, 0.024611886590719223, 0.8308859467506409, 0.5247609615325928, 0.8531917333602905, 0.8896939754486084, 0.8693831562995911, 0.799577534198761, 0.351544052362442, 0.5421406030654907, 0.9846714735031128, 0.8690822124481201, 0.20356711745262146, 0.6575030088424683, 0.5275320410728455, 0.43147990107536316, 0.16633577644824982, 0.4719867408275604, 0.17873863875865936, 0.21566730737686157, 0.5561984181404114, 0.4335269331932068, 0.8640255331993103, 0.5693444013595581, 0.795304000377655, 0.9438491463661194, 0.12722772359848022, 0.4985607862472534, 0.19149304926395416, 0.38307666778564453, 0.7697944641113281, 0.13671599328517914, 0.1972135454416275, 0.489306777715683, 0.12568402290344238, 0.7915934920310974, 0.7271488308906555, 0.17866241931915283, 0.6216742992401123, 0.6942474842071533, 0.7531900405883789, 0.20174995064735413, 0.9558576941490173, 0.7739101052284241, 0.16495957970619202, 0.8771531581878662, 0.5880044102668762, 0.3461625576019287, 0.39127615094184875, 0.32385745644569397, 0.5828137993812561, 0.07311971485614777, 0.3761166036128998, 0.8260281682014465, 0.47406870126724243, 0.0609121210873127, 0.17266660928726196, 0.7752405405044556, 0.7889885902404785, 0.7254471182823181, 0.21520480513572693, 0.197855606675148, 0.736923098564148, 0.33701276779174805, 0.9057586789131165, 0.05430495738983154, 0.528255045413971, 0.2890221178531647, 0.4235239624977112, 0.835067093372345, 0.48812779784202576, 0.16686487197875977, 0.8384672403335571, 0.18324118852615356, 0.7355251908302307, 0.7541176080703735, 0.44770151376724243, 0.9721867442131042, 0.5769413709640503, 0.4298918843269348, 0.19694305956363678, 0.617051362991333, 0.1345927119255066, 0.7520824670791626, 0.6207066178321838, 0.5355159044265747, 0.6852415800094604, 0.5107929110527039, 0.6806468367576599, 0.424528032541275, 0.46807992458343506, 0.030959203839302063, 0.6059139966964722, 0.7529751658439636, 0.8771200180053711, 0.5574277639389038, 0.9417346119880676, 0.2430983930826187, 0.4110446572303772, 0.8356397747993469, 0.14782476425170898, 0.6230294108390808, 0.16062529385089874, 0.8025164008140564, 0.66911780834198, 0.6252984404563904, 0.7077785730361938, 0.9442077279090881, 0.21777011454105377, 0.6641098260879517, 0.3939012587070465, 0.6231403350830078, 0.03832982853055, 0.6593327522277832, 0.7920677661895752, 0.9846814870834351, 0.9253951907157898, 0.6119729280471802, 0.27976539731025696, 0.8153699040412903, 0.9992043972015381, 0.8325477242469788, 0.03155512735247612, 0.9752534627914429, 0.7711237668991089, 0.22843889892101288, 0.7298721075057983, 0.14959698915481567, 0.22184722125530243, 0.028426872566342354, 0.16306829452514648, 0.826932430267334, 0.84047931432724, 0.0020293905399739742, 0.4596202075481415, 0.32154208421707153, 0.2795962393283844, 0.5161988735198975, 0.03225705027580261, 0.05116254836320877, 0.08146639168262482, 0.5507227182388306, 0.4238487482070923, 0.3586706221103668, 0.4213614761829376, 0.9806835651397705, 0.8174975514411926, 0.8365917205810547, 0.007354725617915392, 0.30619651079177856, 0.7518637180328369, 0.8119068145751953, 0.41922467947006226, 0.09361156821250916, 0.5936681628227234, 0.7484479546546936, 0.4122476577758789, 0.5197723507881165, 0.7472350001335144, 0.47742679715156555, 0.8297513127326965, 0.8668643236160278, 0.43973129987716675, 0.1952620893716812, 0.3657649755477905, 0.36372244358062744, 0.7894496321678162, 0.8712038397789001, 0.3811710774898529, 0.8794804215431213, 0.4876638352870941, 0.6612811088562012, 0.3447282314300537, 0.2363416999578476, 0.34162312746047974, 0.36426448822021484, 0.018233943730592728, 0.911026120185852, 0.6428091526031494, 0.1333734691143036, 0.900969922542572, 0.7314406037330627, 0.6574113965034485, 0.6554860472679138, 0.13053575158119202, 0.9948126077651978, 0.03271406516432762, 0.7808041572570801, 0.9527639746665955, 0.4927588105201721, 0.05137522146105766, 0.5756227970123291, 0.5585896372795105, 0.800747275352478, 0.8320572972297668, 0.7954227328300476, 0.2768053412437439, 0.8709347248077393, 0.479422926902771, 0.6942396759986877, 0.4955762028694153, 0.12986749410629272, 0.5718523263931274, 0.6263899803161621, 0.8454587459564209, 0.062315624207258224, 0.6243836879730225, 0.3506953716278076, 0.036406517028808594, 0.8566885590553284, 0.37074580788612366, 0.45027998089790344, 0.9670377969741821, 0.341736763715744, 0.5615126490592957, 0.8598028421401978, 0.30022865533828735, 0.56345134973526, 0.38731542229652405, 0.4687231779098511, 0.17987996339797974, 0.33950570225715637, 0.9052965044975281, 0.6320690512657166, 0.7446611523628235, 0.8660732507705688, 0.05198666453361511, 0.019820580258965492, 0.27898845076560974, 0.012115261517465115, 0.0333862267434597, 0.6596696376800537, 0.9409095644950867, 0.786883533000946, 0.048277921974658966, 0.4397698938846588, 0.6395928859710693, 0.6957913041114807, 0.8080136775970459, 0.4593903422355652, 0.808382511138916, 0.13835106790065765, 0.2738315463066101, 0.787551999092102, 0.27645668387413025, 0.8052689433097839, 0.291166752576828, 0.3153308928012848, 0.26069337129592896, 0.7119103670120239, 0.6567647457122803, 0.0671234056353569, 0.9255020022392273, 0.23857039213180542, 0.8584741353988647, 0.4019520878791809, 0.17483505606651306, 0.43140166997909546, 0.39650893211364746, 0.84713214635849, 0.310851126909256, 0.8375082612037659, 0.5629988312721252, 0.7894384264945984, 0.7872607707977295, 0.0441126823425293, 0.153716579079628, 0.8411281108856201, 0.20248284935951233, 0.5612584948539734, 0.5936828851699829, 0.7252930402755737, 0.7441900372505188, 0.43259409070014954, 0.7825880646705627, 0.8235298991203308, 0.9059191346168518, 0.877435028553009, 0.2704758048057556, 0.06547368317842484, 0.9235319495201111, 0.06953423470258713, 0.8654426336288452, 0.6671262383460999, 0.986158549785614, 0.14294281601905823, 0.2637840211391449, 0.14072978496551514, 0.6559731364250183, 0.4648387134075165, 0.660089373588562, 0.31380489468574524, 0.7879214882850647, 0.4629151225090027, 0.2849327027797699, 0.24317598342895508, 0.6937985420227051, 0.2660629153251648, 0.7345828413963318, 0.6400701999664307, 0.41597503423690796, 0.5626490712165833, 0.9554193615913391, 0.29590561985969543, 0.588887095451355, 0.7627281546592712, 0.048080191016197205, 0.14051547646522522, 0.7140352725982666, 0.9243345856666565, 0.5540916919708252, 0.8576998710632324, 0.48911622166633606, 0.10357453674077988, 0.9536210894584656, 0.9169061779975891, 0.9038464426994324, 0.9054132699966431, 0.5166712999343872, 0.9998325109481812, 0.3714355230331421, 0.2572549879550934, 0.25648409128189087, 0.6456860303878784, 0.52471923828125, 0.3416067957878113, 0.44220805168151855, 0.0031737543176859617, 0.9654694199562073, 0.6050001382827759, 0.919278621673584, 0.4499180316925049, 0.0681602880358696, 0.8617586493492126, 0.16752219200134277, 0.9100353717803955, 0.6804191470146179, 0.2961232662200928, 0.6133525371551514, 0.6846535205841064, 0.6912758946418762, 0.8871551752090454, 0.6880263686180115, 0.1596360206604004, 0.3552289307117462, 0.8890159130096436, 0.8521535396575928, 0.047500304877758026, 0.5418910384178162, 0.6921565532684326, 0.7258775234222412, 0.33736610412597656, 0.12651042640209198, 0.9349095225334167, 0.06442201137542725, 0.04818623512983322, 0.1640157848596573, 0.6333248615264893, 0.11414884775876999, 0.29569923877716064, 0.9087606072425842, 0.802605390548706, 0.4697840213775635, 0.9649648070335388, 0.017806798219680786, 0.21044647693634033, 0.21647192537784576, 0.21426112949848175, 0.25570622086524963, 0.3250560760498047, 0.2621251046657562, 0.3417591452598572, 0.43530064821243286, 0.03060890920460224, 0.9255367517471313, 0.3222452998161316, 0.6089098453521729, 0.22045864164829254, 0.2687370777130127, 0.9711988568305969, 0.6184008121490479, 0.2053210735321045, 0.20724321901798248, 0.23474019765853882, 0.6825464963912964, 0.3703716993331909, 0.9476388692855835, 0.4972802698612213, 0.9452897310256958, 0.9426154494285583, 0.6566086411476135, 0.42493101954460144, 0.4494061768054962, 0.5775091648101807, 0.8260934948921204, 0.7288474440574646, 0.5760579705238342, 0.32612326741218567, 0.27207884192466736, 0.6760498285293579, 0.9127902984619141, 0.8660804629325867, 0.7313582301139832, 0.17770102620124817, 0.08324854075908661, 0.3094593286514282, 0.43767106533050537, 0.5272268652915955, 0.3390115201473236, 0.32055941224098206, 0.6263797283172607, 0.3020850121974945, 0.22321480512619019, 0.27064982056617737, 0.4963389039039612, 0.18685530126094818, 0.9261666536331177, 0.8783203363418579, 0.22251419723033905, 0.3957551419734955, 0.7215035557746887, 0.39923417568206787, 0.48971834778785706, 0.1336180418729782, 0.8970968127250671, 0.2994769811630249, 0.5821309089660645, 0.2440454363822937, 0.7321764230728149, 0.8305322527885437, 0.25053268671035767, 0.9310356974601746, 0.6625810861587524, 0.8958381414413452, 0.9956691265106201, 0.19268515706062317, 0.37126579880714417, 0.8857021927833557, 0.7745158076286316, 0.11367655545473099, 0.5972689390182495, 0.3289646804332733, 0.006927219219505787, 0.682144045829773, 0.616004228591919, 0.5030558109283447, 0.8441178202629089, 0.8211175799369812, 0.9272254109382629, 0.08941292017698288, 0.22637709975242615, 0.6737891435623169, 0.039000775665044785, 0.580018937587738, 0.46313032507896423, 0.2427460253238678, 0.15040172636508942, 0.5328301787376404, 0.0040434845723211765, 0.17439980804920197, 0.39727550745010376, 0.9827837347984314, 0.5315169095993042, 0.6142618060112, 0.4925041198730469, 0.05076461285352707, 0.9144818782806396, 0.8888466358184814, 0.620768666267395, 0.07974159717559814, 0.34225088357925415, 0.7060471177101135, 0.09371748566627502, 0.1316651701927185, 0.8879294991493225, 0.5142166018486023, 0.7842377424240112, 0.8046817779541016, 0.23466262221336365, 0.8932574391365051, 0.9359369874000549, 0.46795371174812317, 0.9637311697006226, 0.3642718493938446, 0.9210003018379211, 0.1938772201538086, 0.9677274227142334, 0.8152120113372803, 0.8667967915534973, 0.23879826068878174, 0.8540950417518616, 0.258508563041687, 0.5231489539146423, 0.6585100293159485, 0.18745402991771698, 0.2615649402141571, 0.2875477075576782, 0.4155407249927521, 0.09849541634321213, 0.5848528742790222, 0.20886071026325226, 0.7990167140960693, 0.6296374797821045] -ADD THE ROWS HERE \ No newline at end of file +11375 diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.json b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.json new file mode 100644 index 0000000..d38cb1b --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.json @@ -0,0 +1 @@ +[{"memref": {"data": [0.30682373046875, 0.9177398681640625, 0.4326171875, 0.9183197021484375, 0.5850677490234375, 0.0698394775390625, 0.4090728759765625, 0.6262054443359375, 0.700897216796875, 0.240814208984375], "dims": [1, 10], "type": "f32"}}] diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.onnx b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.onnx new file mode 100644 index 0000000..d5a3177 Binary files /dev/null and b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.onnx differ diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.res b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.res new file mode 100644 index 0000000..4082705 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicBase.res @@ -0,0 +1,3 @@ +Result: +memref<1x10xf32>[1.2369813341732077, 1.8891534186937171, 1.3496798062693975, 1.8899128414945385, 1.500109432837632, 1.0495998926682173, 1.327832230510485, 1.5434999586846203, 1.6255153918076044, 1.1816593624298102] +90 diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.json b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.json new file mode 100644 index 0000000..728173c --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.json @@ -0,0 +1 @@ +[{"memref": {"data": [0.773895263671875, 0.1213226318359375, 0.6561126708984375, 0.5098114013671875, 0.9202880859375, 0.887786865234375, 0.8052520751953125, 0.0045166015625, 0.202392578125, 0.90814208984375], "dims": [1, 10], "type": "f32"}}] diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.onnx b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.onnx new file mode 100644 index 0000000..60b96ad Binary files /dev/null and b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.onnx differ diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.res b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.res new file mode 100644 index 0000000..1c20f09 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Pow/PowPublicExponent.res @@ -0,0 +1,3 @@ +Result: +memref<1x10xf32>[0.4634966254234314, 0.0017857697093859315, 0.28244590759277344, 0.13250389695167542, 0.7794197201728821, 0.6997230052947998, 0.5221503376960754, 9.213727025780827e-08, 0.008290558122098446, 0.748964786529541] +20 diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowSimple.json b/mlir-assigner/tests/Ops/Onnx/Pow/PowSimple.json new file mode 100644 index 0000000..760dd16 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Pow/PowSimple.json @@ -0,0 +1 @@ +[{"memref": {"data": [0.618316650390625, 0.2113800048828125, 0.17266845703125, 0.4499969482421875, 0.57080078125, 0.42755126953125, 0.326995849609375, 0.157623291015625, 0.4703216552734375, 0.42181396484375], "dims": [1, 10], "type": "f32"}}, {"memref": {"data": [0.988739013671875, 0.0914764404296875, 0.5084228515625, 0.6395263671875, 0.6468505859375, 0.5218505859375, 0.6641387939453125, 0.9370269775390625, 0.58465576171875, 0.564178466796875], "dims": [1, 10], "type": "f32"}}] \ No newline at end of file diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.onnx b/mlir-assigner/tests/Ops/Onnx/Pow/PowSimple.onnx similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/Pow/PowSimple.onnx rename to mlir-assigner/tests/Ops/Onnx/Pow/PowSimple.onnx diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowSimple.res b/mlir-assigner/tests/Ops/Onnx/Pow/PowSimple.res new file mode 100644 index 0000000..bcfdfee --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Pow/PowSimple.res @@ -0,0 +1,3 @@ +Result: +memref<1x10xf32>[0.6216731667518616, 0.8674795627593994, 0.4094318747520447, 0.6000927686691284, 0.6957959532737732, 0.6418461799621582, 0.4759799540042877, 0.1770714968442917, 0.643374502658844, 0.614470899105072] +90 diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.json b/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.json new file mode 100644 index 0000000..2443c93 --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.json @@ -0,0 +1 @@ +[{"memref": {"data": [0.05419921875, 0.2551422119140625, 0.490814208984375, 0.924285888671875, 0.701416015625, 0.7657470703125, 0.064849853515625, 0.7630615234375, 0.920074462890625, 0.2911224365234375], "dims": [1, 10], "type": "f32"}}] diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.onnx b/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.onnx new file mode 100644 index 0000000..ccc3227 Binary files /dev/null and b/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.onnx differ diff --git a/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.res b/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.res new file mode 100644 index 0000000..2a343ae --- /dev/null +++ b/mlir-assigner/tests/Ops/Onnx/Pow/PowSqrt.res @@ -0,0 +1,3 @@ +Result: +memref<1x10xf32>[0.23280726373195648, 0.5051160454750061, 0.7005813121795654, 0.9613978862762451, 0.8375058174133301, 0.8750697374343872, 0.2546563446521759, 0.8735339045524597, 0.9592050909996033, 0.5395576357841492] +20 diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.json b/mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.json similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.json rename to mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.json diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.mlir b/mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.mlir similarity index 58% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.mlir rename to mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.mlir index 17b90b8..8d19a86 100644 --- a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.mlir +++ b/mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.mlir @@ -1,37 +1,39 @@ module attributes {llvm.data_layout = "e-m:e-p270:32:32-p271:32:32-p272:64:64-i64:64-f80:128-n8:16:32:64-S128", llvm.target_triple = "x86_64-pc-linux-gnu", "onnx-mlir.symbol-postfix" = "cumsumsimple.mlir"} { func.func @main_graph(%arg0: memref<1x25xf32>) -> memref<1x25xf32> attributes {input_names = ["in_a"], llvm.emit_c_interface, output_names = ["out_a"]} { - %cst = arith.constant 0.000000e+00 : f32 + %cst = arith.constant 0.693147182 : f32 + %cst_0 = arith.constant 0.000000e+00 : f32 %c0 = arith.constant 0 : index %alloc = memref.alloc() {alignment = 16 : i64} : memref<1x25xf32> - %alloc_0 = memref.alloc() {alignment = 16 : i64} : memref<1x25xf32> + %alloc_1 = memref.alloc() {alignment = 16 : i64} : memref<1x25xf32> affine.for %arg1 = 0 to 1 { affine.for %arg2 = 0 to 25 { %0 = affine.load %arg0[%arg1, %arg2] : memref<1x25xf32> - affine.store %0, %alloc_0[%arg1, %arg2] : memref<1x25xf32> + affine.store %0, %alloc_1[%arg1, %arg2] : memref<1x25xf32> } } affine.for %arg1 = 0 to 5 { %0 = arith.index_cast %arg1 : index to i64 %1 = arith.sitofp %0 : i64 to f32 - %2 = math.exp2 %1 : f32 - %3 = arith.fptosi %2 : f32 to i64 - %4 = arith.index_cast %3 : i64 to index + %2 = arith.mulf %1, %cst : f32 + %3 = math.exp %2 : f32 + %4 = arith.fptosi %3 : f32 to i64 + %5 = arith.index_cast %4 : i64 to index affine.for %arg2 = 0 to 1 { affine.for %arg3 = 0 to 25 { - %5 = affine.load %alloc_0[%arg2, %arg3] : memref<1x25xf32> - %6 = arith.subi %arg3, %4 : index - %7 = arith.cmpi sge, %6, %c0 : index - %8 = arith.select %7, %6, %arg3 : index - %9 = memref.load %alloc_0[%arg2, %8] : memref<1x25xf32> - %10 = arith.select %7, %9, %cst : f32 - %11 = arith.addf %5, %10 : f32 - affine.store %11, %alloc[%arg2, %arg3] : memref<1x25xf32> + %6 = affine.load %alloc_1[%arg2, %arg3] : memref<1x25xf32> + %7 = arith.subi %arg3, %5 : index + %8 = arith.cmpi sge, %7, %c0 : index + %9 = arith.select %8, %7, %arg3 : index + %10 = memref.load %alloc_1[%arg2, %9] : memref<1x25xf32> + %11 = arith.select %8, %10, %cst_0 : f32 + %12 = arith.addf %6, %11 : f32 + affine.store %12, %alloc[%arg2, %arg3] : memref<1x25xf32> } } affine.for %arg2 = 0 to 1 { affine.for %arg3 = 0 to 25 { - %5 = affine.load %alloc[%arg2, %arg3] : memref<1x25xf32> - affine.store %5, %alloc_0[%arg2, %arg3] : memref<1x25xf32> + %6 = affine.load %alloc[%arg2, %arg3] : memref<1x25xf32> + affine.store %6, %alloc_1[%arg2, %arg3] : memref<1x25xf32> } } } diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.onnx b/mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.onnx similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.onnx rename to mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.onnx diff --git a/mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.res b/mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.res similarity index 100% rename from mlir-assigner/tests/Ops/NeedsBlueprintComponent/CumSum/CumSumSimple.res rename to mlir-assigner/tests/Ops/Problematic/CumSum/CumSumSimple.res diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.json b/mlir-assigner/tests/Ops/Problematic/OneHot/OneHotFloat.json similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.json rename to mlir-assigner/tests/Ops/Problematic/OneHot/OneHotFloat.json diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.onnx b/mlir-assigner/tests/Ops/Problematic/OneHot/OneHotFloat.onnx similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.onnx rename to mlir-assigner/tests/Ops/Problematic/OneHot/OneHotFloat.onnx diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.res b/mlir-assigner/tests/Ops/Problematic/OneHot/OneHotFloat.res similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/OneHot/OneHotFloat.res rename to mlir-assigner/tests/Ops/Problematic/OneHot/OneHotFloat.res diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.json b/mlir-assigner/tests/Ops/Problematic/OneHot/OneHotSimple.json similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.json rename to mlir-assigner/tests/Ops/Problematic/OneHot/OneHotSimple.json diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.onnx b/mlir-assigner/tests/Ops/Problematic/OneHot/OneHotSimple.onnx similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.onnx rename to mlir-assigner/tests/Ops/Problematic/OneHot/OneHotSimple.onnx diff --git a/mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.res b/mlir-assigner/tests/Ops/Problematic/OneHot/OneHotSimple.res similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/OneHot/OneHotSimple.res rename to mlir-assigner/tests/Ops/Problematic/OneHot/OneHotSimple.res diff --git a/mlir-assigner/tests/Ops/Onnx/ReverseSequence/ReverseSequenceSimple.json b/mlir-assigner/tests/Ops/Problematic/ReverseSequence/ReverseSequenceSimple.json similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/ReverseSequence/ReverseSequenceSimple.json rename to mlir-assigner/tests/Ops/Problematic/ReverseSequence/ReverseSequenceSimple.json diff --git a/mlir-assigner/tests/Ops/Onnx/ReverseSequence/ReverseSequenceSimple.onnx b/mlir-assigner/tests/Ops/Problematic/ReverseSequence/ReverseSequenceSimple.onnx similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/ReverseSequence/ReverseSequenceSimple.onnx rename to mlir-assigner/tests/Ops/Problematic/ReverseSequence/ReverseSequenceSimple.onnx diff --git a/mlir-assigner/tests/Ops/Onnx/ReverseSequence/ReverseSequenceSimple.res b/mlir-assigner/tests/Ops/Problematic/ReverseSequence/ReverseSequenceSimple.res similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/ReverseSequence/ReverseSequenceSimple.res rename to mlir-assigner/tests/Ops/Problematic/ReverseSequence/ReverseSequenceSimple.res diff --git a/mlir-assigner/tests/Ops/Onnx/Trilu/TriluOffset.json b/mlir-assigner/tests/Ops/Problematic/Trilu/TriluOffset.json similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/Trilu/TriluOffset.json rename to mlir-assigner/tests/Ops/Problematic/Trilu/TriluOffset.json diff --git a/mlir-assigner/tests/Ops/Onnx/Trilu/TriluOffset.onnx b/mlir-assigner/tests/Ops/Problematic/Trilu/TriluOffset.onnx similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/Trilu/TriluOffset.onnx rename to mlir-assigner/tests/Ops/Problematic/Trilu/TriluOffset.onnx diff --git a/mlir-assigner/tests/Ops/Onnx/Trilu/TriluOffset.res b/mlir-assigner/tests/Ops/Problematic/Trilu/TriluOffset.res similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/Trilu/TriluOffset.res rename to mlir-assigner/tests/Ops/Problematic/Trilu/TriluOffset.res diff --git a/mlir-assigner/tests/Ops/Onnx/Trilu/TriluSimple.json b/mlir-assigner/tests/Ops/Problematic/Trilu/TriluSimple.json similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/Trilu/TriluSimple.json rename to mlir-assigner/tests/Ops/Problematic/Trilu/TriluSimple.json diff --git a/mlir-assigner/tests/Ops/Onnx/Trilu/TriluSimple.onnx b/mlir-assigner/tests/Ops/Problematic/Trilu/TriluSimple.onnx similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/Trilu/TriluSimple.onnx rename to mlir-assigner/tests/Ops/Problematic/Trilu/TriluSimple.onnx diff --git a/mlir-assigner/tests/Ops/Onnx/Trilu/TriluSimple.res b/mlir-assigner/tests/Ops/Problematic/Trilu/TriluSimple.res similarity index 100% rename from mlir-assigner/tests/Ops/Onnx/Trilu/TriluSimple.res rename to mlir-assigner/tests/Ops/Problematic/Trilu/TriluSimple.res diff --git a/zkml-onnx-compiler/CMakeLists.txt b/zkml-onnx-compiler/CMakeLists.txt index 0679561..533b2f5 100644 --- a/zkml-onnx-compiler/CMakeLists.txt +++ b/zkml-onnx-compiler/CMakeLists.txt @@ -1,3 +1,28 @@ cmake_minimum_required(VERSION 3.19.0) -add_subdirectory(src) +add_onnx_mlir_library(ZkMLIRTransform + src/Passes/mlir/Transform/PowFToGenericExpPass.cpp + src/Passes/mlir/Transform/ElimCopySignPass.cpp + + LINK_LIBS PRIVATE + + MLIRTransforms + MLIRFuncDialect + MLIRMathDialect + ) + +add_onnx_mlir_executable(zkml-onnx-compiler + src/zkml-onnx-compiler.cpp + + LINK_LIBS PRIVATE + + OMCompilerOptions + OMCompilerUtils + ZkMLIRTransform + + MLIRZkMlDialect + MLIROpenMPToLLVMIRTranslation + ) + +target_include_directories(ZkMLIRTransform PUBLIC include) +target_include_directories(zkml-onnx-compiler PUBLIC include) diff --git a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySign.cpp.inc b/zkml-onnx-compiler/include/Passes/mlir/Transform/ElimCopySign.cpp.inc similarity index 100% rename from zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySign.cpp.inc rename to zkml-onnx-compiler/include/Passes/mlir/Transform/ElimCopySign.cpp.inc diff --git a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySign.td b/zkml-onnx-compiler/include/Passes/mlir/Transform/ElimCopySign.td similarity index 90% rename from zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySign.td rename to zkml-onnx-compiler/include/Passes/mlir/Transform/ElimCopySign.td index 10168bd..de99be8 100644 --- a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySign.td +++ b/zkml-onnx-compiler/include/Passes/mlir/Transform/ElimCopySign.td @@ -14,4 +14,4 @@ def ElimCopySign : Pat< [(HasOneUse:$a)] >; -#endif // ZK_ML_TOOLCHAIN_ELIM_COPY_SIGN_TD \ No newline at end of file +#endif // ZK_ML_TOOLCHAIN_ELIM_COPY_SIGN_TD diff --git a/zkml-onnx-compiler/include/Passes/mlir/Transform/ElimCopySignPass.hpp b/zkml-onnx-compiler/include/Passes/mlir/Transform/ElimCopySignPass.hpp new file mode 100644 index 0000000..4ac2990 --- /dev/null +++ b/zkml-onnx-compiler/include/Passes/mlir/Transform/ElimCopySignPass.hpp @@ -0,0 +1,28 @@ +#ifndef ZK_ML_TOOLCHAIN_ELIM_COPY_SIGN_PASS +#define ZK_ML_TOOLCHAIN_ELIM_COPY_SIGN_PASS + +#include "mlir/Pass/Pass.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +namespace mlir { + namespace zk_ml { + class ElimCopySignPass : public mlir::PassWrapper> { + private: + void runOnOperation() override; + + StringRef getArgument() const final { + return "elim-copysign-pass"; + } + StringRef getDescription() const final { + return "Eliminates redundant copysign operations that follow an frem operation"; + } + }; + std::unique_ptr createElimCopySignPass(); +#include "ElimCopySign.cpp.inc" + + } // namespace zk_ml +} // namespace mlir + +#endif // ZK_ML_TOOLCHAIN_ELIM_COPY_SIGN_PASS diff --git a/zkml-onnx-compiler/include/Passes/mlir/Transform/PowFToGenericExpPass.hpp b/zkml-onnx-compiler/include/Passes/mlir/Transform/PowFToGenericExpPass.hpp new file mode 100644 index 0000000..fcfc3ba --- /dev/null +++ b/zkml-onnx-compiler/include/Passes/mlir/Transform/PowFToGenericExpPass.hpp @@ -0,0 +1,25 @@ +#ifndef ZK_ML_TOOLCHAIN_POWF_TO_GENERIC_EXP_PASS +#define ZK_ML_TOOLCHAIN_POWF_TO_GENERIC_EXP_PASS + +#include "mlir/Pass/Pass.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/Math/IR/Math.h" + +namespace mlir { + namespace zk_ml { + class PowFToGenericExpPass : public mlir::PassWrapper> { + private: + void runOnOperation() override; + + StringRef getArgument() const final { + return "powf-to-generic-exp-pass"; + } + + StringRef getDescription() const final { + return "Rewrites powf calls to generic form and transforms to simpler terms when exponent is {0.5, 2, 3}"; + } + }; + std::unique_ptr createPowFToGenericExpPass(); + } // namespace zk_ml +} // namespace mlir +#endif // ZK_ML_TOOLCHAIN_POWF_TO_GENERIC_EXP_PASS diff --git a/zkml-onnx-compiler/src/CMakeLists.txt b/zkml-onnx-compiler/src/CMakeLists.txt deleted file mode 100644 index 1d400b1..0000000 --- a/zkml-onnx-compiler/src/CMakeLists.txt +++ /dev/null @@ -1,16 +0,0 @@ -add_subdirectory(Passes) - -add_onnx_mlir_executable(zkml-onnx-compiler - zkml-onnx-compiler.cpp - - LINK_LIBS PRIVATE - - OMCompilerOptions - OMCompilerUtils - ZkMLIRAnalysis - ZkMLIRConversion - ZkMLIRTransform - - MLIRZkMlDialect - MLIROpenMPToLLVMIRTranslation - ) diff --git a/zkml-onnx-compiler/src/Passes/CMakeLists.txt b/zkml-onnx-compiler/src/Passes/CMakeLists.txt deleted file mode 100644 index 5a46e22..0000000 --- a/zkml-onnx-compiler/src/Passes/CMakeLists.txt +++ /dev/null @@ -1 +0,0 @@ -add_subdirectory(mlir) diff --git a/zkml-onnx-compiler/src/Passes/mlir/Analysis/CMakeLists.txt b/zkml-onnx-compiler/src/Passes/mlir/Analysis/CMakeLists.txt deleted file mode 100644 index abd0d46..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Analysis/CMakeLists.txt +++ /dev/null @@ -1,15 +0,0 @@ -add_onnx_mlir_library(ZkMLIRAnalysis - PrintPass.cpp - CountPass.cpp - - LINK_LIBS PRIVATE - - OMSpecializedKernelOpInterface - OMCompilerOptions - OMONNXOps - OMSupport - MLIRTransforms - MLIRAffineUtils - ) - - diff --git a/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.cpp deleted file mode 100644 index b31b9fa..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.cpp +++ /dev/null @@ -1,276 +0,0 @@ -#include "CountPass.h" - -int64_t zk_ml::evalAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols) { - int64_t lhs = 0, rhs = 0; - if (auto bin = expr.dyn_cast()) { - lhs = evalAffineExpr(bin.getLHS(), dims, symbols); - rhs = evalAffineExpr(bin.getRHS(), dims, symbols); - } - switch (expr.getKind()) { - case AffineExprKind::Add: - return lhs + rhs; - case AffineExprKind::Mul: - return lhs * rhs; - case AffineExprKind::Mod: - return mod(lhs, rhs); - case AffineExprKind::FloorDiv: - return floorDiv(lhs, rhs); - case AffineExprKind::CeilDiv: - return ceilDiv(lhs, rhs); - case AffineExprKind::Constant: - return expr.cast().getValue(); - case AffineExprKind::DimId: - return dims[expr.cast().getPosition()]; - case AffineExprKind::SymbolId: - return symbols[expr.cast().getPosition()]; - default: - llvm_unreachable("must be one of AffineExprKind"); - } -} - -bool zk_ml::evalIntegerSet(IntegerSet set, ArrayRef dims, ArrayRef symbols) { - // according to mlir/lib/IR/IntegerSetDetail.h constraints are either - // an equality (affine_expr == 0) or an inequality (affine_expr >= 0). - // Nevertheless, according to https://mlir.llvm.org/docs/Dialects/Affine/ - // a Constraint in an IntegerSet may be one of: - // affine_expr <= affine_expr - // affine_expr == affine_expr - // affine_expr >= affine_expr - // we have to stick to code anyway but somehow strange - ArrayRef constraints = set.getConstraints(); - for (unsigned i = 0; i < constraints.size(); ++i) { - int64_t constraint = evalAffineExpr(constraints[i], dims, symbols); - if (set.isEq(i)) { - llvm::outs() << "we have a equality????\n"; - exit(-1); - } else { - if (constraint < 0) { - return false; - } - } - } - return true; -} -bool zk_ml::evalIntegerSet(IntegerSet set, ArrayRef operands) { - return evalIntegerSet(set, operands.take_front(set.getNumDims()), operands.drop_front(set.getNumDims())); -} -SmallVector zk_ml::evalAffineMap(AffineMap map, ArrayRef dims, ArrayRef symbols) { - SmallVector result; - for (auto expr : map.getResults()) { - result.push_back(evalAffineExpr(expr, dims, symbols)); - } - return result; -} - -llvm::SmallVector zk_ml::evalAffineMap(AffineMap map, ArrayRef operands) { - return evalAffineMap(map, operands.take_front(map.getNumDims()), operands.drop_front(map.getNumDims())); -} - -// END COPY - -StringRef zk_ml::CountPass::getArgument() const { - return "count-pass"; -} -StringRef zk_ml::CountPass::getDescription() const { - return "Does some counting - lets see what"; -} -void zk_ml::CountPass::runOnOperation() { - Operation *op = getOperation(); - countDepth(op); - for (auto elem : this->counter) { - llvm::outs() << elem.first << ": " << elem.second << "\n"; - } -} - -template -T zk_ml::CountPass::castFromAttr(Attribute attr) { - T result = llvm::dyn_cast(attr); - assert(result); - return result; -} - -int64_t zk_ml::CountPass::getMaxFromVector(llvm::SmallVector v) { - assert(!v.empty()); - int64_t currentMax = v[0]; - for (unsigned i = 1; i < v.size(); ++i) { - if (currentMax < v[i]) - currentMax = v[i]; - } - return currentMax; -} -int64_t zk_ml::CountPass::getMinFromVector(llvm::SmallVector v) { - assert(!v.empty()); - int64_t currentMin = v[0]; - for (unsigned i = 1; i < v.size(); ++i) { - if (currentMin > v[i]) - currentMin = v[i]; - } - return currentMin; -} - -void zk_ml::CountPass::printIndent(unsigned offset) { - if (DEBUG_FLAG) { - assert(indent >= offset); - for (unsigned i = 0; i < indent - offset; ++i) - llvm::outs() << " "; - } -} - -void zk_ml::CountPass::doAffineFor(Operation *op, int64_t from, int64_t to, int64_t step) { - assert(from < to); - assert(step); - assert(op->getRegions().size() == 1); - assert(op->getRegions()[0].hasOneBlock()); - assert(op->getRegions()[0].getArguments().size() == 1); - printIndent(); - DEBUG("for (" << from << "->" << to << " step " << step << ")"); - indent++; - llvm::hash_code counterHash = hash_value(op->getRegions()[0].getArguments()[0]); - DEBUG("inserting hash: " << counterHash << ":" << from); - this->values.insert(std::make_pair(counterHash, from)); - while (from < to) { - for (Region ®ion : op->getRegions()) - countRegion(region); - from += step; - DEBUG("updating hash: " << counterHash << ":" << from); - this->values.insert(std::make_pair(counterHash, from)); - printIndent(1); - DEBUG(from << "->" << to); - DEBUG("for done! go next iteration.."); - } - this->values.erase(counterHash); - DEBUG("deleting: " << counterHash); - indent--; -} - -template -void zk_ml::CountPass::printSmallvector(llvm::SmallVector &v) { - if (DEBUG_FLAG) { - llvm::outs() << "v["; - for (auto c : v) - llvm::outs() << c << ","; - llvm::outs() << "]\n"; - } -} - -int64_t zk_ml::CountPass::evaluateForParameter(AffineMap &affineMap, llvm::SmallVector &operands, bool from) { - if (affineMap.isConstant()) { - return affineMap.getResult(0).cast().getValue(); - } else { - assert(affineMap.getNumInputs() == operands.size()); - llvm::SmallVector inVector(affineMap.getNumInputs()); - for (unsigned i = 0; i < affineMap.getNumInputs(); ++i) { - llvm::hash_code hash = hash_value(operands[i]); - DEBUG("looking for: " << hash); - if (values.find(hash) == values.end()) { - DEBUG(affineMap); - DEBUG("CANNOT FIND " << hash_value(operands[i])); - DEBUG("CANNOT FIND " << operands[i]); - exit(0); - } else { - assert(values.find(hash) != values.end()); - assert(values.count(hash)); - inVector[i] = this->values[hash]; - } - } - llvm::SmallVector eval = evalAffineMap(affineMap, inVector); - return from ? getMaxFromVector(eval) : getMinFromVector(eval); - } -} - -void zk_ml::CountPass::countDepth(Operation *op) { - // Print the operation itself and some of its properties - // Print the operation attributes - std::string opName = op->getName().getIdentifier().str(); - // printIndent(); - // DEBUG("visiting " << opName); - if (opName == AFFINE_FOR) { - DEBUG("visiting affine for!"); - assert(op->getAttrs().size() == 3); - AffineMap fromMap = castFromAttr(op->getAttrs()[0].getValue()).getAffineMap(); - int64_t step = llvm::dyn_cast(op->getAttrs()[1].getValue()).getInt(); - AffineMap toMap = castFromAttr(op->getAttrs()[2].getValue()).getAffineMap(); - assert(fromMap.getNumInputs() + toMap.getNumInputs() == op->getNumOperands()); - llvm::SmallVector operandsFrom(op->getOperands().begin(), - op->getOperands().begin() + fromMap.getNumInputs()); - llvm::SmallVector operandsTo(op->getOperands().begin() + fromMap.getNumInputs(), - op->getOperands().end()); - int64_t from = evaluateForParameter(fromMap, operandsFrom, true); - int64_t to = evaluateForParameter(toMap, operandsTo, false); - doAffineFor(op, from, to, step); - } else if (opName == AFFINE_IF) { - DEBUG("visiting affine if!"); - assert(op->getAttrs().size() == 1); - IntegerSet condition = castFromAttr(op->getAttrs()[0].getValue()).getValue(); - // IntegerSet condition = op->getAttrs()[0].getValue(); - // assert(op->getNumOperands() == condition.getNumInputs()); - llvm::SmallVector operands(op->getNumOperands()); - DEBUG(op->getAttrs()[0].getValue()); - int i = 0; - for (auto operand : op->getOperands()) { - llvm::hash_code hash = hash_value(operand); - assert(values.find(hash) != values.end()); - assert(values.count(hash)); - int64_t test = this->values[hash]; - operands[i++] = test; - } - if (evalIntegerSet(condition, operands)) { - countRegion(op->getRegion(0)); - } else { - countRegion(op->getRegion(1)); - } - } else if (opName == "affine.apply" || opName == "affine.min") { - DEBUG("got affine.apply"); - assert(op->getResults().size() == 1); - assert(op->getAttrs().size() == 1); - AffineMap applyMap = castFromAttr(op->getAttrs()[0].getValue()).getAffineMap(); - llvm::SmallVector operands(op->getOperands().begin(), op->getOperands().end()); - int64_t result = evaluateForParameter(applyMap, operands, false); - values.insert(std::make_pair(hash_value(op->getResults()[0]), result)); - } else if (opName == "affine.max") { - DEBUG("got affine.apply"); - assert(op->getResults().size() == 1); - assert(op->getAttrs().size() == 1); - AffineMap applyMap = castFromAttr(op->getAttrs()[0].getValue()).getAffineMap(); - llvm::SmallVector operands(op->getOperands().begin(), op->getOperands().end()); - int64_t result = evaluateForParameter(applyMap, operands, true); - values.insert(std::make_pair(hash_value(op->getResults()[0]), result)); - } else if (opName == ARITH_CONST) { - assert(op->getNumResults() == 1); - assert(op->getAttrs().size() == 1); - Attribute contantValue = op->getAttrs()[0].getValue(); - if (contantValue.isa()) { - int64_t value = llvm::dyn_cast(contantValue).getInt(); - values.insert(std::make_pair(hash_value(op->getResult(0)), value)); - } else { - DEBUG("ignoring non int constant"); - } - } else { - auto operationIter = this->counter.find(opName); - if (operationIter != this->counter.end()) { - (*operationIter).second++; - // std::cout << "increasing " << opName << std::endl; - } else { - this->counter.insert(std::make_pair(opName, 1)); - // std::cout << "inserting " << opName << std::endl; - } - - // Recurse into each of the regions attached to the operation. - for (Region ®ion : op->getRegions()) - countRegion(region); - } -} - -void zk_ml::CountPass::countRegion(Region ®ion) { - for (Block &block : region.getBlocks()) - countBlock(block); -} - -void zk_ml::CountPass::countBlock(Block &block) { - for (Operation &op : block.getOperations()) - countDepth(&op); -} - -std::unique_ptr zk_ml::createCountPass() { - return std::make_unique(); -} diff --git a/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.h b/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.h deleted file mode 100644 index 879552f..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Analysis/CountPass.h +++ /dev/null @@ -1,72 +0,0 @@ -#pragma once - -#include "mlir/Pass/Pass.h" -#include "mlir/IR/Attributes.h" -#include "mlir/IR/IntegerSet.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/AffineExpr.h" -#include "mlir/Support/MathExtras.h" -#include -#include -#include -using namespace mlir; - -#define AFFINE_FOR "affine.for" -#define AFFINE_IF "affine.if" -#define ARITH_CONST "arith.constant" - -#define DEBUG_FLAG false -#define DEBUG(X) \ - if (DEBUG_FLAG) \ - llvm::outs() << X << "\n" - -namespace zk_ml -{ - - // TODO link to mlir-hlo so that we do not have to copy-paste - - int64_t evalAffineExpr(AffineExpr expr, ArrayRef dims, ArrayRef symbols); - bool evalIntegerSet(IntegerSet set, ArrayRef dims, ArrayRef symbols); - bool evalIntegerSet(IntegerSet set, ArrayRef operands); - SmallVector evalAffineMap(AffineMap map, ArrayRef dims, ArrayRef symbols); - llvm::SmallVector evalAffineMap(AffineMap map, ArrayRef operands); - - // END COPY - - class CountPass - : public mlir::PassWrapper> - { - private: - unsigned indent = 0; - std::unordered_map counter; - std::map values; - - StringRef getArgument() const final; - StringRef getDescription() const final; - void runOnOperation() override; - - template - T castFromAttr(Attribute attr); - - int64_t getMaxFromVector(llvm::SmallVector v); - int64_t getMinFromVector(llvm::SmallVector v); - - void printIndent(unsigned offset = 0); - - void doAffineFor(Operation *op, int64_t from, int64_t to, int64_t step); - - template - void printSmallvector(llvm::SmallVector &v); - - int64_t evaluateForParameter(AffineMap &affineMap, llvm::SmallVector &operands, bool from); - - void countDepth(Operation *op); - - void countRegion(Region ®ion); - - void countBlock(Block &block); - }; - - std::unique_ptr createCountPass(); -} // namespace zk_ml - diff --git a/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.cpp deleted file mode 100644 index 3de8cc3..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.cpp +++ /dev/null @@ -1,120 +0,0 @@ - -#include "PrintPass.h" - -StringRef zk_ml_toolchain::PrintPass::getArgument() const { - return "print-pass"; -} - -StringRef zk_ml_toolchain::PrintPass::getDescription() const { - return "Prints some Debug Information (copied from Tutorial)"; -} - -void zk_ml_toolchain::PrintPass::runOnOperation() { - Operation *op = getOperation(); - resetIndent(); - std::vector typeIds; - printOperation(op, typeIds); -} - -void zk_ml_toolchain::PrintPass::printVector(std::vector &typeIds) { - std::cout << "["; - for (auto element : typeIds) { - std::cout << element << ", "; - } - std::cout << "]" << std::endl; -} - -void zk_ml_toolchain::PrintPass::printOperation(Operation *op, std::vector &typeIds) { - // Print the operation itself and some of its properties - std::string opName = op->getName().getIdentifier().str(); - if (opName == "krnl.gloabl") { - printIndent() << "visiting: krnl.global"; - return; - } - unsigned numOperands = op->getNumOperands(); - unsigned numResults = op->getNumResults(); - printIndent() << "visiting op: '" << op->getName() << "' with " << numOperands << " operands and " << numResults - << " results\n"; - // Print the operation attributes - if (!op->getAttrs().empty()) { - printIndent() << op->getAttrs().size() << " attributes:\n"; - for (NamedAttribute attr : op->getAttrs()) - printIndent() << " - '" << attr.getName().getValue() << "' : '" << attr.getValue() << "'\n"; - } - - // Recurse into each of the regions attached to the operation. - printIndent() << " " << op->getNumRegions() << " nested regions:\n"; - if (opName == "arith.constant") { - if (numResults != 1) { - std::cout << "whaaaat" << std::endl; - exit(0); - } - llvm::hash_code hash = hash_value(op->getResults()[0]); - std::cout << hash << std::endl; - printVector(typeIds); - if (std::find(typeIds.begin(), typeIds.end(), hash) != typeIds.end()) { - std::cout << "whaaaaaaaaat already in vector" << std::endl; - std::cout << *(std::find(typeIds.begin(), typeIds.end(), hash)) << std::endl; - } else { - typeIds.emplace_back(hash); - } - } else if (opName == "affine.for" && numOperands > 0) { - OperandRange operands = op->getOperands(); - for (uint64_t i = 0; i < operands.size(); ++i) { - llvm::hash_code hash = hash_value(operands[i].getType()); - if (std::find(typeIds.begin(), typeIds.end(), hash) == typeIds.end()) { - std::cout << "whaaaaaaaaat not in vector" << std::endl; - exit(0); - } - } - } - auto indent = pushIndent(); - for (Region ®ion : op->getRegions()) - printRegion(region, typeIds); -} - -void zk_ml_toolchain::PrintPass::printRegion(Region ®ion, std::vector &typeIds) { - // A region does not hold anything by itself other than a list of blocks. - printIndent() << "Region with " << region.getBlocks().size() << " blocks:\n"; - auto indent = pushIndent(); - for (Block &block : region.getBlocks()) - printBlock(block, typeIds); -} - -void zk_ml_toolchain::PrintPass::printBlock(Block &block, std::vector &typeIds) { - // Print the block intrinsics properties (basically: argument list) - printIndent() << "Block with " << block.getNumArguments() << " arguments, " << block.getNumSuccessors() - << " successors, and " - // Note, this `.size()` is traversing a linked-list and is O(n). - << block.getOperations().size() << " operations\n"; - - // Block main role is to hold a list of Operations: let's recurse. - auto indent = pushIndent(); - for (Operation &op : block.getOperations()) - printOperation(&op, typeIds); -} - -zk_ml_toolchain::PrintPass::IdentRAII::IdentRAII(int &indent) : indent(indent) { -} - -zk_ml_toolchain::PrintPass::IdentRAII::~IdentRAII() { - --indent; -} - -void zk_ml_toolchain::PrintPass::resetIndent() { - indent = 0; -} - -zk_ml_toolchain::PrintPass::IdentRAII zk_ml_toolchain::PrintPass::pushIndent() { - return IdentRAII(++indent); -} - -llvm::raw_ostream &zk_ml_toolchain::PrintPass::printIndent() { - for (int i = 0; i < indent; ++i) - llvm::outs() << " "; - return llvm::outs(); -} - -std::unique_ptr zk_ml_toolchain::createPrintPass() { - return std::make_unique(); -} diff --git a/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.h b/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.h deleted file mode 100644 index 2412940..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Analysis/PrintPass.h +++ /dev/null @@ -1,53 +0,0 @@ -#ifndef ZK_ML_TOOLCHAIN_PRINT_PASS -#define ZK_ML_TOOLCHAIN_PRINT_PASS - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/TypeID.h" -#include -#include -#include -#include - -#include "llvm/ADT/Hashing.h" -using namespace mlir; - -namespace zk_ml_toolchain -{ - class PrintPass : public mlir::PassWrapper> - { - StringRef getArgument() const final; - - StringRef getDescription() const final; - - void runOnOperation() override; - - void printVector(std::vector &typeIds); - - void printOperation(Operation *op, std::vector &typeIds); - - void printRegion(Region ®ion, std::vector &typeIds); - - void printBlock(Block &block, std::vector &typeIds); - - int indent; - - struct IdentRAII - { - int &indent; - IdentRAII(int &indent); - ~IdentRAII(); - }; - - void resetIndent(); - - IdentRAII pushIndent(); - - llvm::raw_ostream &printIndent(); - }; - - std::unique_ptr createPrintPass(); - -} // namespace zk_ml_toolchain - -#endif // ZK_ML_TOOLCHAIN_PRINT_PASS \ No newline at end of file diff --git a/zkml-onnx-compiler/src/Passes/mlir/CMakeLists.txt b/zkml-onnx-compiler/src/Passes/mlir/CMakeLists.txt deleted file mode 100644 index 2378387..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/CMakeLists.txt +++ /dev/null @@ -1,3 +0,0 @@ -add_subdirectory(Analysis) -add_subdirectory(Conversion) -add_subdirectory(Transform) diff --git a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.cpp deleted file mode 100644 index 684b4b4..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.cpp +++ /dev/null @@ -1,20 +0,0 @@ -#include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" - -#include "AffineFullUnrollPass.h" - -using mlir::affine::AffineForOp; -using mlir::affine::loopUnrollFull; - -void zk_ml::AffineFullUnrollPass::runOnOperation() { - getOperation().walk([&](AffineForOp op) { - if (failed(loopUnrollFull(op))) { - op.emitError("unrolling failed"); - signalPassFailure(); - } - }); -} -std::unique_ptr zk_ml::createFullUnrollPass() { - return std::make_unique(); -} diff --git a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.h b/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.h deleted file mode 100644 index df50cd3..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPass.h +++ /dev/null @@ -1,24 +0,0 @@ -#pragma once -#include "mlir/Pass/Pass.h" - -// This class unrolls all affine for loops using the C++ API. Usually, this -// is NOT the correct way to handle things, but for more minor transformations -// and analyses it is sufficient. -// -// It can be seen as a tutorial pass. Check the AffineFullUnrollPattern -// for how to use the Rewrite Engine. - -using namespace mlir; -namespace zk_ml { - class AffineFullUnrollPass: public PassWrapper> { - private: - void runOnOperation() override; // implemented in AffineFullUnroll.cpp - - StringRef getArgument() const final { return "affine-full-unroll"; } - - StringRef getDescription() const final { - return "Fully unroll all affine loops"; - } - }; - std::unique_ptr createFullUnrollPass(); -}// zk ml diff --git a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.cpp b/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.cpp deleted file mode 100644 index b81f62b..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.cpp +++ /dev/null @@ -1,33 +0,0 @@ -#include "mlir/Dialect/Affine/LoopUtils.h" -#include "mlir/Dialect/Affine/IR/AffineOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "AffineFullUnrollPattern.h" - -using mlir::affine::AffineForOp; -using mlir::affine::loopUnrollFull; - -namespace { - struct AffineFullUnrollPattern : public OpRewritePattern { - - AffineFullUnrollPattern(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) { - } - - LogicalResult matchAndRewrite(AffineForOp op, PatternRewriter &rewriter) const override { - return loopUnrollFull(op); - } - }; -} // namespace - -void zk_ml::AffineFullUnrollPassAsPatternRewrite::runOnOperation() { - mlir::RewritePatternSet patterns(&getContext()); - patterns.add(&getContext()); - // One could use GreedyRewriteConfig here to slightly tweak the behavior of - // the pattern application. - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); -} -std::unique_ptr zk_ml::createFullUnrollPassPatternRewriter() { - return std::make_unique(); -} diff --git a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.h b/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.h deleted file mode 100644 index af003a1..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Conversion/AffineFullUnrollPattern.h +++ /dev/null @@ -1,20 +0,0 @@ -#pragma once - -#include "mlir/Pass/Pass.h" - -using namespace mlir; -namespace zk_ml { - - class AffineFullUnrollPassAsPatternRewrite: public PassWrapper> { - private: - void runOnOperation() override; - - StringRef getArgument() const final { return "affine-full-unroll-with-pattern"; } - - StringRef getDescription() const final { - return "Fully unroll all affine loops with pattern rewrite"; - } - }; - std::unique_ptr createFullUnrollPassPatternRewriter(); - -}// zk ml diff --git a/zkml-onnx-compiler/src/Passes/mlir/Conversion/CMakeLists.txt b/zkml-onnx-compiler/src/Passes/mlir/Conversion/CMakeLists.txt deleted file mode 100644 index 86a48bc..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Conversion/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -add_onnx_mlir_library(ZkMLIRConversion - AffineFullUnrollPass.cpp - AffineFullUnrollPattern.cpp - #RewriteMulOps.cpp - - LINK_LIBS PRIVATE - - OMSpecializedKernelOpInterface - OMCompilerOptions - OMONNXOps - OMSupport - MLIRTransforms - MLIRAffineUtils - OMMlirDialects - ) - - diff --git a/zkml-onnx-compiler/src/Passes/mlir/Transform/CMakeLists.txt b/zkml-onnx-compiler/src/Passes/mlir/Transform/CMakeLists.txt deleted file mode 100644 index 6c40bc0..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Transform/CMakeLists.txt +++ /dev/null @@ -1,12 +0,0 @@ -add_onnx_mlir_library(ZkMLIRTransform - ElimCopySignPass.cpp - - LINK_LIBS PRIVATE - - OMSpecializedKernelOpInterface - OMCompilerOptions - OMONNXOps - OMSupport - MLIRTransforms - MLIRAffineUtils - ) diff --git a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.cpp index 93f6868..cd2e111 100644 --- a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.cpp +++ b/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.cpp @@ -1,20 +1,14 @@ +#include +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "ElimCopySignPass.h" +using namespace mlir; -StringRef zk_ml_toolchain::ElimCopySignPass::getArgument() const { - return "elim-copysign-pass"; -} - -StringRef zk_ml_toolchain::ElimCopySignPass::getDescription() const { - return "Eliminates redundant copysign operations that follow an frem operation"; -} - -void zk_ml_toolchain::ElimCopySignPass::runOnOperation() { +void zk_ml::ElimCopySignPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); patterns.add(&getContext()); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } -std::unique_ptr zk_ml_toolchain::createElimCopySignPass() { +std::unique_ptr mlir::zk_ml::createElimCopySignPass() { return std::make_unique(); } diff --git a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.h b/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.h deleted file mode 100644 index 82a61c6..0000000 --- a/zkml-onnx-compiler/src/Passes/mlir/Transform/ElimCopySignPass.h +++ /dev/null @@ -1,37 +0,0 @@ -#ifndef ZK_ML_TOOLCHAIN_ELIM_COPY_SIGN_PASS -#define ZK_ML_TOOLCHAIN_ELIM_COPY_SIGN_PASS - -#include "mlir/IR/BuiltinOps.h" -#include "mlir/IR/PatternMatch.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Support/TypeID.h" -#include "mlir/Dialect/Math/IR/Math.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include -#include -#include -#include -#include - -#include "llvm/ADT/Hashing.h" -using namespace mlir; - -namespace zk_ml_toolchain -{ - class ElimCopySignPass : public mlir::PassWrapper> - { - StringRef getArgument() const final; - - StringRef getDescription() const final; - - void runOnOperation() override; - }; - - std::unique_ptr createElimCopySignPass(); - -#include "ElimCopySign.cpp.inc" - -} // namespace zk_ml_toolchain - -#endif // ZK_ML_TOOLCHAIN_ELIM_COPY_SIGN_PASS diff --git a/zkml-onnx-compiler/src/Passes/mlir/Transform/PowFToGenericExpPass.cpp b/zkml-onnx-compiler/src/Passes/mlir/Transform/PowFToGenericExpPass.cpp new file mode 100644 index 0000000..b66f85c --- /dev/null +++ b/zkml-onnx-compiler/src/Passes/mlir/Transform/PowFToGenericExpPass.cpp @@ -0,0 +1,70 @@ +#include + +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" +#include "mlir/Dialect/Arith/IR/Arith.h" + +using namespace mlir; +using mlir::math::PowFOp; +using mlir::math::Exp2Op; +namespace { + struct PowFToGenericRewritePattern : public OpRewritePattern { + + PowFToGenericRewritePattern(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) { + } + + LogicalResult matchAndRewrite(PowFOp Op, PatternRewriter &Rewriter) const override { + Location Loc = NameLoc::get(StringAttr::get(Op->getContext(), "math.powf"), Op->getLoc()); + // check if exp is a constant + auto Base = Op->getOperand(0); + auto Exp = Op->getOperand(1); + Operation *DefiningOp = Exp.getDefiningOp(); + if (arith::ConstantFloatOp ConstOp = llvm::dyn_cast(*DefiningOp)) { + APFloat Constant = ConstOp.value(); + double d = Constant.convertToDouble(); + if (d == 0.5 || d == 1.0 || d == 2.0 || d == 3.0) + llvm_unreachable("I do not think this can happen but we want to assert to catch IF it happens"); + } + // just ordinary rewrite + // a^b becomes exp(ln(a)*b) + Value LnA = Rewriter.create(Loc, Base); + Value NewExp = Rewriter.create(Loc, LnA, Exp); + Rewriter.replaceOp(Op, Rewriter.create(Loc, NewExp)); + return success(); + } + }; + + struct Exp2ToGenericRewritePattern : public OpRewritePattern { + + Exp2ToGenericRewritePattern(MLIRContext *context) : OpRewritePattern(context, /*benefit=*/1) { + } + + LogicalResult matchAndRewrite(Exp2Op Op, PatternRewriter &Rewriter) const override { + Location Loc = NameLoc::get(StringAttr::get(Op->getContext(), "math.exp2"), Op->getLoc()); + auto Exp = Op->getOperand(0); + assert(Exp.getType().isa() && "Exponent must be float for exp2"); + FloatType FTy = Exp.getType().cast(); + // 2^b becomes exp(ln(2)*b) + // so we create constant for ln(2) and go from there + //0.6931471805599453094172321214581765680755001343602552541206800094 + // Value Ln2 = Rewriter.create(Loc, APFloat(FTy.getFloatSemantics(), "0.69314718055994530941"), FTy); + // Value NewExp = Rewriter.create(Loc, Ln2, Exp); + // Rewriter.replaceOp(Op, Rewriter.create(Loc, NewExp)); + Value Base = Rewriter.create(Loc, APFloat(FTy.getFloatSemantics(), "2.0"), FTy); + Value LnA = Rewriter.create(Loc, Base); + Value NewExp = Rewriter.create(Loc, LnA, Exp); + Rewriter.replaceOp(Op, Rewriter.create(Loc, NewExp)); + llvm::errs() << "base: " << Base << "\n"; + return success(); + } + }; +} // namespace + +void mlir::zk_ml::PowFToGenericExpPass::runOnOperation() { + mlir::RewritePatternSet patterns(&getContext()); + patterns.add(&getContext()); + patterns.add(&getContext()); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); +} +std::unique_ptr mlir::zk_ml::createPowFToGenericExpPass() { + return std::make_unique(); +} diff --git a/zkml-onnx-compiler/src/zkml-onnx-compiler.cpp b/zkml-onnx-compiler/src/zkml-onnx-compiler.cpp index 1c02b90..0f832b7 100644 --- a/zkml-onnx-compiler/src/zkml-onnx-compiler.cpp +++ b/zkml-onnx-compiler/src/zkml-onnx-compiler.cpp @@ -15,7 +15,8 @@ #include "src/Version/Version.hpp" #include "mlir/Dialect/zkml/ZkMlDialect.h" -#include "Passes/mlir/Transform/ElimCopySignPass.h" +#include +#include #define STDOUT_MARKER "stdout" @@ -41,12 +42,12 @@ bool hasEnding(std::string const &fullString, std::string const &ending) { } } -std::string dirName(StringRef inputFilename) { +std::string dirName(llvm::StringRef inputFilename) { llvm::SmallVector path(inputFilename.begin(), inputFilename.end()); llvm::sys::path::remove_filename(path); return std::string(path.data(), path.size()); } -int loadOnnxFile(StringRef inputFilename, mlir::MLIRContext &context, mlir::OwningOpRef &module, +int loadOnnxFile(llvm::StringRef inputFilename, mlir::MLIRContext &context, mlir::OwningOpRef &module, std::string *errorMessage) { // we use default options for now from onnx-mlir, lets see if we need // something else @@ -135,7 +136,7 @@ int main(int argc, char **argv) { // context.appendDialectRegistry(onnx_mlir::registerDialects(onnx_mlir::maccel)); // context.loadAllAvailableDialects(); // onnx_mlir::registerDialects(context); - context.getOrLoadDialect(); + context.getOrLoadDialect(); mlir::OwningOpRef module; std::string errorMessage; @@ -149,11 +150,15 @@ int main(int argc, char **argv) { bool EmitMLIR = EmitLevel::zkMLIR == EmitLevel || EmitLevel::MLIR == EmitLevel; onnx_mlir::configurePasses(); mlir::PassManager pm(module.get()->getName(), mlir::OpPassManager::Nesting::Implicit); - if (EmitLevel == EmitLevel::ONNX) { + if (EmitLevel::ONNX == EmitLevel) { onnx_mlir::addPasses(module, pm, onnx_mlir::EmissionTargetType::EmitONNXIR, outputFilename); } else { - onnx_mlir::addPasses(module, pm, onnx_mlir::EmissionTargetType::EmitMLIR, outputFilename, EmitLevel == EmitLevel::zkMLIR); - pm.addPass(zk_ml_toolchain::createElimCopySignPass()); + onnx_mlir::addPasses(module, pm, onnx_mlir::EmissionTargetType::EmitMLIR, outputFilename, + EmitLevel == EmitLevel::zkMLIR); + if (EmitLevel::zkMLIR == EmitLevel) { + pm.addPass(mlir::zk_ml::createElimCopySignPass()); + pm.addPass(mlir::zk_ml::createPowFToGenericExpPass()); + } if (!EmitMLIR) { // third parameter here is optional in onnx-mlir. Maybe we should do that // too?