Dynamic Routing with Default Dispatcher and 2-D tensor

In [1]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


class DynamicRouting(nn.Module):
    def __init__(self):
        super().__init__()
        self.route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU())
        self.scatter_router_0 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.scatter_router_1 = ScatterRouter(
            protocol_type="threshold",
            fabric_type="dispatch",
        )
        self.expert1 = nn.Linear(10, 10)  # keep
        self.expert2 = nn.Linear(10, 20)  # upsample
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(fabric_type="combine")
        self.gather_router_1 = GatherRouter(fabric_type="combine")

    def forward(self, x, y):
        x_gates = self.route_func(x)
        y_gates = self.route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting()

dy_model.cpu()
for i in range(1):
    x = torch.randn((4, 10)).cpu()
    y = torch.randn((4, 10)).cpu()

    x, y = dy_model(x, y)
    print(x)
    print(y)


tensor([[1, 0],
        [2, 1],
        [3, 0],
        [0, 0]])
tensor([[0, 0],
        [1, 1],
        [3, 2],
        [0, 3]])
tensor([[ 0.3660, -0.0140, -0.1118, -0.4344,  0.0198,  0.4701,  0.2797,  0.0335,
         -0.0127, -0.2996],
        [ 0.7227, -0.3174,  0.5219,  0.4925,  0.8064,  0.4889,  0.3464, -0.8427,
          0.2820,  0.7845],
        [-0.3810,  0.7461,  0.8341, -0.6514, -0.5469,  0.4411,  0.7746,  1.1438,
          1.1653,  0.2097],
        [-0.6370,  0.1650,  0.0581, -0.3994,  0.5916, -0.9249,  0.3834, -0.1862,
          0.4555,  0.5315]], grad_fn=<AliasBackward0>)
tensor([[-0.4431, -1.2263,  0.1920, -0.3768, -0.9342, -0.6178, -0.9474,  1.0413,
          0.9288, -0.8641, -0.3729, -1.6866, -0.0978,  0.6561, -1.8517, -0.3970,
         -0.0591,  0.0233, -0.3774, -0.2812],
        [-0.0112, -0.0064,  0.9926, -0.0563,  0.3931, -0.1251, -0.3660, -0.7686,
         -2.4168, -0.1195, -1.0198,  0.7313,  0.2253,  0.3907,  1.5573,  0.3501,
          0.0618, -0.2955,  0.7401,  

In [2]:
dy_model.cuda()
for i in range(1):
    x = torch.randn((4, 10)).cuda()
    y = torch.randn((4, 10)).cuda()

    x, y = dy_model(x, y)
    print(x)
    print(y)
    

print(dy_model.scatter_router_0.load_history)

GenerateSrcIndices: 4 2
hot_mask: 0x7fd1f1602a00
src_indices: 0x7fd1f1603000
loads: 0x7fd1f1603200
supported_capacities: (nil)
tensor([[0, 1],
        [0, 2],
        [0, 3],
        [0, 0]], device='cuda:0', dtype=torch.int32)
tensor([1, 3], device='cuda:0', dtype=torch.int32)
tensor([[0, 1],
        [0, 2],
        [0, 3],
        [0, 0]], device='cuda:0', dtype=torch.int32)
GenerateSrcIndices: 4 2
hot_mask: 0x7fd1f1603000
src_indices: 0x7fd1f1603800
loads: 0x7fd1f1603a00
supported_capacities: (nil)
tensor([[2, 0],
        [3, 1],
        [0, 2],
        [0, 3]], device='cuda:0', dtype=torch.int32)
tensor([2, 4], device='cuda:0', dtype=torch.int32)
tensor([[2, 0],
        [3, 1],
        [0, 2],
        [0, 3]], device='cuda:0', dtype=torch.int32)
tensor([[ 0.3605, -0.6698, -0.0489,  0.2956,  0.0193,  0.4296,  0.7702,  0.1680,
          0.4649,  0.6566],
        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,
          0.0000,  0.0000],
        [-0.6247, -0.5

Dynamic Routing with Residual Router and 2-D tensor

In [3]:
import torch

from brt.router import ScatterRouter, GatherRouter
from brt.router.proto_tensor import ProtoTensor

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU()).cuda()


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        x_gates = route_func(x)
        y_gates = route_func(y)
        route_results_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


TypeError: __init__() got an unexpected keyword argument 'dst_num'

Dynamic Routing with Residual Router and 2-D tensor while routing gates

In [12]:
import torch
from brt.router import ScatterRouter, GatherRouter

route_func = nn.Sequential(nn.Linear(10, 2), nn.ReLU()).cuda()


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
            routing_gates=True,
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Linear(10, 10)
        self.expert2 = nn.Linear(10, 20)
        self.expert3 = nn.Linear(10, 10)
        self.expert4 = nn.Linear(10, 20)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num)

    def forward(self, x, y):
        x_gates = route_func(x)
        y_gates = route_func(y)
        route_results_x, route_gates_x = self.scatter_router_0(x, x_gates)
        route_results_y = self.scatter_router_1(y, y_gates)
        x_0 = self.expert1(route_results_x[0])
        x_0 = route_gates_x[0] * x_0
        x_1 = self.expert2(route_results_x[1])
        x_1 = route_gates_x[1] * x_1
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 10)).cuda()
    y = torch.randn((3, 10)).cuda()

    x, y = dy_model(x, y)

    print(x.shape)
    print(y.shape)


torch.Size([3, 10])
torch.Size([1, 20])


KeyboardInterrupt: 

Dynamic Routing with Residual Router and 4-D tensor

In [None]:
import torch
import torch.nn as nn
from brt.router import ScatterRouter, GatherRouter


route_func = lambda x : nn.Sequential(
    nn.Conv2d(4, 2, 1), nn.AdaptiveAvgPool2d((1, 1)), nn.Conv2d(2, 2, 1),
).cuda()(x).view(-1, 2)  # [bs x dst_num x 1 x 1] keep up down -> keep down


class DynamicRouting(nn.Module):
    def __init__(self, dst_num):
        super().__init__()
        self.scatter_router_0 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
            routing_gates=True
        )
        self.scatter_router_1 = ScatterRouter(
            dst_num=dst_num,
            route_func=route_func,
            route_method="threshold",
            threshold=0,
            residual_dst=0,
        )
        self.expert1 = nn.Conv2d(4, 4, 1)
        self.expert2 = nn.Conv2d(4, 8, 1)
        self.expert3 = nn.Conv2d(4, 4, 1)
        self.expert4 = nn.Conv2d(4, 8, 1)
        self.gather_router_0 = GatherRouter(dst_num=dst_num)
        self.gather_router_1 = GatherRouter(dst_num=dst_num, sparse=True)

    def forward(self, x, y):
        route_results_x = self.scatter_router_0(x, route_func(x))
        route_results_y = self.scatter_router_1(y, route_func(x))
        x_0 = self.expert1(route_results_x[0])
        x_1 = self.expert2(route_results_x[1])
        y_0 = self.expert3(route_results_y[0])
        y_1 = self.expert4(route_results_y[1])
        x = self.gather_router_0([x_0, y_0])
        y = self.gather_router_1([x_1, y_1])
        return x, y


dy_model = DynamicRouting(2)
dy_model.cuda()
for i in range(10):
    x = torch.randn((3, 4, 2, 2)).cuda()
    y = torch.randn((3, 4, 2, 2)).cuda()
    x, y = dy_model(x, y)
    print(x.shape)
    print(y.shape)
