In [2]:
from brt.jit.tvm import TVMTuner
from brt.common import BRT_CACHE_PATH
from brt.common import log

log.set_level("jit", "DEBUG")

tuner = TVMTuner()
onnx_model = "sparse_fusion_2_thor_model"
tuner.import_onnx_netlet(onnx_model)
tuner.export_netlet_template()

Get devices for measurement successfully!
DEBUG:brainstorm.jit:kernel args: [[8, 64, 64], [8, 64, 64], [8, 64, 64]]


In [None]:
from brt.jit import CUDACompiler
from brt.common import BRT_KERNEL_TEMPLATE_PATH
import torch

kernel_name = "sparse_fusion_2_thor_model"

kernel_template_filename = str(BRT_KERNEL_TEMPLATE_PATH / (kernel_name + ".cu"))

kernel_template_source = open(kernel_template_filename, "r").read()
kernel_func = CUDACompiler.generate_kernel(
    keyword_dict=None, template=kernel_template_source
)
data = torch.ones((8, 64, 64), device="cuda")
weight = torch.ones((8, 64, 64), device="cuda")
outdata = torch.ones((8, 64, 64), device="cuda")
kernel_func(data, weight, outdata)
print(outdata.shape)

In [6]:
from brt.jit.compiler import CUDACompiler
from brt.common import BRT_KERNEL_TEMPLATE_PATH
import time
import torch

kernel_name = "sample"

kernel_template_filename = str(BRT_KERNEL_TEMPLATE_PATH / (kernel_name + ".cu"))

kernel_template_source = open(kernel_template_filename, "r").read()
# print(kernel_template_source)
kernel_func = CUDACompiler.generate_kernel(
    {"batch_num": 2, "num_samples": 2}, kernel_template_source,
)
data = torch.ones((8, 64, 64), device="cuda")
weight = torch.ones((8, 64, 64), device="cuda")
outdata = torch.ones((8, 64, 64), device="cuda")
start_stamp = time.time()
kernel_func(data, weight, outdata)
end_stamp = time.time()
print("first time: {:.3f}".format((end_stamp - start_stamp) * 1000))

start_stamp = time.time()
kernel_func(data, weight, outdata)
end_stamp = time.time()
print("second time: {:.3f}".format((end_stamp - start_stamp) * 1000))


first time: 481.813
second time: 0.067


In [7]:
from brt.common import BRT_KERNEL_TEMPLATE_PATH
from brt.jit.generic import GenericFunction
from brt.jit.compiler import CUDACompiler
import torch
kernel_name = "sample"

kernel_template_filename = str(BRT_KERNEL_TEMPLATE_PATH / (kernel_name + ".cu"))

kernel_template_source = open(kernel_template_filename, "r").read()

# print(kernel_template_source)
generic_function = GenericFunction(kernel_template_source)

code = generic_function.get_code("global")
# processed_template_fname = str(
#     BRT_KERNEL_TEMPLATE_PATH / ("processed_" + kernel_name + ".cu")
# )
# with open(processed_template_fname, "w") as f:
#     f.write(code)
data = torch.ones((8, 64, 64), device="cuda")
weight = torch.ones((8, 64, 64), device="cuda")
outdata = torch.ones((8, 64, 64), device="cuda")
func = CUDACompiler.generate_kernel(None, code)

torch.cuda.synchronize()
stream = torch.cuda.default_stream()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)

start_event.record(stream)
func(data, weight, outdata)
end_event.record(stream)
stream.synchronize()
print("first time: {:.3f}".format(start_event.elapsed_time(end_event)))

start_event.record(stream)
for i in range(100):
    func(data, weight, outdata)
end_event.record(stream)
stream.synchronize()
print("forward time: {:.3f}".format(start_event.elapsed_time(end_event)/100))


# print(outdata)

first time: 478.322
forward time: 0.005


In [11]:
import torch
from brt.jit import BlockFuser
from brt.common import BRT_KERNEL_TEMPLATE_PATH, log
from brt.jit.generic import GenericFunction
from brt.jit.compiler import CUDACompiler
import time

log.set_level("jit", "DEBUG")

kernel_name = "sample"

kernel_template_filename = str(BRT_KERNEL_TEMPLATE_PATH / (kernel_name + ".cu"))

kernel_template_source = open(kernel_template_filename, "r").read()

