In [5]:
from timeit import timeit
from pprint import pprint

import torch
from torch import nn
from torch.utils.benchmark import Timer

from brt.jit import make_jit_kernel

from archs.fuse import TunedKernel, FusedLayer

In [6]:
from itertools import chain, combinations
import more_itertools as mit
from more_itertools import set_partitions

In [7]:
conv2d = nn.Conv2d(36, 36, 3, padding=1).eval().cuda()
subnet_bs = sorted([6, 7, 12, 27, 8, 8, 8, 12, 12, 4]) # [4, 6, 7, 8, 8, 8, 12, 12, 12, 27]

jit_kernel_info = {}
for bs in set(subnet_bs):
    inout_shape = [bs, 36, 32, 32]
    x = torch.empty(inout_shape, device="cuda")
    kernel = TunedKernel(conv2d, inout_shape, inout_shape)
    time = Timer(
        f"kernel(x)",
        setup="from __main__ import kernel, x; import torch; torch.cuda.synchronize();",
    ).timeit(100).mean * 10e6
    jit_kernel_info[str([bs])] = time

pprint(jit_kernel_info)


{'[12]': 333.6390014737844,
 '[27]': 773.7475913017988,
 '[4]': 168.63897908478975,
 '[6]': 204.3970162048936,
 '[7]': 260.6593072414398,
 '[8]': 263.308291323483}


In [8]:
i = 0
greedy_partition = []
while i < len(subnet_bs):
    cur_subnet_bs = [subnet_bs[i]]
    print(f"NEW\t {cur_subnet_bs}")
    i = i + 1

    while i < len(subnet_bs):
        cur_time = jit_kernel_info[str(cur_subnet_bs)]
        new_subnet_bs = cur_subnet_bs + [subnet_bs[i]]
        new_inout_shapes = [[bs, 36, 32, 32] for bs in new_subnet_bs]
        new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
        new_kernel = FusedLayer([conv2d] * len(new_subnet_bs), new_inout_shapes, new_inout_shapes)
        new_time = Timer(
            f"new_kernel(new_x)",
            setup="from __main__ import new_kernel, new_x; import torch; torch.cuda.synchronize();",
        ).timeit(100).mean * 10e6
        jit_kernel_info[str(new_subnet_bs)] = new_time
        old_time = jit_kernel_info[str(cur_subnet_bs)] + jit_kernel_info[str([subnet_bs[i]])]
        print(f"      \t {old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%")
        if new_time < old_time:
            cur_subnet_bs = new_subnet_bs
            print(f"APPEND\t {cur_subnet_bs}")
        else:
            break
        i = i + 1

    print(f"FINALLY\t {cur_subnet_bs}")
    greedy_partition.append(cur_subnet_bs)

print(greedy_partition)
# [[4, 6], [7], [8, 8, 8, 12, 12, 12], [27]]
    


NEW	 [4]
      	 373.036->349.433, -23.603, -6.327%
APPEND	 [4, 6]
      	 610.092->841.940, 231.848, 38.002%
FINALLY	 [4, 6]
NEW	 [7]
      	 523.968->730.111, 206.143, 39.343%
FINALLY	 [7]
NEW	 [8]
      	 526.617->487.886, -38.731, -7.355%
APPEND	 [8, 8]
      	 751.194->670.037, -81.157, -10.804%
APPEND	 [8, 8, 8]
      	 1003.676->996.740, -6.936, -0.691%
APPEND	 [8, 8, 8, 12]
      	 1330.379->1243.815, -86.564, -6.507%
APPEND	 [8, 8, 8, 12, 12]
      	 1577.454->1521.812, -55.642, -3.527%
APPEND	 [8, 8, 8, 12, 12, 12]
      	 2295.560->2585.217, 289.658, 12.618%
FINALLY	 [8, 8, 8, 12, 12, 12]
NEW	 [27]
FINALLY	 [27]
[[4, 6], [7], [8, 8, 8, 12, 12, 12], [27]]


In [10]:
print(f"{sum([jit_kernel_info[str([bs])] for bs in subnet_bs])}")
print(f"{sum([jit_kernel_info[str(bs)] for bs in greedy_partition])}")

3198.2847722247243
2905.6517872959375
