In [1]:
import torch
import torch.nn as nn
from brt.runtime import log
from brt.app.rand import RandScatter
from brt.router import GatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_scatter = RandScatter(path_num=2)
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = GatherRouter()
        self.iteration = 1
        self.ret = 1

    def forward(self, x):
        route_results = self.rand_scatter(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        x = self.moe(x)
        return x


moe_model = SimpleModel()

indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = moe_model(indata)
print(outdata)


tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
        [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.]])


In [2]:
from brt.trace.graph import GraphTracer
from torch.fx.graph_module import GraphModule
from brt.runtime import BRT_CACHE_PATH
tracer = GraphTracer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__ if isinstance(moe_model, torch.nn.Module) else moe_model.__name__
graph_module= GraphModule(tracer.root, graph, name)

from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(graph_module, "brt_model")
with open("a.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())




In [3]:
print(graph_module.code)
models = graph_module.named_modules()
# for node in graph.nodes:
#     print(node.target, node.args , node.users)

for node in graph.nodes:
    if node.target == "moe.gather_router":
        print(node.args)
        new_args = ([node.args[0][1]],)
        node.args = new_args
        print(node.args)

graph.eliminate_dead_code()
new_graph_module = GraphModule(tracer.root, graph, name)

print(new_graph_module.code)


from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(new_graph_module, "new_brt_model")
with open("b.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())



torch.fx._symbolic_trace.wrap("brt_app_rand_rand_gate")

def forward(self, x):
    rand_gate = brt_app_rand_rand_gate(x, 2)
    moe_rand_scatter_scatter_router = self.moe.rand_scatter.scatter_router(x, rand_gate);  x = rand_gate = None
    getitem = moe_rand_scatter_scatter_router[0]
    moe_expert1 = self.moe.expert1(getitem);  getitem = None
    getitem_1 = moe_rand_scatter_scatter_router[1];  moe_rand_scatter_scatter_router = None
    moe_expert2 = self.moe.expert2(getitem_1);  getitem_1 = None
    moe_gather_router = self.moe.gather_router([moe_expert1, moe_expert2]);  moe_expert1 = moe_expert2 = None
    return moe_gather_router
    
([moe_expert1, moe_expert2],)
([moe_expert2],)

torch.fx._symbolic_trace.wrap("brt_app_rand_rand_gate")

def forward(self, x):
    rand_gate = brt_app_rand_rand_gate(x, 2)
    moe_rand_scatter_scatter_router = self.moe.rand_scatter.scatter_router(x, rand_gate);  x = rand_gate = None
    getitem_1 = moe_rand_scatter_scatter_router[1];  moe_rand_scatte

In [4]:
from brt.trace.graph import GraphTracer
from torch.fx.graph_module import GraphModule
from brt.runtime import BRT_CACHE_PATH
tracer = GraphTracer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__
graph_module= GraphModule(tracer.root, graph, name)
models = graph_module.named_modules()

for node in graph.nodes:
    if node.target == "moe.gather_router":
        new_args = ([node.args[0][1]],)
        node.args = new_args
        print(node.args)

graph.eliminate_dead_code()
new_graph_module = GraphModule(tracer.root, graph, name)
indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = new_graph_module(indata)
print(outdata)

([moe_expert2],)
tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.],
        [ 0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.,  0.]])


In [12]:
import torch
import torch.nn as nn
from brt.runtime import log
from brt.app.rand import RandScatter
from brt.router import GatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_scatter = RandScatter(path_num=2, capturing=True)
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = GatherRouter(
            fabric_kwargs={"sparse": True}, capturing=True
        )
        self.iteration = 1
        self.ret = 1

    def forward(self, x):
        route_results = self.rand_scatter(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x


class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        x = self.moe(x)
        return x


moe_model = SimpleModel()

indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = moe_model(indata)
print(outdata)
print(moe_model.moe.rand_scatter.scatter_router.load_history)
print(moe_model.moe.gather_router.load_history)


tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
        [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.]])
