Dynamic Routing with Default Dispatcher and 2-D tensor

In [1]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter
from brt.router.flow_tensor import FlowTensor

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num, sparse=False)
        self.gather_router_1 = GatherRouter(dst_num=dst_num, sparse=False)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
# dy_model.cuda()
# for i in range(1):
#     x = torch.randn((4, 10)).cuda()
#     y = torch.randn((4, 10)).cuda()

#     x, y = dy_model(x, y)

#     print(x.shape)
#     print(y)
#     print(type(y))

dy_model.cpu()
for i in range(10):
    x = torch.randn((4, 10)).cpu()
    y = torch.randn((4, 10)).cpu()

    x, y = dy_model(x, y)
    print(x)

    # print(x.shape)
    # print(y.shape)


tensor([[ 0.2709,  0.5970,  0.0935,  0.1688, -0.0648, -0.1933, -0.7397, -0.0809,
          0.2628, -0.2490],
        [-0.2226,  0.4430,  0.3782, -0.5954, -0.6260, -0.0501, -0.3067, -0.1330,
          0.4923, -0.0642],
        [-0.1741, -0.0975,  0.1915, -0.8805, -0.2971,  0.0463, -0.0851, -0.0773,
         -0.1696, -0.1624],
        [-0.3684,  0.5725,  1.7869, -0.2038,  0.9539,  1.3754,  0.0187,  1.4874,
         -1.5636,  1.4867]], grad_fn=<AliasBackward0>)
tensor([[ 0.2088,  0.3297,  0.1851, -0.2540, -0.1622,  0.1434, -0.6789, -0.4777,
          0.6791, -0.3356],
        [-0.2280,  0.0704, -0.2100, -0.1877, -0.1578, -0.0770, -0.1275, -0.2176,
          0.1965, -0.0088],
        [-0.0908,  1.7897,  1.2684, -0.4034, -0.2312, -0.6456, -1.3710,  0.0659,
          1.1458, -0.7625],
        [-0.1892,  0.8616, -0.3557,  0.0357, -0.5238,  0.0441, -0.9366, -0.5431,
          0.7165, -0.2807]], grad_fn=<AliasBackward0>)
tensor([[-1.2850,  0.4802,  0.7560, -0.7699, -0.2171,  0.9582, -0.0330,  0

Dynamic Routing with Residual Router and 2-D tensor

In [2]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((2, 10)).cuda()
    y = torch.randn((2, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([1, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([0, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([1, 20])


Dynamic Routing with Residual Router and 4-D tensor

In [3]:
import torch
import brt.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = nn.Sequential(
    nn.Conv2d(4, 2, 1),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Conv2d(2, 2, 1),
)


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num,sparse=True)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((2, 4, 2, 2)).cuda()
    y = torch.randn((2, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)
    


torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
