In [None]:
import torch
import torch.fx
import torch.nn as nn


class WrapModule(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, x):
        return x


class SampleModule(nn.Module):
    def __init__(self):
        super().__init__()
        self.wrap = WrapModule()

    def forward(self, x):
        return self.wrap(x)


sample_module = SampleModule()
traced_module = torch.fx.symbolic_trace(sample_module)
print(traced_module.graph)

In [8]:
import torch
import torch.nn as nn
from brt.common import log
from brt.routers.app import RandScatterRouter
from brt.routers import GatherRouter


log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandScatterRouter(path_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = GatherRouter(path_num=2)

    def forward(self, x):
        route_results = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x


class MoEModel(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


moe_model = MoEModel()
indata = torch.arange(0, 30, dtype=torch.float32).view(3, 10)
outdata = moe_model(indata)
print(outdata)


tensor([[ -0.8853,  -2.4841,   2.3834,  -4.2445,   3.2223,  -3.9983,   1.2346,
          -0.1389,  -6.2976,  -1.4735],
        [ -6.7856,  -1.6684,   6.8806,  -8.2467,   5.0737, -10.9107,   3.6119,
          -7.4432, -16.8607,  -7.0406],
        [-12.6859,  -0.8526,  11.3779, -12.2490,   6.9252, -17.8231,   5.9893,
         -14.7474, -27.4239, -12.6077]], grad_fn=<AliasBackward0>)


In [9]:
from brt.transform.tracer import BRTTRacer
from torch.fx.graph_module import GraphModule
tracer = BRTTRacer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__ if isinstance(moe_model, torch.nn.Module) else moe_model.__name__
graph_module= GraphModule(tracer.root, graph, name)
print(graph_module.graph)
outdata = graph_module(indata)
print(outdata)


graph():
    %x : [#users=1] = placeholder[target=x]
    %moe_scatter_router : [#users=2] = call_module[target=moe.scatter_router](args = (%x,), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%moe_scatter_router, 0), kwargs = {})
    %moe_expert1 : [#users=1] = call_module[target=moe.expert1](args = (%getitem,), kwargs = {})
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%moe_scatter_router, 1), kwargs = {})
    %moe_expert2 : [#users=1] = call_module[target=moe.expert2](args = (%getitem_1,), kwargs = {})
    %moe_gather_router : [#users=1] = call_module[target=moe.gather_router](args = ([%moe_expert1, %moe_expert2],), kwargs = {})
    return moe_gather_router
tensor([[ -0.8853,  -2.4841,   2.3834,  -4.2445,   3.2223,  -3.9983,   1.2346,
          -0.1389,  -6.2976,  -1.4735],
        [ -0.7183,  -9.7317,   3.7233,   9.1266, -12.8416,  -2.9161, -14.1668,
           5.3655,  -0.9843,   7.6397],
        [-12.6859,  -0