In [1]:
import torch
from brt._C.router import (
    generate_indices_and_loads,
    dispatch_with_indices_and_loads,
    split_fused_cells_to_paths,
    fuse_split_cells_from_paths,
    combine_with_indices_and_loads,
)

path_num = 4
gates = torch.randn((8, path_num)).cuda()
topk_indices = torch.topk(gates, k=1, dim=1).indices

hot_mask = (
    torch.zeros(gates.size(0), path_num, dtype=torch.int32, device=gates.device,)
    .scatter_(1, topk_indices, 1)
    .cuda()
)

supported_capacities = torch.Tensor([1, 2, 4, 8, 16]).int().cuda()
# supported_capacities = None

route_indices, dst_loads = generate_indices_and_loads(
    hot_mask, supported_capacities, capacity_padding=True, is_tag_index=False
)
print(route_indices)
print(dst_loads)


  from .autonotebook import tqdm as notebook_tqdm


tensor([[0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 2, 0],
        [2, 0, 0, 0],
        [3, 0, 0, 0],
        [0, 0, 3, 0],
        [0, 0, 0, 1],
        [0, 0, 4, 0]], device='cuda:0', dtype=torch.int32)
tensor([4, 0, 4, 1], device='cuda:0', dtype=torch.int32)


In [2]:
in_data = torch.randn((8, path_num)).cuda()
print(in_data)
print(route_indices)
print(dst_loads)
# (out_data_1,) = dispatch_with_indices_and_loads(in_data, route_indices, dst_loads)
# print(out_data_1)
# split_data = split_fused_cells_to_paths(out_data_1, dst_loads)
# print(split_data)

out_data_1, tags = dispatch_with_indices_and_loads(
    in_data, route_indices, dst_loads, tag_generating=True
)
print(tags)

split_data, split_loads, split_tags = split_fused_cells_to_paths(
    out_data_1, dst_loads, is_load_split=True, is_tag_split=True, tags=tags
)
print(split_tags)


# results = (
#     torch.zeros((9, 4), dtype=torch.float32)
#     .cuda()
#     .scatter_reduce(0, global_dst_indices, fused_data, reduce="sum")
# )

# out_data_1, tags = dispatch_with_indices_and_loads(
#     in_data, route_indices, dst_loads, max_path_padding=True, tag_generating=True
# )
# split_data = split_fused_cells_to_paths(out_data_1, dst_loads, max_path_padding=True,tags=tags)
# print(split_data)# print(out_data_1)

# out_data_2 = dispatch_with_indices_and_loads(
#     in_data, route_indices, dst_loads, is_1d_routing=False
# )
# print(out_data_2)
# torch.allclose(out_data_1[0], out_data_2)


tensor([[-0.4411, -0.6808,  0.1197, -0.7383],
        [ 1.0994, -0.3781,  0.4088, -0.0084],
        [-0.9510,  1.0840,  0.2321, -0.4861],
        [-1.0944,  0.5339,  1.2363, -0.2070],
        [-0.2036, -1.4755,  0.5498,  0.9185],
        [-0.4672, -1.4138,  0.2818, -1.5813],
        [ 0.3960,  1.5728,  0.1456, -1.3696],
        [ 0.1105, -0.5637, -0.3096, -0.2232]], device='cuda:0')
tensor([[0, 0, 1, 0],
        [1, 0, 0, 0],
        [0, 0, 2, 0],
        [2, 0, 0, 0],
        [3, 0, 0, 0],
        [0, 0, 3, 0],
        [0, 0, 0, 1],
        [0, 0, 4, 0]], device='cuda:0', dtype=torch.int32)
tensor([4, 0, 4, 1], device='cuda:0', dtype=torch.int32)
tensor([2, 4, 5, 0, 1, 3, 6, 8, 7], device='cuda:0', dtype=torch.int32)
[tensor([2, 4, 5, 0], device='cuda:0', dtype=torch.int32), tensor([], device='cuda:0', dtype=torch.int32), tensor([1, 3, 6, 8], device='cuda:0', dtype=torch.int32), tensor([7], device='cuda:0', dtype=torch.int32)]


In [3]:
fused_data, new_tags, global_seat_indices = fuse_split_cells_from_paths(
    split_data, is_tag_fuse=True, tags=split_tags
)

print(fused_data)
print(new_tags)
print(global_seat_indices)

final_data = combine_with_indices_and_loads(
    fused_data, global_seat_indices, dst_loads, tags=new_tags, is_tag_index=True
)
print(final_data)
print(torch.allclose(final_data, in_data))





tensor([[ 1.0994, -0.3781,  0.4088, -0.0084],
        [-1.0944,  0.5339,  1.2363, -0.2070],
        [-0.2036, -1.4755,  0.5498,  0.9185],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-0.4411, -0.6808,  0.1197, -0.7383],
        [-0.9510,  1.0840,  0.2321, -0.4861],
        [-0.4672, -1.4138,  0.2818, -1.5813],
        [ 0.1105, -0.5637, -0.3096, -0.2232],
        [ 0.3960,  1.5728,  0.1456, -1.3696]], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0', dtype=torch.int32)
tensor([2, 4, 5, 0, 1, 3, 6, 8, 7], device='cuda:0')
tensor([[-0.4411, -0.6808,  0.1197, -0.7383],
        [ 1.0994, -0.3781,  0.4088, -0.0084],
        [-0.9510,  1.0840,  0.2321, -0.4861],
        [-1.0944,  0.5339,  1.2363, -0.2070],
        [-0.2036, -1.4755,  0.5498,  0.9185],
        [-0.4672, -1.4138,  0.2818, -1.5813],
        [ 0.3960,  1.5728,  0.1456, -1.3696],
        [ 0.1105, -0.5637, -0.3096, -0.2232]], device='cuda:0')
True




In [4]:
fused_data, fused_loads = fuse_split_cells_from_paths(
    split_data, is_load_fuse=True, loads=split_loads
)

print(fused_data)
print(new_tags)
print(global_seat_indices)

final_data = combine_with_indices_and_loads(
    fused_data, route_indices, fused_loads
)
print(final_data)
print(torch.allclose(final_data, in_data))


tensor([[ 1.0994, -0.3781,  0.4088, -0.0084],
        [-1.0944,  0.5339,  1.2363, -0.2070],
        [-0.2036, -1.4755,  0.5498,  0.9185],
        [ 0.0000,  0.0000,  0.0000,  0.0000],
        [-0.4411, -0.6808,  0.1197, -0.7383],
        [-0.9510,  1.0840,  0.2321, -0.4861],
        [-0.4672, -1.4138,  0.2818, -1.5813],
        [ 0.1105, -0.5637, -0.3096, -0.2232],
        [ 0.3960,  1.5728,  0.1456, -1.3696]], device='cuda:0')
tensor([0, 1, 2, 3, 4, 5, 6, 7, 8], device='cuda:0', dtype=torch.int32)
tensor([2, 4, 5, 0, 1, 3, 6, 8, 7], device='cuda:0')
tensor([[-0.4411, -0.6808,  0.1197, -0.7383],
        [ 1.0994, -0.3781,  0.4088, -0.0084],
        [-0.9510,  1.0840,  0.2321, -0.4861],
        [-1.0944,  0.5339,  1.2363, -0.2070],
        [-0.2036, -1.4755,  0.5498,  0.9185],
        [-0.4672, -1.4138,  0.2818, -1.5813],
        [ 0.3960,  1.5728,  0.1456, -1.3696],
        [ 0.1105, -0.5637, -0.3096, -0.2232]], device='cuda:0')
True
