In [2]:
import torch
import torch.nn as nn
from brt.runtime import log
from brt.app import RandScatter
from brt.router import GatherRouter


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.rand_scatter = RandScatter(rand_path_num=2)
        self.expert1 = nn.Identity()
        self.expert2 = nn.Identity()
        self.gather_router = GatherRouter(path_num=2)

    def forward(self, x):
        route_results = self.rand_scatter(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1])
        return x


moe_model = MoE()

indata = torch.arange(0, 30, dtype=torch.float32).view(3, 10)
outdata = moe_model(indata)
print(outdata)


tensor([[ 0.,  1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9.],
        [10., 11., 12., 13., 14., 15., 16., 17., 18., 19.],
        [20., 21., 22., 23., 24., 25., 26., 27., 28., 29.]])


In [4]:
import brt
print(RandScatter.__module__)
print(moe_model.rand_scatter.scatter_router.trace_kwargs)

brt.routers.app.rand
{'path_num': 2, 'protocol_type': 'topk', 'fabric_type': 'dispatch', 'k': 1}


In [5]:
from brt.trace.graph import GraphTracer
from torch.fx.graph_module import GraphModule
from brt.runtime import BRT_CACHE_PATH
tracer = GraphTracer()
graph = tracer.trace(moe_model)
name = moe_model.__class__.__name__ if isinstance(moe_model, torch.nn.Module) else moe_model.__name__
graph_module= GraphModule(tracer.root, graph, name)
print(graph_module.graph)
print(graph_module.code)
outdata = graph_module(indata)
print(outdata)
graph_module.to_folder(BRT_CACHE_PATH/"transformed_model")

graph():
    %x : [#users=2] = placeholder[target=x]
    %rand_gate : [#users=1] = call_function[target=brt.routers.app.rand.rand_gate](args = (%x, 2), kwargs = {})
    %scatter_router_scatter_router : [#users=2] = call_module[target=scatter_router.scatter_router](args = (%x, %rand_gate), kwargs = {})
    %getitem : [#users=1] = call_function[target=operator.getitem](args = (%scatter_router_scatter_router, 0), kwargs = {})
    %expert1 : [#users=1] = call_module[target=expert1](args = (%getitem,), kwargs = {})
    %getitem_1 : [#users=1] = call_function[target=operator.getitem](args = (%scatter_router_scatter_router, 1), kwargs = {})
    %expert2 : [#users=1] = call_module[target=expert2](args = (%getitem_1,), kwargs = {})
    %gather_router : [#users=1] = call_module[target=gather_router](args = ([%expert1, %expert2],), kwargs = {})
    return gather_router

torch.fx._symbolic_trace.wrap("brt_routers_app_rand_rand_gate")

def forward(self, x):
    rand_gate = brt_routers_app_rand_rand



In [5]:
from typing import Callable
from brt.runtime.registry import Registry

class Foo1:
    pass


def register_Foo1(name: str)-> Callable:
    return Registry.register_cls(name, Foo1)

def make_Foo1(name: str, **kwargs) -> Foo1:
    return Registry.get_cls(name, Foo1)(**kwargs)


@register_Foo1("SubFoo1")
class SubFoo1(Foo1):
    pass

@register_Foo1("Sub2Foo1")
class Sub2Foo1(Foo1):
    pass

sub_foo1 = make_Foo1("SubFoo1")
print(sub_foo1)
sub_foo2 = make_Foo1("Sub2Foo1")
print(sub_foo2)


<__main__.SubFoo1 object at 0x7f1605beda90>
<__main__.Sub2Foo1 object at 0x7f1605bed940>


In [None]:
import torch.nn

model = nn.Transformer()