In [1]:
from brt.routers.inference import make_homo_proto_tensor_cls

make_homo_proto_tensor_cls()

from brt.routers.app import RandHomoFusedScatterRouter
from brt.routers.inference import HomoFusedGatherRouter
import brt.frontend.nn as nn
import torch


class FusedMoE(nn.Module):
    def __init__(self, expert_num):
        super().__init__()
        self.scatter_router = RandHomoFusedScatterRouter(
            dst_num=expert_num,
            supported_capacities=[
                2,
                4,
                8,
                16,
                32,
                64,
                128,
                256,
                512,
                1024,
                2048,
                4096,
                8192,
            ],
        )
        self.gather_router = HomoFusedGatherRouter(dst_num=expert_num)

    def forward(self, inputs):
        route_results = self.scatter_router(inputs)
        # print(route_results)
        route_results = self.gather_router(route_results)
        return route_results


fused_moe = FusedMoE(expert_num=8).cuda()

input_tensor = torch.rand((1024, 64)).cuda()
# print(input_tensor)
output_tensor = fused_moe(input_tensor)
# print(output_tensor)
print(torch.allclose(output_tensor, input_tensor))


generate_local_indices elapsed time: 0.205
route_with_local_indices elapsed time: 0.347
route_back_with_local_indices elapsed time: 0.189
True


In [1]:
import torch
import torch.nn as nn

class SimpleNet(nn.Module):
    def __init__(self) -> None:
        super().__init__()
        self.linear = nn.Linear(10, 10)
        self.conv = nn.Conv2d(3, 3, 3)

    def forward(self, x):
        x = self.linear(x)
        x = self.conv(x)
        return x

simple_net =SimpleNet()
simple_net.eval()
in_data = torch.randn(1,3,10,10)
with torch.inference_mode():
    origin_out_data = simple_net(in_data)

In [2]:
from brt.runtime.weight_load import WeightLoader

print(f"Allocated: {torch.cuda.memory_allocated()}")
print(f"Reserved: {torch.cuda.memory_reserved()}")

WeightLoader.init()

print(f"Allocated: {torch.cuda.memory_allocated()}")
print(f"Reserved: {torch.cuda.memory_reserved()}")

pinned_simple_net = WeightLoader.pin_memory(simple_net)
pinned_simple_net.eval()
with torch.inference_mode():
    pinned_out_data = pinned_simple_net(in_data)

print(f"Allocated: {torch.cuda.memory_allocated()}")
print(f"Reserved: {torch.cuda.memory_reserved()}")

Allocated: 0
Reserved: 0
Allocated: 0
Reserved: 0
Allocated: 0
Reserved: 0


In [3]:
cuda_simple_net = WeightLoader.load(pinned_simple_net)
print(f"Allocated: {torch.cuda.memory_allocated()}")
print(f"Reserved: {torch.cuda.memory_reserved()}")


Allocated: 2048
Reserved: 2097152


In [4]:

cuda_simple_net.eval()
cuda_in_data = in_data.cuda()
with torch.inference_mode():
    cuda_out_data = cuda_simple_net(cuda_in_data)
print(f"Allocated: {torch.cuda.memory_allocated()}")
print(f"Reserved: {torch.cuda.memory_reserved()}")

cuda_in_data=None
cuda_out_data = None

print(f"Allocated: {torch.cuda.memory_allocated()}")
print(f"Reserved: {torch.cuda.memory_reserved()}")


Allocated: 4608
Reserved: 2097152
Allocated: 2048
Reserved: 2097152


In [5]:
unload_simple_net = WeightLoader.unload(cuda_simple_net)
print(f"Allocated: {torch.cuda.memory_allocated()}")
print(f"Reserved: {torch.cuda.memory_reserved()}")

unload_simple_net.eval()
with torch.inference_mode():
    unload_out_data = unload_simple_net(in_data)

torch.cuda.empty_cache()
print(f"Allocated: {torch.cuda.memory_allocated()}")
print(f"Reserved: {torch.cuda.memory_reserved()}")


Allocated: 0
Reserved: 2097152
Allocated: 0
Reserved: 0
