In [1]:
import torch
import torch.nn as nn

from typing import Type, List, Union


from brt.runtime.proto_tensor import (
    collect_proto_attr_stack,
    init_proto_tensor,
    make_proto_tensor_from,
)
from brt.jit import make_jit_kernel


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
class FusedLayer(nn.Module):
    def __init__(
        self,
        models: Union[List[nn.Module], List[nn.Sequential]],
        input_shapes: List[torch.Size],
        output_shapes: List[torch.Size],
    ):
        super().__init__()
        models = nn.ModuleList(models)
        print(models)
        self.num_submodels = len(models)
        assert len(input_shapes) == self.num_submodels
        assert len(output_shapes) == self.num_submodels
        for i, model in enumerate(models):
            for name, tensor in model.named_parameters(f"m{i}"):
                self.register_parameter(name.replace(".", "_"), tensor)
            for name, tensor in model.named_buffers(f"m{i}"):
                self.register_buffer(name.replace(".", "_"), tensor)
        self.input_shapes = input_shapes
        self.output_shapes = output_shapes
        sample_inputs = [torch.randn(shp).cuda() for shp in input_shapes]
        self.fused_kernel = make_jit_kernel(
            models, sample_inputs, opt_level="hetero_fuse"
        )
        print([[n, t.shape] for n, t in self.named_parameters()])
        self.ACTIVE_BLOCKS = [1] * self.num_submodels
        # Conv2dBiasPReLU or Conv2dBias or ConvTranspose2dBias
        if isinstance(models[0], nn.Sequential):
            conv2d = models[0][0]
            prelu = models[0][1]
            if (
                isinstance(conv2d, nn.Conv2d)
                and conv2d.bias is not None
                and isinstance(prelu, nn.PReLU)
            ):
                self.module_name = "Conv2dBiasPReLU"
            else:
                raise NotImplementedError(f"{models}")
        elif isinstance(models[0], nn.Conv2d) and models[0].bias is not None:
            self.module_name = "Conv2dBias"
        elif isinstance(models[0], nn.ConvTranspose2d) and models[0].bias is not None:
            self.module_name = "ConvTranspose2dBias"
        else:
            self.module_name = "ERROR"
            raise NotImplementedError(f"{models}")

        self.inputs_templete = {}
        self.inputs_templete["forward"] = []
        if self.module_name == "Conv2dBiasPReLU":
            for i in range(len(models)):
                self.inputs_templete["forward"].extend(
                    [
                        None,
                        self.get_parameter(f"m{i}_0_weight"),
                        None,
                        self.get_parameter(f"m{i}_0_bias"),
                        self.get_parameter(f"m{i}_1_weight"),
                    ]
                )
            self.input_indices = [i * 5 for i in range(self.num_submodels)]
            self.output_indices = [i * 5 + 2 for i in range(self.num_submodels)]
        elif (
            self.module_name == "Conv2dBias"
            or self.module_name == "ConvTranspose2dBias"
        ):
            for i in range(len(models)):
                self.inputs_templete["forward"].extend(
                    [
                        None,
                        self.get_parameter(f"m{i}_weight"),
                        None,
                        self.get_parameter(f"m{i}_bias"),
                    ]
                )
            self.input_indices = [i * 4 for i in range(self.num_submodels)]
            self.output_indices = [i * 4 + 2 for i in range(self.num_submodels)]
        # elif self.module_name == "ConvTranspose2dBias":
        else:
            raise NotImplementedError(f"{self.module_name}")
        self.forward(sample_inputs)

    def forward(self, inputs: List[torch.Tensor]):
        for i in range(self.num_submodels):
            self.inputs_templete["forward"][self.input_indices[i]] = inputs[i]
            self.inputs_templete["forward"][self.output_indices[i]] = torch.empty(
                self.output_shapes[i], device="cuda"
            )
        self.fused_kernel(
            *self.inputs_templete["forward"], active_blocks=self.ACTIVE_BLOCKS
        )
        torch.cuda.synchronize()
        outputs = [
            self.inputs_templete["forward"][index] for index in self.output_indices
        ]
        for i in range(self.num_submodels):
            self.inputs_templete["forward"][self.input_indices[i]] = None
            self.inputs_templete["forward"][self.output_indices[i]] = None
        return outputs

    def extra_repr(self):
        return self.module_name


In [3]:
IN_SHAPE = [34, 3, 32, 32]
OUT_SHAPE = [34, 16, 32, 32]


In [4]:
raw = nn.Sequential(
    nn.Conv2d(in_channels=3, out_channels=16, kernel_size=5, stride=1, padding=2),
    nn.PReLU(),
).cuda()

