In [5]:
import torch
import torch.nn as nn
from brt.runtime import log
from brt.app import RandScatter
from brt.router import GatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_scatter = RandScatter(path_num=2)
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = GatherRouter()
        self.iteration = 2
        self.ret = 1

    def forward(self, x):
        for i in range(self.iteration):
            route_results = self.rand_scatter(x)
            x_0 = self.expert1(route_results[0])
            x_1 = self.expert2(route_results[1])
            x = self.gather_router([x_0, x_1])
        if self.ret == 1:
            return [[x],[]]
        else:
            return [[x], None]

class SimpleModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.moe = MoE()
    
    def forward(self, x):
        x = self.moe(x)
        return x[0][0]

moe_model = SimpleModel()

indata = torch.arange(0, 40, dtype=torch.float32).view(4, 10)
outdata = moe_model(indata)
# print(outdata[0])


[ProtoTensor(
data: ProtoTensor([[0., 1., 2., 3., 4., 5., 6., 7., 8., 9.]])
tag_stack: [tensor([[0]])]
load stack: [4]), ProtoTensor(
data: ProtoTensor([[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
             [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.],
             [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.]])
tag_stack: [tensor([[1],
        [2],
        [3]])]
load stack: [4])]
[ProtoTensor(
data: ProtoTensor([[10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
             [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.]])
tag_stack: [tensor([[1],
        [2]])]
load stack: [4]), ProtoTensor(
data: ProtoTensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
             [30., 31., 32., 33., 34., 35., 36., 37., 38., 39.]])
tag_stack: [tensor([[0],
        [3]])]
load stack: [4])]


In [3]:
from brt.trace.graph import GraphTracer
from torch.fx.graph_module import GraphModule
from brt.runtime import BRT_CACHE_PATH
tracer = GraphTracer()
print(moe_model)
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__ if isinstance(moe_model, torch.nn.Module) else moe_model.__name__
graph_module= GraphModule(tracer.root, graph, name)
print(graph_module.graph)
print(graph_module.code)
outdata = graph_module(indata)
print(outdata)
graph_module.to_folder(BRT_CACHE_PATH/"transformed_model")

SimpleModel(
  (moe): MoE(
    (rand_scatter): RandScatter(
      (scatter_router): ScatterRouter(
        (protocol): TopKProtocol()
        (fabric): DispatchFabric()
      )
    )
    (expert1): Identity()
    (expert2): Identity()
    (gather_router): GatherRouter(
      (fabric): CombineFabric()
    )
  )
)
graph():
    %x : [#users=2] = placeholder[target=x]
    %rand_gate : [#users=1] = call_function[target=brt.app.rand.rand_gate](args = (%x, 2), kwargs = {})
    %moe_rand_scatter_scatter_router : [#users=2] = call_module[target=moe.rand_scatter.scatter_router](args = (%x, %rand_gate), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%moe_rand_scatter_scatter_router, 0), kwargs = {})
    %moe_expert1 : [#users=1] = call_module[target=moe.expert1](args = (%getitem,), kwargs = {})
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%moe_rand_scatter_scatter_router, 1), kwargs = {})
    %moe_expert2 : [#users=1] = call



In [7]:
from torch.fx.passes.graph_drawer import FxGraphDrawer

graph_drawer = FxGraphDrawer(graph_module, "brt_model")
with open("a.svg", "wb") as f:
    f.write(graph_drawer.get_dot_graph().create_svg())
