-
Notifications
You must be signed in to change notification settings - Fork 38
Closed
Description
using Enzyme
using Reactant
function wind_stress_init(cpu=true)
res = ones(63, 63)
if !cpu
res = Reactant.to_rarray(res)
end
return res
end
function estimate_tracer_error(wind_stress)
u = similar(wind_stress, 63, 63, 16)
fill!(u, 0)
v = similar(wind_stress, 78, 78, 31)
fill!(v, 0)
copyto!(@view(v[8:end-8, 8:end-8, 15]), wind_stress)
@trace track_numbers=false for _ = 1:3
vsl = @view(v[8:end-8, 8:end-8, 8:end-8])
copyto!(vsl, Reactant.Ops.add(v[8:end-8, 8:end-8, 8:end-8], Reactant.TracedUtils.broadcast_to_size(u, size(vsl))))
copyto!(u, v[9:end-7, 7:end-9, 8:end-8])
copyto!(@view(u[:, :, 2]), u[:, :, 8])
copyto!(vsl, Reactant.TracedUtils.broadcast_to_size(v[8:end-8, 8:end-8, 9], size(vsl)))
end
mean_sq_surface_u = Reactant.Ops.reduce(u, zero(eltype(u)), [1, 2, 3], +)
return mean_sq_surface_u
end
function differentiate_tracer_error(J)
dJ = zero(J)
dedν = autodiff(set_strong_zero(Enzyme.Reverse),
estimate_tracer_error, Active,
Duplicated(J, dJ))
return dedν, dJ
end
rwind_stress = wind_stress_init(false)
@info "Compiling..."
dJ = make_zero(rwind_stress)
pre_pass_pipeline = pass_pipeline = "mark-func-memory-effects,inline{default-pipeline=canonicalize max-iterations=4},propagate-constant-bounds,sroa-wrappers{instcombine=false instsimplify=true },canonicalize,sroa-wrappers{instcombine=false instsimplify=true },libdevice-funcs-raise,canonicalize,remove-duplicate-func-def,canonicalize,cse,canonicalize,enzyme-hlo-generate-td{patterns=compare_op_canon<16>;transpose_transpose<16>;broadcast_in_dim_op_canon<16>;convert_op_canon<16>;dynamic_broadcast_in_dim_op_not_actually_dynamic<16>;chained_dynamic_broadcast_in_dim_canonicalization<16>;dynamic_broadcast_in_dim_all_dims_non_expanding<16>;noop_reduce_op_canon<16>;empty_reduce_op_canon<16>;dynamic_reshape_op_canon<16>;get_tuple_element_op_canon<16>;real_op_canon<16>;imag_op_canon<16>;conj_complex_negate<16>;get_dimension_size_op_canon<16>;gather_op_canon<16>;reshape_op_canon<16>;merge_consecutive_reshapes<16>;transpose_is_reshape<16>;zero_extent_tensor_canon<16>;cse_broadcast_in_dim<16>;cse_slice<16>;cse_transpose<16>;cse_convert<16>;cse_pad<16>;cse_dot_general<16>;cse_reshape<16>;cse_mul<16>;cse_div<16>;cse_add<16>;cse_subtract<16>;cse_min<16>;cse_max<16>;cse_neg<16>;cse_abs<16>;cse_concatenate<16>;concatenate_op_canon<16>(1024);select_op_canon<16>(1024);add_simplify<16>;sub_simplify<16>;and_simplify<16>;max_simplify<16>;min_simplify<16>;or_simplify<16>;xor_simplify<16>;mul_simplify<16>;div_simplify<16>;rem_simplify<16>;pow_simplify<16>;simplify_extend<16>;simplify_wrap<16>;simplify_rotate<16>;noop_slice<16>;noop_reverse<16>;slice_slice<16>;shift_right_logical_simplify<16>;pad_simplify<16>(1024);select_pad_to_dus<1>;and_pad_pad<1>;negative_pad_to_slice<16>;slice_simplify<16>;convert_simplify<16>;dynamic_slice_to_static<16>;dynamic_update_slice_elim<16>;concat_to_broadcast<16>;reduce_to_reshape<16>;broadcast_to_reshape<16>;slice_internal;iota_simplify<16>(1024);broadcast_in_dim_simplify<16>(1024);convert_concat<1>;dynamic_update_to_concat<1>;slice_of_dynamic_update<1>;slice_elementwise<1>;slice_pad<1>;dot_reshape_dot<1>;concat_fuse<1>;pad_reshape_pad<1>;pad_pad<1>;concat_push_binop_add<1>;concat_push_binop_mul<1>;scatter_to_dynamic_update_slice<1>;reduce_concat<1>;slice_concat<1>;concat_slice<1>;select_op_used_within_if<1>;bin_broadcast_splat_add<1>;bin_broadcast_splat_subtract<1>;bin_broadcast_splat_div<1>;bin_broadcast_splat_mul<1>;dot_general_simplify<16>;transpose_simplify<16>;reshape_empty_broadcast<1>;add_pad_pad_to_concat<1>;broadcast_reshape<1>;concat_pad<1>;reduce_pad<1>;broadcast_pad<1>;zero_product_reshape_pad<1>;mul_zero_pad<1>;div_zero_pad<1>;binop_const_reshape_pad<1>;binop_const_pad_add<1>;binop_const_pad_subtract<1>;binop_const_pad_mul<1>;binop_const_pad_div<1>;binop_binop_pad_pad_add<1>;binop_binop_pad_pad_mul<1>;binop_pad_pad_add<1>;binop_pad_pad_subtract<1>;binop_pad_pad_mul<1>;binop_pad_pad_div<1>;binop_pad_pad_min<1>;binop_pad_pad_max<1>;unary_pad_push_convert<1>;unary_pad_push_tanh<1>;unary_pad_push_exp<1>;transpose_dot_reorder<1>;dot_transpose<1>;transpose_convolution<1>;convolution_transpose<1>;convert_convert_float<1>;concat_to_pad<1>;reshape_iota<1>;broadcast_reduce<1>;slice_dot_general<1>;if_inline<1>;if_to_select<1>;dynamic_gather_op_is_not_dynamic<16>;divide_sqrt_to_multiply_rsqrt<16>;associative_binary_op_reordering<1>;transpose_broadcast_in_dim_to_broadcast_in_dim<16>;replace_neg_add_with_subtract;binop_const_simplify;not_select_simplify;common_compare_expression_rewrite;compare_select_simplify;while_simplify<1>(1);if_remove_unused;transpose_reshape_to_broadcast;reshape_transpose_to_broadcast;dus_dus;dus_dus_concat;abs_positive_simplify;transpose_unary_transpose_abs;transpose_unary_transpose_neg;transpose_unary_transpose_sqrt;transpose_unary_transpose_rsqrt;transpose_unary_transpose_ceil;transpose_unary_transpose_convert;transpose_unary_transpose_cosine;transpose_unary_transpose_exp;transpose_unary_transpose_expm1;transpose_unary_transpose_log;transpose_unary_transpose_log1p;transpose_unary_transpose_sign;transpose_unary_transpose_sine;transpose_unary_transpose_tanh;select_comp_iota_const_simplify<1>;sign_abs_simplify<1>;broadcastindim_is_reshape;slice_reduce_window<1>;while_deadresult;dus_licm(0);dus_pad;dus_concat;slice_dus_to_concat;slice_licm(0);pad_licm(0);elementwise_licm(0);concatenate_licm(0);slice_broadcast;associative_common_mul_op_reordering;slice_select_to_select_slice;pad_concat_to_concat_pad;slice_if;dus_to_i32;rotate_pad;slice_extend;concat_wrap;cse_extend<16>;cse_wrap<16>;cse_rotate<16>;cse_rotate<16>;concat_concat_axis_swap;concat_multipad;concat_concat_to_dus;speculate_if_pad_to_select;broadcast_iota_simplify;select_comp_iota_to_dus;compare_cleanup;broadcast_compare;not_compare;broadcast_iota;cse_iota;compare_iota_const_simplify;reshuffle_ands_compares;square_abs_simplify;divide_divide_simplify;concat_reshape_slice;full_reduce_reshape_or_transpose;concat_reshape_reduce;concat_elementwise;reduce_reduce;conj_real;select_broadcast_in_dim;if_op_lift_common_ops;involution_neg_simplify;involution_conj_simplify;involution_not_simplify;real_conj_simplify;conj_complex_simplify;split_convolution_into_reverse_convolution;scatter_multiply_simplify;unary_elementwise_scatter_simplify;gather_elementwise;chlo_inf_const_prop<16>;gamma_const_prop<16>;abs_const_prop<16>;log_const_prop<1>;log_plus_one_const_prop<1>;is_finite_const_prop;not_const_prop;neg_const_prop;sqrt_const_prop;rsqrt_const_prop;cos_const_prop;sin_const_prop;exp_const_prop;expm1_const_prop;tanh_const_prop;logistic_const_prop;conj_const_prop;ceil_const_prop;cbrt_const_prop;real_const_prop;imag_const_prop;round_const_prop;round_nearest_even_const_prop;sign_const_prop;floor_const_prop;tan_const_prop;add_const_prop;and_const_prop;atan2_const_prop;complex_const_prop;div_const_prop;max_const_prop;min_const_prop;mul_const_prop;or_const_prop;pow_const_prop;rem_const_prop;sub_const_prop;xor_const_prop;const_prop_through_barrier<16>;concat_const_prop<1>(1024);dynamic_update_slice_const_prop(1024);scatter_update_computation_const_prop;gather_const_prop;dus_slice_simplify;reshape_concat;reshape_dus;dot_reshape_pad<1>;pad_dot_general<1>(0);pad_dot_general<1>(1);reshape_pad;reshape_wrap;reshape_rotate;reshape_extend;reshape_slice(1);reshape_elementwise(1);transpose_select;transpose_while;transpose_slice;transpose_concat;transpose_iota;transpose_reduce;transpose_reduce_window;transpose_dus;transpose_pad<1>;transpose_einsum<1>;transpose_wrap;transpose_extend;transpose_rotate;transpose_dynamic_slice;transpose_reverse;transpose_batch_norm_training;transpose_batch_norm_inference;transpose_batch_norm_grad;transpose_if;transpose_elementwise(1);no_nan_add_sub_simplify(0);lower_extend;lower_wrap;lower_rotate},transform-interpreter,enzyme-hlo-remove-transform,enzyme"
pass_pipeline = pre_pass_pipeline*"{postpasses=\"arith-raise{stablehlo=true},canonicalize,cse,canonicalize,remove-unnecessary-enzyme-ops,enzyme-simplify-math,canonicalize,cse,canonicalize\"},symbol-dce"
tic = time()
restimate_tracer_error = @compile optimize=pass_pipeline raise_first=true raise=true sync=true estimate_tracer_error(rwind_stress)
println(@code_hlo optimize=pass_pipeline raise_first=true raise=true estimate_tracer_error(rwind_stress))
rdifferentiate_tracer_error = @compile optimize=pass_pipeline raise_first=true raise=true sync=true differentiate_tracer_error(rwind_stress)
println(@code_hlo optimize=pre_pass_pipeline raise_first=true raise=true differentiate_tracer_error(rwind_stress))
println(@code_hlo optimize=pass_pipeline raise_first=true raise=true differentiate_tracer_error(rwind_stress))
compile_toc = time() - tic
@show compile_toc
@info "Running..."
@show restimate_tracer_error(rwind_stress)
@info "Running non-reactant for comparison..."
@info "Initialized non-reactant model"
vwind_stress = wind_stress_init(true)
@info "Initialized non-reactant tracers and wind stress"
i = 10
j = 10
dedν, dJ = rdifferentiate_tracer_error(rwind_stress)
@allowscalar @show dJ[i, j]
# Produce finite-difference gradients for comparison:
ϵ_list = [1e-1, 1e-2, 1e-3] #, 1e-4, 1e-5, 1e-6, 1e-7, 1e-8]
for ϵ in ϵ_list
rwind_stressP = wind_stress_init(false)
@allowscalar diff = 2ϵ * abs(rwind_stressP[i, j])
@allowscalar rwind_stressP[i, j] = rwind_stressP[i, j] + ϵ * abs(rwind_stressP[i, j])
sq_surface_uP = restimate_tracer_error(rwind_stressP)
rwind_stressM = wind_stress_init(false)
@allowscalar rwind_stressM[i, j] = rwind_stressM[i, j] - ϵ * abs(rwind_stressM[i, j])
sq_surface_uM = restimate_tracer_error(rwind_stressM)
dsq_surface_u = (sq_surface_uP - sq_surface_uM) / diff
@show ϵ, dsq_surface_u
end
[ Info: Compiling...
module @reactant_estimat... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func @main(%arg0: tensor<63x63xf64>) -> tensor<f64> {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%c = stablehlo.constant dense<0> : tensor<i64>
%c_0 = stablehlo.constant dense<3> : tensor<i64>
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<14> : tensor<i32>
%c_3 = stablehlo.constant dense<7> : tensor<i32>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<78x78x31xf64>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%0 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<63x63xf64>) -> tensor<63x63x1xf64>
%1 = stablehlo.dynamic_update_slice %cst_4, %0, %c_3, %c_3, %c_2 : (tensor<78x78x31xf64>, tensor<63x63x1xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%2:3 = stablehlo.while(%iterArg = %c, %iterArg_6 = %1, %iterArg_7 = %cst_5) : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64> attributes {enzymexla.disable_min_cut}
cond {
%4 = stablehlo.compare LT, %iterArg, %c_0 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %4 : tensor<i1>
} do {
%4 = stablehlo.add %iterArg, %c_1 : tensor<i64>
%5 = stablehlo.slice %iterArg_6 [7:70, 7:70, 7:23] : (tensor<78x78x31xf64>) -> tensor<63x63x16xf64>
%6 = stablehlo.add %5, %iterArg_7 : tensor<63x63x16xf64>
%7 = stablehlo.dynamic_update_slice %iterArg_6, %6, %c_3, %c_3, %c_3 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%8 = stablehlo.slice %7 [8:71, 6:69, 14:15] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%9 = stablehlo.slice %7 [8:71, 6:69, 7:8] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%10 = stablehlo.slice %7 [8:71, 6:69, 9:23] : (tensor<78x78x31xf64>) -> tensor<63x63x14xf64>
%11 = stablehlo.concatenate %9, %8, %10, dim = 2 : (tensor<63x63x1xf64>, tensor<63x63x1xf64>, tensor<63x63x14xf64>) -> tensor<63x63x16xf64>
%12 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<78x78x31xf64>) -> tensor<31x78x78xf64>
%13 = stablehlo.slice %12 [8:9, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<1x63x63xf64>
%14 = stablehlo.reshape %13 : (tensor<1x63x63xf64>) -> tensor<63x63xf64>
%15 = stablehlo.broadcast_in_dim %14, dims = [1, 0] : (tensor<63x63xf64>) -> tensor<63x63x16xf64>
%16 = stablehlo.dynamic_update_slice %iterArg_6, %15, %c_3, %c_3, %c_3 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
stablehlo.return %4, %16, %11 : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64>
}
%3 = stablehlo.reduce(%2#2 init: %cst) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<63x63x16xf64>, tensor<f64>) -> tensor<f64>
return %3 : tensor<f64>
}
}
module @reactant_differe... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func private @"Const{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0: tensor<63x63xf64>) -> (tensor<f64>, tensor<63x63xf64>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%c = stablehlo.constant dense<0> : tensor<i64>
%c_0 = stablehlo.constant dense<3> : tensor<i64>
%c_1 = stablehlo.constant dense<1> : tensor<i64>
%c_2 = stablehlo.constant dense<14> : tensor<i32>
%c_3 = stablehlo.constant dense<7> : tensor<i32>
%cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<78x78x31xf64>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%0 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<63x63xf64>) -> tensor<63x63x1xf64>
%1 = stablehlo.dynamic_update_slice %cst_4, %0, %c_3, %c_3, %c_2 : (tensor<78x78x31xf64>, tensor<63x63x1xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%2:3 = stablehlo.while(%iterArg = %c, %iterArg_6 = %1, %iterArg_7 = %cst_5) : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64> attributes {enzymexla.disable_min_cut}
cond {
%4 = stablehlo.compare LT, %iterArg, %c_0 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %4 : tensor<i1>
} do {
%4 = stablehlo.add %iterArg, %c_1 : tensor<i64>
%5 = stablehlo.slice %iterArg_6 [7:70, 7:70, 7:23] : (tensor<78x78x31xf64>) -> tensor<63x63x16xf64>
%6 = stablehlo.add %5, %iterArg_7 : tensor<63x63x16xf64>
%7 = stablehlo.dynamic_update_slice %iterArg_6, %6, %c_3, %c_3, %c_3 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%8 = stablehlo.slice %7 [8:71, 6:69, 14:15] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%9 = stablehlo.slice %7 [8:71, 6:69, 7:8] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%10 = stablehlo.slice %7 [8:71, 6:69, 9:23] : (tensor<78x78x31xf64>) -> tensor<63x63x14xf64>
%11 = stablehlo.concatenate %9, %8, %10, dim = 2 : (tensor<63x63x1xf64>, tensor<63x63x1xf64>, tensor<63x63x14xf64>) -> tensor<63x63x16xf64>
%12 = stablehlo.transpose %7, dims = [2, 1, 0] : (tensor<78x78x31xf64>) -> tensor<31x78x78xf64>
%13 = stablehlo.slice %12 [8:9, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<1x63x63xf64>
%14 = stablehlo.reshape %13 : (tensor<1x63x63xf64>) -> tensor<63x63xf64>
%15 = stablehlo.broadcast_in_dim %14, dims = [1, 0] : (tensor<63x63xf64>) -> tensor<63x63x16xf64>
%16 = stablehlo.dynamic_update_slice %iterArg_6, %15, %c_3, %c_3, %c_3 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
stablehlo.return %4, %16, %11 : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64>
}
%3 = stablehlo.reduce(%2#2 init: %cst) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<63x63x16xf64>, tensor<f64>) -> tensor<f64>
return %3, %arg0 : tensor<f64>, tensor<63x63xf64>
}
func.func private @"diffeConst{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0: tensor<63x63xf64>, %arg1: tensor<f64>, %arg2: tensor<63x63xf64>) -> (tensor<63x63xf64>, tensor<63x63xf64>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%0 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x1xf64>>
%cst = arith.constant dense<0.000000e+00> : tensor<63x63x1xf64>
"enzyme.set"(%0, %cst) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%1 = "enzyme.init"() : () -> !enzyme.Cache<tensor<i32>>
%2 = "enzyme.init"() : () -> !enzyme.Cache<tensor<i32>>
%3 = "enzyme.init"() : () -> !enzyme.Cache<tensor<i32>>
%4 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<78x78x31xf64>>
%cst_0 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%4, %cst_0) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%5 = "enzyme.init"() : () -> !enzyme.Cache<tensor<i32>>
%6 = "enzyme.init"() : () -> !enzyme.Cache<tensor<i32>>
%7 = "enzyme.init"() : () -> !enzyme.Cache<tensor<i32>>
%8 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<78x78x31xf64>>
%cst_1 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%8, %cst_1) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%9 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x16xf64>>
%cst_2 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%9, %cst_2) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%10 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x14xf64>>
%cst_3 = arith.constant dense<0.000000e+00> : tensor<63x63x14xf64>
"enzyme.set"(%10, %cst_3) : (!enzyme.Gradient<tensor<63x63x14xf64>>, tensor<63x63x14xf64>) -> ()
%11 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x1xf64>>
%cst_4 = arith.constant dense<0.000000e+00> : tensor<63x63x1xf64>
"enzyme.set"(%11, %cst_4) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%12 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x1xf64>>
%cst_5 = arith.constant dense<0.000000e+00> : tensor<63x63x1xf64>
"enzyme.set"(%12, %cst_5) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%13 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<78x78x31xf64>>
%cst_6 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%13, %cst_6) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%14 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x16xf64>>
%cst_7 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%14, %cst_7) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%15 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x16xf64>>
%cst_8 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%15, %cst_8) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%16 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x16xf64>>
%cst_9 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%16, %cst_9) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%17 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<78x78x31xf64>>
%cst_10 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%17, %cst_10) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%18 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<78x78x31xf64>>
%cst_11 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%18, %cst_11) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%19 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63x16xf64>>
%cst_12 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%19, %cst_12) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%20 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<63x63xf64>>
%cst_13 = arith.constant dense<0.000000e+00> : tensor<63x63xf64>
"enzyme.set"(%20, %cst_13) : (!enzyme.Gradient<tensor<63x63xf64>>, tensor<63x63xf64>) -> ()
%21 = "enzyme.init"() : () -> !enzyme.Gradient<tensor<f64>>
%cst_14 = arith.constant dense<0.000000e+00> : tensor<f64>
"enzyme.set"(%21, %cst_14) : (!enzyme.Gradient<tensor<f64>>, tensor<f64>) -> ()
%cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%c = stablehlo.constant dense<0> : tensor<i64>
%c_16 = stablehlo.constant dense<3> : tensor<i64>
%c_17 = stablehlo.constant dense<1> : tensor<i64>
%c_18 = stablehlo.constant dense<14> : tensor<i32>
%c_19 = stablehlo.constant dense<7> : tensor<i32>
%cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<78x78x31xf64>
%cst_21 = stablehlo.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%22 = stablehlo.broadcast_in_dim %arg0, dims = [1, 0] : (tensor<63x63xf64>) -> tensor<63x63x1xf64>
"enzyme.push"(%3, %c_19) : (!enzyme.Cache<tensor<i32>>, tensor<i32>) -> ()
"enzyme.push"(%2, %c_19) : (!enzyme.Cache<tensor<i32>>, tensor<i32>) -> ()
"enzyme.push"(%1, %c_18) : (!enzyme.Cache<tensor<i32>>, tensor<i32>) -> ()
%23 = stablehlo.dynamic_update_slice %cst_20, %22, %c_19, %c_19, %c_18 : (tensor<78x78x31xf64>, tensor<63x63x1xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%24:3 = stablehlo.while(%iterArg = %c, %iterArg_30 = %23, %iterArg_31 = %cst_21) : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64> attributes {enzymexla.disable_min_cut}
cond {
%54 = stablehlo.compare LT, %iterArg, %c_16 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %54 : tensor<i1>
} do {
%54 = stablehlo.add %iterArg, %c_17 : tensor<i64>
%55 = stablehlo.slice %iterArg_30 [7:70, 7:70, 7:23] : (tensor<78x78x31xf64>) -> tensor<63x63x16xf64>
%56 = stablehlo.add %55, %iterArg_31 : tensor<63x63x16xf64>
"enzyme.push"(%7, %c_19) : (!enzyme.Cache<tensor<i32>>, tensor<i32>) -> ()
"enzyme.push"(%6, %c_19) : (!enzyme.Cache<tensor<i32>>, tensor<i32>) -> ()
"enzyme.push"(%5, %c_19) : (!enzyme.Cache<tensor<i32>>, tensor<i32>) -> ()
%57 = stablehlo.dynamic_update_slice %iterArg_30, %56, %c_19, %c_19, %c_19 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%58 = stablehlo.slice %57 [8:71, 6:69, 14:15] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%59 = stablehlo.slice %57 [8:71, 6:69, 7:8] : (tensor<78x78x31xf64>) -> tensor<63x63x1xf64>
%60 = stablehlo.slice %57 [8:71, 6:69, 9:23] : (tensor<78x78x31xf64>) -> tensor<63x63x14xf64>
%61 = stablehlo.concatenate %59, %58, %60, dim = 2 : (tensor<63x63x1xf64>, tensor<63x63x1xf64>, tensor<63x63x14xf64>) -> tensor<63x63x16xf64>
%62 = stablehlo.transpose %57, dims = [2, 1, 0] : (tensor<78x78x31xf64>) -> tensor<31x78x78xf64>
%63 = stablehlo.slice %62 [8:9, 7:70, 7:70] : (tensor<31x78x78xf64>) -> tensor<1x63x63xf64>
%64 = stablehlo.reshape %63 : (tensor<1x63x63xf64>) -> tensor<63x63xf64>
%65 = stablehlo.broadcast_in_dim %64, dims = [1, 0] : (tensor<63x63xf64>) -> tensor<63x63x16xf64>
%66 = stablehlo.dynamic_update_slice %iterArg_30, %65, %c_19, %c_19, %c_19 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
stablehlo.return %54, %66, %61 : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64>
}
%25 = stablehlo.reduce(%24#2 init: %cst_15) applies stablehlo.add across dimensions = [0, 1, 2] : (tensor<63x63x16xf64>, tensor<f64>) -> tensor<f64>
cf.br ^bb1
^bb1: // pred: ^bb0
%26 = "enzyme.get"(%21) : (!enzyme.Gradient<tensor<f64>>) -> tensor<f64>
%27 = arith.addf %26, %arg1 : tensor<f64>
"enzyme.set"(%21, %27) : (!enzyme.Gradient<tensor<f64>>, tensor<f64>) -> ()
%28 = "enzyme.get"(%20) : (!enzyme.Gradient<tensor<63x63xf64>>) -> tensor<63x63xf64>
%29 = arith.addf %28, %arg2 : tensor<63x63xf64>
"enzyme.set"(%20, %29) : (!enzyme.Gradient<tensor<63x63xf64>>, tensor<63x63xf64>) -> ()
%cst_22 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%30 = "enzyme.get"(%21) : (!enzyme.Gradient<tensor<f64>>) -> tensor<f64>
%cst_23 = arith.constant dense<0.000000e+00> : tensor<f64>
"enzyme.set"(%21, %cst_23) : (!enzyme.Gradient<tensor<f64>>, tensor<f64>) -> ()
%31 = stablehlo.broadcast_in_dim %30, dims = [] : (tensor<f64>) -> tensor<63x63x16xf64>
%32 = "enzyme.get"(%19) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%33 = arith.addf %32, %31 : tensor<63x63x16xf64>
"enzyme.set"(%19, %33) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%c_24 = stablehlo.constant dense<3> : tensor<i64>
%c_25 = stablehlo.constant dense<0> : tensor<i64>
%34 = "enzyme.get"(%18) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%35 = "enzyme.get"(%19) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%cst_26 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%19, %cst_26) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%36:3 = stablehlo.while(%iterArg = %c_25, %iterArg_30 = %34, %iterArg_31 = %35) : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64>
cond {
%54 = stablehlo.compare LT, %iterArg, %c_24 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %54 : tensor<i1>
} do {
%c_32 = stablehlo.constant dense<1> : tensor<i64>
%54 = stablehlo.add %iterArg, %c_32 : tensor<i64>
%cst_33 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%17, %cst_33) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%cst_34 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%16, %cst_34) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%cst_35 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%15, %cst_35) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%cst_36 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%14, %cst_36) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%cst_37 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%13, %cst_37) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%cst_38 = arith.constant dense<0.000000e+00> : tensor<63x63x1xf64>
"enzyme.set"(%12, %cst_38) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%cst_39 = arith.constant dense<0.000000e+00> : tensor<63x63x1xf64>
"enzyme.set"(%11, %cst_39) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%cst_40 = arith.constant dense<0.000000e+00> : tensor<63x63x14xf64>
"enzyme.set"(%10, %cst_40) : (!enzyme.Gradient<tensor<63x63x14xf64>>, tensor<63x63x14xf64>) -> ()
%cst_41 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%9, %cst_41) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
"enzyme.set"(%8, %iterArg_30) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
"enzyme.set"(%9, %iterArg_31) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%55 = "enzyme.get"(%9) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%cst_42 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%9, %cst_42) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%56 = stablehlo.slice %55 [0:63, 0:63, 0:1] : (tensor<63x63x16xf64>) -> tensor<63x63x1xf64>
%57 = stablehlo.reshape %56 : (tensor<63x63x1xf64>) -> tensor<63x63x1xf64>
%58 = "enzyme.get"(%11) : (!enzyme.Gradient<tensor<63x63x1xf64>>) -> tensor<63x63x1xf64>
%59 = arith.addf %58, %57 : tensor<63x63x1xf64>
"enzyme.set"(%11, %59) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%60 = stablehlo.slice %55 [0:63, 0:63, 1:2] : (tensor<63x63x16xf64>) -> tensor<63x63x1xf64>
%61 = stablehlo.reshape %60 : (tensor<63x63x1xf64>) -> tensor<63x63x1xf64>
%62 = "enzyme.get"(%12) : (!enzyme.Gradient<tensor<63x63x1xf64>>) -> tensor<63x63x1xf64>
%63 = arith.addf %62, %61 : tensor<63x63x1xf64>
"enzyme.set"(%12, %63) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%64 = stablehlo.slice %55 [0:63, 0:63, 2:16] : (tensor<63x63x16xf64>) -> tensor<63x63x14xf64>
%65 = stablehlo.reshape %64 : (tensor<63x63x14xf64>) -> tensor<63x63x14xf64>
%66 = "enzyme.get"(%10) : (!enzyme.Gradient<tensor<63x63x14xf64>>) -> tensor<63x63x14xf64>
%67 = arith.addf %66, %65 : tensor<63x63x14xf64>
"enzyme.set"(%10, %67) : (!enzyme.Gradient<tensor<63x63x14xf64>>, tensor<63x63x14xf64>) -> ()
%68 = "enzyme.get"(%10) : (!enzyme.Gradient<tensor<63x63x14xf64>>) -> tensor<63x63x14xf64>
%cst_43 = arith.constant dense<0.000000e+00> : tensor<63x63x14xf64>
"enzyme.set"(%10, %cst_43) : (!enzyme.Gradient<tensor<63x63x14xf64>>, tensor<63x63x14xf64>) -> ()
%cst_44 = arith.constant dense<0.000000e+00> : tensor<f64>
%69 = stablehlo.pad %68, %cst_44, low = [8, 6, 9], high = [7, 9, 8], interior = [0, 0, 0] : (tensor<63x63x14xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%70 = "enzyme.get"(%13) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%71 = arith.addf %70, %69 : tensor<78x78x31xf64>
"enzyme.set"(%13, %71) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%72 = "enzyme.get"(%11) : (!enzyme.Gradient<tensor<63x63x1xf64>>) -> tensor<63x63x1xf64>
%cst_45 = arith.constant dense<0.000000e+00> : tensor<63x63x1xf64>
"enzyme.set"(%11, %cst_45) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%cst_46 = arith.constant dense<0.000000e+00> : tensor<f64>
%73 = stablehlo.pad %72, %cst_46, low = [8, 6, 7], high = [7, 9, 23], interior = [0, 0, 0] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%74 = "enzyme.get"(%13) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%75 = arith.addf %74, %73 : tensor<78x78x31xf64>
"enzyme.set"(%13, %75) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%76 = "enzyme.get"(%12) : (!enzyme.Gradient<tensor<63x63x1xf64>>) -> tensor<63x63x1xf64>
%cst_47 = arith.constant dense<0.000000e+00> : tensor<63x63x1xf64>
"enzyme.set"(%12, %cst_47) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%cst_48 = arith.constant dense<0.000000e+00> : tensor<f64>
%77 = stablehlo.pad %76, %cst_48, low = [8, 6, 14], high = [7, 9, 16], interior = [0, 0, 0] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%78 = "enzyme.get"(%13) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%79 = arith.addf %78, %77 : tensor<78x78x31xf64>
"enzyme.set"(%13, %79) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%80 = "enzyme.pop"(%7) : (!enzyme.Cache<tensor<i32>>) -> tensor<i32>
%81 = "enzyme.pop"(%6) : (!enzyme.Cache<tensor<i32>>) -> tensor<i32>
%82 = "enzyme.pop"(%5) : (!enzyme.Cache<tensor<i32>>) -> tensor<i32>
%83 = "enzyme.get"(%13) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%cst_49 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%84 = stablehlo.dynamic_update_slice %83, %cst_49, %80, %81, %82 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%85 = "enzyme.get"(%17) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%86 = arith.addf %85, %84 : tensor<78x78x31xf64>
"enzyme.set"(%17, %86) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%87 = stablehlo.dynamic_slice %83, %80, %81, %82, sizes = [63, 63, 16] : (tensor<78x78x31xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<63x63x16xf64>
%88 = "enzyme.get"(%14) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%89 = arith.addf %88, %87 : tensor<63x63x16xf64>
"enzyme.set"(%14, %89) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%cst_50 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%13, %cst_50) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%90 = "enzyme.get"(%14) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%cst_51 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%14, %cst_51) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%91 = "enzyme.get"(%15) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%92 = arith.addf %91, %90 : tensor<63x63x16xf64>
"enzyme.set"(%15, %92) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%93 = "enzyme.get"(%16) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%94 = arith.addf %93, %90 : tensor<63x63x16xf64>
"enzyme.set"(%16, %94) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%95 = "enzyme.get"(%15) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%cst_52 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%15, %cst_52) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
%cst_53 = arith.constant dense<0.000000e+00> : tensor<f64>
%96 = stablehlo.pad %95, %cst_53, low = [7, 7, 7], high = [8, 8, 8], interior = [0, 0, 0] : (tensor<63x63x16xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%97 = "enzyme.get"(%17) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%98 = arith.addf %97, %96 : tensor<78x78x31xf64>
"enzyme.set"(%17, %98) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%99 = "enzyme.get"(%17) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%cst_54 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%17, %cst_54) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%100 = "enzyme.get"(%16) : (!enzyme.Gradient<tensor<63x63x16xf64>>) -> tensor<63x63x16xf64>
%cst_55 = arith.constant dense<0.000000e+00> : tensor<63x63x16xf64>
"enzyme.set"(%16, %cst_55) : (!enzyme.Gradient<tensor<63x63x16xf64>>, tensor<63x63x16xf64>) -> ()
stablehlo.return %54, %99, %100 : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64>
}
%37 = "enzyme.get"(%4) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%38 = arith.addf %37, %36#1 : tensor<78x78x31xf64>
"enzyme.set"(%4, %38) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%39 = "enzyme.pop"(%3) : (!enzyme.Cache<tensor<i32>>) -> tensor<i32>
%40 = "enzyme.pop"(%2) : (!enzyme.Cache<tensor<i32>>) -> tensor<i32>
%41 = "enzyme.pop"(%1) : (!enzyme.Cache<tensor<i32>>) -> tensor<i32>
%42 = "enzyme.get"(%4) : (!enzyme.Gradient<tensor<78x78x31xf64>>) -> tensor<78x78x31xf64>
%43 = stablehlo.dynamic_slice %42, %39, %40, %41, sizes = [63, 63, 1] : (tensor<78x78x31xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<63x63x1xf64>
%44 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<63x63x1xf64>>) -> tensor<63x63x1xf64>
%45 = arith.addf %44, %43 : tensor<63x63x1xf64>
"enzyme.set"(%0, %45) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%cst_27 = arith.constant dense<0.000000e+00> : tensor<78x78x31xf64>
"enzyme.set"(%4, %cst_27) : (!enzyme.Gradient<tensor<78x78x31xf64>>, tensor<78x78x31xf64>) -> ()
%46 = "enzyme.get"(%0) : (!enzyme.Gradient<tensor<63x63x1xf64>>) -> tensor<63x63x1xf64>
%cst_28 = arith.constant dense<0.000000e+00> : tensor<63x63x1xf64>
"enzyme.set"(%0, %cst_28) : (!enzyme.Gradient<tensor<63x63x1xf64>>, tensor<63x63x1xf64>) -> ()
%cst_29 = arith.constant dense<0.000000e+00> : tensor<f64>
%47 = stablehlo.reduce(%46 init: %cst_29) applies stablehlo.add across dimensions = [2] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<63x63xf64>
%48 = stablehlo.reshape %47 : (tensor<63x63xf64>) -> tensor<63x63x1xf64>
%49 = stablehlo.transpose %48, dims = [1, 0, 2] : (tensor<63x63x1xf64>) -> tensor<63x63x1xf64>
%50 = stablehlo.reshape %49 : (tensor<63x63x1xf64>) -> tensor<63x63xf64>
%51 = "enzyme.get"(%20) : (!enzyme.Gradient<tensor<63x63xf64>>) -> tensor<63x63xf64>
%52 = arith.addf %51, %50 : tensor<63x63xf64>
"enzyme.set"(%20, %52) : (!enzyme.Gradient<tensor<63x63xf64>>, tensor<63x63xf64>) -> ()
%53 = "enzyme.get"(%20) : (!enzyme.Gradient<tensor<63x63xf64>>) -> tensor<63x63xf64>
return %arg0, %53 : tensor<63x63xf64>, tensor<63x63xf64>
}
func.func @main(%arg0: tensor<63x63xf64> {tf.aliasing_output = 1 : i32}) -> (tensor<63x63xf64>, tensor<63x63xf64>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f64>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<63x63xf64>
%0:2 = call @"diffeConst{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0, %cst, %cst_0) : (tensor<63x63xf64>, tensor<f64>, tensor<63x63xf64>) -> (tensor<63x63xf64>, tensor<63x63xf64>)
return %0#1, %0#0 : tensor<63x63xf64>, tensor<63x63xf64>
}
}
module @reactant_differe... attributes {mhlo.num_partitions = 1 : i64, mhlo.num_replicas = 1 : i64} {
func.func private @"diffeConst{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0: tensor<63x63xf64>, %arg1: tensor<f64>, %arg2: tensor<63x63xf64>) -> (tensor<63x63xf64>, tensor<63x63xf64>) attributes {enzymexla.memory_effects = ["read", "write", "allocate", "free"]} {
%c = stablehlo.constant dense<2> : tensor<i64>
%c_0 = stablehlo.constant dense<7> : tensor<i32>
%c_1 = stablehlo.constant dense<14> : tensor<i32>
%c_2 = stablehlo.constant dense<1> : tensor<i64>
%c_3 = stablehlo.constant dense<3> : tensor<i64>
%c_4 = stablehlo.constant dense<0> : tensor<i64>
%cst = stablehlo.constant dense<0.000000e+00> : tensor<f64>
%cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<63x63x16xf64>
%cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<78x78x31xf64>
%0 = stablehlo.broadcast_in_dim %arg1, dims = [] : (tensor<f64>) -> tensor<63x63x16xf64>
%1:4 = stablehlo.while(%iterArg = %c_4, %iterArg_7 = %cst_6, %iterArg_8 = %0, %iterArg_9 = %c) : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i64>
cond {
%8 = stablehlo.compare LT, %iterArg, %c_3 : (tensor<i64>, tensor<i64>) -> tensor<i1>
stablehlo.return %8 : tensor<i1>
} do {
%8 = stablehlo.add %iterArg, %c_2 : tensor<i64>
%9 = stablehlo.slice %iterArg_8 [0:63, 0:63, 0:1] : (tensor<63x63x16xf64>) -> tensor<63x63x1xf64>
%10 = stablehlo.reshape %9 : (tensor<63x63x1xf64>) -> tensor<63x63x1xf64>
%11 = stablehlo.slice %iterArg_8 [0:63, 0:63, 1:2] : (tensor<63x63x16xf64>) -> tensor<63x63x1xf64>
%12 = stablehlo.reshape %11 : (tensor<63x63x1xf64>) -> tensor<63x63x1xf64>
%13 = stablehlo.slice %iterArg_8 [0:63, 0:63, 2:16] : (tensor<63x63x16xf64>) -> tensor<63x63x14xf64>
%14 = stablehlo.reshape %13 : (tensor<63x63x14xf64>) -> tensor<63x63x14xf64>
%15 = stablehlo.pad %14, %cst, low = [8, 6, 9], high = [7, 9, 8], interior = [0, 0, 0] : (tensor<63x63x14xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%16 = stablehlo.pad %10, %cst, low = [8, 6, 7], high = [7, 9, 23], interior = [0, 0, 0] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%17 = stablehlo.add %15, %16 : tensor<78x78x31xf64>
%18 = stablehlo.pad %12, %cst, low = [8, 6, 14], high = [7, 9, 16], interior = [0, 0, 0] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%19 = stablehlo.add %17, %18 : tensor<78x78x31xf64>
%20 = stablehlo.dynamic_update_slice %19, %cst_5, %c_0, %c_0, %c_0 : (tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<78x78x31xf64>
%21 = stablehlo.dynamic_slice %19, %c_0, %c_0, %c_0, sizes = [63, 63, 16] : (tensor<78x78x31xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<63x63x16xf64>
%22 = stablehlo.pad %21, %cst, low = [7, 7, 7], high = [8, 8, 8], interior = [0, 0, 0] : (tensor<63x63x16xf64>, tensor<f64>) -> tensor<78x78x31xf64>
%23 = stablehlo.add %20, %22 : tensor<78x78x31xf64>
%24 = stablehlo.subtract %iterArg_9, %c_2 : tensor<i64>
stablehlo.return %8, %23, %21, %24 : tensor<i64>, tensor<78x78x31xf64>, tensor<63x63x16xf64>, tensor<i64>
}
%2 = stablehlo.dynamic_slice %1#1, %c_0, %c_0, %c_1, sizes = [63, 63, 1] : (tensor<78x78x31xf64>, tensor<i32>, tensor<i32>, tensor<i32>) -> tensor<63x63x1xf64>
%3 = stablehlo.reduce(%2 init: %cst) applies stablehlo.add across dimensions = [2] : (tensor<63x63x1xf64>, tensor<f64>) -> tensor<63x63xf64>
%4 = stablehlo.reshape %3 : (tensor<63x63xf64>) -> tensor<63x63x1xf64>
%5 = stablehlo.transpose %4, dims = [1, 0, 2] : (tensor<63x63x1xf64>) -> tensor<63x63x1xf64>
%6 = stablehlo.reshape %5 : (tensor<63x63x1xf64>) -> tensor<63x63xf64>
%7 = stablehlo.add %arg2, %6 : tensor<63x63xf64>
return %arg0, %7 : tensor<63x63xf64>, tensor<63x63xf64>
}
func.func @main(%arg0: tensor<63x63xf64> {tf.aliasing_output = 1 : i32}) -> (tensor<63x63xf64>, tensor<63x63xf64>) {
%cst = stablehlo.constant dense<1.000000e+00> : tensor<f64>
%cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<63x63xf64>
%0:2 = call @"diffeConst{typeof(estimate_tracer_error)}(Main.estimate_tracer_error)_autodiff"(%arg0, %cst, %cst_0) : (tensor<63x63xf64>, tensor<f64>, tensor<63x63xf64>) -> (tensor<63x63xf64>, tensor<63x63xf64>)
return %0#1, %0#0 : tensor<63x63xf64>, tensor<63x63xf64>
}
}
compile_toc = 5.39072585105896
[ Info: Running...
restimate_tracer_error(rwind_stress) = ConcreteIFRTArray{Float64, 0, Reactant.Sharding.ShardInfo{Reactant.Sharding.NoSharding, Nothing}, Nothing}(66736.0)
[ Info: Running non-reactant for comparison...
[ Info: Initialized non-reactant model
[ Info: Initialized non-reactant tracers and wind stress
dJ[i, j] = 2.0
(ϵ, dsq_surface_u) = (0.1, 18.000000000029104)
(ϵ, dsq_surface_u) = (0.01, 18.0000000007567)
(ϵ, dsq_surface_u) = (0.001, 18.00000001094304)
Metadata
Metadata
Assignees
Labels
No labels