block_fuser = BlockFuser(
    [kernel_template_source, kernel_template_source, kernel_template_source]
)

code = block_fuser.get_code()

processed_template_fname = str(
    BRT_KERNEL_TEMPLATE_PATH / ("processed_" + kernel_name + ".cu")
)
with open(processed_template_fname, "w") as f:
    f.write(code)

fused_matmul = CUDACompiler.generate_kernel(None, code)

data_0 = torch.ones((8, 64, 64), device="cuda")
weight_0 = torch.ones((8, 64, 64), device="cuda")
outdata_0 = torch.ones((8, 64, 64), device="cuda")
data_1 = torch.ones((8, 64, 64), device="cuda")
weight_1 = torch.ones((8, 64, 64), device="cuda")
outdata_1 = torch.ones((8, 64, 64), device="cuda")
data_2 = torch.ones((8, 64, 64), device="cuda")
weight_2 = torch.ones((8, 64, 64), device="cuda")
outdata_2 = torch.ones((8, 64, 64), device="cuda")
torch.cuda.synchronize()
stream = torch.cuda.default_stream()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record(stream)
fused_matmul(
    data_0,
    weight_0,
    outdata_0,
    data_1,
    weight_1,
    outdata_1,
    data_2,
    weight_2,
    outdata_2,
)
end_event.record(stream)
stream.synchronize()
print("first time: {:.3f}".format(start_event.elapsed_time(end_event)))

start_event.record(stream)
for i in range(100):
    fused_matmul(
        data_0,
        weight_0,
        outdata_0,
        data_1,
        weight_1,
        outdata_1,
        data_2,
        weight_2,
        outdata_2,
    )
end_event.record(stream)
stream.synchronize()
print("forward time: {:.3f}".format(start_event.elapsed_time(end_event) / 100))


DEBUG:brainstorm.jit:Fusing blocks from 0 to 255 for 0-th block
DEBUG:brainstorm.jit:Fusing blocks from 256 to 511 for 1-th block
DEBUG:brainstorm.jit:Fusing blocks from 512 to 767 for 2-th block
first time: 1329.271
forward time: 0.008


In [12]:
import torch
from brt.jit import BlockFuser
from brt.common import BRT_KERNEL_TEMPLATE_PATH, log
from brt.jit.generic import GenericFunction
from brt.jit.compiler import CUDACompiler
import time

log.set_level("jit", "DEBUG")

kernel_name = "sample"

kernel_template_filename = str(BRT_KERNEL_TEMPLATE_PATH / (kernel_name + ".cu"))

kernel_template_source = open(kernel_template_filename, "r").read()

block_fuser = BlockFuser([kernel_template_source, kernel_template_source])

code = block_fuser.get_code()

processed_template_fname = str(
    BRT_KERNEL_TEMPLATE_PATH / ("processed_" + kernel_name + ".cu")
)
with open(processed_template_fname, "w") as f:
    f.write(code)

fused_matmul = CUDACompiler.generate_kernel(None, code)

data_0 = torch.ones((8, 64, 64), device="cuda")
weight_0 = torch.ones((8, 64, 64), device="cuda")
outdata_0 = torch.ones((8, 64, 64), device="cuda")
data_1 = torch.ones((8, 64, 64), device="cuda")
weight_1 = torch.ones((8, 64, 64), device="cuda")
outdata_1 = torch.ones((8, 64, 64), device="cuda")
torch.cuda.synchronize()
stream = torch.cuda.default_stream()
start_event = torch.cuda.Event(enable_timing=True)
end_event = torch.cuda.Event(enable_timing=True)
start_event.record(stream)
fused_matmul(
    data_0, weight_0, outdata_0, data_1, weight_1, outdata_1,
)
end_event.record(stream)
stream.synchronize()
print("first time: {:.3f}".format(start_event.elapsed_time(end_event)))

start_event.record(stream)
for i in range(100):
    fused_matmul(
        data_0,
        weight_0,
        outdata_0,
        data_1,
        weight_1,
        outdata_1,
    )
end_event.record(stream)
stream.synchronize()
print("forward time: {:.3f}".format(start_event.elapsed_time(end_event) / 100))


DEBUG:brainstorm.jit:Fusing blocks from 0 to 255 for 0-th block
DEBUG:brainstorm.jit:Fusing blocks from 256 to 511 for 1-th block
first time: 912.221
forward time: 0.006
