-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
opening it here because there is a possibility that we are emitting the ops incorrectly
module @reactant_bmm_fd attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func private @"*_broadcast_scalar"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.multiply %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"+_broadcast_scalar"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.add %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"-_broadcast_scalar"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.subtract %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar"(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
return %arg0 : tensor<f64>
}
func.func private @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_1"(%arg0: tensor<f32> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
return %0, %arg0 : tensor<f64>, tensor<f32>
}
func.func private @identity_broadcast_scalar(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
return %arg0 : tensor<f64>
}
func.func private @"unbatched_#7"(%arg0: tensor<5x4x5x1xf32> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: tensor<3x4x5x2xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor<f64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2, 3] : (tensor<5x4x5x1xf32>) -> tensor<5x4x5x2xf32>
%1 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar"(%arg1) {batch_shape = array<i64: 3, 4, 5, 2>} : (tensor<3x4x5x2xf64>) -> tensor<3x4x5x2xf64>
%2:2 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_1"(%0) {batch_shape = array<i64: 5, 4, 5, 2>} : (tensor<5x4x5x2xf32>) -> (tensor<5x4x5x2xf64>, tensor<5x4x5x2xf32>)
%3 = stablehlo.dot_general %1, %2#0, batching_dims = [3, 2] x [3, 2], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x4x5x2xf64>, tensor<5x4x5x2xf64>) -> tensor<2x5x3x5xf64>
%4 = stablehlo.transpose %3, dims = [2, 3, 1, 0] : (tensor<2x5x3x5xf64>) -> tensor<3x5x5x2xf64>
%5 = enzyme.batch @identity_broadcast_scalar(%4) {batch_shape = array<i64: 3, 5, 5, 2>} : (tensor<3x5x5x2xf64>) -> tensor<3x5x5x2xf64>
%6 = stablehlo.reduce(%5 init: %cst) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<3x5x5x2xf64>, tensor<f64>) -> tensor<f64>
return %6 : tensor<f64>
}
func.func private @"-_broadcast_scalar_1"(%arg0: tensor<f64> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.subtract %arg0, %arg1 : tensor<f64>
return %0, %arg0, %arg1 : tensor<f64>, tensor<f64>, tensor<f64>
}
func.func private @"/_broadcast_scalar"(%arg0: tensor<f64> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.divide %arg0, %arg1 : tensor<f64>
return %0, %arg0, %arg1 : tensor<f64>, tensor<f64>, tensor<f64>
}
func.func private @"*_broadcast_scalar_1"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.multiply %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"+_broadcast_scalar_1"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.add %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"-_broadcast_scalar_2"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.subtract %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_2"(%arg0: tensor<f32> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
return %0, %arg0 : tensor<f64>, tensor<f32>
}
func.func private @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_3"(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
return %arg0 : tensor<f64>
}
func.func private @identity_broadcast_scalar_1(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
return %arg0 : tensor<f64>
}
func.func private @"unbatched_#8"(%arg0: tensor<3x4x5x2xf32> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: tensor<5x4x5x1xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor<f64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%0 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1, 2, 3] : (tensor<5x4x5x1xf64>) -> tensor<5x4x5x2xf64>
%1:2 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_2"(%arg0) {batch_shape = array<i64: 3, 4, 5, 2>} : (tensor<3x4x5x2xf32>) -> (tensor<3x4x5x2xf64>, tensor<3x4x5x2xf32>)
%2 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_3"(%0) {batch_shape = array<i64: 5, 4, 5, 2>} : (tensor<5x4x5x2xf64>) -> tensor<5x4x5x2xf64>
%3 = stablehlo.dot_general %1#0, %2, batching_dims = [3, 2] x [3, 2], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x4x5x2xf64>, tensor<5x4x5x2xf64>) -> tensor<2x5x3x5xf64>
%4 = stablehlo.transpose %3, dims = [2, 3, 1, 0] : (tensor<2x5x3x5xf64>) -> tensor<3x5x5x2xf64>
%5 = enzyme.batch @identity_broadcast_scalar_1(%4) {batch_shape = array<i64: 3, 5, 5, 2>} : (tensor<3x5x5x2xf64>) -> tensor<3x5x5x2xf64>
%6 = stablehlo.reduce(%5 init: %cst) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<3x5x5x2xf64>, tensor<f64>) -> tensor<f64>
return %6 : tensor<f64>
}
func.func private @"-_broadcast_scalar_3"(%arg0: tensor<f64> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.subtract %arg0, %arg1 : tensor<f64>
return %0, %arg0, %arg1 : tensor<f64>, tensor<f64>, tensor<f64>
}
func.func private @"/_broadcast_scalar_1"(%arg0: tensor<f64> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.divide %arg0, %arg1 : tensor<f64>
return %0, %arg0, %arg1 : tensor<f64>, tensor<f64>, tensor<f64>
}
func.func @main(%arg0: tensor<2x5x4x3xf32> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: tensor<1x5x4x5xf32> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> (tensor<2x5x4x3xf64>, tensor<1x5x4x5xf64>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<1.2831061023768837E-5> : tensor<100xf64>
%cst_0 = stablehlo.constant dense<6.4155305118844185E-6> : tensor<100x100xf64>
%c = stablehlo.constant dense<"0x000000000000000000000000000000000100000000000000010000000000000002000000000000000200000000000000030000000000000003000000000000000400000000000000040000000000000005000000000000000500000000000000060000000000000006000000000000000700000000000000070000000000000008000000000000000800000000000000090000000000000009000000000000000A000000000000000A000000000000000B000000000000000B000000000000000C000000000000000C000000000000000D000000000000000D000000000000000E000000000000000E000000000000000F000000000000000F00000000000000100000000000000010000000000000001100000000000000110000000000000012000000000000001200000000000000130000000000000013000000000000001400000000000000140000000000000015000000000000001500000000000000160000000000000016000000000000001700000000000000170000000000000018000000000000001800000000000000190000000000000019000000000000001A000000000000001A000000000000001B000000000000001B000000000000001C000000000000001C000000000000001D000000000000001D000000000000001E000000000000001E000000000000001F000000000000001F00000000000000200000000000000020000000000000002100000000000000210000000000000022000000000000002200000000000000230000000000000023000000000000002400000000000000240000000000000025000000000000002500000000000000260000000000000026000000000000002700000000000000270000000000000028000000000000002800000000000000290000000000000029000000000000002A000000000000002A000000000000002B000000000000002B000000000000002C000000000000002C000000000000002D000000000000002D000000000000002E000000000000002E000000000000002F000000000000002F00000000000000300000000000000030000000000000003100000000000000310000000000000032000000000000003200000000000000330000000000000033000000000000003400000000000000340000000000000035000000000000003500000000000000360000000000000036000000000000003700000000000000370000000000000038000000000000003800000000000000390000000000000039000000000000003A000000000000003A000000000000003B000000000000003B000000000000003C000000000000003C000000000000003D000000000000003D000000000000003E000000000000003E000000000000003F000000000000003F00000000000000400000000000000040000000000000004100000000000000410000000000000042000000000000004200000000000000430000000000000043000000000000004400000000000000440000000000000045000000000000004500000000000000460000000000000046000000000000004700000000000000470000000000000048000000000000004800000000000000490000000000000049000000000000004A000000000000004A000000000000004B000000000000004B000000000000004C000000000000004C000000000000004D000000000000004D000000000000004E000000000000004E000000000000004F000000000000004F00000000000000500000000000000050000000000000005100000000000000510000000000000052000000000000005200000000000000530000000000000053000000000000005400000000000000540000000000000055000000000000005500000000000000560000000000000056000000000000005700000000000000570000000000000058000000000000005800000000000000590000000000000059000000000000005A000000000000005A000000000000005B000000000000005B000000000000005C000000000000005C000000000000005D000000000000005D000000000000005E000000000000005E000000000000005F000000000000005F0000000000000060000000000000006000000000000000610000000000000061000000000000006200000000000000620000000000000063000000000000006300000000000000"> : tensor<100x2xi64>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<100xf32>
%cst_2 = stablehlo.constant dense<1.2831061023768837E-5> : tensor<120xf64>
%cst_3 = stablehlo.constant dense<6.4155305118844185E-6> : tensor<120x120xf64>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<100x100xf32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<120x120xf32>
%cst_6 = stablehlo.constant dense<1.000000e+00> : tensor<120xf32>
%c_7 = stablehlo.constant dense<"0x000000000000000000000000000000000100000000000000010000000000000002000000000000000200000000000000030000000000000003000000000000000400000000000000040000000000000005000000000000000500000000000000060000000000000006000000000000000700000000000000070000000000000008000000000000000800000000000000090000000000000009000000000000000A000000000000000A000000000000000B000000000000000B000000000000000C000000000000000C000000000000000D000000000000000D000000000000000E000000000000000E000000000000000F000000000000000F00000000000000100000000000000010000000000000001100000000000000110000000000000012000000000000001200000000000000130000000000000013000000000000001400000000000000140000000000000015000000000000001500000000000000160000000000000016000000000000001700000000000000170000000000000018000000000000001800000000000000190000000000000019000000000000001A000000000000001A000000000000001B000000000000001B000000000000001C000000000000001C000000000000001D000000000000001D000000000000001E000000000000001E000000000000001F000000000000001F00000000000000200000000000000020000000000000002100000000000000210000000000000022000000000000002200000000000000230000000000000023000000000000002400000000000000240000000000000025000000000000002500000000000000260000000000000026000000000000002700000000000000270000000000000028000000000000002800000000000000290000000000000029000000000000002A000000000000002A000000000000002B000000000000002B000000000000002C000000000000002C000000000000002D000000000000002D000000000000002E000000000000002E000000000000002F000000000000002F00000000000000300000000000000030000000000000003100000000000000310000000000000032000000000000003200000000000000330000000000000033000000000000003400000000000000340000000000000035000000000000003500000000000000360000000000000036000000000000003700000000000000370000000000000038000000000000003800000000000000390000000000000039000000000000003A000000000000003A000000000000003B000000000000003B000000000000003C000000000000003C000000000000003D000000000000003D000000000000003E000000000000003E000000000000003F000000000000003F00000000000000400000000000000040000000000000004100000000000000410000000000000042000000000000004200000000000000430000000000000043000000000000004400000000000000440000000000000045000000000000004500000000000000460000000000000046000000000000004700000000000000470000000000000048000000000000004800000000000000490000000000000049000000000000004A000000000000004A000000000000004B000000000000004B000000000000004C000000000000004C000000000000004D000000000000004D000000000000004E000000000000004E000000000000004F000000000000004F00000000000000500000000000000050000000000000005100000000000000510000000000000052000000000000005200000000000000530000000000000053000000000000005400000000000000540000000000000055000000000000005500000000000000560000000000000056000000000000005700000000000000570000000000000058000000000000005800000000000000590000000000000059000000000000005A000000000000005A000000000000005B000000000000005B000000000000005C000000000000005C000000000000005D000000000000005D000000000000005E000000000000005E000000000000005F000000000000005F00000000000000600000000000000060000000000000006100000000000000610000000000000062000000000000006200000000000000630000000000000063000000000000006400000000000000640000000000000065000000000000006500000000000000660000000000000066000000000000006700000000000000670000000000000068000000000000006800000000000000690000000000000069000000000000006A000000000000006A000000000000006B000000000000006B000000000000006C000000000000006C000000000000006D000000000000006D000000000000006E000000000000006E000000000000006F000000000000006F000000000000007000000000000000700000000000000071000000000000007100000000000000720000000000000072000000000000007300000000000000730000000000000074000000000000007400000000000000750000000000000075000000000000007600000000000000760000000000000077000000000000007700000000000000"> : tensor<120x2xi64>
%cst_8 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%0 = "stablehlo.scatter"(%cst_5, %c_7, %cst_6) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
stablehlo.return %cst_8 : tensor<f32>
}) : (tensor<120x120xf32>, tensor<120x2xi64>, tensor<120xf32>) -> tensor<120x120xf32>
%1:3 = enzyme.batch @"*_broadcast_scalar"(%0, %cst_3) {batch_shape = array<i64: 120, 120>} : (tensor<120x120xf32>, tensor<120x120xf64>) -> (tensor<120x120xf64>, tensor<120x120xf32>, tensor<120x120xf64>)
%2 = stablehlo.broadcast_in_dim %arg0, dims = [3, 2, 1, 0] : (tensor<2x5x4x3xf32>) -> tensor<3x4x5x2x120xf32>
%3 = stablehlo.transpose %1#0, dims = [1, 0] : (tensor<120x120xf64>) -> tensor<120x120xf64>
%4 = stablehlo.reshape %3 : (tensor<120x120xf64>) -> tensor<120x2x5x4x3xf64>
%5 = stablehlo.transpose %4, dims = [4, 3, 2, 1, 0] : (tensor<120x2x5x4x3xf64>) -> tensor<3x4x5x2x120xf64>
%6:3 = enzyme.batch @"+_broadcast_scalar"(%2, %5) {batch_shape = array<i64: 3, 4, 5, 2, 120>} : (tensor<3x4x5x2x120xf32>, tensor<3x4x5x2x120xf64>) -> (tensor<3x4x5x2x120xf64>, tensor<3x4x5x2x120xf32>, tensor<3x4x5x2x120xf64>)
%7:3 = enzyme.batch @"-_broadcast_scalar"(%2, %5) {batch_shape = array<i64: 3, 4, 5, 2, 120>} : (tensor<3x4x5x2x120xf32>, tensor<3x4x5x2x120xf64>) -> (tensor<3x4x5x2x120xf64>, tensor<3x4x5x2x120xf32>, tensor<3x4x5x2x120xf64>)
%8 = stablehlo.concatenate %6#0, %7#0, dim = 4 : (tensor<3x4x5x2x120xf64>, tensor<3x4x5x2x120xf64>) -> tensor<3x4x5x2x240xf64>
%9 = stablehlo.transpose %8, dims = [4, 0, 1, 2, 3] : (tensor<3x4x5x2x240xf64>) -> tensor<240x3x4x5x2xf64>
%10 = enzyme.batch @"unbatched_#7"(%9) {batch_shape = array<i64: 240>} : (tensor<240x3x4x5x2xf64>) -> tensor<240xf64>
%11 = stablehlo.slice %10 [0:120] : (tensor<240xf64>) -> tensor<120xf64>
%12 = stablehlo.slice %10 [120:240] : (tensor<240xf64>) -> tensor<120xf64>
%13:3 = enzyme.batch @"-_broadcast_scalar_1"(%11, %12) {batch_shape = array<i64: 120>} : (tensor<120xf64>, tensor<120xf64>) -> (tensor<120xf64>, tensor<120xf64>, tensor<120xf64>)
%14:3 = enzyme.batch @"/_broadcast_scalar"(%13#0, %cst_2) {batch_shape = array<i64: 120>} : (tensor<120xf64>, tensor<120xf64>) -> (tensor<120xf64>, tensor<120xf64>, tensor<120xf64>)
%15 = stablehlo.reshape %14#0 : (tensor<120xf64>) -> tensor<2x5x4x3xf64>
%16 = "stablehlo.scatter"(%cst_4, %c, %cst_1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
stablehlo.return %cst_8 : tensor<f32>
}) : (tensor<100x100xf32>, tensor<100x2xi64>, tensor<100xf32>) -> tensor<100x100xf32>
%17:3 = enzyme.batch @"*_broadcast_scalar_1"(%16, %cst_0) {batch_shape = array<i64: 100, 100>} : (tensor<100x100xf32>, tensor<100x100xf64>) -> (tensor<100x100xf64>, tensor<100x100xf32>, tensor<100x100xf64>)
%18 = stablehlo.broadcast_in_dim %arg1, dims = [3, 2, 1, 0] : (tensor<1x5x4x5xf32>) -> tensor<5x4x5x1x100xf32>
%19 = stablehlo.transpose %17#0, dims = [1, 0] : (tensor<100x100xf64>) -> tensor<100x100xf64>
%20 = stablehlo.reshape %19 : (tensor<100x100xf64>) -> tensor<100x1x5x4x5xf64>
%21 = stablehlo.transpose %20, dims = [4, 3, 2, 1, 0] : (tensor<100x1x5x4x5xf64>) -> tensor<5x4x5x1x100xf64>
%22:3 = enzyme.batch @"+_broadcast_scalar_1"(%18, %21) {batch_shape = array<i64: 5, 4, 5, 1, 100>} : (tensor<5x4x5x1x100xf32>, tensor<5x4x5x1x100xf64>) -> (tensor<5x4x5x1x100xf64>, tensor<5x4x5x1x100xf32>, tensor<5x4x5x1x100xf64>)
%23:3 = enzyme.batch @"-_broadcast_scalar_2"(%18, %21) {batch_shape = array<i64: 5, 4, 5, 1, 100>} : (tensor<5x4x5x1x100xf32>, tensor<5x4x5x1x100xf64>) -> (tensor<5x4x5x1x100xf64>, tensor<5x4x5x1x100xf32>, tensor<5x4x5x1x100xf64>)
%24 = stablehlo.concatenate %22#0, %23#0, dim = 4 : (tensor<5x4x5x1x100xf64>, tensor<5x4x5x1x100xf64>) -> tensor<5x4x5x1x200xf64>
%25 = stablehlo.transpose %24, dims = [4, 0, 1, 2, 3] : (tensor<5x4x5x1x200xf64>) -> tensor<200x5x4x5x1xf64>
%26 = enzyme.batch @"unbatched_#8"(%25) {batch_shape = array<i64: 200>} : (tensor<200x5x4x5x1xf64>) -> tensor<200xf64>
%27 = stablehlo.slice %26 [0:100] : (tensor<200xf64>) -> tensor<100xf64>
%28 = stablehlo.slice %26 [100:200] : (tensor<200xf64>) -> tensor<100xf64>
%29:3 = enzyme.batch @"-_broadcast_scalar_3"(%27, %28) {batch_shape = array<i64: 100>} : (tensor<100xf64>, tensor<100xf64>) -> (tensor<100xf64>, tensor<100xf64>, tensor<100xf64>)
%30:3 = enzyme.batch @"/_broadcast_scalar_1"(%29#0, %cst) {batch_shape = array<i64: 100>} : (tensor<100xf64>, tensor<100xf64>) -> (tensor<100xf64>, tensor<100xf64>, tensor<100xf64>)
%31 = stablehlo.reshape %30#0 : (tensor<100xf64>) -> tensor<1x5x4x5xf64>
return %15, %31 : tensor<2x5x4x3xf64>, tensor<1x5x4x5xf64>
}
}Metadata
Metadata
Assignees
Labels
No labels