In [19]:
import torch
import torch.nn as nn
from brt.runtime import log
from brt.app import RandScatter
from brt.router import GatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_scatter = RandScatter(path_num=2)
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = GatherRouter()
        self.iteration = 1
        self.ret = 1

    def forward(self, x):
        route_results = self.rand_scatter(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        x = self.moe(x)
        return x


moe_model = SimpleModel()

indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = moe_model(indata)
# print(outdata[0])


In [28]:
from brt.trace.graph import GraphTracer
from torch.fx.graph_module import GraphModule
from brt.runtime import BRT_CACHE_PATH
tracer = GraphTracer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__ if isinstance(moe_model, torch.nn.Module) else moe_model.__name__
graph_module= GraphModule(tracer.root, graph, name)
print(graph_module.code)
models = graph_module.named_modules()
# for node in graph.nodes:
#     print(node.target, node.args , node.users)

for node in graph.nodes:
    if node.target == "moe.gather_router":
        # pass
        new_args = ([node.args[0][1]],)
        node.args = new_args
        print(node.args)

graph.eliminate_dead_code()
new_graph_module = GraphModule(tracer.root, graph, name)

print(new_graph_module.code)




torch.fx._symbolic_trace.wrap("brt_app_rand_rand_gate")

def forward(self, x):
    rand_gate = brt_app_rand_rand_gate(x, 2)
    moe_rand_scatter_scatter_router = self.moe.rand_scatter.scatter_router(x, rand_gate);  x = rand_gate = None
    getitem = moe_rand_scatter_scatter_router[0]
    moe_expert1 = self.moe.expert1(getitem);  getitem = None
    getitem_1 = moe_rand_scatter_scatter_router[1];  moe_rand_scatter_scatter_router = None
    moe_expert2 = self.moe.expert2(getitem_1);  getitem_1 = None
    moe_gather_router = self.moe.gather_router([moe_expert1, moe_expert2]);  moe_expert1 = moe_expert2 = None
    return moe_gather_router
    
([moe_expert2],)

torch.fx._symbolic_trace.wrap("brt_app_rand_rand_gate")

def forward(self, x):
    rand_gate = brt_app_rand_rand_gate(x, 2)
    moe_rand_scatter_scatter_router = self.moe.rand_scatter.scatter_router(x, rand_gate);  x = rand_gate = None
    getitem_1 = moe_rand_scatter_scatter_router[1];  moe_rand_scatter_scatter_router = None
    mo

In [7]:
from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(graph_module, "brt_model")
with open("a.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())