pth_file_path = r"/home/v-louyang/brainstorm_project/brainstorm/benchmark/classsr/experiments/pre_trained_models/ClassSR_FSRCNN.pth"
full_state_dict = torch.load(pth_file_path)
# print(full_state_dict.keys())
state_dict = {
    key[15:]: full_state_dict[key]
    for key in [
        "net1.head_conv.0.weight",
        "net1.head_conv.0.bias",
        "net1.head_conv.1.weight",
    ]
}
print(state_dict.keys())
raw.load_state_dict(state_dict)
raw.eval()

fused = FusedLayer([raw], [IN_SHAPE], [OUT_SHAPE])
fused.cuda()
fused.eval()


dict_keys(['0.weight', '0.bias', '1.weight'])
ModuleList(
  (0): Sequential(
    (0): Conv2d(3, 16, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (1): PReLU(num_parameters=1)
  )
)
[['m0_0_weight', torch.Size([16, 3, 5, 5])], ['m0_0_bias', torch.Size([16])], ['m0_1_weight', torch.Size([1])]]


FusedLayer(Conv2dBiasPReLU)

In [5]:
x = torch.randn(IN_SHAPE).cuda()
# print(f"######## x: {x}")

y_raw = raw(x)
# print(f"######## raw: {y_raw}")

y_fused = fused([x])[0]
# print(f"######## fused: {y_fused}")

torch.allclose(y_fused, y_raw, rtol=0.1)


False

In [11]:
inputs = [
        torch.randn(shp, device="cuda")
        for shp in [
            [34, 12, 32, 32],
        ]
    ]

raw_nets = nn.ModuleList([
            nn.Sequential(
                nn.Conv2d(
                    in_channels=12, out_channels=och, kernel_size=3, stride=1, padding=1
                ),
                nn.PReLU(),
            ).cuda()
            for och in [12]
        ]).cuda()

fused_conv2d = make_jit_kernel(
    raw_nets[0],
    inputs[0],
    # opt_level="hetero_fuse",
)

fused_inputs = [
    inputs[0],
    raw_nets[0][0].weight,
    torch.empty([34, 12, 32, 32]).cuda(),
    raw_nets[0][0].bias.expand(12),
    raw_nets[0][1].weight.expand(12),
]

conv2d_out0 = raw_nets[0](inputs[0])

fused_conv2d(*fused_inputs)

fused_conv2d_out0 = fused_inputs[2]

print(torch.allclose(conv2d_out0, fused_conv2d_out0, rtol=0.1))
print(fused_conv2d_out0)



RuntimeError: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call,so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.

In [None]:
# inputs = [
#         torch.randn(shp, device="cuda")
#         for shp in [
#             [34, 3, 32, 32],
#             [38, 3, 32, 32],
#             [29, 3, 32, 32],
#         ]
#     ]

# raw_nets = nn.ModuleList([
#             nn.Sequential(
#                 nn.Conv2d(
#                     in_channels=3, out_channels=och, kernel_size=5, stride=1, padding=2
#                 ),
#                 nn.PReLU(),
#             ).cuda()
#             for och in [16, 36, 56]
#         ]).cuda()

# fused_conv2d = make_jit_kernel(
#     raw_nets,
#     inputs,
#     # opt_level="hetero_fuse",
# )

# fused_inputs = [
#     inputs[0],
#     raw_nets[0][0].weight,
#     torch.empty([34, 16, 32, 32]).cuda(),
#     raw_nets[0][0].bias,
#     raw_nets[0][1].weight,
#     inputs[1],
#     raw_nets[1][0].weight,
#     torch.empty([38, 36, 32, 32]).cuda(),
#     raw_nets[1][0].bias,
#     raw_nets[1][1].weight,
#     inputs[2],
#     raw_nets[2][0].weight,
#     torch.empty([29, 56, 32, 32]).cuda(),
#     raw_nets[2][0].bias,
#     raw_nets[2][1].weight,
# ]

# conv2d_out0 = raw_nets[0](inputs[0])
# conv2d_out1 = raw_nets[1](inputs[1])
# conv2d_out2 = raw_nets[2](inputs[2])

# fused_conv2d(*fused_inputs, active_blocks=[1, 1, 1])

# fused_conv2d_out0 = fused_inputs[2]
# fused_conv2d_out1 = fused_inputs[7]
# fused_conv2d_out2 = fused_inputs[12]

# # print(fused_conv2d_out0)
# print(torch.allclose(conv2d_out0, fused_conv2d_out0, rtol=0.1))
# print(torch.allclose(conv2d_out1, fused_conv2d_out1, rtol=0.1))
# print(torch.allclose(conv2d_out2, fused_conv2d_out2, rtol=0.1))

