Dynamic Routing with Default Dispatcher and 2-D tensor

In [1]:
import torch
import brt.nn as nn
from brt.routers import ScatterRouter, GatherRouter
from brt.routers.proto_tensor import ProtoTensor

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
        )
        self.expert1 = nn.Linear(10, 10) # keep
        self.expert2 = nn.Linear(10, 20) # upsample
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num, sparse=False)
        self.gather_router_1 = GatherRouter(dst_num=dst_num, sparse=False)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x) # 0 , 1, 2, 3
        route_results_y = self.scatter_router_1(y) # 0 , 1, 2, 3
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
# dy_model.cuda()
# for i in range(1):
#     x = torch.randn((4, 10)).cuda()
#     y = torch.randn((4, 10)).cuda()

#     x, y = dy_model(x, y)

#     print(x.shape)
#     print(y)
#     print(type(y))

dy_model.cpu()
for i in range(1):
    x = torch.randn((4, 10)).cpu()
    y = torch.randn((4, 10)).cpu()

    x, y = dy_model(x, y)
    print(x)

    # print(x.shape)
    # print(y.shape)


tensor([[ 0.2027, -0.2118, -0.0993,  0.0267, -0.2270,  0.2822, -0.1599,  0.2881,
         -0.2039, -0.1834],
        [ 0.6500,  0.0399, -0.4081,  0.0031,  0.7953,  0.6627,  0.4448,  0.3174,
         -0.1215, -0.6479],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000]], grad_fn=<AliasBackward0>)


Dynamic Routing with Residual Router and 2-D tensor

In [4]:
import torch
import brt.nn as nn
from brt.routers import ScatterRouter, GatherRouter

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((2, 10)).cuda()
    y = torch.randn((2, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([2, 10])
torch.Size([1, 20])
torch.Size([2, 10])
torch.Size([1, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([1, 20])
torch.Size([2, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([0, 20])
torch.Size([2, 10])
torch.Size([0, 20])
torch.Size([1, 10])
torch.Size([2, 20])
torch.Size([2, 10])
torch.Size([2, 20])


Dynamic Routing with Residual Router and 4-D tensor

In [6]:
import torch
import brt.nn as nn
from brt.routers import ScatterRouter, GatherRouter


route_func = nn.Sequential(
    nn.Conv2d(4, 2, 1),
    nn.AdaptiveAvgPool2d((1, 1)),
    nn.Conv2d(2, 2, 1),
) # [bs x dst_num x 1 x 1] keep up down -> keep down


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num,sparse=True)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x)
        route_results_y = self.scatter_router_1(y)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((2, 4, 2, 2)).cuda()
    y = torch.randn((2, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)
    


torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([1, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([1, 8, 2, 2])
torch.Size([2, 4, 2, 2])
torch.Size([0, 8, 2, 2])
