-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
opening it here because there is a possibility that we are emitting the ops incorrectly
module @reactant_bmm_fd attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func private @"*_broadcast_scalar"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.multiply %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"+_broadcast_scalar"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.add %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"-_broadcast_scalar"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.subtract %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar"(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
return %arg0 : tensor<f64>
}
func.func private @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_1"(%arg0: tensor<f32> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
return %0, %arg0 : tensor<f64>, tensor<f32>
}
func.func private @identity_broadcast_scalar(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
return %arg0 : tensor<f64>
}
func.func private @"unbatched_#7"(%arg0: tensor<5x4x5x1xf32> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: tensor<3x4x5x2xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor<f64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%0 = stablehlo.broadcast_in_dim %arg0, dims = [0, 1, 2, 3] : (tensor<5x4x5x1xf32>) -> tensor<5x4x5x2xf32>
%1 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar"(%arg1) {batch_shape = array<i64: 3, 4, 5, 2>} : (tensor<3x4x5x2xf64>) -> tensor<3x4x5x2xf64>
%2:2 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_1"(%0) {batch_shape = array<i64: 5, 4, 5, 2>} : (tensor<5x4x5x2xf32>) -> (tensor<5x4x5x2xf64>, tensor<5x4x5x2xf32>)
%3 = stablehlo.dot_general %1, %2#0, batching_dims = [3, 2] x [3, 2], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x4x5x2xf64>, tensor<5x4x5x2xf64>) -> tensor<2x5x3x5xf64>
%4 = stablehlo.transpose %3, dims = [2, 3, 1, 0] : (tensor<2x5x3x5xf64>) -> tensor<3x5x5x2xf64>
%5 = enzyme.batch @identity_broadcast_scalar(%4) {batch_shape = array<i64: 3, 5, 5, 2>} : (tensor<3x5x5x2xf64>) -> tensor<3x5x5x2xf64>
%6 = stablehlo.reduce(%5 init: %cst) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<3x5x5x2xf64>, tensor<f64>) -> tensor<f64>
return %6 : tensor<f64>
}
func.func private @"-_broadcast_scalar_1"(%arg0: tensor<f64> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.subtract %arg0, %arg1 : tensor<f64>
return %0, %arg0, %arg1 : tensor<f64>, tensor<f64>, tensor<f64>
}
func.func private @"/_broadcast_scalar"(%arg0: tensor<f64> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.divide %arg0, %arg1 : tensor<f64>
return %0, %arg0, %arg1 : tensor<f64>, tensor<f64>, tensor<f64>
}
func.func private @"*_broadcast_scalar_1"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.multiply %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"+_broadcast_scalar_1"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.add %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"-_broadcast_scalar_2"(%arg0: tensor<f32> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
%1 = stablehlo.subtract %0, %arg1 : tensor<f64>
return %1, %arg0, %arg1 : tensor<f64>, tensor<f32>, tensor<f64>
}
func.func private @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_2"(%arg0: tensor<f32> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f32>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.convert %arg0 : (tensor<f32>) -> tensor<f64>
return %0, %arg0 : tensor<f64>, tensor<f32>
}
func.func private @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_3"(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
return %arg0 : tensor<f64>
}
func.func private @identity_broadcast_scalar_1(%arg0: tensor<f64> {enzymexla.memory_effects = []}) -> tensor<f64> attributes {enzymexla.memory_effects = []} {
return %arg0 : tensor<f64>
}
func.func private @"unbatched_#8"(%arg0: tensor<3x4x5x2xf32> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: tensor<5x4x5x1xf64> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> tensor<f64> attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%0 = stablehlo.broadcast_in_dim %arg1, dims = [0, 1, 2, 3] : (tensor<5x4x5x1xf64>) -> tensor<5x4x5x2xf64>
%1:2 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_2"(%arg0) {batch_shape = array<i64: 3, 4, 5, 2>} : (tensor<3x4x5x2xf32>) -> (tensor<3x4x5x2xf64>, tensor<3x4x5x2xf32>)
%2 = enzyme.batch @"Reactant.TracedUtils.TypeCast{Float64}()_broadcast_scalar_3"(%0) {batch_shape = array<i64: 5, 4, 5, 2>} : (tensor<5x4x5x2xf64>) -> tensor<5x4x5x2xf64>
%3 = stablehlo.dot_general %1#0, %2, batching_dims = [3, 2] x [3, 2], contracting_dims = [1] x [1], precision = [DEFAULT, DEFAULT] : (tensor<3x4x5x2xf64>, tensor<5x4x5x2xf64>) -> tensor<2x5x3x5xf64>
%4 = stablehlo.transpose %3, dims = [2, 3, 1, 0] : (tensor<2x5x3x5xf64>) -> tensor<3x5x5x2xf64>
%5 = enzyme.batch @identity_broadcast_scalar_1(%4) {batch_shape = array<i64: 3, 5, 5, 2>} : (tensor<3x5x5x2xf64>) -> tensor<3x5x5x2xf64>
%6 = stablehlo.reduce(%5 init: %cst) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<3x5x5x2xf64>, tensor<f64>) -> tensor<f64>
return %6 : tensor<f64>
}
func.func private @"-_broadcast_scalar_3"(%arg0: tensor<f64> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.subtract %arg0, %arg1 : tensor<f64>
return %0, %arg0, %arg1 : tensor<f64>, tensor<f64>, tensor<f64>
}
func.func private @"/_broadcast_scalar_1"(%arg0: tensor<f64> {enzymexla.memory_effects = []}, %arg1: tensor<f64> {enzymexla.memory_effects = []}) -> (tensor<f64>, tensor<f64>, tensor<f64>) attributes {enzymexla.memory_effects = []} {
%0 = stablehlo.divide %arg0, %arg1 : tensor<f64>
return %0, %arg0, %arg1 : tensor<f64>, tensor<f64>, tensor<f64>
}
func.func @main(%arg0: tensor<2x5x4x3xf32> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}, %arg1: tensor<1x5x4x5xf32> {enzymexla.memory_effects = ["read", "write", "allocate", "free"]}) -> (tensor<2x5x4x3xf64>, tensor<1x5x4x5xf64>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<1.2831061023768837E-5> : tensor<100xf64>
%cst_0 = stablehlo.constant dense<6.4155305118844185E-6> : tensor<100x100xf64>
%c = stablehlo.constant dense<"0xtensor<100x2xi64>
%cst_1 = stablehlo.constant dense<1.000000e+00> : tensor<100xf32>
%cst_2 = stablehlo.constant dense<1.2831061023768837E-5> : tensor<120xf64>
%cst_3 = stablehlo.constant dense<6.4155305118844185E-6> : tensor<120x120xf64>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<100x100xf32>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<120x120xf32>
%cst_6 = stablehlo.constant dense<1.000000e+00> : tensor<120xf32>
%c_7 = stablehlo.constant dense<"0xtensor<120x2xi64>
%cst_8 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
%0 = "stablehlo.scatter"(%cst_5, %c_7, %cst_6) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
stablehlo.return %cst_8 : tensor<f32>
}) : (tensor<120x120xf32>, tensor<120x2xi64>, tensor<120xf32>) -> tensor<120x120xf32>
%1:3 = enzyme.batch @"*_broadcast_scalar"(%0, %cst_3) {batch_shape = array<i64: 120, 120>} : (tensor<120x120xf32>, tensor<120x120xf64>) -> (tensor<120x120xf64>, tensor<120x120xf32>, tensor<120x120xf64>)
%2 = stablehlo.broadcast_in_dim %arg0, dims = [3, 2, 1, 0] : (tensor<2x5x4x3xf32>) -> tensor<3x4x5x2x120xf32>
%3 = stablehlo.transpose %1#0, dims = [1, 0] : (tensor<120x120xf64>) -> tensor<120x120xf64>
%4 = stablehlo.reshape %3 : (tensor<120x120xf64>) -> tensor<120x2x5x4x3xf64>
%5 = stablehlo.transpose %4, dims = [4, 3, 2, 1, 0] : (tensor<120x2x5x4x3xf64>) -> tensor<3x4x5x2x120xf64>
%6:3 = enzyme.batch @"+_broadcast_scalar"(%2, %5) {batch_shape = array<i64: 3, 4, 5, 2, 120>} : (tensor<3x4x5x2x120xf32>, tensor<3x4x5x2x120xf64>) -> (tensor<3x4x5x2x120xf64>, tensor<3x4x5x2x120xf32>, tensor<3x4x5x2x120xf64>)
%7:3 = enzyme.batch @"-_broadcast_scalar"(%2, %5) {batch_shape = array<i64: 3, 4, 5, 2, 120>} : (tensor<3x4x5x2x120xf32>, tensor<3x4x5x2x120xf64>) -> (tensor<3x4x5x2x120xf64>, tensor<3x4x5x2x120xf32>, tensor<3x4x5x2x120xf64>)
%8 = stablehlo.concatenate %6#0, %7#0, dim = 4 : (tensor<3x4x5x2x120xf64>, tensor<3x4x5x2x120xf64>) -> tensor<3x4x5x2x240xf64>
%9 = stablehlo.transpose %8, dims = [4, 0, 1, 2, 3] : (tensor<3x4x5x2x240xf64>) -> tensor<240x3x4x5x2xf64>
%10 = enzyme.batch @"unbatched_#7"(%9) {batch_shape = array<i64: 240>} : (tensor<240x3x4x5x2xf64>) -> tensor<240xf64>
%11 = stablehlo.slice %10 [0:120] : (tensor<240xf64>) -> tensor<120xf64>
%12 = stablehlo.slice %10 [120:240] : (tensor<240xf64>) -> tensor<120xf64>
%13:3 = enzyme.batch @"-_broadcast_scalar_1"(%11, %12) {batch_shape = array<i64: 120>} : (tensor<120xf64>, tensor<120xf64>) -> (tensor<120xf64>, tensor<120xf64>, tensor<120xf64>)
%14:3 = enzyme.batch @"/_broadcast_scalar"(%13#0, %cst_2) {batch_shape = array<i64: 120>} : (tensor<120xf64>, tensor<120xf64>) -> (tensor<120xf64>, tensor<120xf64>, tensor<120xf64>)
%15 = stablehlo.reshape %14#0 : (tensor<120xf64>) -> tensor<2x5x4x3xf64>
%16 = "stablehlo.scatter"(%cst_4, %c, %cst_1) <{indices_are_sorted = false, scatter_dimension_numbers = #stablehlo.scatter<inserted_window_dims = [0, 1], scatter_dims_to_operand_dims = [0, 1], index_vector_dim = 1>, unique_indices = true}> ({
^bb0(%arg2: tensor<f32>, %arg3: tensor<f32>):
stablehlo.return %cst_8 : tensor<f32>
}) : (tensor<100x100xf32>, tensor<100x2xi64>, tensor<100xf32>) -> tensor<100x100xf32>
%17:3 = enzyme.batch @"*_broadcast_scalar_1"(%16, %cst_0) {batch_shape = array<i64: 100, 100>} : (tensor<100x100xf32>, tensor<100x100xf64>) -> (tensor<100x100xf64>, tensor<100x100xf32>, tensor<100x100xf64>)
%18 = stablehlo.broadcast_in_dim %arg1, dims = [3, 2, 1, 0] : (tensor<1x5x4x5xf32>) -> tensor<5x4x5x1x100xf32>
%19 = stablehlo.transpose %17#0, dims = [1, 0] : (tensor<100x100xf64>) -> tensor<100x100xf64>
%20 = stablehlo.reshape %19 : (tensor<100x100xf64>) -> tensor<100x1x5x4x5xf64>
%21 = stablehlo.transpose %20, dims = [4, 3, 2, 1, 0] : (tensor<100x1x5x4x5xf64>) -> tensor<5x4x5x1x100xf64>
%22:3 = enzyme.batch @"+_broadcast_scalar_1"(%18, %21) {batch_shape = array<i64: 5, 4, 5, 1, 100>} : (tensor<5x4x5x1x100xf32>, tensor<5x4x5x1x100xf64>) -> (tensor<5x4x5x1x100xf64>, tensor<5x4x5x1x100xf32>, tensor<5x4x5x1x100xf64>)
%23:3 = enzyme.batch @"-_broadcast_scalar_2"(%18, %21) {batch_shape = array<i64: 5, 4, 5, 1, 100>} : (tensor<5x4x5x1x100xf32>, tensor<5x4x5x1x100xf64>) -> (tensor<5x4x5x1x100xf64>, tensor<5x4x5x1x100xf32>, tensor<5x4x5x1x100xf64>)
%24 = stablehlo.concatenate %22#0, %23#0, dim = 4 : (tensor<5x4x5x1x100xf64>, tensor<5x4x5x1x100xf64>) -> tensor<5x4x5x1x200xf64>
%25 = stablehlo.transpose %24, dims = [4, 0, 1, 2, 3] : (tensor<5x4x5x1x200xf64>) -> tensor<200x5x4x5x1xf64>
%26 = enzyme.batch @"unbatched_#8"(%25) {batch_shape = array<i64: 200>} : (tensor<200x5x4x5x1xf64>) -> tensor<200xf64>
%27 = stablehlo.slice %26 [0:100] : (tensor<200xf64>) -> tensor<100xf64>
%28 = stablehlo.slice %26 [100:200] : (tensor<200xf64>) -> tensor<100xf64>
%29:3 = enzyme.batch @"-_broadcast_scalar_3"(%27, %28) {batch_shape = array<i64: 100>} : (tensor<100xf64>, tensor<100xf64>) -> (tensor<100xf64>, tensor<100xf64>, tensor<100xf64>)
%30:3 = enzyme.batch @"/_broadcast_scalar_1"(%29#0, %cst) {batch_shape = array<i64: 100>} : (tensor<100xf64>, tensor<100xf64>) -> (tensor<100xf64>, tensor<100xf64>, tensor<100xf64>)
%31 = stablehlo.reshape %30#0 : (tensor<100xf64>) -> tensor<1x5x4x5xf64>
return %15, %31 : tensor<2x5x4x3xf64>, tensor<1x5x4x5xf64>
}
}Metadata
Metadata
Assignees
Labels
No labels