In [4]:
from timeit import timeit
from pprint import pprint

import torch
from torch import nn
from torch.utils.benchmark import Timer

from brt.jit import make_jit_kernel

from archs.fuse import TunedKernel, FusedLayer


In [5]:
from itertools import chain, combinations
import more_itertools as mit
from more_itertools import set_partitions

jit_kernel_info = {}


### Single Kernels


In [6]:
NUM_FEATURE = 18

conv2d = nn.Conv2d(NUM_FEATURE, NUM_FEATURE, 3, padding=1).eval().cuda()
subnet_bs = sorted(
    [6, 7, 12, 27, 8, 8, 8, 12, 12, 4]
)  # [4, 6, 7, 8, 8, 8, 12, 12, 12, 27]

for bs in set(subnet_bs):
    for rank in range(1, 6):
        inout_shape = [bs, NUM_FEATURE, 32, 32]
        x = torch.empty(inout_shape, device="cuda")
        kernel = TunedKernel(conv2d, inout_shape, inout_shape, rank)
        time = (
            Timer(
                f"kernel(x)",
                setup="from __main__ import kernel, x; import torch; torch.cuda.synchronize();",
            )
            .timeit(100)
            .mean
            * 10e6
        )
        jit_kernel_info[((bs,), (rank,))] = time

pprint(jit_kernel_info)


{((4,), (1,)): 91.36011358350515,
 ((4,), (2,)): 105.07579427212477,
 ((4,), (3,)): 114.63670525699852,
 ((4,), (4,)): 147.4882010370493,
 ((4,), (5,)): 117.2125106677413,
 ((6,), (1,)): 152.44551468640566,
 ((6,), (2,)): 164.629309438169,
 ((6,), (3,)): 181.20440654456615,
 ((6,), (4,)): 183.78919921815395,
 ((6,), (5,)): 159.58589501678944,
 ((7,), (1,)): 115.07151648402213,
 ((7,), (2,)): 130.11470437049866,
 ((7,), (3,)): 135.02179645001888,
 ((7,), (4,)): 144.40750237554312,
 ((7,), (5,)): 146.48227952420712,
 ((8,), (1,)): 130.3850905969739,
 ((8,), (2,)): 157.31549356132746,
 ((8,), (3,)): 161.39321960508823,
 ((8,), (4,)): 168.92231069505215,
 ((8,), (5,)): 171.2535973638296,
 ((12,), (1,)): 154.13069631904364,
 ((12,), (2,)): 155.76770529150963,
 ((12,), (3,)): 163.19360584020615,
 ((12,), (4,)): 260.5539979413152,
 ((12,), (5,)): 267.45800860226154,
 ((27,), (1,)): 230.38950748741624,
 ((27,), (2,)): 237.57601156830788,
 ((27,), (3,)): 310.3171940892935,
 ((27,), (4,)): 305.6

### Searching Group & Rank

In [15]:
i = 0
greedy_partition = []
while i < len(subnet_bs):
    cur_subnet_bs = (subnet_bs[i],)
    cur_ranks = (1,)
    print(f"\tNEW\t {cur_subnet_bs} {cur_ranks}")
    i = i + 1

    while i < len(subnet_bs):
        cur_time = jit_kernel_info[(cur_subnet_bs, cur_ranks)]
        new_subnet_bs = cur_subnet_bs + (subnet_bs[i],)
        new_inout_shapes = [[bs, NUM_FEATURE, 32, 32] for bs in new_subnet_bs]
        new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
        rank_times = []
        for rank in range(1, 6):
            new_ranks = cur_ranks + (rank,)
            new_kernel_rank = FusedLayer(
                [conv2d] * len(new_subnet_bs),
                new_inout_shapes,
                new_inout_shapes,
                new_ranks,
            )
            new_time_rank = (
                Timer(
                    f"new_kernel_rank(new_x)",
                    setup="from __main__ import new_kernel_rank, new_x; import torch; torch.cuda.synchronize();",
                )
                .timeit(100)
                .mean
                * 10e6
            )
            jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time_rank
            rank_times.append((new_ranks, new_time_rank))
        new_ranks, new_time = min(rank_times, key=lambda x: x[1])
        old_time = (
            jit_kernel_info[(cur_subnet_bs, cur_ranks)]
            + min([jit_kernel_info[((subnet_bs[i],), (rank,))] for rank in range(1, 6)])
        )
        print(
            f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
        )
        if new_time < old_time:
            cur_subnet_bs = new_subnet_bs
            cur_ranks = new_ranks
            print(f"\tAPPEND\t {cur_subnet_bs} {cur_ranks}")
        else:
            break
        i = i + 1

    print(f"\tCLOSE\t {cur_subnet_bs} {cur_ranks}")
    greedy_partition.append((cur_subnet_bs, cur_ranks))

print(greedy_partition)
# [[4, 6], [7], [8, 8, 8, 12, 12, 12], [27]]
# [((4, 6), (1, 1)), ((7, 8, 8, 8), (1, 2, 2, 2)), ((12, 12, 12, 27), (1, 3, 3, 4))]


	NEW	 (4,) (1,)
243.806->185.716, -58.090, -23.826%, 5
	APPEND	 (4, 6) (1, 5)
300.787->218.651, -82.136, -27.307%, 3
	APPEND	 (4, 6, 7) (1, 5, 3)
349.037->336.577, -12.459, -3.570%, 1
	APPEND	 (4, 6, 7, 8) (1, 5, 3, 1)
466.962->376.330, -90.633, -19.409%, 1
	APPEND	 (4, 6, 7, 8, 8) (1, 5, 3, 1, 1)
506.715->471.055, -35.660, -7.037%, 1
	APPEND	 (4, 6, 7, 8, 8, 8) (1, 5, 3, 1, 1, 1)
625.186->577.343, -47.843, -7.653%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12) (1, 5, 3, 1, 1, 1, 1)
731.474->642.814, -88.660, -12.121%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12) (1, 5, 3, 1, 1, 1, 1, 1)
796.945->708.776, -88.168, -11.063%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12, 12) (1, 5, 3, 1, 1, 1, 1, 1, 1)
939.166->891.293, -47.873, -5.097%, 2
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12, 12, 27) (1, 5, 3, 1, 1, 1, 1, 1, 1, 2)
	CLOSE	 (4, 6, 7, 8, 8, 8, 12, 12, 12, 27) (1, 5, 3, 1, 1, 1, 1, 1, 1, 2)
