In [1]:
import torch
import torch.nn as nn
from brt.common import log
from brt.routers.app import RandScatterRouter
from brt.routers import GatherRouter


log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandScatterRouter(path_num=2)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = GatherRouter(path_num=2)

    def forward(self, x):
        route_results = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x


class MoEModel(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


moe_model = MoEModel()
indata = torch.arange(0, 30, dtype=torch.float32).view(3, 10)
outdata = moe_model(indata)
print(outdata)


tensor([[ -0.4084,  -3.6441,   1.4626,   1.5102,  -0.4632,  -4.9812,  -3.4973,
           7.3618,  -3.3022,  -5.5773],
        [ 13.0121,  -5.2027,  12.9716,   2.6283,  11.0787,   1.8211,  -5.1149,
          -0.7463,   3.1213,  -6.2672],
        [ 21.5346,  -8.4781,  23.2726,   4.0588,  16.2145,   0.3744,  -8.7244,
           0.7261,   7.0585, -12.7260]], grad_fn=<AliasBackward0>)


In [2]:
print(RandScatterRouter.__module__)

brt.routers.app.rand


In [2]:
from brt.transform.tracer import BRTTRacer
from torch.fx.graph_module import GraphModule
from brt.common import BRT_CACHE_PATH
tracer = BRTTRacer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__ if isinstance(moe_model, torch.nn.Module) else moe_model.__name__
graph_module= GraphModule(tracer.root, graph, name)
print(graph_module.graph)
print(graph_module.code)
outdata = graph_module(indata)
print(outdata)
graph_module.to_folder(BRT_CACHE_PATH/"transformed_model")

graph():
    %x : [#users=2] = placeholder[target=x]
    %rand_gate : [#users=2] = call_function[target=brt.routers.app.rand.rand_gate](args = (%x, 2), kwargs = {})
    %moe_scatter_router_scatter_router_protocol : [#users=1] = call_module[target=moe.scatter_router.scatter_router.protocol](args = (%rand_gate,), kwargs = {})
    %moe_scatter_router_scatter_router_fabric : [#users=2] = call_module[target=moe.scatter_router.scatter_router.fabric](args = (%x, %moe_scatter_router_scatter_router_protocol, %rand_gate), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%moe_scatter_router_scatter_router_fabric, 0), kwargs = {})
    %moe_expert1 : [#users=1] = call_module[target=moe.expert1](args = (%getitem,), kwargs = {})
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%moe_scatter_router_scatter_router_fabric, 1), kwargs = {})
    %moe_expert2 : [#users=1] = call_module[target=moe.expert2](args = (%getitem_1,), kwargs = {})
 

