Skip to content

[2nd Order AD] Regularization term in loss function #449

@avik-pal

Description

@avik-pal

xref https://discourse.julialang.org/t/second-order-gradient-with-lux-zygote-cuda-enzyme/124301

using Lux, Random, OneHotArrays
using Reactant, Enzyme

model = Chain(
    Conv((5, 5), 1 => 6, relu),
    MaxPool((2, 2)),
    Conv((5, 5), 6 => 16, relu),
    MaxPool((2, 2)),
    FlattenLayer(3),
    Chain(
        Dense(256 => 128, relu),
        Dense(128 => 84, relu),
        Dense(84 => 2)
    )
)

dev = reactant_device(; force=true)

ps, st = Lux.setup(Random.default_rng(), model) |> dev;

x = randn(Float32, 28, 28, 1, 32) |> dev;
δ = randn(Float32, 28, 28, 1, 32) |> dev;
y = onehotbatch(rand((1, 2), 32), 1:2) |> dev;

const celoss = CrossEntropyLoss(; logits=true)
const regloss = MSELoss()

function loss_function(model, ps, st, x, y)
    pred, _ = model(x, ps, st)
    return celoss(pred, y)
end

function ∂xloss_function(model, ps, st, x, δ, y)
    ∂x = Enzyme.gradient(
        Reverse, loss_function, Const(model), Const(ps), Const(st), x, Const(y))[4]
    return regloss(∂x, δ) + loss_function(model, ps, st, x, y)
end

@code_hlo ∂xloss_function(model, ps, st, x, δ, y)

function ∂∂xloss_function(model, ps, st, x, δ, y)
    return Enzyme.gradient(
        Reverse, ∂xloss_function, Const(model), ps, Const(st), Const(x), Const(δ), Const(y)
    )[2]
end

@code_hlo optimize=false ∂∂xloss_function(model, ps, st, x, δ, y)

@code_hlo ∂∂xloss_function(model, ps, st, x, δ, y)

# error: could not compute the adjoint for this operation "enzyme.push"(%56, %125) : (!enzyme.Cache<tensor<32xf32>>, tensor<32xf32>) -> ()
# loc("subtract"("/mnt/software/lux/Reactant.jl/src/Ops.jl":193:0)): error: could not compute the adjoint for this operation "enzyme.push"(%51, %123) : (!enzyme.Cache<tensor<2x32xf32>>, tensor<2x32xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%42, %arg8) : (!enzyme.Cache<tensor<84x2xf32>>, tensor<84x2xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%34, %arg6) : (!enzyme.Cache<tensor<128x84xf32>>, tensor<128x84xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%26, %arg4) : (!enzyme.Cache<tensor<256x128xf32>>, tensor<256x128xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%19, %104) : (!enzyme.Cache<tensor<8x8x16x32xf32>>, tensor<8x8x16x32xf32>) -> ()
# loc("reverse"("/mnt/software/lux/Reactant.jl/src/Ops.jl":1038:0)): error: could not compute the adjoint for this operation "enzyme.push"(%11, %99) : (!enzyme.Cache<tensor<5x5x6x16xf32>>, tensor<5x5x6x16xf32>) -> ()
# error: could not compute the adjoint for this operation "enzyme.push"(%8, %97) : (!enzyme.Cache<tensor<24x24x6x32xf32>>, tensor<24x24x6x32xf32>) -> ()
# loc("reverse"("/mnt/software/lux/Reactant.jl/src/Ops.jl":1038:0)): error: could not compute the adjoint for this operation "enzyme.push"(%0, %92) : (!enzyme.Cache<tensor<5x5x1x6xf32>>, tensor<5x5x1x6xf32>) -> ()
Unoptimized IR