[((4, 6, 7, 8, 8, 8, 12, 12, 12, 27), (1, 5, 3, 1, 1, 1, 1, 1, 1, 2))]


In [16]:
print(f"{sum([jit_kernel_info[((bs, ), (1, ))] for bs in subnet_bs])}")
print(f"{sum([jit_kernel_info[info] for info in greedy_partition])}")


1442.8140129894018
891.2930963560938


### Searching Group (Rank = 1)

In [17]:
i = 0
greedy_partition = []
while i < len(subnet_bs):
    cur_subnet_bs = (subnet_bs[i],)
    cur_ranks = (1,)
    print(f"\tNEW\t {cur_subnet_bs} {cur_ranks}")
    i = i + 1

    while i < len(subnet_bs):
        cur_time = jit_kernel_info[(cur_subnet_bs, cur_ranks)]
        new_subnet_bs = cur_subnet_bs + (subnet_bs[i],)
        new_inout_shapes = [[bs, NUM_FEATURE, 32, 32] for bs in new_subnet_bs]
        new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
        rank_times = []
        for rank in range(1, 2):
            new_ranks = cur_ranks + (rank,)
            new_kernel_rank = FusedLayer(
                [conv2d] * len(new_subnet_bs),
                new_inout_shapes,
                new_inout_shapes,
                new_ranks,
            )
            new_time_rank = (
                Timer(
                    f"new_kernel_rank(new_x)",
                    setup="from __main__ import new_kernel_rank, new_x; import torch; torch.cuda.synchronize();",
                )
                .timeit(100)
                .mean
                * 10e6
            )
            jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time_rank
            rank_times.append((new_ranks, new_time_rank))
        new_ranks, new_time = min(rank_times, key=lambda x: x[1])
        old_time = (
            jit_kernel_info[(cur_subnet_bs, cur_ranks)]
            + jit_kernel_info[((subnet_bs[i],), (1,))]
        )
        print(
            f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
        )
        if new_time < old_time:
            cur_subnet_bs = new_subnet_bs
            cur_ranks = new_ranks
            print(f"\tAPPEND\t {cur_subnet_bs} {cur_ranks}")
        else:
            break
        i = i + 1

    print(f"\tCLOSE\t {cur_subnet_bs} {cur_ranks}")
    greedy_partition.append((cur_subnet_bs, cur_ranks))

