In [1]:
from timeit import timeit
from pprint import pprint

import torch
from torch import nn
from torch.utils.benchmark import Timer

from brt.jit import make_jit_kernel

from archs.fuse import TunedKernel, FusedLayer, set_objective_func


  from .autonotebook import tqdm as notebook_tqdm


In [2]:
from itertools import chain, combinations
import more_itertools as mit
from more_itertools import set_partitions

jit_kernel_info = {}

# set_objective_func("most_efficient")
set_objective_func("fastest")


<archs.fuse._ObjectiveFuncContext at 0x7f6dd08eb820>

### Single Kernels


In [3]:
NUM_FEATURE = 8

conv2d = nn.Conv2d(NUM_FEATURE, NUM_FEATURE, 3, padding=1).eval().cuda()
subnet_bs = sorted(
    [6, 7, 12, 27, 8, 8, 8, 12, 12, 4]
)  # [4, 6, 7, 8, 8, 8, 12, 12, 12, 27]

for bs in set(subnet_bs):
    for rank in range(1, 6):
        inout_shape = [bs, NUM_FEATURE, 32, 32]
        x = torch.empty(inout_shape, device="cuda")
        kernel = TunedKernel(conv2d, inout_shape, inout_shape, rank)
        time = (
            Timer(
                f"kernel(x)",
                setup="from __main__ import kernel, x; import torch; torch.cuda.synchronize();",
            )
            .timeit(100)
            .mean
            * 10e6
        )
        jit_kernel_info[((bs,), (rank,))] = time

pprint(jit_kernel_info)


{((4,), (1,)): 76.71370112802833,
 ((4,), (2,)): 76.66559831704944,
 ((4,), (3,)): 76.72369829379022,
 ((4,), (4,)): 79.82650131452829,
 ((4,), (5,)): 78.78750038798898,
 ((6,), (1,)): 75.86000137962401,
 ((6,), (2,)): 76.4872005674988,
 ((6,), (3,)): 77.31569930911063,
 ((6,), (4,)): 75.33709867857397,
 ((6,), (5,)): 77.88689981680363,
 ((7,), (1,)): 79.88949946593493,
 ((7,), (2,)): 80.22819820325822,
 ((7,), (3,)): 77.30669749435037,
 ((7,), (4,)): 76.91799837630242,
 ((7,), (5,)): 83.97430065087974,
 ((8,), (1,)): 82.66670047305524,
 ((8,), (2,)): 76.09949971083552,
 ((8,), (3,)): 77.87380018271507,
 ((8,), (4,)): 76.3449992518872,
 ((8,), (5,)): 78.04009946994483,
 ((12,), (1,)): 78.08419759385288,
 ((12,), (2,)): 77.99899904057384,
 ((12,), (3,)): 80.83030115813017,
 ((12,), (4,)): 95.04499903414398,
 ((12,), (5,)): 102.02800040133297,
 ((27,), (1,)): 85.99010179750621,
 ((27,), (2,)): 93.129399465397,
 ((27,), (3,)): 112.57579899393022,
 ((27,), (4,)): 111.34850210510194,
 ((27,

### Searching Group & Rank

In [5]:
i = 0
greedy_partition = []
while i < len(subnet_bs):
    cur_subnet_bs = (subnet_bs[i],)
    cur_ranks = (1,)
    # (cur_subnet_bs, cur_ranks), _ = min(
    #     [x for x in jit_kernel_info.items() if x[0][0] == (subnet_bs[i],)],
    #     key=lambda x: x[1],
    # )
    print(f"\tNEW\t {cur_subnet_bs} {cur_ranks}")
    i = i + 1

    while i < len(subnet_bs):
        cur_time = jit_kernel_info[(cur_subnet_bs, cur_ranks)]
        new_subnet_bs = cur_subnet_bs + (subnet_bs[i],)
        new_inout_shapes = [[bs, NUM_FEATURE, 32, 32] for bs in new_subnet_bs]
        new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
        rank_times = []
        for rank in range(1, 6):
            new_ranks = cur_ranks + (rank,)
            new_kernel_rank = FusedLayer(
                [conv2d] * len(new_subnet_bs),
                new_inout_shapes,
                new_inout_shapes,
                new_ranks,
            )
            new_time_rank = (
                Timer(
                    f"new_kernel_rank(new_x)",
                    setup="from __main__ import new_kernel_rank, new_x; import torch; torch.cuda.synchronize();",
                )
                .timeit(100)
                .mean
                * 10e6
            )
            jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time_rank
            rank_times.append((new_ranks, new_time_rank))
        new_ranks, new_time = min(rank_times, key=lambda x: x[1])
        old_time = jit_kernel_info[(cur_subnet_bs, cur_ranks)] + min(
            [jit_kernel_info[((subnet_bs[i],), (rank,))] for rank in range(1, 6)]
        )
        print(
            f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
        )
        if new_time < old_time:
            cur_subnet_bs = new_subnet_bs
            cur_ranks = new_ranks
            print(f"\tAPPEND\t {cur_subnet_bs} {cur_ranks}")
        else:
            break
        i = i + 1

    print(f"\tCLOSE\t {cur_subnet_bs} {cur_ranks}")
    greedy_partition.append((cur_subnet_bs, cur_ranks))

print(greedy_partition)
# [[4, 6], [7], [8, 8, 8, 12, 12, 12], [27]]
# [((4, 6), (1, 1)), ((7, 8, 8, 8), (1, 2, 2, 2)), ((12, 12, 12, 27), (1, 3, 3, 4))]


	NEW	 (4,) (1,)
152.051->59.592, -92.458, -60.808%, 1
	APPEND	 (4, 6) (1, 1)
136.510->73.846, -62.664, -45.904%, 1
	APPEND	 (4, 6, 7) (1, 1, 1)
149.946->82.686, -67.260, -44.856%, 4
	APPEND	 (4, 6, 7, 8) (1, 1, 1, 4)
158.785->97.756, -61.029, -38.435%, 4
	APPEND	 (4, 6, 7, 8, 8) (1, 1, 1, 4, 4)
173.856->111.181, -62.674, -36.050%, 4
	APPEND	 (4, 6, 7, 8, 8, 8) (1, 1, 1, 4, 4, 4)
189.180->142.784, -46.396, -24.525%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12) (1, 1, 1, 4, 4, 4, 1)
220.783->159.496, -61.288, -27.759%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12) (1, 1, 1, 4, 4, 4, 1, 1)
237.495->191.513, -45.982, -19.361%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12, 12) (1, 1, 1, 4, 4, 4, 1, 1, 1)
277.503->249.518, -27.985, -10.084%, 2
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12, 12, 27) (1, 1, 1, 4, 4, 4, 1, 1, 1, 2)
	CLOSE	 (4, 6, 7, 8, 8, 8, 12, 12, 12, 27) (1, 1, 1, 4, 4, 4, 1, 1, 1, 2)
