In [1]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt import Annotator


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.annotator = Annotator([0])
        self.scatter_router = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 10)  # upsample
        self.gather_router = GatherRouter(fabric_type="combine")

    def forward(self, x):
        x = self.annotator(x)
        x_gates = self.route_func(x)
        route_results_x = self.scatter_router(x, x_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        x = self.gather_router([x_0, x_1])
        return x


dy_model = DynamicRouting()

for i in range(1):
    in_x = torch.randn((4, 10)).cuda()
    print(f"in_x: {in_x}")
    dy_model = dy_model.cuda().eval()
    with torch.inference_mode():
        results = dy_model(in_x.cuda())
    print(results)

from brt import symbolic_trace

traced_dy_model = symbolic_trace(dy_model)


  from .autonotebook import tqdm as notebook_tqdm


in_x: tensor([[-0.3391, -0.7984, -0.1172, -0.6895, -0.3774, -0.0508,  0.7085,  0.6831,
          0.3684,  0.0477],
        [ 0.1209,  0.1663, -0.7751, -1.3780,  0.0835,  0.7720,  1.0615,  2.0556,
         -0.5559,  0.0347],
        [ 0.6359, -1.6873,  0.4640, -0.5502, -0.6227, -0.9929,  0.8678,  1.6215,
          1.6093,  1.6121],
        [-0.4614,  0.1946,  0.5785,  0.4891, -0.8515,  0.6766, -0.2434, -0.5327,
          0.4611, -0.8400]], device='cuda:0')
GridTensor([[-0.6194, -0.8369,  0.1094, -0.1937,  0.4171,  0.0740,  0.2162,
              0.8904, -0.8934, -0.8794]], device='cuda:0')
tag_stack: [tensor([4], device='cuda:0', dtype=torch.int32)]
load stack: [tensor([1], device='cuda:0', dtype=torch.int32)]
extra_attr_dict: {}




Dynamic Routing with Default Dispatcher and 2-D tensor


In [2]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt import Annotator


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.annotator = Annotator([0])
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 20)  # upsample
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x = self.annotator(x)
        y = self.annotator(y)
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        print("=======scatter results=======")
        print(f"routed results x_0: {route_results_x[0]}")
        print(f"routed results x_1: {route_results_x[1]}")
        route_results_y = self.scatter_router_1(y, y_gates)
        print(f"routed results y_0: {route_results_y[0]}")
        print(f"routed results y_1: {route_results_y[1]}")
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()

for i in range(1):
    in_x = torch.randn((4, 10)).cuda()
    in_y = torch.randn((4, 10)).cuda()
    print("=======init inputs=======")
    print(in_x)
    print(in_y)

    dy_model = dy_model.cuda().eval()
    with torch.inference_mode():
        results = dy_model(in_x.cuda(), in_y.cuda())
    print("=======final results=======")
    print(results)


tensor([[-1.1749, -0.4041,  0.8114, -0.4443,  0.0035,  0.9211,  1.3677, -1.4528,
          0.8938,  1.5142],
        [ 0.0856,  1.5622, -0.6412,  1.0929, -0.4518, -1.1815,  0.0976,  0.3586,
         -1.1757,  0.4791],
        [-0.5621, -0.8564,  1.2285, -0.1026, -2.1095,  0.8689,  0.2859, -0.4830,
          1.8036,  1.0846],
        [-0.8266, -0.8269,  0.7378, -2.4536,  1.6403,  0.5254, -1.0185,  1.1363,
         -0.9528,  0.5939]], device='cuda:0')
tensor([[ 0.5976, -0.6808,  1.2866,  0.5877, -0.7214,  0.2938,  0.8010, -0.3186,
         -0.0925,  0.6383],
        [-0.5206, -1.3267,  0.6838, -1.4214, -0.7063, -1.7512, -0.3362,  0.2202,
          0.4085,  0.5689],
        [-0.3622, -1.0420, -1.5170,  1.5639, -0.2891, -0.5698, -1.0670,  0.3363,
          1.1127, -0.0434],
        [ 0.5764, -0.1310,  0.3552,  0.9449,  0.2095, -0.3974,  0.3978, -0.5305,
          2.3572,  0.5395]], device='cuda:0')
