@@ -714,6 +714,150 @@ func.func @test_recursively_speculatable_op_failure(%lb: index, %ub: index, %ste
714714 return
715715}
716716
717+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_success
718+ func.func @test_speculatable_op_with_read_side_effect_success (%lb: index , %ub: index , %step: index ) -> i32 {
719+ // CHECK: test.always_speculatable_op
720+ // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
721+ // CHECK-NEXT: scf.if %[[CMP]]
722+ // CHECK-NEXT: test.speculatable_op_with_memread
723+ // CHECK: else
724+ // CHECK-NEXT: ub.poison : i32
725+ // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
726+ // CHECK-NOT: test.always_speculatable_op
727+ // CHECK-NOT: test.speculatable_op_with_memread
728+ %cst_0 = arith.constant 0 : i32
729+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
730+ %ind_42 = arith.constant 42 : index
731+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
732+ %always_speculate = " test.always_speculatable_op" () : () -> i32
733+ %only_read = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
734+ %i_cast = arith.index_cast %i: index to i32
735+ %add = arith.addi %acc , %i_cast : i32
736+ %sum = arith.addi %add , %only_read : i32
737+ scf.yield %sum : i32
738+ }
739+ return %sum_result : i32
740+ }
741+
742+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_multiple_result_success
743+ func.func @test_speculatable_op_with_read_side_effect_multiple_result_success (%lb: index , %ub: index , %step: index ) -> i32 {
744+ // CHECK: test.always_speculatable_op
745+ // CHECK-NEXT: %[[CMP:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
746+ // CHECK-NEXT: scf.if %[[CMP]]
747+ // CHECK-NEXT: test.speculatable_op_with_memread
748+ // CHECK: else
749+ // CHECK-NEXT: ub.poison : i32
750+ // CHECK-NEXT: ub.poison : f32
751+ // CHECK: scf.for %[[_:.*]] = %[[LB]] to %[[UB]]
752+ // CHECK-NOT: test.always_speculatable_op
753+ // CHECK-NOT: test.speculatable_op_with_memread
754+ %cst_0 = arith.constant 0 : i32
755+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
756+ %ind_42 = arith.constant 42 : index
757+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
758+ %always_speculate = " test.always_speculatable_op" () : () -> i32
759+ %only_read:2 = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> (i32 , f32 )
760+ %i_cast = arith.index_cast %i: index to i32
761+ %add = arith.addi %acc , %i_cast : i32
762+ %sum = arith.addi %add , %only_read#0 : i32
763+ scf.yield %sum : i32
764+ }
765+ return %sum_result : i32
766+ }
767+
768+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_success_with_dependents
769+ func.func @test_speculatable_op_with_read_side_effect_success_with_dependents (%lb: index , %ub: index , %step: index ) -> i32 {
770+ // CHECK: %[[ALWAYS:.*]] = "test.always_speculatable_op"
771+ // CHECK-NEXT: %[[CMP0:.*]] = arith.cmpi ult, %[[LB:.*]], %[[UB:.*]] : index
772+ // CHECK-NEXT: %[[IF0:.*]] = scf.if %[[CMP0]]
773+ // CHECK-NEXT: test.speculatable_op_with_memread
774+ // CHECK: else
775+ // CHECK-NEXT: ub.poison : i32
776+ // CHECK: %[[CMP1:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
777+ // CHECK-NEXT: %[[IF1:.*]] = scf.if %[[CMP1]]
778+ // CHECK-NEXT: arith.addi %[[ALWAYS]], %[[IF0]]
779+ // CHECK: else
780+ // CHECK-NEXT: ub.poison : i32
781+ // CHECK: %[[CMP2:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
782+ // CHECK-NEXT: %[[IF2:.*]] = scf.if %[[CMP2]]
783+ // CHECK-NEXT: test.speculatable_op_with_memread
784+ // CHECK: else
785+ // CHECK-NEXT: ub.poison : i32
786+ // CHECK: %[[CMP3:.*]] = arith.cmpi ult, %[[LB]], %[[UB]] : index
787+ // CHECK-NEXT: %{{.*}} = scf.if %[[CMP3]]
788+ // CHECK-NEXT: arith.addi %[[IF1]], %[[IF2]]
789+ // CHECK: else
790+ // CHECK-NEXT: ub.poison : i32
791+ // CHECK: scf.for %{{.*}} = %[[LB]] to %[[UB]]
792+ // CHECK-NOT: test.always_speculatable_op
793+ // CHECK-NOT: test.speculatable_op_with_memread
794+ %cst_0 = arith.constant 0 : i32
795+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
796+ %ind_42 = arith.constant 42 : index
797+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
798+ %always_speculate = " test.always_speculatable_op" () : () -> i32
799+ %only_read_0 = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
800+ %add_0 = arith.addi %always_speculate , %only_read_0 : i32
801+ %only_read_1 = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
802+ %add_1 = arith.addi %add_0 , %only_read_1 : i32
803+ %i_cast = arith.index_cast %i: index to i32
804+ %sum = arith.addi %add_1 , %i_cast : i32
805+ scf.yield %sum : i32
806+ }
807+ return %sum_result : i32
808+ }
809+
810+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_write
811+ func.func @test_speculatable_op_with_read_side_effect_failure_due_to_write (%lb: index , %ub: index , %step: index ) -> i32 {
812+ // CHECK: test.always_speculatable_op
813+ // CHECK-NEXT: scf.for
814+ // CHECK-NOT: test.always_speculatable_op
815+ // CHECK: test.speculatable_op_with_memread
816+ // CHECK: test.speculatable_op_with_memwrite
817+ %cst_0 = arith.constant 0 : i32
818+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
819+ %ind_42 = arith.constant 42 : index
820+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
821+ %always_speculate = " test.always_speculatable_op" () : () -> i32
822+ %only_read = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
823+ %i_cast = arith.index_cast %i: index to i32
824+ %add = arith.addi %acc , %i_cast : i32
825+ %sum = arith.addi %add , %only_read : i32
826+ %write = " test.speculatable_op_with_memwrite" (%cst_42 ) : (tensor <64 xi32 >) -> i32
827+ scf.yield %sum : i32
828+ }
829+ return %sum_result : i32
830+ }
831+
832+ // CHECK-LABEL: test_speculatable_op_with_read_side_effect_failure_due_to_nested_write
833+ func.func @test_speculatable_op_with_read_side_effect_failure_due_to_nested_write (%lb: index , %ub: index , %step: index ) -> i32 {
834+ // CHECK: test.always_speculatable_op
835+ // CHECK-NEXT: scf.for
836+ // CHECK-NOT: test.always_speculatable_op
837+ // CHECK: test.speculatable_op_with_memread
838+ // CHECK: scf.for
839+ // CHECK: scf.if
840+ // CHECK: test.speculatable_op_with_memwrite
841+ %cst_0 = arith.constant 0 : i32
842+ %cst_42 = arith.constant dense <42 > : tensor <64 xi32 >
843+ %ind_42 = arith.constant 42 : index
844+ %sum_result = scf.for %i = %lb to %ub step %step iter_args (%acc = %cst_0 ) -> i32 {
845+ %always_speculate = " test.always_speculatable_op" () : () -> i32
846+ %only_read = " test.speculatable_op_with_memread" (%cst_42 , %ind_42 ) : (tensor <64 xi32 >, index ) -> i32
847+ %i_cast = arith.index_cast %i: index to i32
848+ %add = arith.addi %acc , %i_cast : i32
849+ %sum = arith.addi %add , %only_read : i32
850+ scf.for %j = %lb to %ub step %step {
851+ %eq42 = arith.cmpi eq , %j , %ind_42 : index
852+ scf.if %eq42 {
853+ %always_write = " test.speculatable_op_with_memwrite" (%cst_42 ) : (tensor <64 xi32 >) -> i32
854+ }
855+ }
856+ scf.yield %sum : i32
857+ }
858+ return %sum_result : i32
859+ }
860+
717861// -----
718862
719863func.func @speculate_tensor_dim_unknown_rank_unknown_dim (
0 commit comments