In [2]:
from timeit import timeit
from pprint import pprint

import torch
from torch import nn
from torch.utils.benchmark import Timer

from brt.jit import make_jit_kernel

from archs.fuse import TunedKernel, FusedLayer


In [3]:
from itertools import chain, combinations
import more_itertools as mit
from more_itertools import set_partitions

jit_kernel_info = {}


### Single Kernels


In [4]:
conv2d = nn.Conv2d(36, 36, 3, padding=1).eval().cuda()
subnet_bs = sorted(
    [6, 7, 12, 27, 8, 8, 8, 12, 12, 4]
)  # [4, 6, 7, 8, 8, 8, 12, 12, 12, 27]

for bs in set(subnet_bs):
    for rank in range(1, 6):
        inout_shape = [bs, 36, 32, 32]
        x = torch.empty(inout_shape, device="cuda")
        kernel = TunedKernel(conv2d, inout_shape, inout_shape, rank)
        time = (
            Timer(
                f"kernel(x)",
                setup="from __main__ import kernel, x; import torch; torch.cuda.synchronize();",
            )
            .timeit(100)
            .mean
            * 10e6
        )
        jit_kernel_info[((bs,), (rank,))] = time

pprint(jit_kernel_info)


{((4,), (1,)): 169.04701478779316,
 ((4,), (2,)): 176.61810852587223,
 ((4,), (3,)): 185.1822016760707,
 ((4,), (4,)): 193.07979382574558,
 ((4,), (5,)): 194.45340149104595,
 ((6,), (1,)): 210.38518752902746,
 ((6,), (2,)): 237.27540392428637,
 ((6,), (3,)): 337.4476917088032,
 ((6,), (4,)): 261.84439193457365,
 ((6,), (5,)): 243.706488981843,
 ((7,), (1,)): 261.0732102766633,
 ((7,), (2,)): 437.7095028758049,
 ((7,), (3,)): 425.736210308969,
 ((7,), (4,)): 447.4033135920763,
 ((7,), (5,)): 453.0767910182476,
 ((8,), (1,)): 263.5668031871319,
 ((8,), (2,)): 270.4497193917632,
 ((8,), (3,)): 284.30260717868805,
 ((8,), (4,)): 287.6819111406803,
 ((8,), (5,)): 298.7565938383341,
 ((12,), (1,)): 334.4644093886018,
 ((12,), (2,)): 366.01140163838863,
 ((12,), (3,)): 385.5569055303931,
 ((12,), (4,)): 398.6133029684424,
 ((12,), (5,)): 396.01740427315235,
 ((27,), (1,)): 774.1475012153387,
 ((27,), (2,)): 726.6403874382377,
 ((27,), (3,)): 655.2041042596102,
 ((27,), (4,)): 715.604797005653

### Searching Group & Rank

In [5]:
i = 0
greedy_partition = []
while i < len(subnet_bs):
    cur_subnet_bs = (subnet_bs[i],)
    cur_ranks = (1,)
    print(f"\tNEW\t {cur_subnet_bs} {cur_ranks}")
    i = i + 1

    while i < len(subnet_bs):
        cur_time = jit_kernel_info[(cur_subnet_bs, cur_ranks)]
        new_subnet_bs = cur_subnet_bs + (subnet_bs[i],)
        new_inout_shapes = [[bs, 36, 32, 32] for bs in new_subnet_bs]
        new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
        rank_times = []
        for rank in range(1, 6):
            new_ranks = cur_ranks + (rank,)
            new_kernel_rank = FusedLayer(
                [conv2d] * len(new_subnet_bs),
                new_inout_shapes,
                new_inout_shapes,
                new_ranks,
            )
            new_time_rank = (
                Timer(
                    f"new_kernel_rank(new_x)",
                    setup="from __main__ import new_kernel_rank, new_x; import torch; torch.cuda.synchronize();",
                )
                .timeit(100)
                .mean
                * 10e6
            )
            jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time_rank
            rank_times.append((new_ranks, new_time_rank))
        new_ranks, new_time = min(rank_times, key=lambda x: x[1])
        old_time = (
            jit_kernel_info[(cur_subnet_bs, cur_ranks)]
            + min([jit_kernel_info[((subnet_bs[i],), (rank,))] for rank in range(1, 6)])
        )
        print(
            f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
        )
        if new_time < old_time:
            cur_subnet_bs = new_subnet_bs
            cur_ranks = new_ranks
            print(f"\tAPPEND\t {cur_subnet_bs} {cur_ranks}")
        else:
            break
        i = i + 1

    print(f"\tCLOSE\t {cur_subnet_bs} {cur_ranks}")
    greedy_partition.append((cur_subnet_bs, cur_ranks))

print(greedy_partition)
# [[4, 6], [7], [8, 8, 8, 12, 12, 12], [27]]
# [((4, 6), (1, 1)), ((7, 8, 8, 8), (1, 2, 2, 2)), ((12, 12, 12, 27), (1, 3, 3, 4))]


	NEW	 (4,) (1,)
379.432->349.132, -30.300, -7.986%, 1
	APPEND	 (4, 6) (1, 1)
610.205->700.643, 90.438, 14.821%, 3
	CLOSE	 (4, 6) (1, 1)
	NEW	 (7,) (1,)
524.640->509.755, -14.885, -2.837%, 2
	APPEND	 (7, 8) (1, 2)
773.322->685.944, -87.378, -11.299%, 2
	APPEND	 (7, 8, 8) (1, 2, 2)
949.511->876.967, -72.544, -7.640%, 2
	APPEND	 (7, 8, 8, 8) (1, 2, 2, 2)
1211.431->1226.181, 14.750, 1.218%, 2
	CLOSE	 (7, 8, 8, 8) (1, 2, 2, 2)
	NEW	 (12,) (1,)
668.929->642.889, -26.040, -3.893%, 3
	APPEND	 (12, 12) (1, 3)
977.353->879.911, -97.442, -9.970%, 3
	APPEND	 (12, 12, 12) (1, 3, 3)
1535.115->1549.141, 14.026, 0.914%, 4
	CLOSE	 (12, 12, 12) (1, 3, 3)
	NEW	 (27,) (1,)
	CLOSE	 (27,) (1,)
[((4, 6), (1, 1)), ((7, 8, 8, 8), (1, 2, 2, 2)), ((12, 12, 12), (1, 3, 3)), ((27,), (1,))]


In [6]:
print(f"{sum([jit_kernel_info[((bs, ), (1, ))] for bs in subnet_bs])}")
print(f"{sum([jit_kernel_info[info] for info in greedy_partition])}")


3208.7465515360236
2880.1569249480963


### Searching Group (Rank = 1)

In [9]:
i = 0
greedy_partition = []
while i < len(subnet_bs):
    cur_subnet_bs = (subnet_bs[i],)
    cur_ranks = (1,)
    print(f"\tNEW\t {cur_subnet_bs} {cur_ranks}")
    i = i + 1

    while i < len(subnet_bs):
        cur_time = jit_kernel_info[(cur_subnet_bs, cur_ranks)]
        new_subnet_bs = cur_subnet_bs + (subnet_bs[i],)
        new_inout_shapes = [[bs, 36, 32, 32] for bs in new_subnet_bs]
        new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
        rank_times = []
        for rank in range(1, 2):
            new_ranks = cur_ranks + (rank,)
            new_kernel_rank = FusedLayer(
                [conv2d] * len(new_subnet_bs),
                new_inout_shapes,
                new_inout_shapes,
                new_ranks,
            )
            new_time_rank = (
                Timer(
                    f"new_kernel_rank(new_x)",
                    setup="from __main__ import new_kernel_rank, new_x; import torch; torch.cuda.synchronize();",
                )
                .timeit(100)
                .mean
                * 10e6
            )
            jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time_rank
            rank_times.append((new_ranks, new_time_rank))
        new_ranks, new_time = min(rank_times, key=lambda x: x[1])
        old_time = (
            jit_kernel_info[(cur_subnet_bs, cur_ranks)]
            + jit_kernel_info[((subnet_bs[i],), (1,))]
        )
        print(
            f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
        )
        if new_time < old_time:
            cur_subnet_bs = new_subnet_bs
            cur_ranks = new_ranks
            print(f"\tAPPEND\t {cur_subnet_bs} {cur_ranks}")
        else:
            break
        i = i + 1

    print(f"\tCLOSE\t {cur_subnet_bs} {cur_ranks}")
    greedy_partition.append((cur_subnet_bs, cur_ranks))

print(greedy_partition)
# [[4, 6], [7], [8, 8, 8, 12, 12, 12], [27]]


	NEW	 (4,) (1,)
376.892->348.704, -28.187, -7.479%, 1
	APPEND	 (4, 6) (1, 1)
609.790->845.561, 235.771, 38.664%, 1
	CLOSE	 (4, 6) (1, 1)
	NEW	 (7,) (1,)
524.485->713.922, 189.437, 36.119%, 1
	CLOSE	 (7,) (1,)
	NEW	 (8,) (1,)
526.800->465.021, -61.779, -11.727%, 1
	APPEND	 (8, 8) (1, 1)
728.421->686.387, -42.033, -5.770%, 1
	APPEND	 (8, 8, 8) (1, 1, 1)
1020.591->997.386, -23.205, -2.274%, 1
	APPEND	 (8, 8, 8, 12) (1, 1, 1, 1)
1331.590->1245.289, -86.301, -6.481%, 1
	APPEND	 (8, 8, 8, 12, 12) (1, 1, 1, 1, 1)
1579.493->1521.756, -57.737, -3.655%, 1
	APPEND	 (8, 8, 8, 12, 12, 12) (1, 1, 1, 1, 1, 1)
2295.558->2570.368, 274.809, 11.971%, 1
	CLOSE	 (8, 8, 8, 12, 12, 12) (1, 1, 1, 1, 1, 1)
	NEW	 (27,) (1,)
	CLOSE	 (27,) (1,)
[((4, 6), (1, 1)), ((7,), (1,)), ((8, 8, 8, 12, 12, 12), (1, 1, 1, 1, 1, 1)), ((27,), (1,))]


In [10]:
print(f"{sum([jit_kernel_info[((bs, ), (1, ))] for bs in subnet_bs])}")
print(f"{sum([jit_kernel_info[info] for info in greedy_partition])}")


3204.59078066051
2905.34810628742


### Searching Rank (bs = 12)


In [None]:
# i = 0
# greedy_partition = []

bs = 12
num_models = 3

new_subnet_bs = (bs,) * num_models

for rank in range(1, 6):
    new_ranks = (rank,) * num_models
    print(f"\tNEW\t {cur_subnet_bs} {cur_ranks}")
    new_inout_shapes = [[bs, 36, 32, 32] for bs in new_subnet_bs]
    new_kernel = FusedLayer(
        [conv2d] * len(cur_subnet_bs),
        new_inout_shapes,
        new_inout_shapes,
        new_ranks,
    )
    new_x = [torch.empty(shp, device="cuda") for shp in new_inout_shapes]
    new_time = (
        Timer(
            f"new_kernel(new_x)",
            setup="from __main__ import new_kernel, new_x; import torch; torch.cuda.synchronize();",
        )
        .timeit(100)
        .mean
        * 10e6
    )
    jit_kernel_info[(new_subnet_bs, new_ranks)] = new_time
    old_time = jit_kernel_info[((bs,), (rank,))] * 3
    print(
        f"{old_time:.3f}->{new_time:.3f}, {new_time-old_time:.3f}, {100 * (new_time/old_time-1):.3f}%, {new_ranks[-1]}"
    )
