In [2]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt import annotate


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 10)  # upsample
        self.gather_router = GatherRouter(fabric_type="combine")

    def forward(self, x):
        x = annotate(x, [0])
        x_gates = self.route_func(x)
        route_results_x = self.scatter_router(x, x_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        print(f"x_0: {x_0}")
        print(f"x_1: {x_1}")
        x = self.gather_router([x_0, x_1])
        return x


dy_model = DynamicRouting()

for i in range(1):
    in_x = torch.randn((4, 10)).cuda()
    print(f"in_x: {in_x}")
    dy_model = dy_model.cuda().eval()
    with torch.inference_mode():
        results = dy_model(in_x.cuda())
    print(results)

from brt import symbolic_trace

traced_dy_model = symbolic_trace(dy_model)


in_x: tensor([[ 0.1721,  0.2055, -0.8656,  1.4035,  0.3877,  1.3031, -1.0075, -1.0683,
          1.9448,  0.2695],
        [-0.6468, -1.2882,  0.1866,  2.0845,  0.1092,  0.5862,  0.9875, -0.5144,
          0.7807,  0.2720],
        [-0.2088,  1.6116,  0.9699, -1.0278,  0.7947,  1.2612, -1.0729,  0.9025,
         -0.9402, -0.6119],
        [-0.6897, -0.0112,  0.8530, -0.1993, -0.3786, -0.1821,  0.8073,  0.1916,
          0.1905, -0.5488]], device='cuda:0')
GridTensor([[0, 1],
            [1, 1],
            [0, 1],
            [1, 1]], device='cuda:0', dtype=torch.int32)
tag_stack: [None]
load stack: [None]
extra_attr_dict: {} None False True False
tensor([[0, 1],
        [1, 2],
        [0, 3],
        [2, 4]], device='cuda:0', dtype=torch.int32) tensor([2, 4], device='cuda:0', dtype=torch.int32)
x_0: GridTensor([[-0.7577,  1.0721,  0.1135, -0.1744, -0.3359, -0.3534, -0.3849,
             -0.5662, -0.9574,  0.4290],
            [-0.6462, -0.1764,  0.3958,  0.2417, -0.5441,  0.2635, -0.

RuntimeError: 'len' is not supported in symbolic tracing by default. If you want this call to be recorded, please call torch.fx.wrap('len') at module scope

Dynamic Routing with Default Dispatcher and 2-D tensor


In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter
import brt


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 20)  # upsample
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x = brt.annotate(x, [0])
        y = brt.annotate(y, [0])
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return route_results_x, route_results_y


dy_model = DynamicRouting()

for i in range(1):
    in_x = torch.randn((4, 10)).cuda()
    in_y = torch.randn((4, 10)).cuda()
    print(in_x)
    print(in_y)

    dy_model = dy_model.cuda().eval()
    results = dy_model(in_x.cuda(), in_y.cuda())
    print(results)


Dynamic Routing with Residual Router and 2-D tensor


In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()
dy_model.cuda()

for i in range(1):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x)
    print(y)


Dynamic Routing with Residual Router and 2-D tensor while routing gates


In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            dispatch_score=True, protocol_type="threshold", fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x, route_gates_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_0 = route_gates_x[0] * x_0
        x_1 = self.expert2(route_results_x[1])
        x_1 = route_gates_x[1] * x_1
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    # print(x.shape)
    # print(y.shape)


Dynamic Routing with Residual Router and 4-D tensor


In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = (
    lambda x: nn.Sequential(
        nn.Conv2d(4, 2, 1), nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(2, 2, 1),
    )
    .cuda()(x)
    .view(-1, 2)
)  # [bs x dst_num x 1 x 1] keep up down -> keep down


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.route_func = nn.Sequential(
            nn.Conv2d(4, 2, 1), nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(2, 2, 1),
        )
        self.scatter_router_0 = ScatterRouter(
            dispatch_score=True, protocol_type="threshold", fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold", fabric_type="dispatch",
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        gates_x = self.route_func(x).view(-1, 2)
        gates_y = self.route_func(y).view(-1, 2)
        route_results_x, _ = self.scatter_router_0(x, gates_x)
        route_results_y = self.scatter_router_1(y, gates_y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 4, 2, 2)).cuda()
    y = torch.randn((3, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)
