In [None]:
# This developer tutorial discusses patterns -- sequences of operations that can be matched and replaced with traceable functions. 
# It's a work-in-progress, and it currently only discusses how patterns can be constructed and how they're matched,
#   along with some related utilities.

In [1]:
# Imports the modules, classes, and functions we'll need for this tutorial
import torch

import thunder
from thunder.core.patterns import Pattern, bind_names, numbered_ancestors
from thunder.core.proxies import TensorProxy
from thunder.core.symbol import BoundSymbol

In [3]:
# To match a pattern, start by creating a Pattern object
p = Pattern()

In [4]:
# Then define one or more "matchers" that determine if a BoundSymbol is a "match", and add them to the 
#   pattern using its match() method

# The matcher signature not only accepts a BoundSymbol to review, but also a list of BoundSymbols that were
#   already matched by the pattern, and a match_ctx dictionary that contains whatever state you like from previous matches
# The matcher returns True if the BoundSymbol should be matched, and False otherwise. When returning True it should return
#   a dict that will be used to update the match_ctx for future matches. This will be clearer in a moment with an example.
# The following matcher is very permissive, and it matches any add operation.
def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:
        if bsym.sym.name == 'add':
            return True, {}
        
        return False, None

a = torch.randn((2, 2))
b = torch.randn((2, 2))

# An example program that performs an addition and a subtraction
def foo(a, b):
    c = a + b
    d = a - b
    return c, d
trc = thunder.trace()(foo, a, b)

# The matcher is told to match any addition
p.match(add_matcher)

# Calling the Pattern object on a trace returns a list of matches. 
# Each match is a list of (int, BoundSymbol) tuples, where int is the 
#   position of the BoundSymbol in the trace.
matches = p(trc)

# In this case there is just one match -- the first addition
print(matches)