[((4, 6, 7, 8, 8, 8, 12, 12, 12, 27), (1, 1, 1, 4, 4, 4, 1, 1, 1, 2))]


In [7]:
print(f"{sum([jit_kernel_info[((bs, ), (1, ))] for bs in subnet_bs])}")
print(f"{sum([jit_kernel_info[info] for info in greedy_partition])}")
# fastest
## 2434.2245975276455
## 2193.2725998340175


800.7059979718179
249.5181019185111


### Searching Group (Rank = 1)

In [8]:
i = 0
greedy_partition = []
while i < len(subnet_bs):
    cur_subnet_bs = (subnet_bs[i],)
    cur_ranks = (1,)
    print(f"\tNEW\t {cur_subnet_bs} {cur_ranks}")
    i = i + 1

    while i < len(subnet_bs):
        cur_time = jit_kernel_info[(cur_subnet_bs, cur_ranks)]
        new_subnet_bs = cur_subnet_bs + (subnet_bs[i],)
        new_inout_shapes = [[bs, NUM_FEATURE, 32, 32] for bs in new_subnet_bs]
        new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
        rank_times = []
        for rank in range(1, 2):
            new_ranks = cur_ranks + (rank,)
            new_kernel_rank = FusedLayer(
                [conv2d] * len(new_subnet_bs),
                new_inout_shapes,
                new_inout_shapes,
                new_ranks,
            )
            new_time_rank = (
                Timer(
                    f"new_kernel_rank(new_x)",
                    setup="from __main__ import new_kernel_rank, new_x; import torch; torch.cuda.synchronize();",
                )
                .timeit(100)
                .mean
                * 10e6
            )
            jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time_rank
            rank_times.append((new_ranks, new_time_rank))
        new_ranks, new_time = min(rank_times, key=lambda x: x[1])
        old_time = (
            jit_kernel_info[(cur_subnet_bs, cur_ranks)]
            + jit_kernel_info[((subnet_bs[i],), (1,))]
        )
        print(
            f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
        )
        if new_time < old_time:
            cur_subnet_bs = new_subnet_bs
            cur_ranks = new_ranks
            print(f"\tAPPEND\t {cur_subnet_bs} {cur_ranks}")
        else:
            break
        i = i + 1

    print(f"\tCLOSE\t {cur_subnet_bs} {cur_ranks}")
    greedy_partition.append((cur_subnet_bs, cur_ranks))

