In [1]:
from brt.router.utils import generate_dst_indices
import torch

dst_num = 4
gates = torch.randn((4, 4)).cuda()
topk_indices = torch.topk(gates, k=1, dim=1).indices

route_indices = (
    torch.zeros(
        gates.size(0),
        dst_num,
        dtype=torch.int64,
        device=gates.device,
    )
    .scatter_(1, topk_indices, 1)
    .cuda()
)

supported_capacities = torch.Tensor([1, 2, 4, 8, 16]).int().cuda()
supported_capacities = None

local_indices, dst_loads = generate_dst_indices(route_indices, supported_capacities)
print(local_indices)
print(dst_loads)


tensor([[1, 0, 0, 0],
        [0, 1, 0, 0],
        [2, 0, 0, 0],
        [3, 0, 0, 0]], device='cuda:0', dtype=torch.int32)
tensor([3, 1, 0, 0], dtype=torch.int32)


In [2]:
from brt.router.utils import generate_dst_indices, generate_src_indices
import torch
import numpy as np
dst_num = 4
gates = torch.randn((4, 4)).cuda()
topk_indices = torch.topk(gates, k=1, dim=1).indices

route_indices = (
    torch.zeros(
        gates.size(0),
        dst_num,
        dtype=torch.int32,
        device=gates.device,
    )
    .scatter_(1, topk_indices, 1)
    .cuda()
)

supported_capacities = torch.Tensor([1, 2, 4, 8, 16]).int().cuda()
supported_capacities = None

local_indices, dst_loads = generate_dst_indices(route_indices, supported_capacities)
print(local_indices)
print(dst_loads)
print(type(dst_loads))


tensor([[0, 0, 0, 1],
        [1, 0, 0, 0],
        [2, 0, 0, 0],
        [0, 0, 0, 2]], device='cuda:0', dtype=torch.int32)
tensor([2, 0, 0, 2], dtype=torch.int32)
<class 'torch.Tensor'>


In [3]:
from brt._C.router import dispatch_with_dst_indices_1d
import torch

in_data = torch.randn(4, 4, 4, dtype=torch.float32, requires_grad=True).cuda()

total_load = torch.sum(dst_loads).item()
out_data = torch.zeros((total_load, 4), dtype=in_data.dtype).cuda()

new_gates = torch.ones_like(gates, dtype=torch.float32).cuda()

print(in_data)
print(local_indices)
print(dst_loads)
out_data = dispatch_with_dst_indices_1d(
    in_data,
    local_indices,
    dst_loads,
    auto_pad=True,
)

print(out_data)


tensor([[[-0.6390,  0.2580, -0.6136,  0.4543],
         [ 0.8854,  0.9437, -1.1913,  0.0093],
         [-0.5327, -0.2249,  0.4053, -0.5767],
         [ 0.1819,  1.3430,  2.1227,  1.0500]],

        [[-0.6257, -1.6932, -0.6594, -0.2160],
         [-0.7684,  1.9193, -2.0850, -1.3683],
         [ 0.4937,  0.1979,  0.4986, -0.0726],
         [-1.9005, -0.5340,  0.0559,  0.4355]],

        [[-0.1116,  1.2155,  0.7450,  0.9826],
         [-0.2844, -1.1684,  0.7935,  1.3454],
         [ 1.7312, -0.7616,  0.2194, -0.4549],
         [-0.4971,  1.6603, -0.3589, -0.6392]],

        [[ 0.3859,  0.8998,  0.2174,  2.2742],
         [ 1.3869, -0.0438, -0.2435, -0.4630],
         [-0.9164,  0.4059, -1.9830, -0.1341],
         [ 1.3370, -2.0032,  0.6529,  0.1511]]], device='cuda:0',
       grad_fn=<ToCopyBackward0>)
tensor([[0, 0, 0, 1],
        [1, 0, 0, 0],
        [2, 0, 0, 0],
        [0, 0, 0, 2]], device='cuda:0', dtype=torch.int32)
tensor([2, 0, 0, 2], dtype=torch.int32)
tensor([[[-0.6257, -1.69

In [4]:
from brt._C.router import combine_with_src_indices

print(out_data)
print(local_indices)
score = torch.tensor(
    [
        [0.1, 10, 0.1, 1],
        [0.1, 10, 0.1, 1],
        [0.1, 10, 0.1, 1],
        [0.1, 10, 0.1, 1],
    ],
    dtype=torch.float32,
).cuda()
final_data = combine_with_src_indices(
    out_data, local_indices, dst_loads, auto_pad=True, gates=score
)

print(final_data)
print(torch.allclose(final_data, in_data))
print(final_data.shape)


tensor([[[-0.6257, -1.6932, -0.6594, -0.2160],
         [-0.7684,  1.9193, -2.0850, -1.3683],
         [ 0.4937,  0.1979,  0.4986, -0.0726],
         [-1.9005, -0.5340,  0.0559,  0.4355]],

        [[-0.1116,  1.2155,  0.7450,  0.9826],
         [-0.2844, -1.1684,  0.7935,  1.3454],
         [ 1.7312, -0.7616,  0.2194, -0.4549],
         [-0.4971,  1.6603, -0.3589, -0.6392]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000],
         [ 0.0000,  0.0000,  0.0000,  0.0000]],

        [[ 0.0000,  0.0000,  0.0000,  0.0000],
   