In [1]:
from brt.router.utils import generate_dst_indices
import torch

dst_num = 4
gates = torch.randn((4, 4)).cuda()
topk_indices = torch.topk(gates, k=1, dim=1).indices

route_indices = (
    torch.zeros(
        gates.size(0),
        dst_num,
        dtype=torch.int64,
        device=gates.device,
    )
    .scatter_(1, topk_indices, 1)
    .cuda()
)

supported_capacities = torch.Tensor([1, 2, 4, 8, 16]).int().cuda()
supported_capacities = None

local_indices, dst_loads = generate_dst_indices(route_indices, supported_capacities)
print(local_indices)
print(dst_loads)


tensor([[0, 1, 0, 0],
        [0, 0, 1, 0],
        [0, 2, 0, 0],
        [0, 0, 0, 1]], device='cuda:0', dtype=torch.int32)
tensor([0, 2, 1, 1], dtype=torch.int32)


In [3]:
from brt.router.utils import generate_dst_indices, generate_src_indices
import torch
import numpy as np
dst_num = 4
gates = torch.randn((4, 4)).cuda()
topk_indices = torch.topk(gates, k=1, dim=1).indices

route_indices = (
    torch.zeros(
        gates.size(0),
        dst_num,
        dtype=torch.int32,
        device=gates.device,
    )
    .scatter_(1, topk_indices, 1)
    .cuda()
)

supported_capacities = torch.Tensor([1, 2, 4, 8, 16]).int().cuda()
supported_capacities = None

local_indices, dst_loads = generate_dst_indices(route_indices, supported_capacities)
print(local_indices)
print(dst_loads)
print(type(dst_loads))


tensor([[0, 0, 0, 1],
        [0, 0, 0, 2],
        [1, 0, 0, 0],
        [0, 1, 0, 0]], device='cuda:0', dtype=torch.int32)
tensor([1, 1, 0, 2], dtype=torch.int32)
<class 'torch.Tensor'>


In [5]:
from brt._C.router import dispatch_with_dst_indices_1d
import torch

in_data = torch.randn(4, 4, 4, dtype=torch.float32, requires_grad=True).cuda()

total_load = torch.sum(dst_loads).item()
out_data = torch.zeros((total_load, 4), dtype=in_data.dtype).cuda()

new_gates = torch.ones_like(gates, dtype=torch.float32).cuda()

print(in_data)
print(local_indices)
print(dst_loads)
out_data = dispatch_with_dst_indices_1d(
    in_data,
    local_indices,
    dst_loads,
)

print(out_data)


tensor([[[ 0.5350,  0.2327,  0.5188, -1.6168],
         [ 0.4517, -0.0207, -0.3176, -0.4912],
         [-0.0086,  1.0315, -0.3159, -0.2336],
         [-0.7730,  1.6559, -0.1177,  0.3795]],

        [[ 0.6152, -0.4958,  0.9154,  0.6812],
         [ 1.5039, -0.7479,  0.1105, -0.2781],
         [ 0.6638,  1.2818,  1.7344,  0.6269],
         [-0.1343,  0.1003, -0.0755, -0.2001]],

        [[ 0.0110,  0.1355, -0.1765, -0.3073],
         [-0.0182,  0.7717, -0.5185,  1.3886],
         [ 1.0003, -0.6967,  0.2998,  1.2957],
         [ 0.7187, -0.2509,  0.1214, -1.1183]],

        [[-0.5182,  0.2225, -0.9024, -0.1015],
         [ 0.2034, -0.2301, -0.9605,  0.4297],
         [ 0.4747,  1.8300,  0.0021, -0.7760],
         [ 0.1383, -0.9376, -1.1034, -0.5206]]], device='cuda:0',
       grad_fn=<ToCopyBackward0>)
tensor([[0, 0, 0, 1],
        [0, 0, 0, 2],
        [1, 0, 0, 0],
        [0, 1, 0, 0]], device='cuda:0', dtype=torch.int32)
tensor([1, 1, 0, 2], dtype=torch.int32)
tensor([[[ 0.0110,  0.13

In [6]:
from brt._C.router import combine_with_src_indices

final_data = combine_with_src_indices(out_data, local_indices, dst_loads)

print(final_data)
print(torch.allclose(final_data, in_data))
print(final_data.shape)

tensor([[[ 0.5350,  0.2327,  0.5188, -1.6168],
         [ 0.4517, -0.0207, -0.3176, -0.4912],
         [-0.0086,  1.0315, -0.3159, -0.2336],
         [-0.7730,  1.6559, -0.1177,  0.3795]],

        [[ 0.6152, -0.4958,  0.9154,  0.6812],
         [ 1.5039, -0.7479,  0.1105, -0.2781],
         [ 0.6638,  1.2818,  1.7344,  0.6269],
         [-0.1343,  0.1003, -0.0755, -0.2001]],

        [[ 0.0110,  0.1355, -0.1765, -0.3073],
         [-0.0182,  0.7717, -0.5185,  1.3886],
         [ 1.0003, -0.6967,  0.2998,  1.2957],
         [ 0.7187, -0.2509,  0.1214, -1.1183]],

        [[-0.5182,  0.2225, -0.9024, -0.1015],
         [ 0.2034, -0.2301, -0.9605,  0.4297],
         [ 0.4747,  1.8300,  0.0021, -0.7760],
         [ 0.1383, -0.9376, -1.1034, -0.5206]]], device='cuda:0')
True
torch.Size([4, 4, 4])
