In [1]:
import torch
from torch.nn.functional import conv3d

In [2]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [3]:
import torch
import time
from typing import Any, Tuple
import triton.testing

def time_op(op):
    return triton.testing.do_bench(lambda: op(), warmup=15, rep=100)


In [4]:
def timestats(fn):
    times = []
    for i in range(10):
        dt = time_op(fn)
        times.append(dt)
    # print stats
    print(f"min:\t{min(times)}")
    print(f"max:\t{max(times)}")
    print(f"mean:\t{sum(times) / len(times)}")


In [5]:
conv3d_compiled = torch.compile(conv3d, mode="max-autotune")

In [30]:
# dilated conv 
# B, C, T, H, W = 128, 3, 4, 1024, 1024
input_shape = (8, 384, 3, 34, 34)
weight_shape = (384, 384, 3, 3, 3)
B, C, T, H, W = 8, 384, 3, 34, 34
input = torch.randn(input_shape, device=device, dtype=torch.float16)
kernel = torch.randn(weight_shape, device=device, dtype=torch.float16)
bias = torch.randn(C, device=device, dtype=torch.float16).contiguous()

args = (input, kernel)
kwargs = dict(padding=0, dilation=1, stride=1, bias=bias)
def fn():
    return conv3d(*args, **kwargs)

In [31]:
from torch.profiler import profile, schedule, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    schedule=schedule(wait=1, warmup=1, active=1, repeat=1),
    record_shapes=True,
    profile_memory=True,
    with_stack=True,
) as prof:
    for i in range(10):
        fn()
        prof.step()
prof.export_chrome_trace("conv3d.json")

In [25]:
timestats(lambda: conv3d(*args, **kwargs))

min:	10.735214339362251
max:	11.516579985618591
mean:	11.049642984072367


In [26]:
timestats(lambda: conv3d_compiled(*args, **kwargs))

min:	0.7348821630349031
max:	0.7364604794979095
mean:	0.7359658483492362


In [27]:
import torch

torch.manual_seed(42)
assert torch.cuda.is_available()
device = torch.device("cuda")
input_shape = (8, 384, 3, 34, 34)
weight_shape = (384, 384, 3, 3, 3)

input_tensor = torch.randn(
    input_shape, device=device, dtype=torch.float16
).requires_grad_(False)
weight_tensor = torch.randn(
    weight_shape, device=device, dtype=torch.float16
).requires_grad_(False)
bias_tensor = (
    torch.randn(weight_tensor.shape[1], device=device, dtype=torch.float16)
    .requires_grad_(False)
    .reshape(1, -1, 1, 1, 1)
    .to(memory_format=torch.channels_last_3d)
)

In [28]:
import triton.testing
def time_op(op):
    return triton.testing.do_bench(lambda: op(), warmup=15, rep=100)
bias_for_torch = bias_tensor.view(-1)

print(time_op(lambda: torch.nn.functional.conv3d(input_tensor, weight_tensor, bias=bias_for_torch, padding=0, stride=1, dilation=1)))

11.034243941307068