tensor([0., 4.], dtype=torch.float64)
tensor([0., 4.], dtype=torch.float64)


In [15]:
from brt.passes import get_pass

eliminate_pass_cls = get_pass("dead_path_eliminate")
eliminate_pass = eliminate_pass_cls(moe_model)
eliminate_pass.run_on_graph()
new_moe_model = eliminate_pass.finalize()
print(new_moe_model.code)
indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = new_moe_model(indata)
print(outdata)



torch.fx._symbolic_trace.wrap("brt_app_rand_rand_gate")

def forward(self, x):
    rand_gate = brt_app_rand_rand_gate(x, 2)
    moe_rand_scatter_scatter_router = self.moe.rand_scatter.scatter_router(x, rand_gate);  x = rand_gate = None
    getitem = moe_rand_scatter_scatter_router[0];  moe_rand_scatter_scatter_router = None
    moe_expert1 = self.moe.expert1(getitem);  getitem = None
    moe_gather_router = self.moe.gather_router([moe_expert1]);  moe_expert1 = None
    return moe_gather_router
    
ProtoTensor([[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.]])
tag_stack: [tensor([[1]])]
load stack: [4])


In [1]:
import torch
import torch.nn as nn
import torch.fx as fx

class SimpleNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.indentity = nn.Identity()

    def forward(self, y, z):
        y = self.indentity(y)
        z = self.indentity(z)
        return y, z

simple_net = SimpleNet()
x = torch.randn(2, 3)
z = torch.randn(2, 3)
y = simple_net(x,z)
print(y)


(tensor([[ 1.5824, -0.4830,  1.0120],
        [ 0.9769, -0.0374,  1.3430]]), tensor([[ 0.2445,  1.5441, -0.0383],
        [ 1.4352,  0.0041,  0.6622]]))


In [2]:
from brt.runtime.memory_plan import EventCollector, EventEmitter
from brt.trace.graph import symbolic_trace
from torch.fx import Node
import operator

graph_m = symbolic_trace(simple_net)

event_emitter = EventEmitter(1)
event_collector = EventCollector(1)
setattr(graph_m, "event_emitter", event_emitter)
setattr(graph_m, "event_collector", event_collector)

node: Node
for node in graph_m.graph.nodes:
    if node.target == "indentity":
        with graph_m.graph.inserting_after(node):
            emmit_node = graph_m.graph.call_module(
                "event_emitter", (node,)
            )
        with graph_m.graph.inserting_after(emmit_node):
            new_node = graph_m.graph.call_function(operator.getitem, (emmit_node, 0))
        node.replace_all_uses_with(new_node, lambda x: x != emmit_node)

        graph_m.graph.lint()
graph_m.recompile()
print(list(graph_m.graph.nodes))
print(graph_m.code)
y = graph_m(x,z)

print(y)

[y, z, indentity, event_emitter, getitem, indentity_1, event_emitter_1, getitem_1, output]



def forward(self, y, z):
    indentity = self.indentity(y);  y = None
    event_emitter = self.event_emitter(indentity);  indentity = None
    getitem = event_emitter[0];  event_emitter = None
    indentity_1 = self.indentity(z);  z = None
    event_emitter_1 = self.event_emitter(indentity_1);  indentity_1 = None
    getitem_1 = event_emitter_1[0];  event_emitter_1 = None
    return (getitem, getitem_1)
    
(tensor([[ 1.5824, -0.4830,  1.0120],
        [ 0.9769, -0.0374,  1.3430]]), tensor([[ 0.2445,  1.5441, -0.0383],
        [ 1.4352,  0.0041,  0.6622]]))


In [11]:

class InjectedSimpleNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.event_collector = EventCollector(2)
        self.indentity = nn.Identity()

    def forward(self, y, z):
        y, e1, e2 = self.event_emitter(y)
        y = self.indentity(y)
        z = self.indentity(z)
        y = self.event_collector(y, 0, 1)
        return y, z