routed results x_0: GridTensor([[-0.5621, -0.8564,  1.2285, -0.1026, -2.1095,  0.8689,  0.28

Dynamic Routing with Residual Router and 2-D tensor


In [3]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt import Annotator


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.annotator = Annotator([0])
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
            protocol_kwargs={"residual_path": 0},
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
            protocol_kwargs={"residual_path": 0},
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x = self.annotator(x)
        y = self.annotator(y)
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        print("=======gather inputs=======")
        print(f"x_0: {x_0}")
        print(f"y_0: {y_0}")
        print(f"x_1: {x_1}")
        print(f"y_1: {y_1}")
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()
dy_model = dy_model.cuda().eval()

for i in range(1):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    with torch.inference_mode():
        x, y = dy_model(x, y)

    print(x)
    print(y)


x_0: GridTensor([[ 0.1091, -0.9170,  0.4356, -0.3602, -0.4261, -0.7068, -0.3975,
             -0.4991,  0.8470,  0.9959],
            [-0.3384, -0.9381, -0.2078,  0.1462,  0.0573, -0.2604, -0.5532,
             -0.3802,  0.9303,  0.5089]], device='cuda:0')
tag_stack: [tensor([1, 3], device='cuda:0', dtype=torch.int32)]
load stack: [tensor([2], device='cuda:0', dtype=torch.int32)]
extra_attr_dict: {}
y_0: GridTensor([[-1.3767, -0.2169,  0.2297,  0.3003, -0.0279, -0.2192,  0.5845,
             -0.6014, -0.2571, -0.6309],
            [-1.0285,  0.4932, -0.6910,  1.0688, -0.4522,  0.1841, -0.1227,
             -1.5571, -0.1911, -0.1539],
            [-0.2268,  0.4516,  0.9427, -0.1094,  0.0927, -1.0385,  0.5325,
              0.1608, -0.1912, -0.4807]], device='cuda:0')
tag_stack: [tensor([1, 2, 3], device='cuda:0', dtype=torch.int32)]
load stack: [tensor([3], device='cuda:0', dtype=torch.int32)]
extra_attr_dict: {}
x_1: GridTensor([[-0.3686, -0.2849,  0.4628, -0.3495, -0.5090,  0.2628,  0

Dynamic Routing with Residual Router and 2-D tensor while routing gates


In [4]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt import Annotator


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.annotator = Annotator([0])
        self.scatter_router_0 = ScatterRouter(
            dispatch_score=True,
            protocol_type="residual_threshold",
            fabric_type="dispatch",
            protocol_kwargs={"residual_path": 0, "threshold": 0.5},
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x = self.annotator(x)
        y = self.annotator(y)
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x, route_gates_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_0 = route_gates_x[0] * x_0
        x_1 = self.expert2(route_results_x[1])
        x_1 = route_gates_x[1] * x_1
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()
dy_model = dy_model.cuda().eval()
with torch.inference_mode():
    for i in range(1):
        x = torch.randn((3, 10)).cuda()
        y = torch.randn((3, 10)).cuda()

        x, y = dy_model(x, y)

        print(x)
        print(y)


GridTensor([[ 0.9916,  0.3189, -0.1299, -1.2428, -1.0203,  0.4072, -0.3563,
             -1.3071, -0.7796,  0.2188]], device='cuda:0')
tag_stack: [tensor([1], device='cuda:0', dtype=torch.int32)]
load stack: [tensor([1], device='cuda:0', dtype=torch.int32)]
extra_attr_dict: {}
GridTensor([[-0.5785,  0.1877,  0.5272,  0.6567,  0.7705,  0.0801, -0.3151,
             -0.9948, -0.0836,  1.4111, -0.3184,  0.1045, -0.1758, -0.3024,
              0.7401,  0.7916,  0.1828, -1.0672,  0.7433,  0.3995],
            [ 0.7639,  0.8094, -1.4858, -0.9996, -0.9652,  0.3988,  0.5664,
             -0.5325, -0.1345, -1.1464,  0.2782, -0.1364,  0.2080,  0.9035,
              0.7719, -0.6144, -1.2163,  1.3061, -0.0617, -0.3396],
            [ 0.3861,  0.2587, -0.1797, -0.1789,  0.3552,  0.1392, -0.0086,
             -0.2502,  0.0438, -0.3467, -0.3436,  0.0415, -0.2150,  0.4548,
             -0.2964, -0.1291, -0.2322, -0.0439, -0.3934, -0.3841]],
           device='cuda:0')
tag_stack: [tensor([1, 2, 3], dev

Dynamic Routing with Residual Router and 4-D tensor


In [5]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt import Annotator


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.annotator = Annotator([0])
        self.route_func = nn.Sequential(
            nn.Conv2d(4, 2, 1), nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(2, 2, 1),
        )
        self.scatter_router_0 = ScatterRouter(
            dispatch_score=True, protocol_type="threshold", fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x = self.annotator(x)
        y = self.annotator(y)
        gates_x = self.route_func(x).view(-1, 2)
        gates_y = self.route_func(y).view(-1, 2)
        route_results_x, _ = self.scatter_router_0(x, gates_x)
        route_results_y = self.scatter_router_1(y, gates_y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model = dy_model.cuda().eval()
with torch.inference_mode():
    for i in range(10):
        x = torch.randn((3, 4, 2, 2)).cuda()
        y = torch.randn((3, 4, 2, 2)).cuda()
        x, y = dy_model(x, y)


In [6]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt import Annotator


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.annotator = Annotator([0])
        self.scatter_router = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 10)
        self.gather_router = GatherRouter(fabric_type="combine")
        self.route_func_1 = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
            protocol_kwargs={"residual_path": 0},
        )
        self.expert3 = nn.Linear(10, 10)
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x):
        x = self.annotator(x)
        x_gates = self.route_func(x)
        route_results_x = self.scatter_router(x, x_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        x_gather = self.gather_router([x_0, x_1])
        # print(f"==============x==============\n{x}")
        x_gates_1 = self.route_func_1(x_gather)
        route_results_x_1 = self.scatter_router_1(x_gather, x_gates_1)
        x_0 = self.expert3(route_results_x_1[0])
        # print(f"==============x_0==============\n{x_0}")
        # print(f"==============route_results_x_1==============\n{route_results_x_1[1]}")
        x_gather_1 = self.gather_router_1([x_0, route_results_x_1[1]])

        return x_gather, x_gather_1


dy_model = DynamicRouting()

for i in range(10):
    in_x = torch.randn((4, 10)).cuda()
    dy_model = dy_model.cuda().eval()
    with torch.inference_mode():
        results = dy_model(in_x.cuda())
    assert torch.allclose(results[0].tag, results[1].tag)
    # print(f"==============results==============\n{results}")

from brt import symbolic_trace

traced_dy_model = symbolic_trace(dy_model)
