In [9]:
import os
import torch
from torch.utils.cpp_extension import load

In [10]:
cpp_source = os.path.join(os.getcwd(), 'build_cuda', 'main.cpp')
cuda_source = os.path.join(os.getcwd(), 'build_cuda', 'cuda.cu')

In [11]:
square_matrix_extension = load(
    name="square_matrix_extension",
    sources=[cpp_source, cuda_source],
    with_cuda=True,
    # verbose=True,
)

In [12]:
def time_pytorch_function(func, input):
    # CUDA IS ASYNC so can't use python time module
    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    # Warmup
    for _ in range(10):
        func(input)

    start.record()
    func(input)
    end.record()
    torch.cuda.synchronize()
    return start.elapsed_time(end)

In [13]:
b = torch.randn(10000, 10000).cuda()

In [14]:
def square_2(a):
    return a * a

def square_3(a):
    return square_matrix_extension.square_matrix(a)

In [15]:
print("=======================================")
print("Profiling torch.square")
print("Total time: ", time_pytorch_function(torch.square, b), "ms")

print("=======================================")
print("Profiling a * a")
print("Total time: ", time_pytorch_function(square_2, b), "ms")

print("=======================================")
print("Profiling square_extension")
print("Total time: ", time_pytorch_function(square_3, b), "ms")

Profiling torch.square
Total time:  3.6659200191497803 ms
Profiling a * a
Total time:  3.6577279567718506 ms
Profiling square_extension
Total time:  3.334144115447998 ms