print(greedy_partition)
# [[4, 6], [7], [8, 8, 8, 12, 12, 12], [27]]


	NEW	 (4,) (1,)
152.574->59.734, -92.840, -60.849%, 1
	APPEND	 (4, 6) (1, 1)
139.623->73.159, -66.464, -47.603%, 1
	APPEND	 (4, 6, 7) (1, 1, 1)
155.826->103.094, -52.732, -33.840%, 1
	APPEND	 (4, 6, 7, 8) (1, 1, 1, 1)
185.761->103.390, -82.371, -44.343%, 1
	APPEND	 (4, 6, 7, 8, 8) (1, 1, 1, 1, 1)
186.056->109.478, -76.578, -41.159%, 1
	APPEND	 (4, 6, 7, 8, 8, 8) (1, 1, 1, 1, 1, 1)
187.562->136.492, -51.071, -27.229%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12) (1, 1, 1, 1, 1, 1, 1)
214.576->152.011, -62.565, -29.158%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12) (1, 1, 1, 1, 1, 1, 1, 1)
230.095->181.061, -49.034, -21.310%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12, 12) (1, 1, 1, 1, 1, 1, 1, 1, 1)
267.051->231.056, -35.996, -13.479%, 1
	APPEND	 (4, 6, 7, 8, 8, 8, 12, 12, 12, 27) (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
	CLOSE	 (4, 6, 7, 8, 8, 8, 12, 12, 12, 27) (1, 1, 1, 1, 1, 1, 1, 1, 1, 1)
[((4, 6, 7, 8, 8, 8, 12, 12, 12, 27), (1, 1, 1, 1, 1, 1, 1, 1, 1, 1))]


In [9]:
print(f"{sum([jit_kernel_info[((bs, ), (1, ))] for bs in subnet_bs])}")
print(f"{sum([jit_kernel_info[info] for info in greedy_partition])}")

# fastest
## 2434.2245975276455
## 2235.7680994900875


800.7059979718179
231.05560103431344


### Searching Rank (bs = 12)


In [10]:
# i = 0
# greedy_partition = []

bs = 8
num_models = 3


for num_models in range(2, 21):
    print(f"====== {num_models = } ====================================")
    for rank in range(1, 6):
        new_subnet_bs = (bs,) * num_models

        new_ranks = (rank,) * num_models
        # print(f"\tNEW\t {bs=} {rank=} x {num_models}")
        new_inout_shapes = [[bs, NUM_FEATURE, 32, 32] for bs in new_subnet_bs]
        new_kernel = FusedLayer(
            [conv2d] * len(new_subnet_bs),
            new_inout_shapes,
            new_inout_shapes,
            new_ranks,
        )
        new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
        new_time = (
            Timer(
                f"new_kernel(new_x)",
                setup="from __main__ import new_kernel, new_x; import torch; torch.cuda.synchronize();",
            )
            .timeit(100)
            .mean
            * 10e6
        )
        jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time
        old_time = (
            min(jit_kernel_info[((bs,), (rank,))] for rank in range(1, 6)) * num_models
        )
        print(
            f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
        )


152.199->67.498, -84.701, -55.651%, 1
152.199->65.088, -87.111, -57.235%, 2
152.199->69.417, -82.782, -54.391%, 3
152.199->66.809, -85.390, -56.104%, 4
152.199->77.118, -75.081, -49.331%, 5
228.298->82.603, -145.696, -63.818%, 1
228.298->67.035, -161.263, -70.637%, 2
228.298->79.899, -148.400, -65.003%, 3
228.298->82.614, -145.685, -63.813%, 4
228.298->77.641, -150.657, -65.991%, 5
304.398->96.848, -207.550, -68.184%, 1
304.398->81.203, -223.195, -73.323%, 2
304.398->93.796, -210.602, -69.187%, 3
304.398->96.436, -207.961, -68.319%, 4
304.398->98.927, -205.471, -67.501%, 5
380.497->110.147, -270.350, -71.052%, 1
380.497->82.226, -298.271, -78.390%, 2
380.497->104.564, -275.934, -72.519%, 3
380.497->109.419, -271.079, -71.243%, 4
380.497->99.115, -281.382, -73.951%, 5
456.597->119.752, -336.845, -73.773%, 1
456.597->96.858, -359.739, -78.787%, 2
456.597->124.674, -331.923, -72.695%, 3
456.597->127.004, -329.593, -72.185%, 4
456.597->119.051, -337.546, -73.926%, 5
532.696->150.704, -381.