print(greedy_partition)
# [[4, 6], [7], [8, 8, 8, 12, 12, 12], [27]]


	NEW	 (4,) (1,)
243.806->195.972, -47.834, -19.620%, 1
	APPEND	 (4, 6) (1, 1)
311.044->262.423, -48.620, -15.631%, 1
	APPEND	 (4, 6, 7) (1, 1, 1)
392.808->345.376, -47.433, -12.075%, 1
	APPEND	 (4, 6, 7, 8) (1, 1, 1, 1)
475.761->422.329, -53.432, -11.231%, 1
	APPEND	 (4, 6, 7, 8, 8) (1, 1, 1, 1, 1)
552.714->477.586, -75.129, -13.593%, 1
	APPEND	 (4, 6, 7, 8, 8, 8) (1, 1, 1, 1, 1, 1)
631.716->562.003, -69.714, -11.036%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12) (1, 1, 1, 1, 1, 1, 1)
716.133->618.895, -97.238, -13.578%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12) (1, 1, 1, 1, 1, 1, 1, 1)
773.026->687.632, -85.394, -11.047%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12, 12) (1, 1, 1, 1, 1, 1, 1, 1, 1)
918.021->896.292, -21.730, -2.367%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12, 12, 27) (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
	CLOSE	 (4, 6, 7, 8, 8, 8, 12, 12, 12, 27) (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
[((4, 6, 7, 8, 8, 8, 12, 12, 12, 27), (1, 1, 1, 1, 1, 1, 1, 1, 1, 1))]


In [10]:
print(f"{sum([jit_kernel_info[((bs, ), (1, ))] for bs in subnet_bs])}")
print(f"{sum([jit_kernel_info[info] for info in greedy_partition])}")


1442.8140129894018
891.228998079896


### Searching Rank (bs = 12)


In [14]:
# i = 0
# greedy_partition = []

bs = 12
num_models = 3

new_subnet_bs = (bs,) * num_models

for rank in range(1, 6):
    new_ranks = (rank,) * num_models
    print(f"\tNEW\t {new_subnet_bs} {new_ranks}")
    new_inout_shapes = [[bs, NUM_FEATURE, 32, 32] for bs in new_subnet_bs]
    new_kernel = FusedLayer(
        [conv2d] * len(new_subnet_bs),
        new_inout_shapes,
        new_inout_shapes,
        new_ranks,
    )
    new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
    new_time = (
        Timer(
            f"new_kernel(new_x)",
            setup="from __main__ import new_kernel, new_x; import torch; torch.cuda.synchronize();",
        )
        .timeit(100)
        .mean
        * 10e6
    )
    jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time
    old_time = min(jit_kernel_info[((bs,), (rank,))] for rank in range(1, 6)) * 3
    print(
        f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
    )


	NEW	 (12, 12, 12) (1, 1, 1)
462.392->295.493, -166.899, -36.095%, 1
	NEW	 (12, 12, 12) (2, 2, 2)
462.392->303.034, -159.358, -34.464%, 2
	NEW	 (12, 12, 12) (3, 3, 3)
462.392->423.969, -38.423, -8.310%, 3
	NEW	 (12, 12, 12) (4, 4, 4)
462.392->399.059, -63.333, -13.697%, 4
	NEW	 (12, 12, 12) (5, 5, 5)
462.392->390.600, -71.793, -15.526%, 5
