Skip to content

Conversation

@avik-pal
Copy link
Collaborator

@avik-pal avik-pal commented Nov 9, 2024

The fallback generates quite an inefficient IR:

julia> @code_hlo NNlib.make_causal_mask(x_ra)
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<4x4xf64>) -> tensor<4x4xi1> {
    %c = stablehlo.constant dense<3> : tensor<i64>
    %c_0 = stablehlo.constant dense<2> : tensor<i64>
    %c_1 = stablehlo.constant dense<false> : tensor<1x1xi1>
    %c_2 = stablehlo.constant dense<true> : tensor<1x1xi1>
    %c_3 = stablehlo.constant dense<[[true, true, true, true], [true, true, true, true], [true, false, true, true], [true, false, true, false]]> : tensor<4x4xi1>
    %c_4 = stablehlo.constant dense<1> : tensor<i64>
    %c_5 = stablehlo.constant dense<0> : tensor<i64>
    %0 = stablehlo.dynamic_update_slice %c_3, %c_2, %c_4, %c_5 : (tensor<4x4xi1>, tensor<1x1xi1>, tensor<i64>, tensor<i64>) -> tensor<4x4xi1>
    %1 = stablehlo.dynamic_update_slice %0, %c_1, %c_0, %c_5 : (tensor<4x4xi1>, tensor<1x1xi1>, tensor<i64>, tensor<i64>) -> tensor<4x4xi1>
    %2 = stablehlo.dynamic_update_slice %1, %c_1, %c, %c_5 : (tensor<4x4xi1>, tensor<1x1xi1>, tensor<i64>, tensor<i64>) -> tensor<4x4xi1>
    %3 = stablehlo.dynamic_update_slice %2, %c_1, %c_0, %c_4 : (tensor<4x4xi1>, tensor<1x1xi1>, tensor<i64>, tensor<i64>) -> tensor<4x4xi1>
    %4 = stablehlo.dynamic_update_slice %3, %c_1, %c, %c_4 : (tensor<4x4xi1>, tensor<1x1xi1>, tensor<i64>, tensor<i64>) -> tensor<4x4xi1>
    %5 = stablehlo.dynamic_update_slice %4, %c_1, %c, %c_0 : (tensor<4x4xi1>, tensor<1x1xi1>, tensor<i64>, tensor<i64>) -> tensor<4x4xi1>
    %6 = stablehlo.transpose %5, dims = [1, 0] : (tensor<4x4xi1>) -> tensor<4x4xi1>
    return %6 : tensor<4x4xi1>
  }
}

Now it is:

julia> @code_hlo NNlib.make_causal_mask(x_ra)
Module:
module attributes {transform.with_named_sequence} {
  func.func @main(%arg0: tensor<4x4xf64>) -> tensor<4x4xi1> {
    %c = stablehlo.constant dense<[[true, false, false, false], [true, true, false, false], [true, true, true, false], [true, true, true, true]]> : tensor<4x4xi1>
    return %c : tensor<4x4xi1>
  }
}

Copy link
Contributor

@github-actions github-actions bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reactant.jl Benchmarks

Benchmark suite Current: 2208689 Previous: 3ba7c3e Ratio
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 6858075737 ns 5787425685 ns 1.18
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant 5072161342 ns 5292258390 ns 0.96
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 5293042416 ns 6086056532 ns 0.87
ViT base (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 7570320604 ns 7587601119 ns 1.00
ViT base (256 x 256 x 3 x 32)/forward/CPU/Lux 34386587369 ns 28087750784 ns 1.22
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1567026340 ns 1563822331 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant 1547479117 ns 1543677512 ns 1.00
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1538428547 ns 1553822136 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3290329710 ns 3309603029 ns 0.99
ViT small (256 x 256 x 3 x 4)/forward/CPU/Lux 2513790076 ns 3236551447 ns 0.78
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 2130550406 ns 2198150190 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant 2125615319 ns 2155687426 ns 0.99
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2129131440 ns 2192886728 ns 0.97
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 3953179944 ns 3908194881 ns 1.01
ViT tiny (256 x 256 x 3 x 32)/forward/CPU/Lux 5651667419 ns 5993416352 ns 0.94
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1417477536 ns 1406808783.5 ns 1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant 1418453633 ns 1407299141 ns 1.01
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1410175057 ns 1410969730 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3155846862 ns 3156311368 ns 1.00
ViT tiny (256 x 256 x 3 x 4)/forward/CPU/Lux 1423440128 ns 1099155376.5 ns 1.30
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 1706126657 ns 1727787162 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant 1708603124 ns 1727804980 ns 0.99
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 1713281205 ns 1711663111 ns 1.00
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3507448246 ns 3460051766 ns 1.01
ViT tiny (256 x 256 x 3 x 16)/forward/CPU/Lux 3101266599 ns 3010659432 ns 1.03
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 2243348680 ns 2148427239 ns 1.04
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant 2222823475 ns 2170426380 ns 1.02
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 2254280155 ns 2187259107 ns 1.03
ViT small (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 3982137313 ns 3958804601 ns 1.01
ViT small (256 x 256 x 3 x 16)/forward/CPU/Lux 6647248361.5 ns 6647100753 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :after_enzyme) 3057498046 ns 3146044029 ns 0.97
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant 3036622857 ns 3146912971 ns 0.96
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :before_enzyme) 2982009512 ns 3047329260 ns 0.98
ViT small (256 x 256 x 3 x 32)/forward/CPU/Reactant (optimize = :only_enzyme) 4852912583 ns 4862728550 ns 1.00
ViT small (256 x 256 x 3 x 32)/forward/CPU/Lux 20924167144 ns 12794226734 ns 1.64
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :after_enzyme) 3190120181 ns 3132478421 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant 3234498742 ns 3179953038 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :before_enzyme) 3247140333 ns 3185074336 ns 1.02
ViT base (256 x 256 x 3 x 16)/forward/CPU/Reactant (optimize = :only_enzyme) 5060562695 ns 5092564084 ns 0.99
ViT base (256 x 256 x 3 x 16)/forward/CPU/Lux 15937052624 ns 12253319305 ns 1.30
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :after_enzyme) 1865518703 ns 1855345054 ns 1.01
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant 1841417056 ns 1849809131 ns 1.00
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :before_enzyme) 1840848716 ns 1855337197 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CPU/Reactant (optimize = :only_enzyme) 3581499468 ns 3604644289 ns 0.99
ViT base (256 x 256 x 3 x 4)/forward/CPU/Lux 3065045482 ns 5868629461.5 ns 0.52

This comment was automatically generated by workflow using github-action-benchmark.

@wsmoses
Copy link
Member

wsmoses commented Nov 9, 2024

Separately the first code indicates that we need to write contant propagation for dynamic_update_slice, there should never be an operation like that which has constant operands

@wsmoses
Copy link
Member

wsmoses commented Nov 9, 2024

cc @Pangoraw re the latter comment

@wsmoses wsmoses merged commit 6217123 into main Nov 9, 2024
22 of 32 checks passed
@wsmoses wsmoses deleted the ap/causal_mask branch November 9, 2024 18:39
Pangoraw pushed a commit to Pangoraw/Reactant.jl that referenced this pull request Nov 11, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants