In [1]:
import torch
import torch.nn as nn
from brt.runtime import log
from brt.app.rand import RandScatter
from brt.router import GatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_scatter = RandScatter(path_num=2)
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = GatherRouter()
        self.iteration = 1
        self.ret = 1

    def forward(self, x):
        route_results = self.rand_scatter(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        x = self.moe(x)
        return x


moe_model = SimpleModel()

indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = moe_model(indata)
print(outdata)


  from .autonotebook import tqdm as notebook_tqdm


Starting scatter_router_1
4
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
        [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.]])


In [2]:
from brt.trace.graph import GraphTracer
from torch.fx.graph_module import GraphModule
from brt.runtime import BRT_CACHE_PATH
tracer = GraphTracer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__ if isinstance(moe_model, torch.nn.Module) else moe_model.__name__
graph_module= GraphModule(tracer.root, graph, name)

from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(graph_module, "brt_model")
with open("a.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())




In [3]:
print(graph_module.code)
models = graph_module.named_modules()
# for node in graph.nodes:
#     print(node.target, node.args , node.users)

for node in graph.nodes:
    if node.target == "moe.gather_router":
        print(node.args)
        new_args = ([node.args[0][1]],)
        node.args = new_args
        print(node.args)

graph.eliminate_dead_code()
new_graph_module = GraphModule(tracer.root, graph, name)

print(new_graph_module.code)


from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(new_graph_module, "new_brt_model")
with open("b.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())



torch.fx._symbolic_trace.wrap("brt_app_rand_rand_gate")

def forward(self, x):
    rand_gate = brt_app_rand_rand_gate(x, 2)
    moe_rand_scatter_scatter_router = self.moe.rand_scatter.scatter_router(x, rand_gate);  x = rand_gate = None
    getitem = moe_rand_scatter_scatter_router[0]
    moe_expert1 = self.moe.expert1(getitem);  getitem = None
    getitem_1 = moe_rand_scatter_scatter_router[1];  moe_rand_scatter_scatter_router = None
    moe_expert2 = self.moe.expert2(getitem_1);  getitem_1 = None
    moe_gather_router = self.moe.gather_router([moe_expert1, moe_expert2]);  moe_expert1 = moe_expert2 = None
    return moe_gather_router
    
([moe_expert1, moe_expert2],)
([moe_expert2],)

torch.fx._symbolic_trace.wrap("brt_app_rand_rand_gate")

def forward(self, x):
    rand_gate = brt_app_rand_rand_gate(x, 2)
    moe_rand_scatter_scatter_router = self.moe.rand_scatter.scatter_router(x, rand_gate);  x = rand_gate = None
    getitem_1 = moe_rand_scatter_scatter_router[1];  moe_rand_scatte

In [4]:
from brt.trace.graph import GraphTracer
from torch.fx.graph_module import GraphModule
from brt.runtime import BRT_CACHE_PATH
tracer = GraphTracer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__
graph_module= GraphModule(tracer.root, graph, name)
models = graph_module.named_modules()

for node in graph.nodes:
    if node.target == "moe.gather_router":
        new_args = ([node.args[0][1]],)
        node.args = new_args
        print(node.args)

graph.eliminate_dead_code()
new_graph_module = GraphModule(tracer.root, graph, name)
indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = new_graph_module(indata)
print(outdata)

([moe_expert2],)
4
tensor([[ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.]])


In [5]:
import torch
import torch.nn as nn
from brt.runtime import log
from brt.app.rand import RandScatter
from brt.router import GatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_scatter = RandScatter(path_num=2, capturing=True)
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = GatherRouter(
            fabric_kwargs={"sparse": True}, capturing=True
        )
        self.iteration = 1
        self.ret = 1

    def forward(self, x):
        route_results = self.rand_scatter(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x


class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        x = self.moe(x)
        return x


moe_model = SimpleModel()

indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = moe_model(indata)
print(outdata)
print(moe_model.moe.rand_scatter.scatter_router.load_history)
print(moe_model.moe.gather_router.load_history)


Starting scatter_router_1
4
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
        [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.]])
[1. 3.]
[1. 3.]


In [6]:
from brt.passes import get_pass

eliminate_pass_cls = get_pass("dead_path_eliminate")
eliminate_pass = eliminate_pass_cls(moe_model)
eliminate_pass.run_on_graph()
new_moe_model = eliminate_pass.finalize()
print(new_moe_model.code)
indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = new_moe_model(indata)
print(outdata)


ImportError: cannot import name 'get_pass' from 'brt.passes' (/home/yichuanjiaoda/brainstorm_project/brainstorm/python/brt/passes/__init__.py)

In [None]:
import torch
import torch.nn as nn
import torch.fx as fx

class SimpleNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.indentity = nn.Identity()

    def forward(self, y, z):
        y = self.indentity(y)
        z = self.indentity(z)
        return y, z

simple_net = SimpleNet()
x = torch.randn(2, 3)
z = torch.randn(2, 3)
y = simple_net(x,z)
print(y)


(tensor([[ 1.5824, -0.4830,  1.0120],
        [ 0.9769, -0.0374,  1.3430]]), tensor([[ 0.2445,  1.5441, -0.0383],
        [ 1.4352,  0.0041,  0.6622]]))


In [1]:
import torch
import torch.nn as nn
from brt.runtime import log
from brt.app.rand import RandScatter
from brt.router import GatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_scatter = RandScatter(path_num=2)
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = GatherRouter()
        self.iteration = 1
        self.ret = 1
        self.relu=nn.ReLU()

    def forward(self, x):
        route_results = self.rand_scatter(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x_0=self.relu(x_0)
        x_1=self.relu(x_1)
        
        
        x = self.gather_router([x_0, x_1])
        
        x= self.relu(x)
        route_results = self.rand_scatter(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x_0=self.relu(x_0)
        x_1=self.relu(x_1)
        
        
        x = self.gather_router([x_0, x_1])
        x=torch.mul(x,2)
        return x

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        x = self.moe(x)
        return x


moe_model = SimpleModel()

indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = moe_model(indata)
print(outdata)


  from .autonotebook import tqdm as notebook_tqdm


Starting scatter_router_1
score:  tensor([[-0.2148, -1.8816],
        [-0.7317,  1.6150],
        [-1.4599,  1.6989],
        [-0.2382,  1.2885]])
score:  tensor([[-0.2148, -1.8816],
        [-0.7317,  1.6150],
        [-1.4599,  1.6989],
        [-0.2382,  1.2885]])
tensor([[ 0.,  2.,  4.,  6.,  8., 10., 12., 14., 16., 18.],
        [20., 22., 24., 26., 28., 30., 32., 34., 36., 38.],
        [40., 42., 44., 46., 48., 50., 52., 54., 56., 58.],
        [60., 62., 64., 66., 68., 70., 72., 74., 76., 78.]])


In [2]:
from brt.trace.graph import GraphTracer
from torch.fx.graph_module import GraphModule
from brt.runtime import BRT_CACHE_PATH
tracer = GraphTracer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__ if isinstance(moe_model, torch.nn.Module) else moe_model.__name__
graph_module= GraphModule(tracer.root, graph, name)

from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(graph_module, "brt_model")
with open("a.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())

score:  tensor([[-0.2148, -1.8816],
        [-0.7317,  1.6150],
        [-1.4599,  1.6989],
        [-0.2382,  1.2885]])
score:  tensor([[-0.2148, -1.8816],
        [-0.7317,  1.6150],
        [-1.4599,  1.6989],
        [-0.2382,  1.2885]])


In [3]:
print(graph_module.code)
models = graph_module.named_modules()
# for node in graph.nodes:
#     print(node.target, node.args , node.users)

for node in graph.nodes:
    if node.target == "moe.gather_router":
        print(node.args)
        new_args = ([node.args[0][1]],)
        node.args = new_args
        print(node.args)

graph.eliminate_dead_code()
new_graph_module = GraphModule(tracer.root, graph, name)

print(new_graph_module.code)


from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(new_graph_module, "new_brt_model")
with open("b.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())





def forward(self, x):
    _tensor_constant0 = self._tensor_constant0
    moe_rand_scatter_scatter_router = self.moe.rand_scatter.scatter_router(x, _tensor_constant0);  x = _tensor_constant0 = None
    getitem = moe_rand_scatter_scatter_router[0]
    moe_expert1 = self.moe.expert1(getitem);  getitem = None
    getitem_1 = moe_rand_scatter_scatter_router[1];  moe_rand_scatter_scatter_router = None
    moe_expert2 = self.moe.expert2(getitem_1);  getitem_1 = None
    moe_relu = self.moe.relu(moe_expert1);  moe_expert1 = None
    moe_relu_1 = self.moe.relu(moe_expert2);  moe_expert2 = None
    moe_gather_router = self.moe.gather_router([moe_relu, moe_relu_1]);  moe_relu = moe_relu_1 = None
    moe_relu_2 = self.moe.relu(moe_gather_router);  moe_gather_router = None
    _tensor_constant1 = self._tensor_constant1
    moe_rand_scatter_scatter_router_1 = self.moe.rand_scatter.scatter_router(moe_relu_2, _tensor_constant1);  moe_relu_2 = _tensor_constant1 = None
    getitem_2 = moe_rand_scatte

In [4]:
m = nn.ReLU()
input = torch.randn(2)
print(input)
output = m(input)
print(output)

tensor([1.7595, 1.0445])
tensor([1.7595, 1.0445])


In [1]:
import torch
from torch.fx import symbolic_trace
import operator
from torch import nn

"""
How to Replace One Op With Another
1. Iterate through all Nodes in your GraphModule's Graph.
2. Determine if the current Node should be replaced. (Suggested: match
on the Node's ``target`` attribute).
3. Create a replacement Node and add it to the Graph.
4. Use the FX built-in ``replace_all_uses_with`` to replace all uses of
the current Node with the replacement.
5. Delete the old Node from the graph.
6. Call ``recompile`` on the GraphModule. This updates the generated
Python code to reflect the new Graph state.
Currently, FX does not provide any way to guarantee that replaced
operators are syntactically valid. It's up to the user to confirm that
any new operators will work with the existing operands.
The following code demonstrates an example of replacing any instance of
addition with a bitwise AND.
To examine how the Graph evolves during op replacement, add the
statement `print(traced.graph)` after the line you want to inspect.
Alternatively, call `traced.graph.print_tabular()` to see the IR in a
tabular format.
"""

# Define model
class NeuralNetwork(nn.Module):
    def __init__(self):
        super().__init__()
        self.indentify1= nn.Identity()
        self.indentify2= nn.Identity()
        
        
        
        self.flatten = nn.Flatten()
        self.linear_relu_stack = nn.Sequential(
            nn.Linear(28*28, 512),
            nn.ReLU(),
            nn.Linear(512, 512),
            nn.ReLU(),
            nn.Linear(512, 10)
        )

    def forward(self, x):
        x = self.flatten(x)
        logits = self.linear_relu_stack(x)
        return logits

# Sample module
class M(torch.nn.Module):
    def forward(self, x, y):
        return x + y, torch.add(x, y), x.add(y)

# Symbolically trace an instance of the module
traced = symbolic_trace(M())

# As demonstrated in the above example, there are several different ways
# to denote addition. The possible cases are:
#     1. `x + y` - A `call_function` Node with target `operator.add`.
#         We can match for equality on that `operator.add` directly.
#     2. `torch.add(x, y)` - A `call_function` Node with target
#         `torch.add`. Similarly, we can match this function directly.
#     3. `x.add(y)` - The Tensor method call, whose target we can match
#         as a string.

patterns = set([operator.add, torch.add, "add"])

print(traced.graph)
# Go through all the nodes in the Graph
for n in traced.graph.nodes:
    # If the target matches one of the patterns
    if any(n.target == pattern for pattern in patterns):
        # Set the insert point, add the new node, and replace all uses
        # of `n` with the new node
        with traced.graph.inserting_after(n):
            new_node = traced.graph.call_function(torch.bitwise_and, n.args, n.kwargs)
            import pdb; pdb.set_trace()
            n.replace_all_uses_with(new_node)
        # Remove the old node from the graph
        traced.graph.erase_node(n)

# Don't forget to recompile!
traced.recompile()
print(traced.graph)

  from .autonotebook import tqdm as notebook_tqdm


graph():
    %x : [#users=3] = placeholder[target=x]
    %y : [#users=3] = placeholder[target=y]
    %add : [#users=1] = call_function[target=operator.add](args = (%x, %y), kwargs = {})
    %add_1 : [#users=1] = call_function[target=torch.add](args = (%x, %y), kwargs = {})
    %add_2 : [#users=1] = call_method[target=add](args = (%x, %y), kwargs = {})
    return (add, add_1, add_2)
> [0;32m/tmp/ipykernel_32228/1852051842.py[0m(80)[0;36m<module>[0;34m()[0m
[0;32m     78 [0;31m            [0mnew_node[0m [0;34m=[0m [0mtraced[0m[0;34m.[0m[0mgraph[0m[0;34m.[0m[0mcall_function[0m[0;34m([0m[0mtorch[0m[0;34m.[0m[0mbitwise_and[0m[0;34m,[0m [0mn[0m[0;34m.[0m[0margs[0m[0;34m,[0m [0mn[0m[0;34m.[0m[0mkwargs[0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m     79 [0;31m            [0;32mimport[0m [0mpdb[0m[0;34m;[0m [0mpdb[0m[0;34m.[0m[0mset_trace[0m[0;34m([0m[0;34m)[0m[0;34m[0m[0;34m[0m[0m
[0m[0;32m---> 80 [0;31m            