[[(2, t0 = ltorch.add(a, b, alpha=None)  # t0: "cpu f32[2, 2]"
  # t0 = prims.add(a, b)  # t0: "cpu f32[2, 2]")]]


In [5]:
def foo(a, b):
    c = a + b
    d = a + b
    return c, d
trc = thunder.trace()(foo, a, b)

# When the program is changed to include two additions, both additions
#   are matched and two matches are created.
matches = p(trc)
print(matches)

[[(2, t0 = ltorch.add(a, b, alpha=None)  # t0: "cpu f32[2, 2]"
  # t0 = prims.add(a, b)  # t0: "cpu f32[2, 2]")], [(3, t1 = ltorch.add(a, b, alpha=None)  # t1: "cpu f32[2, 2]"
  # t1 = prims.add(a, b)  # t1: "cpu f32[2, 2]")]]


In [6]:
# In addition to matching a single operation, a pattern can match any number of sequential operations --
#   that is, operations that are immediately adjacent to each other. We do this by providing
#   max_times and (optionally) min_times arguments to match()
# Negative max_times values are interpreted as matching the pattern any number of times
# Matching multiple operations occurs greedily and before any additional matching can occur

p = Pattern()
p.match(add_matcher, min_times=1, max_times=-1)

# The pattern now matches once, and the single match contains both additions
matches = p(trc)
print(matches)

[[(2, t0 = ltorch.add(a, b, alpha=None)  # t0: "cpu f32[2, 2]"
  # t0 = prims.add(a, b)  # t0: "cpu f32[2, 2]"), (3, t1 = ltorch.add(a, b, alpha=None)  # t1: "cpu f32[2, 2]"
  # t1 = prims.add(a, b)  # t1: "cpu f32[2, 2]")]]


In [7]:
# Multiple operations can also be matched by calling match() multiple times. Each 
#   match() attempts to evaluate itself in the order it's called.

def foo(a, b):
    c = a + b
    d = a - b
    return c, d
trc = thunder.trace()(foo, a, b)

# Let's match an addition followed by a subtraction on the same inputs. This will also show how to update the match_ctx dict
#   and let us use the bind_names() utility.
def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:
        if bsym.sym.name == 'add':
            # bind_names() produces an object with properties corresponding to the function's (Symbol's) parameters, when
            #   accessed they return their corresponding arguments
            bn = bind_names(bsym)
            # Stores the inputs in the context
            return True, {'a': bn.a, 'b': bn.b}
        
        return False, None

def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:
        if bsym.sym.name == 'sub':
            bn = bind_names(bsym)

            # Acquires the previously stored values from the match_ctx
            a = match_ctx['a']
            b = match_ctx['b']

            # Matches the sub only if the arguments are the same as the addition's, and in the same order
            if a is bn.a and b is bn.b:
                return True, {}
        
        return False, None

p = Pattern()
p.match(add_matcher)
p.match(sub_matcher)

# Matches the addition and the subtraction
matches = p(trc)
print(matches)

[[(2, t0 = ltorch.add(a, b, alpha=None)  # t0: "cpu f32[2, 2]"
  # t0 = prims.add(a, b)  # t0: "cpu f32[2, 2]"), (3, t1 = ltorch.sub(a, b, alpha=None)  # t1: "cpu f32[2, 2]"
  # t1 = prims.sub(a, b)  # t1: "cpu f32[2, 2]")]]


In [None]:
# Another version of the above example that uses the previously_matched argument to decide whether to match
#   the subtraction

def foo(a, b):
    c = a + b
    d = a - b
    return c, d
trc = thunder.trace()(foo, a, b)

def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:
        if bsym.sym.name == 'add':
            # Doesn't update the context -- the context is just scratch space for you
            return True, {}
        
        return False, None

def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:
        if bsym.sym.name == 'sub':
            my_bn = bind_names(bsym)

            add_bsym = previously_matched
            add_bn = bind_names(add_bsym)

            # Matches the sub only if the arguments are the same as the addition's, and in the same order
            if add_bn.a is my_bn.a and add_bn.b is my_bn.b:
                return True, {}
        
        return False, None

p = Pattern()
p.match(add_matcher)
p.match(sub_matcher)

# Matches the addition and the subtraction
matches = p(trc)
print(matches)

In [8]:
# Operations in a pattern don't have to be next to each other, but they do have to within a "window" of 
#   the previous operation. Currently the window is 5 operations. Each operation also has to be 
#   "reorderable" to be "next to" operations that were already matched. This isn't always
#   possible. If an operation consumes an input that is not directly from a previously matched symbol, but
#   is derived from the output of a previously matched symbol, then it cannot be reordered adjacent to the 
#   other operations in the pattern.
# Let's see how this works with two examples.

# An operation between the first addition and second subtraction doesn't stop the previous pattern
#   from matching as expected, because the operation producing d can be reordered to be 
#   immediately after the operation producing c
def foo(a, b):
    c = a + b
    x = a + 2
    d = a - b
    return c, d, x
trc = thunder.trace()(foo, a, b)

# The match is the same as when x isn't computed
matches = p(trc)
print(matches)

[[(2, t0 = ltorch.add(a, b, alpha=None)  # t0: "cpu f32[2, 2]"
  # t0 = prims.add(a, b)  # t0: "cpu f32[2, 2]"), (4, t2 = ltorch.sub(a, b, alpha=None)  # t2: "cpu f32[2, 2]"
  # t2 = prims.sub(a, b)  # t2: "cpu f32[2, 2]")]]


In [9]:
# Too many intervening operations pushes the subtraction out of pattern matching "window" and prevents
#   the match
# In the future we may expose an option to set the window larger -- share your thoughts by filing an issue!
def foo(a, b):
    c = a + b
    x = a + 2
    x = x + 2
    x = x + 2
    x = x + 2
    x = x + 2
    x = x + 2
    x = x + 2
    x = x + 2
    d = a - b
    return c, d, x
trc = thunder.trace()(foo, a, b)

# No matches because the computation of c and the computation of d are separated by too many operations
matches = p(trc)
print(matches)

[]


In [11]:
# The computation of e depends on the computation of d and the computation of c
def foo(a, b):
      c = a + b
      d = c - 5
      e = c + d
      return e
trc = thunder.trace()(foo, a, b)

def add_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:
        if bsym.sym.name == 'add':
            return True, {}
        
        return False, None

p = Pattern()
p.match(add_matcher)
p.match(add_matcher)

# Attempting to match two additions fails, because the computation of e cannot be reordered next to the computation of c
matches = p(trc)
print(matches)

[]


In [12]:
# Including the subtraction in the pattern allows it to be matched
def sub_matcher(bsym: BoundSymbol, *, previously_matched: list[BoundSymbol], match_ctx: dict) -> tuple[bool, None | dict]:
        if bsym.sym.name == 'sub':
            return True, {}
        
        return False, None

p = Pattern()
p.match(add_matcher)
p.match(sub_matcher)
p.match(add_matcher)

matches = p(trc)
print(matches)


[[(2, t0 = ltorch.add(a, b, alpha=None)  # t0: "cpu f32[2, 2]"
  # t0 = prims.add(a, b)  # t0: "cpu f32[2, 2]"), (3, t1 = ltorch.sub(t0, 5, alpha=None)  # t1: "cpu f32[2, 2]"
  # _ = prims.convert_element_type(5, float)
  # t1 = prims.sub(t0, 5.0)  # t1: "cpu f32[2, 2]"), (4, t2 = ltorch.add(t0, t1, alpha=None)  # t2: "cpu f32[2, 2]"
  # t2 = prims.add(t0, t1)  # t2: "cpu f32[2, 2]")]]
