In [36]:
import torch
from brt._C.router import (
    generate_indices_and_loads,
    dispatch_with_indices_and_loads,
    split_fused_cells_to_paths,
    fuse_split_cells_from_paths,
    combine_with_indices_and_loads,
)

path_num = 4
gates = torch.randn((8, path_num)).cuda()
topk_indices = torch.topk(gates, k=2, dim=1).indices

hot_mask = (
    torch.zeros(gates.size(0), path_num, dtype=torch.int32, device=gates.device,)
    .scatter_(1, topk_indices, 1)
    .cuda()
)

supported_capacities = torch.Tensor([1,2,4,8]).int().cuda()
# supported_capacities = None

route_indices, dst_loads = generate_indices_and_loads(
    hot_mask, supported_capacities, capacity_padding=True, is_tag_index=False
)
print(route_indices)
print(dst_loads)


tensor([[0, 0, 1, 1],
        [0, 1, 0, 2],
        [1, 2, 0, 0],
        [0, 3, 2, 0],
        [0, 4, 3, 0],
        [0, 0, 4, 3],
        [2, 0, 5, 0],
        [0, 5, 6, 0]], device='cuda:0', dtype=torch.int32)
tensor([2, 8, 8, 4], device='cuda:0', dtype=torch.int32)


In [37]:
in_data = torch.randn((8, path_num)).cuda()
print(in_data)
print(route_indices)
print(dst_loads)
# (out_data_1,) = dispatch_with_indices_and_loads(in_data, route_indices, dst_loads)
# print(out_data_1)
# split_data = split_fused_cells_to_paths(out_data_1, dst_loads)
# print(split_data)

out_data_1, tags = dispatch_with_indices_and_loads(
    in_data, route_indices, dst_loads, tag_generating=True
)
print(tags)

split_data, split_loads, split_tags = split_fused_cells_to_paths(
    out_data_1, dst_loads, is_load_split=True, is_tag_split=True, tags=tags
)
print(split_tags)


# results = (
#     torch.zeros((9, 4), dtype=torch.float32)
#     .cuda()
#     .scatter_reduce(0, global_dst_indices, fused_data, reduce="sum")
# )

# out_data_1, tags = dispatch_with_indices_and_loads(
#     in_data, route_indices, dst_loads, max_path_padding=True, tag_generating=True
# )
# split_data = split_fused_cells_to_paths(out_data_1, dst_loads, max_path_padding=True,tags=tags)
# print(split_data)# print(out_data_1)

# out_data_2 = dispatch_with_indices_and_loads(
#     in_data, route_indices, dst_loads, is_1d_routing=False
# )
# print(out_data_2)
# torch.allclose(out_data_1[0], out_data_2)


tensor([[-1.3374, -0.6613, -0.5485, -2.4033],
        [-0.2603,  0.7424,  0.4300, -0.2981],
        [-0.8049, -0.3197, -1.3546,  2.2464],
        [-0.1298, -0.4746,  0.2384, -0.4281],
        [-0.5104,  0.2878, -1.3444, -1.3628],
        [ 0.4549,  0.3370,  0.3600,  0.9183],
        [-0.3438, -0.5142, -1.2949,  0.5509],
        [-0.9167,  0.3661,  1.9040,  0.5151]], device='cuda:0')
tensor([[0, 0, 1, 1],
        [0, 1, 0, 2],
        [1, 2, 0, 0],
        [0, 3, 2, 0],
        [0, 4, 3, 0],
        [0, 0, 4, 3],
        [2, 0, 5, 0],
        [0, 5, 6, 0]], device='cuda:0', dtype=torch.int32)
tensor([2, 8, 8, 4], device='cuda:0', dtype=torch.int32)
tensor([3, 7, 2, 3, 4, 5, 8, 0, 0, 0, 1, 4, 5, 6, 7, 8, 0, 0, 1, 2, 6, 0],
       device='cuda:0', dtype=torch.int32)
[tensor([3, 7], device='cuda:0', dtype=torch.int32), tensor([2, 3, 4, 5, 8, 0, 0, 0], device='cuda:0', dtype=torch.int32), tensor([1, 4, 5, 6, 7, 8, 0, 0], device='cuda:0', dtype=torch.int32), tensor([1, 2, 6, 0], device='cuda

In [38]:
fused_data, new_tags, global_seat_indices = fuse_split_cells_from_paths(
    split_data, is_tag_fuse=True, tags=split_tags
)

print(fused_data)
print(new_tags)
print(global_seat_indices)

final_data = combine_with_indices_and_loads(
    fused_data, global_seat_indices, dst_loads, tags=new_tags, is_tag_index=True
)
print(final_data)
print(torch.allclose(final_data, in_data*2))





tensor([[-0.8049, -0.3197, -1.3546,  2.2464],
        [-0.3438, -0.5142, -1.2949,  0.5509],
        [-0.2603,  0.7424,  0.4300, -0.2981],
        [-0.8049, -0.3197, -1.3546,  2.2464],
        [-0.1298, -0.4746,  0.2384, -0.4281],
        [-0.5104,  0.2878, -1.3444, -1.3628],
        [-0.9167,  0.3661,  1.9040,  0.5151],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-1.3374, -0.6613, -0.5485, -2.4033],
        [-0.1298, -0.4746,  0.2384, -0.4281],
        [-0.5104,  0.2878, -1.3444, -1.3628],
        [ 0.4549,  0.3370,  0.3600,  0.9183],
        [-0.3438, -0.5142, -1.2949,  0.5509],
        [-0.9167,  0.3661,  1.9040,  0.5151],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-1.3374, -0.6613, -0.5485, -2.4033],
        [-0.2603,  0.7424,  0.4300, -0.2981],
        [ 0.4549,  0.3370,  0.3600,  0.9183],
        [ 0.0000,  0.0000,  0.0000

In [39]:
fused_data, fused_loads = fuse_split_cells_from_paths(
    split_data, is_load_fuse=True, loads=split_loads
)

print(fused_data)
print(new_tags)
print(global_seat_indices)

final_data = combine_with_indices_and_loads(fused_data, route_indices, fused_loads)
print(final_data)
print(torch.allclose(final_data, in_data * 2))


tensor([[-0.8049, -0.3197, -1.3546,  2.2464],
        [-0.3438, -0.5142, -1.2949,  0.5509],
        [-0.2603,  0.7424,  0.4300, -0.2981],
        [-0.8049, -0.3197, -1.3546,  2.2464],
        [-0.1298, -0.4746,  0.2384, -0.4281],
        [-0.5104,  0.2878, -1.3444, -1.3628],
        [-0.9167,  0.3661,  1.9040,  0.5151],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-1.3374, -0.6613, -0.5485, -2.4033],
        [-0.1298, -0.4746,  0.2384, -0.4281],
        [-0.5104,  0.2878, -1.3444, -1.3628],
        [ 0.4549,  0.3370,  0.3600,  0.9183],
        [-0.3438, -0.5142, -1.2949,  0.5509],
        [-0.9167,  0.3661,  1.9040,  0.5151],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-1.3374, -0.6613, -0.5485, -2.4033],
        [-0.2603,  0.7424,  0.4300, -0.2981],
        [ 0.4549,  0.3370,  0.3600,  0.9183],
        [ 0.0000,  0.0000,  0.0000