module {
  func.func private @"+_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar2"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar2(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar3"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar3(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar4"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"-_broadcast_scalar"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @exp_fast_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.exponential %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @log_fast_broadcast_scalar(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.log %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"-_broadcast_scalar1"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"*_broadcast_scalar"(%arg0: tensor<i1>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i1>, tensor<f32>) {
    %0 = stablehlo.convert %arg0 : (tensor<i1>) -> tensor<f32>
    %1 = stablehlo.multiply %0, %arg1 : tensor<f32>
    return %1, %arg0, %arg1 : tensor<f32>, tensor<i1>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar1(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"-_broadcast_scalar2"(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.negate %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"Const{typeof(loss_function)}(Main.loss_function)_autodiff"(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x2xi1>) -> (tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %2 = stablehlo.transpose %arg2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %4 = stablehlo.transpose %arg4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %5 = stablehlo.transpose %arg5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %6 = stablehlo.transpose %arg6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %7 = stablehlo.transpose %arg7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %10 = stablehlo.transpose %arg10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %11 = stablehlo.transpose %arg11, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x32xf32>
    %12 = stablehlo.reverse %0, dims = [0, 1] : tensor<5x5x1x6xf32>
    %13 = stablehlo.convolution(%10, %12) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<28x28x1x32xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x32xf32>
    %14 = stablehlo.transpose %1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %15 = stablehlo.reshape %14 : (tensor<6xf32>) -> tensor<1x6x1x1xf32>
    %16 = stablehlo.transpose %15, dims = [3, 2, 1, 0] : (tensor<1x6x1x1xf32>) -> tensor<1x1x6x1xf32>
    %17 = stablehlo.broadcast_in_dim %16, dims = [0, 1, 2, 3] : (tensor<1x1x6x1xf32>) -> tensor<24x24x6x32xf32>
    %18:3 = enzyme.batch @"+_broadcast_scalar"(%13, %17) {batch_shape = array<i64: 24, 24, 6, 32>} : (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>) -> (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>)
    %19:2 = enzyme.batch @relu_broadcast_scalar(%18#0) {batch_shape = array<i64: 24, 24, 6, 32>} : (tensor<24x24x6x32xf32>) -> (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>)
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x6x32xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x6x32xf32>
    %cst_2 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %20 = "stablehlo.reduce_window"(%19#0, %cst_2) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg12: tensor<f32>, %arg13: tensor<f32>):
      %86 = stablehlo.maximum %arg12, %arg13 : tensor<f32>
      stablehlo.return %86 : tensor<f32>
    }) : (tensor<24x24x6x32xf32>, tensor<f32>) -> tensor<12x12x6x32xf32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x32xf32>
    %21 = stablehlo.reverse %2, dims = [0, 1] : tensor<5x5x6x16xf32>
    %22 = stablehlo.convolution(%20, %21) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x32xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x32xf32>
    %23 = stablehlo.transpose %3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %24 = stablehlo.reshape %23 : (tensor<16xf32>) -> tensor<1x16x1x1xf32>
    %25 = stablehlo.transpose %24, dims = [3, 2, 1, 0] : (tensor<1x16x1x1xf32>) -> tensor<1x1x16x1xf32>
    %26 = stablehlo.broadcast_in_dim %25, dims = [0, 1, 2, 3] : (tensor<1x1x16x1xf32>) -> tensor<8x8x16x32xf32>
    %27:3 = enzyme.batch @"+_broadcast_scalar1"(%22, %26) {batch_shape = array<i64: 8, 8, 16, 32>} : (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>) -> (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>)
    %28:2 = enzyme.batch @relu_broadcast_scalar1(%27#0) {batch_shape = array<i64: 8, 8, 16, 32>} : (tensor<8x8x16x32xf32>) -> (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>)
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<4x4x16x32xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<4x4x16x32xf32>
    %cst_6 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %29 = "stablehlo.reduce_window"(%28#0, %cst_6) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg12: tensor<f32>, %arg13: tensor<f32>):
      %86 = stablehlo.maximum %arg12, %arg13 : tensor<f32>
      stablehlo.return %86 : tensor<f32>
    }) : (tensor<8x8x16x32xf32>, tensor<f32>) -> tensor<4x4x16x32xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<128x32xf32>
    %30 = stablehlo.transpose %29, dims = [3, 2, 1, 0] : (tensor<4x4x16x32xf32>) -> tensor<32x16x4x4xf32>
    %31 = stablehlo.reshape %30 : (tensor<32x16x4x4xf32>) -> tensor<32x256xf32>
    %32 = stablehlo.transpose %31, dims = [1, 0] : (tensor<32x256xf32>) -> tensor<256x32xf32>
    %33 = stablehlo.dot_general %4, %32, contracting_dims = [1] x [0] : (tensor<128x256xf32>, tensor<256x32xf32>) -> tensor<128x32xf32>
    %34 = stablehlo.transpose %5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %35 = stablehlo.reshape %34 : (tensor<128xf32>) -> tensor<1x128xf32>
    %36 = stablehlo.transpose %35, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %37 = stablehlo.broadcast_in_dim %36, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x32xf32>
    %38:3 = enzyme.batch @"+_broadcast_scalar2"(%33, %37) {batch_shape = array<i64: 128, 32>} : (tensor<128x32xf32>, tensor<128x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>, tensor<128x32xf32>)
    %39:2 = enzyme.batch @relu_broadcast_scalar2(%38#0) {batch_shape = array<i64: 128, 32>} : (tensor<128x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>)
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<84x32xf32>
    %40 = stablehlo.dot_general %6, %39#0, contracting_dims = [1] x [0] : (tensor<84x128xf32>, tensor<128x32xf32>) -> tensor<84x32xf32>
    %41 = stablehlo.transpose %7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %42 = stablehlo.reshape %41 : (tensor<84xf32>) -> tensor<1x84xf32>
    %43 = stablehlo.transpose %42, dims = [1, 0] : (tensor<1x84xf32>) -> tensor<84x1xf32>
    %44 = stablehlo.broadcast_in_dim %43, dims = [0, 1] : (tensor<84x1xf32>) -> tensor<84x32xf32>
    %45:3 = enzyme.batch @"+_broadcast_scalar3"(%40, %44) {batch_shape = array<i64: 84, 32>} : (tensor<84x32xf32>, tensor<84x32xf32>) -> (tensor<84x32xf32>, tensor<84x32xf32>, tensor<84x32xf32>)
    %46:2 = enzyme.batch @relu_broadcast_scalar3(%45#0) {batch_shape = array<i64: 84, 32>} : (tensor<84x32xf32>) -> (tensor<84x32xf32>, tensor<84x32xf32>)
    %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<2x32xf32>
    %47 = stablehlo.dot_general %8, %46#0, contracting_dims = [1] x [0] : (tensor<2x84xf32>, tensor<84x32xf32>) -> tensor<2x32xf32>
    %48 = stablehlo.broadcast_in_dim %9, dims = [0] : (tensor<2xf32>) -> tensor<2x32xf32>
    %49:3 = enzyme.batch @"+_broadcast_scalar4"(%47, %48) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_11 = stablehlo.constant dense<1.1920929E-7> : tensor<f32>
    %cst_12 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<2x32xf32>
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_15 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %50 = enzyme.batch @identity_broadcast_scalar(%49#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> tensor<2x32xf32>
    %51 = stablehlo.reduce(%50 init: %cst_15) applies stablehlo.maximum across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %52 = stablehlo.transpose %51, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %53 = stablehlo.reshape %52 : (tensor<32xf32>) -> tensor<32x1xf32>
    %54 = stablehlo.transpose %53, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %55 = stablehlo.broadcast_in_dim %54, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<2x32xf32>
    %56:3 = enzyme.batch @"-_broadcast_scalar"(%49#0, %55) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %57:2 = enzyme.batch @exp_fast_broadcast_scalar(%56#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>)
    %58 = stablehlo.reduce(%57#0 init: %cst_16) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %59 = stablehlo.transpose %58, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %60 = stablehlo.reshape %59 : (tensor<32xf32>) -> tensor<32x1xf32>
    %61 = stablehlo.transpose %60, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %62:2 = enzyme.batch @log_fast_broadcast_scalar(%61) {batch_shape = array<i64: 1, 32>} : (tensor<1x32xf32>) -> (tensor<1x32xf32>, tensor<1x32xf32>)
    %63 = stablehlo.broadcast_in_dim %62#0, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<2x32xf32>
    %64:3 = enzyme.batch @"-_broadcast_scalar1"(%57#1, %63) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %65:3 = enzyme.batch @"*_broadcast_scalar"(%11, %64#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xi1>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xi1>, tensor<2x32xf32>)
    %cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %66 = enzyme.batch @identity_broadcast_scalar1(%65#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> tensor<2x32xf32>
    %67 = stablehlo.reduce(%66 init: %cst_17) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %68 = stablehlo.transpose %67, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %69 = stablehlo.reshape %68 : (tensor<32xf32>) -> tensor<32x1xf32>
    %70 = stablehlo.transpose %69, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %71:2 = enzyme.batch @"-_broadcast_scalar2"(%70) {batch_shape = array<i64: 1, 32>} : (tensor<1x32xf32>) -> (tensor<1x32xf32>, tensor<1x32xf32>)
    %72 = stablehlo.reduce(%71#0 init: %cst_18) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<f32>
    %cst_19 = stablehlo.constant dense<3.200000e+01> : tensor<f32>
    %73 = stablehlo.divide %72, %cst_19 : tensor<f32>
    %74 = stablehlo.transpose %0, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %75 = stablehlo.transpose %1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %76 = stablehlo.transpose %2, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %77 = stablehlo.transpose %3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %78 = stablehlo.transpose %4, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %79 = stablehlo.transpose %5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %80 = stablehlo.transpose %6, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %81 = stablehlo.transpose %7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %82 = stablehlo.transpose %8, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %83 = stablehlo.transpose %9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %84 = stablehlo.transpose %10, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %85 = stablehlo.transpose %65#1, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    return %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85 : tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
  func.func private @l2_distance_loss_broadcast_scalar(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    %1 = stablehlo.abs %0 : tensor<f32>
    %2 = stablehlo.multiply %1, %1 : tensor<f32>
    return %2, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar2(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"+_broadcast_scalar5"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar4(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar6"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar5(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar7"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar6(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar8"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @relu_broadcast_scalar7(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %0 = stablehlo.compare  LT, %arg0, %cst : (tensor<f32>, tensor<f32>) -> tensor<i1>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %1 = stablehlo.select %0, %cst_0, %arg0 : tensor<i1>, tensor<f32>
    return %1, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"+_broadcast_scalar9"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.add %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar3(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"-_broadcast_scalar3"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @exp_fast_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.exponential %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @log_fast_broadcast_scalar1(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.log %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"-_broadcast_scalar4"(%arg0: tensor<f32>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<f32>, tensor<f32>) {
    %0 = stablehlo.subtract %arg0, %arg1 : tensor<f32>
    return %0, %arg0, %arg1 : tensor<f32>, tensor<f32>, tensor<f32>
  }
  func.func private @"*_broadcast_scalar1"(%arg0: tensor<i1>, %arg1: tensor<f32>) -> (tensor<f32>, tensor<i1>, tensor<f32>) {
    %0 = stablehlo.convert %arg0 : (tensor<i1>) -> tensor<f32>
    %1 = stablehlo.multiply %0, %arg1 : tensor<f32>
    return %1, %arg0, %arg1 : tensor<f32>, tensor<i1>, tensor<f32>
  }
  func.func private @identity_broadcast_scalar4(%arg0: tensor<f32>) -> tensor<f32> {
    return %arg0 : tensor<f32>
  }
  func.func private @"-_broadcast_scalar5"(%arg0: tensor<f32>) -> (tensor<f32>, tensor<f32>) {
    %0 = stablehlo.negate %arg0 : tensor<f32>
    return %0, %arg0 : tensor<f32>, tensor<f32>
  }
  func.func private @"Const{typeof(\E2\88\82xloss_function)}(Main.\E2\88\82xloss_function)_autodiff"(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x1x28x28xf32>, %arg12: tensor<32x2xi1>) -> (tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %2 = stablehlo.transpose %arg2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %4 = stablehlo.transpose %arg4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %5 = stablehlo.transpose %arg5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %6 = stablehlo.transpose %arg6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %7 = stablehlo.transpose %arg7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %10 = stablehlo.transpose %arg10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %11 = stablehlo.transpose %arg11, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %12 = stablehlo.transpose %arg12, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<28x28x1x32xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %13 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<28x28x1x32xf32>
    %cst_2 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %14 = stablehlo.transpose %0, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %15 = stablehlo.transpose %1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %16 = stablehlo.transpose %2, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %17 = stablehlo.transpose %3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %18 = stablehlo.transpose %4, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %19 = stablehlo.transpose %5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %20 = stablehlo.transpose %6, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %21 = stablehlo.transpose %7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %22 = stablehlo.transpose %8, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %23 = stablehlo.transpose %9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %24 = stablehlo.transpose %10, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %25 = stablehlo.transpose %12, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    %26 = stablehlo.transpose %13, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %27:13 = enzyme.autodiff @"Const{typeof(loss_function)}(Main.loss_function)_autodiff"(%14, %15, %16, %17, %18, %19, %20, %21, %22, %23, %24, %25, %cst_2, %26) {activity = [#enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>]} : (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<f32>, tensor<32x1x28x28xf32>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<32x1x28x28xf32>)
    %28 = stablehlo.transpose %27#0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %29 = stablehlo.transpose %27#1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %30 = stablehlo.transpose %27#2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %31 = stablehlo.transpose %27#3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %32 = stablehlo.transpose %27#4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %33 = stablehlo.transpose %27#5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %34 = stablehlo.transpose %27#6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %35 = stablehlo.transpose %27#7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %36 = stablehlo.transpose %27#8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %37 = stablehlo.transpose %27#9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %38 = stablehlo.transpose %27#10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %39 = stablehlo.transpose %27#11, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %40 = stablehlo.transpose %27#12, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %41:3 = enzyme.batch @l2_distance_loss_broadcast_scalar(%40, %11) {batch_shape = array<i64: 28, 28, 1, 32>} : (tensor<28x28x1x32xf32>, tensor<28x28x1x32xf32>) -> (tensor<28x28x1x32xf32>, tensor<28x28x1x32xf32>, tensor<28x28x1x32xf32>)
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %42 = enzyme.batch @identity_broadcast_scalar2(%41#0) {batch_shape = array<i64: 28, 28, 1, 32>} : (tensor<28x28x1x32xf32>) -> tensor<28x28x1x32xf32>
    %43 = stablehlo.reduce(%42 init: %cst_3) applies stablehlo.add across dimensions = [0, 1, 2, 3] : (tensor<28x28x1x32xf32>, tensor<f32>) -> tensor<f32>
    %cst_4 = stablehlo.constant dense<2.508800e+04> : tensor<f32>
    %44 = stablehlo.divide %43, %cst_4 : tensor<f32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<24x24x6x32xf32>
    %45 = stablehlo.reverse %28, dims = [0, 1] : tensor<5x5x1x6xf32>
    %46 = stablehlo.convolution(%38, %45) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<28x28x1x32xf32>, tensor<5x5x1x6xf32>) -> tensor<24x24x6x32xf32>
    %47 = stablehlo.transpose %29, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %48 = stablehlo.reshape %47 : (tensor<6xf32>) -> tensor<1x6x1x1xf32>
    %49 = stablehlo.transpose %48, dims = [3, 2, 1, 0] : (tensor<1x6x1x1xf32>) -> tensor<1x1x6x1xf32>
    %50 = stablehlo.broadcast_in_dim %49, dims = [0, 1, 2, 3] : (tensor<1x1x6x1xf32>) -> tensor<24x24x6x32xf32>
    %51:3 = enzyme.batch @"+_broadcast_scalar5"(%46, %50) {batch_shape = array<i64: 24, 24, 6, 32>} : (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>) -> (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>)
    %52:2 = enzyme.batch @relu_broadcast_scalar4(%51#0) {batch_shape = array<i64: 24, 24, 6, 32>} : (tensor<24x24x6x32xf32>) -> (tensor<24x24x6x32xf32>, tensor<24x24x6x32xf32>)
    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x6x32xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<12x12x6x32xf32>
    %cst_8 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %53 = "stablehlo.reduce_window"(%52#0, %cst_8) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %121 = stablehlo.maximum %arg13, %arg14 : tensor<f32>
      stablehlo.return %121 : tensor<f32>
    }) : (tensor<24x24x6x32xf32>, tensor<f32>) -> tensor<12x12x6x32xf32>
    %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<8x8x16x32xf32>
    %54 = stablehlo.reverse %30, dims = [0, 1] : tensor<5x5x6x16xf32>
    %55 = stablehlo.convolution(%53, %54) dim_numbers = [0, 1, f, b]x[0, 1, i, o]->[0, 1, f, b], window = {stride = [1, 1], pad = [[0, 0], [0, 0]], rhs_dilate = [1, 1]} {batch_group_count = 1 : i64, feature_group_count = 1 : i64} : (tensor<12x12x6x32xf32>, tensor<5x5x6x16xf32>) -> tensor<8x8x16x32xf32>
    %56 = stablehlo.transpose %31, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %57 = stablehlo.reshape %56 : (tensor<16xf32>) -> tensor<1x16x1x1xf32>
    %58 = stablehlo.transpose %57, dims = [3, 2, 1, 0] : (tensor<1x16x1x1xf32>) -> tensor<1x1x16x1xf32>
    %59 = stablehlo.broadcast_in_dim %58, dims = [0, 1, 2, 3] : (tensor<1x1x16x1xf32>) -> tensor<8x8x16x32xf32>
    %60:3 = enzyme.batch @"+_broadcast_scalar6"(%55, %59) {batch_shape = array<i64: 8, 8, 16, 32>} : (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>) -> (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>)
    %61:2 = enzyme.batch @relu_broadcast_scalar5(%60#0) {batch_shape = array<i64: 8, 8, 16, 32>} : (tensor<8x8x16x32xf32>) -> (tensor<8x8x16x32xf32>, tensor<8x8x16x32xf32>)
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<4x4x16x32xf32>
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<4x4x16x32xf32>
    %cst_12 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %62 = "stablehlo.reduce_window"(%61#0, %cst_12) <{padding = dense<0> : tensor<4x2xi64>, window_dilations = array<i64: 1, 1, 1, 1>, window_dimensions = array<i64: 2, 2, 1, 1>, window_strides = array<i64: 2, 2, 1, 1>}> ({
    ^bb0(%arg13: tensor<f32>, %arg14: tensor<f32>):
      %121 = stablehlo.maximum %arg13, %arg14 : tensor<f32>
      stablehlo.return %121 : tensor<f32>
    }) : (tensor<8x8x16x32xf32>, tensor<f32>) -> tensor<4x4x16x32xf32>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<128x32xf32>
    %63 = stablehlo.transpose %62, dims = [3, 2, 1, 0] : (tensor<4x4x16x32xf32>) -> tensor<32x16x4x4xf32>
    %64 = stablehlo.reshape %63 : (tensor<32x16x4x4xf32>) -> tensor<32x256xf32>
    %65 = stablehlo.transpose %64, dims = [1, 0] : (tensor<32x256xf32>) -> tensor<256x32xf32>
    %66 = stablehlo.dot_general %32, %65, contracting_dims = [1] x [0] : (tensor<128x256xf32>, tensor<256x32xf32>) -> tensor<128x32xf32>
    %67 = stablehlo.transpose %33, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %68 = stablehlo.reshape %67 : (tensor<128xf32>) -> tensor<1x128xf32>
    %69 = stablehlo.transpose %68, dims = [1, 0] : (tensor<1x128xf32>) -> tensor<128x1xf32>
    %70 = stablehlo.broadcast_in_dim %69, dims = [0, 1] : (tensor<128x1xf32>) -> tensor<128x32xf32>
    %71:3 = enzyme.batch @"+_broadcast_scalar7"(%66, %70) {batch_shape = array<i64: 128, 32>} : (tensor<128x32xf32>, tensor<128x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>, tensor<128x32xf32>)
    %72:2 = enzyme.batch @relu_broadcast_scalar6(%71#0) {batch_shape = array<i64: 128, 32>} : (tensor<128x32xf32>) -> (tensor<128x32xf32>, tensor<128x32xf32>)
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<84x32xf32>
    %73 = stablehlo.dot_general %34, %72#0, contracting_dims = [1] x [0] : (tensor<84x128xf32>, tensor<128x32xf32>) -> tensor<84x32xf32>
    %74 = stablehlo.transpose %35, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %75 = stablehlo.reshape %74 : (tensor<84xf32>) -> tensor<1x84xf32>
    %76 = stablehlo.transpose %75, dims = [1, 0] : (tensor<1x84xf32>) -> tensor<84x1xf32>
    %77 = stablehlo.broadcast_in_dim %76, dims = [0, 1] : (tensor<84x1xf32>) -> tensor<84x32xf32>
    %78:3 = enzyme.batch @"+_broadcast_scalar8"(%73, %77) {batch_shape = array<i64: 84, 32>} : (tensor<84x32xf32>, tensor<84x32xf32>) -> (tensor<84x32xf32>, tensor<84x32xf32>, tensor<84x32xf32>)
    %79:2 = enzyme.batch @relu_broadcast_scalar7(%78#0) {batch_shape = array<i64: 84, 32>} : (tensor<84x32xf32>) -> (tensor<84x32xf32>, tensor<84x32xf32>)
    %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<2x32xf32>
    %80 = stablehlo.dot_general %36, %79#0, contracting_dims = [1] x [0] : (tensor<2x84xf32>, tensor<84x32xf32>) -> tensor<2x32xf32>
    %81 = stablehlo.broadcast_in_dim %37, dims = [0] : (tensor<2xf32>) -> tensor<2x32xf32>
    %82:3 = enzyme.batch @"+_broadcast_scalar9"(%80, %81) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_17 = stablehlo.constant dense<1.1920929E-7> : tensor<f32>
    %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<2x32xf32>
    %cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_21 = stablehlo.constant dense<0xFF800000> : tensor<f32>
    %83 = enzyme.batch @identity_broadcast_scalar3(%82#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> tensor<2x32xf32>
    %84 = stablehlo.reduce(%83 init: %cst_21) applies stablehlo.maximum across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %85 = stablehlo.transpose %84, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %86 = stablehlo.reshape %85 : (tensor<32xf32>) -> tensor<32x1xf32>
    %87 = stablehlo.transpose %86, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %88 = stablehlo.broadcast_in_dim %87, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<2x32xf32>
    %89:3 = enzyme.batch @"-_broadcast_scalar3"(%82#0, %88) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %cst_22 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %90:2 = enzyme.batch @exp_fast_broadcast_scalar1(%89#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>)
    %91 = stablehlo.reduce(%90#0 init: %cst_22) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %92 = stablehlo.transpose %91, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %93 = stablehlo.reshape %92 : (tensor<32xf32>) -> tensor<32x1xf32>
    %94 = stablehlo.transpose %93, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %95:2 = enzyme.batch @log_fast_broadcast_scalar1(%94) {batch_shape = array<i64: 1, 32>} : (tensor<1x32xf32>) -> (tensor<1x32xf32>, tensor<1x32xf32>)
    %96 = stablehlo.broadcast_in_dim %95#0, dims = [0, 1] : (tensor<1x32xf32>) -> tensor<2x32xf32>
    %97:3 = enzyme.batch @"-_broadcast_scalar4"(%90#1, %96) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xf32>, tensor<2x32xf32>)
    %98:3 = enzyme.batch @"*_broadcast_scalar1"(%39, %97#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xi1>, tensor<2x32xf32>) -> (tensor<2x32xf32>, tensor<2x32xi1>, tensor<2x32xf32>)
    %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %99 = enzyme.batch @identity_broadcast_scalar4(%98#0) {batch_shape = array<i64: 2, 32>} : (tensor<2x32xf32>) -> tensor<2x32xf32>
    %100 = stablehlo.reduce(%99 init: %cst_23) applies stablehlo.add across dimensions = [0] : (tensor<2x32xf32>, tensor<f32>) -> tensor<32xf32>
    %101 = stablehlo.transpose %100, dims = [0] : (tensor<32xf32>) -> tensor<32xf32>
    %102 = stablehlo.reshape %101 : (tensor<32xf32>) -> tensor<32x1xf32>
    %103 = stablehlo.transpose %102, dims = [1, 0] : (tensor<32x1xf32>) -> tensor<1x32xf32>
    %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %104:2 = enzyme.batch @"-_broadcast_scalar5"(%103) {batch_shape = array<i64: 1, 32>} : (tensor<1x32xf32>) -> (tensor<1x32xf32>, tensor<1x32xf32>)
    %105 = stablehlo.reduce(%104#0 init: %cst_24) applies stablehlo.add across dimensions = [0, 1] : (tensor<1x32xf32>, tensor<f32>) -> tensor<f32>
    %cst_25 = stablehlo.constant dense<3.200000e+01> : tensor<f32>
    %106 = stablehlo.divide %105, %cst_25 : tensor<f32>
    %107 = stablehlo.add %44, %106 : tensor<f32>
    %108 = stablehlo.transpose %28, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %109 = stablehlo.transpose %29, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %110 = stablehlo.transpose %30, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %111 = stablehlo.transpose %31, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %112 = stablehlo.transpose %32, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %113 = stablehlo.transpose %33, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %114 = stablehlo.transpose %34, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %115 = stablehlo.transpose %35, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %116 = stablehlo.transpose %36, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %117 = stablehlo.transpose %37, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %118 = stablehlo.transpose %38, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %119 = stablehlo.transpose %41#2, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %120 = stablehlo.transpose %98#1, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    return %107, %108, %109, %110, %111, %112, %113, %114, %115, %116, %117, %118, %119, %120 : tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
  func.func @main(%arg0: tensor<6x1x5x5xf32>, %arg1: tensor<6xf32>, %arg2: tensor<16x6x5x5xf32>, %arg3: tensor<16xf32>, %arg4: tensor<256x128xf32>, %arg5: tensor<128xf32>, %arg6: tensor<128x84xf32>, %arg7: tensor<84xf32>, %arg8: tensor<84x2xf32>, %arg9: tensor<2xf32>, %arg10: tensor<32x1x28x28xf32>, %arg11: tensor<32x1x28x28xf32>, %arg12: tensor<32x2xi1>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>) {
    %0 = stablehlo.transpose %arg0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %1 = stablehlo.transpose %arg1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %2 = stablehlo.transpose %arg2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %3 = stablehlo.transpose %arg3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %4 = stablehlo.transpose %arg4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %5 = stablehlo.transpose %arg5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %6 = stablehlo.transpose %arg6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %7 = stablehlo.transpose %arg7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %8 = stablehlo.transpose %arg8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %9 = stablehlo.transpose %arg9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %10 = stablehlo.transpose %arg10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %11 = stablehlo.transpose %arg11, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %12 = stablehlo.transpose %arg12, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %cst = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_0 = stablehlo.constant dense<0.000000e+00> : tensor<5x5x1x6xf32>
    %cst_1 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %13 = stablehlo.broadcast_in_dim %cst_1, dims = [] : (tensor<f32>) -> tensor<5x5x1x6xf32>
    %cst_2 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_3 = stablehlo.constant dense<0.000000e+00> : tensor<6xf32>
    %cst_4 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %14 = stablehlo.broadcast_in_dim %cst_4, dims = [] : (tensor<f32>) -> tensor<6xf32>
    %cst_5 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_6 = stablehlo.constant dense<0.000000e+00> : tensor<5x5x6x16xf32>
    %cst_7 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %15 = stablehlo.broadcast_in_dim %cst_7, dims = [] : (tensor<f32>) -> tensor<5x5x6x16xf32>
    %cst_8 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_9 = stablehlo.constant dense<0.000000e+00> : tensor<16xf32>
    %cst_10 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %16 = stablehlo.broadcast_in_dim %cst_10, dims = [] : (tensor<f32>) -> tensor<16xf32>
    %cst_11 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_12 = stablehlo.constant dense<0.000000e+00> : tensor<128x256xf32>
    %cst_13 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %17 = stablehlo.broadcast_in_dim %cst_13, dims = [] : (tensor<f32>) -> tensor<128x256xf32>
    %cst_14 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_15 = stablehlo.constant dense<0.000000e+00> : tensor<128xf32>
    %cst_16 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %18 = stablehlo.broadcast_in_dim %cst_16, dims = [] : (tensor<f32>) -> tensor<128xf32>
    %cst_17 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_18 = stablehlo.constant dense<0.000000e+00> : tensor<84x128xf32>
    %cst_19 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %19 = stablehlo.broadcast_in_dim %cst_19, dims = [] : (tensor<f32>) -> tensor<84x128xf32>
    %cst_20 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_21 = stablehlo.constant dense<0.000000e+00> : tensor<84xf32>
    %cst_22 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %20 = stablehlo.broadcast_in_dim %cst_22, dims = [] : (tensor<f32>) -> tensor<84xf32>
    %cst_23 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_24 = stablehlo.constant dense<0.000000e+00> : tensor<2x84xf32>
    %cst_25 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %21 = stablehlo.broadcast_in_dim %cst_25, dims = [] : (tensor<f32>) -> tensor<2x84xf32>
    %cst_26 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %cst_27 = stablehlo.constant dense<0.000000e+00> : tensor<2xf32>
    %cst_28 = stablehlo.constant dense<0.000000e+00> : tensor<f32>
    %22 = stablehlo.broadcast_in_dim %cst_28, dims = [] : (tensor<f32>) -> tensor<2xf32>
    %cst_29 = stablehlo.constant dense<1.000000e+00> : tensor<f32>
    %23 = stablehlo.transpose %0, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %24 = stablehlo.transpose %1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %25 = stablehlo.transpose %2, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %26 = stablehlo.transpose %3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %27 = stablehlo.transpose %4, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %28 = stablehlo.transpose %5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %29 = stablehlo.transpose %6, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %30 = stablehlo.transpose %7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %31 = stablehlo.transpose %8, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %32 = stablehlo.transpose %9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %33 = stablehlo.transpose %10, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %34 = stablehlo.transpose %11, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %35 = stablehlo.transpose %12, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    %36 = stablehlo.transpose %13, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %37 = stablehlo.transpose %14, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %38 = stablehlo.transpose %15, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %39 = stablehlo.transpose %16, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %40 = stablehlo.transpose %17, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %41 = stablehlo.transpose %18, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %42 = stablehlo.transpose %19, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %43 = stablehlo.transpose %20, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %44 = stablehlo.transpose %21, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %45 = stablehlo.transpose %22, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %46:23 = enzyme.autodiff @"Const{typeof(\E2\88\82xloss_function)}(Main.\E2\88\82xloss_function)_autodiff"(%23, %24, %25, %26, %27, %28, %29, %30, %31, %32, %33, %34, %35, %cst_29, %36, %37, %38, %39, %40, %41, %42, %43, %44, %45) {activity = [#enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>], ret_activity = [#enzyme<activity enzyme_activenoneed>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_active>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>, #enzyme<activity enzyme_const>]} : (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<f32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>) -> (tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>)
    %47 = stablehlo.transpose %46#0, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %48 = stablehlo.transpose %46#1, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %49 = stablehlo.transpose %46#2, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %50 = stablehlo.transpose %46#3, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %51 = stablehlo.transpose %46#4, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %52 = stablehlo.transpose %46#5, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %53 = stablehlo.transpose %46#6, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %54 = stablehlo.transpose %46#7, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %55 = stablehlo.transpose %46#8, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %56 = stablehlo.transpose %46#9, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %57 = stablehlo.transpose %46#10, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %58 = stablehlo.transpose %46#11, dims = [3, 2, 1, 0] : (tensor<32x1x28x28xf32>) -> tensor<28x28x1x32xf32>
    %59 = stablehlo.transpose %46#12, dims = [1, 0] : (tensor<32x2xi1>) -> tensor<2x32xi1>
    %60 = stablehlo.transpose %46#13, dims = [3, 2, 1, 0] : (tensor<6x1x5x5xf32>) -> tensor<5x5x1x6xf32>
    %61 = stablehlo.transpose %46#14, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %62 = stablehlo.transpose %46#15, dims = [3, 2, 1, 0] : (tensor<16x6x5x5xf32>) -> tensor<5x5x6x16xf32>
    %63 = stablehlo.transpose %46#16, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %64 = stablehlo.transpose %46#17, dims = [1, 0] : (tensor<256x128xf32>) -> tensor<128x256xf32>
    %65 = stablehlo.transpose %46#18, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %66 = stablehlo.transpose %46#19, dims = [1, 0] : (tensor<128x84xf32>) -> tensor<84x128xf32>
    %67 = stablehlo.transpose %46#20, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %68 = stablehlo.transpose %46#21, dims = [1, 0] : (tensor<84x2xf32>) -> tensor<2x84xf32>
    %69 = stablehlo.transpose %46#22, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %70 = stablehlo.transpose %60, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %71 = stablehlo.transpose %61, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %72 = stablehlo.transpose %62, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %73 = stablehlo.transpose %63, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %74 = stablehlo.transpose %64, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %75 = stablehlo.transpose %65, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %76 = stablehlo.transpose %66, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %77 = stablehlo.transpose %67, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %78 = stablehlo.transpose %68, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %79 = stablehlo.transpose %69, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %80 = stablehlo.transpose %47, dims = [3, 2, 1, 0] : (tensor<5x5x1x6xf32>) -> tensor<6x1x5x5xf32>
    %81 = stablehlo.transpose %48, dims = [0] : (tensor<6xf32>) -> tensor<6xf32>
    %82 = stablehlo.transpose %49, dims = [3, 2, 1, 0] : (tensor<5x5x6x16xf32>) -> tensor<16x6x5x5xf32>
    %83 = stablehlo.transpose %50, dims = [0] : (tensor<16xf32>) -> tensor<16xf32>
    %84 = stablehlo.transpose %51, dims = [1, 0] : (tensor<128x256xf32>) -> tensor<256x128xf32>
    %85 = stablehlo.transpose %52, dims = [0] : (tensor<128xf32>) -> tensor<128xf32>
    %86 = stablehlo.transpose %53, dims = [1, 0] : (tensor<84x128xf32>) -> tensor<128x84xf32>
    %87 = stablehlo.transpose %54, dims = [0] : (tensor<84xf32>) -> tensor<84xf32>
    %88 = stablehlo.transpose %55, dims = [1, 0] : (tensor<2x84xf32>) -> tensor<84x2xf32>
    %89 = stablehlo.transpose %56, dims = [0] : (tensor<2xf32>) -> tensor<2xf32>
    %90 = stablehlo.transpose %57, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %91 = stablehlo.transpose %58, dims = [3, 2, 1, 0] : (tensor<28x28x1x32xf32>) -> tensor<32x1x28x28xf32>
    %92 = stablehlo.transpose %59, dims = [1, 0] : (tensor<2x32xi1>) -> tensor<32x2xi1>
    return %70, %71, %72, %73, %74, %75, %76, %77, %78, %79, %80, %81, %82, %83, %84, %85, %86, %87, %88, %89, %90, %91, %92 : tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<6x1x5x5xf32>, tensor<6xf32>, tensor<16x6x5x5xf32>, tensor<16xf32>, tensor<256x128xf32>, tensor<128xf32>, tensor<128x84xf32>, tensor<84xf32>, tensor<84x2xf32>, tensor<2xf32>, tensor<32x1x28x28xf32>, tensor<32x1x28x28xf32>, tensor<32x2xi1>
  }
}

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions