Skip to content

[Fix] Update floormod simplification logic for better expression matching#17764

Closed
Ghosts381937 wants to merge 1 commit intoapache:mainfrom
Ghosts381937:main
Closed

[Fix] Update floormod simplification logic for better expression matching#17764
Ghosts381937 wants to merge 1 commit intoapache:mainfrom
Ghosts381937:main

Conversation

@Ghosts381937
Copy link
Contributor

Description

Update the floormod simplification rule to correctly handle expressions of the form floormod(c1*x, c2*x) by simplifying them to floormod(c1, c2). This enhancement enables better optimization of expressions that contain common factors, which frequently appear in transformer model computations.

Test Case

from tvm import tir
from tvm.arith import Analyzer
from tvm.tir.op import floormod

# Define symbolic variable
past_decoder_sequence_length = tir.Var("past_decoder_sequence_length", "int64")

# Create expressions with common factor
expr = tir.IntImm("int64", 64) * (past_decoder_sequence_length + tir.IntImm("int64", 1))
divisor = tir.IntImm("int64", 31) * (past_decoder_sequence_length + tir.IntImm("int64", 1))

# Create Analyzer
analyzer = Analyzer()

# Before: returns unsimplified expression
# After: correctly simplifies to 2
print(analyzer.simplify(floormod(expr, divisor)))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant