Dynamic Routing with Default Dispatcher and 2-D tensor

In [19]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, route_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
        )
        self.scatter_router_1 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(route_num=route_num)
        self.gather_router_1 = GatherRouter(route_num=route_num, route_method="restore")

    def forward(self, x, y):
        route_results_x, route_tags_x, loads = self.scatter_router_0(x)
        route_results_y, route_tags_y, loads = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0], [route_tags_x[0], route_tags_y[0]], loads)
        y = self.gather_router_1([x_1, y_1], [route_tags_x[1], route_tags_y[1]], loads)
        return x, y


dy_model = DynamicRouting(2)
for i in range(10):
    x = torch.randn((2, 10))
    y = torch.randn((2, 10))

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([0, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])


Dynamic Routing with Residual Router and 2-D tensor

In [20]:
import torch
import brt.nn as nn
from brt.router.dispatcher import ResidualDispatcher
from brt.router import ScatterRouter, GatherRouter


route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, route_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
            dispatcher_cls=ResidualDispatcher,
            residual_route=0,
        )
        self.scatter_router_1 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
            dispatcher_cls=ResidualDispatcher,
            residual_route=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(route_num=route_num)
        self.gather_router_1 = GatherRouter(route_num=route_num, route_method="restore")

    def forward(self, x, y):
        route_results_x, route_tags_x, loads = self.scatter_router_0(x)
        route_results_y, route_tags_y, loads = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0], [route_tags_x[0], route_tags_y[0]], loads)
        y = self.gather_router_1([x_1, y_1], [route_tags_x[1], route_tags_y[1]], loads)
        return x, y


dy_model = DynamicRouting(2)

for i in range(10):
    x = torch.randn((2, 10))
    y = torch.randn((2, 10))

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])


Dynamic Routing with Residual Router and 4-D tensor

In [25]:
import torch
from brt.common import log
import brt.nn as nn
from brt.router.dispatcher import ResidualDispatcher
from brt.router import ScatterRouter, GatherRouter


route_func = nn.Sequential(
    nn.Conv2d(4, 2, 1),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Conv2d(2, 2, 1),
)


class DynamicRouting(nn.Module):
    def __init__(self, route_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
            dispatcher_cls=ResidualDispatcher,
            residual_route=0,
        )
        self.scatter_router_1 = ScatterRouter(
            route_num=route_num,
            route_func=route_func,
            route_method="threshold",
            route_method_args=0,
            dispatcher_cls=ResidualDispatcher,
            residual_route=0,
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(route_num=route_num)
        self.gather_router_1 = GatherRouter(route_num=route_num, route_method="restore")

    def forward(self, x, y):
        route_results_x, route_tags_x, loads = self.scatter_router_0(x)
        route_results_y, route_tags_y, loads = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0], [route_tags_x[0], route_tags_y[0]], loads)
        y = self.gather_router_1([x_1, y_1], [route_tags_x[1], route_tags_y[1]], loads)
        return x, y


dy_model = DynamicRouting(2)

for i in range(10):
    x = torch.randn((4, 4, 10, 10))
    y = torch.randn((4, 4, 10, 10))

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
torch.Size([0, 4, 10, 10])
torch.Size([4, 8, 10, 10])
