In [2]:
import torch
from brt.common import log
import brt
import brt.nn as nn
from brt.router import RandomScatterRouter, GatherRouter
from brt.frontend import build_graph
from brt.primitive.helper import symbolize

log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


class MoE(nn.Module):
    def __init__(self):
        super().__init__()
        self.scatter_router = RandomScatterRouter(route_num=2, gran_dim=10)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = GatherRouter(route_num=2, gran_dim=10, reduction="add")

    def forward(self, x):
        route_results, reverse_indice, reverse_shape = self.scatter_router(x)
        x_0 = self.expert1(route_results[0])
        x_1 = self.expert2(route_results[1])
        x = self.gather_router([x_0, x_1], reverse_indice, reverse_shape)
        return x


class MoEModel(nn.Module):
    def __init__(
        self,
    ):
        super().__init__()
        self.moe = MoE()

    def forward(self, x):
        return self.moe(x)


moe_model = MoEModel()
x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
x = moe_model(x)
print(x)
script_moe_model = torch.jit.script(symbolize(moe_model))
sm_graph = script_moe_model.graph
print(sm_graph)
print(moe_model.__class__.__name__)
ir_moe_model = build_graph(moe_model)


[tensor([], size=(0, 10)), tensor([[ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.],
        [ 1.,  2.,  3.,  4.,  5.,  6.,  7.,  8.,  9., 10.]])]
[None, tensor([0, 1])]
tensor([ 2, 10])
tensor([[ 4.4681, -3.5787, -0.1673, -2.9927, -1.4656, -1.1043, -2.0624,  4.4407,
         -0.3591,  0.4295],
        [ 4.4681, -3.5787, -0.1673, -2.9927, -1.4656, -1.1043, -2.0624,  4.4407,
         -0.3591,  0.4295]], grad_fn=<ViewBackward0>)
graph(%self : __torch__.___torch_mangle_16.MoEModel,
      %x.1 : Tensor):
  %moe : __torch__.___torch_mangle_15.MoE = prim::GetAttr[name="moe"](%self)
  %4 : Tensor = prim::CallMethod[name="forward"](%moe, %x.1) # /tmp/ipykernel_3843804/1530002903.py:40:15
  return (%4)

MoEModel


In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import brt.nn

import torch
import brt


class m_model(nn.Module):
    def __init__(self):
        super().__init__()
        self._moe__scatter_router = brt.router.scatter_router.RandomScatterRouter(route_num=2)
        self._moe__expert2 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._moe__expert1 = torch.nn.modules.linear.Linear(in_features=10, out_features=10)
        self._moe__gather_router = brt.router.gather_router.RandomGatherRouter(route_num=2)
        self._mapping_ = {'_moe__scatter_router': None, '_moe__expert2': None, '_moe__expert1': None, '_moe__gather_router': None}

    def forward(self, x__1):
        _moe__Constant2 = 0
        _moe__Constant3 = 1
        _moe__scatter_router = self._moe__scatter_router(x__1)
        _moe__TupleUnpack6 = _moe__scatter_router
        _moe__aten____getitem__10 = _moe__TupleUnpack6[0][_moe__Constant3]
        _moe__aten____getitem__8 = _moe__TupleUnpack6[0][_moe__Constant2]
        _moe__expert2 = self._moe__expert2(_moe__aten____getitem__10)
        _moe__expert1 = self._moe__expert1(_moe__aten____getitem__8)
        _moe__ListConstruct12 = [_moe__expert1, _moe__expert2]
        _moe__gather_router = self._moe__gather_router(_moe__ListConstruct12, _moe__TupleUnpack6[1], _moe__TupleUnpack6[2])
        return _moe__gather_router
    
moe = m_model()
x = torch.Tensor([[1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]])
x = moe(x)

In [1]:
import torch
from brt.common import log
import brt
import brt.nn as nn
from brt.router import MultiplexScatterRouter, AggregateGatherRouter
from brt.frontend import build_graph
from brt.primitive.helper import symbolize
from brt.backend.pytorch import model_to_script

log.set_level("frontend", "INFO")
log.set_level("backend", "INFO")
log.set_level("ir", "INFO")


class DynamicRouting(nn.Module):
    def __init__(self, route_num):
        super().__init__()
        self.scatter_router_0 = MultiplexScatterRouter(route_num=route_num, gran_dim=10)
        self.scatter_router_1 = MultiplexScatterRouter(route_num=route_num, gran_dim=10)
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = AggregateGatherRouter(route_num=route_num)
        self.gather_router_1 = AggregateGatherRouter(route_num=route_num)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        y_0 = self.expert2(route_results_y[0])
        x_1 = self.expert3(route_results_x[1])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, x_1])
        y = self.gather_router_1([y_0, y_1])
        return x, y


dy_model = DynamicRouting(2)

x = torch.randn((2, 10))
y = torch.randn((2, 10))
print(x)
print(y)

x, y = dy_model(x, y)


tensor([[-0.0183, -0.1135, -0.2785, -0.5676, -0.5613,  0.8165, -1.1355, -1.8280,
          0.7946,  0.3524],
        [ 1.4881, -0.2153, -0.9044,  0.6914, -0.7613,  1.6971, -1.8312, -0.1526,
         -0.8194, -1.7015]])
tensor([[-3.9519e-01, -1.3141e-03, -6.6944e-01, -2.7586e-01,  1.6676e+00,
          1.5905e+00, -8.4527e-01, -6.1096e-02,  1.0313e-01,  1.8895e-01],
        [-3.3388e-01, -7.2469e-01, -1.1977e+00, -1.3448e+00, -1.6773e+00,
         -7.3344e-01, -7.3135e-01, -1.6939e+00, -1.6048e-01, -4.5391e-01]])
tensor([[0.0000, 0.0000],
        [0.2624, 1.0384]], grad_fn=<ReluBackward0>)
tensor([[False],
        [ True]])
tensor([[False],
        [ True]])
route_results: [tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 1.4881, -0.2153, -0.9044,  0.6914, -0.7613,  1.6971, -1.8312, -0.1526,
         -0.8194, -1.7015]]), tensor([[ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,

In [None]:

script_dy_model = torch.jit.script(symbolize(dy_model))
sm_graph = script_dy_model.graph
print(sm_graph)
ir_dy_model = build_graph(dy_model)

model_script = model_to_script(ir_dy_model)
print(model_script)