In [1]:
from pathlib import Path

import torch
from torch.utils.cpp_extension import load_inline

In [8]:
# constans
SIZE = 1000
KERNEL_DIR = Path("../kernels")

In [4]:
# utils

def compile_ext(cuda_source: str, cpp_headers: str, ext_name: str, func: list):
    cuda_source = Path(cuda_source).read_text()

    ext = load_inline(
        name=ext_name,
        cpp_sources=cpp_headers,
        cuda_sources=cuda_source,
        functions=func,
        with_cuda=True,
        extra_cuda_cflags=["-O2"],
    )
    return ext


def tensor_details(tensor: torch.Tensor, name: str, head: int = 10):
    print("*" * 50)
    print(f"Tensor {name}")
    print(f"\t Shape: {tensor.shape}")
    print(f"\t Dtype: {tensor.dtype}")
    print(f"\t Device: {tensor.device}")
    print(f"Sample:\n {tensor[:head]}\n")

### Main 

In [6]:
# create input data
v_a = torch.arange(0, SIZE, dtype=torch.float32, device="cuda")
v_b = torch.arange(0, SIZE, dtype=torch.float32, device="cuda")

In [7]:
# Print details
tensor_details(v_a, "A")
tensor_details(v_b, "B")

**************************************************
Tensor A
	 Shape: torch.Size([1000])
	 Dtype: torch.float32
	 Device: cuda:0
Sample:
 tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], device='cuda:0')

**************************************************
Tensor B
	 Shape: torch.Size([1000])
	 Dtype: torch.float32
	 Device: cuda:0
Sample:
 tensor([0., 1., 2., 3., 4., 5., 6., 7., 8., 9.], device='cuda:0')



In [9]:
# Seu up cuda & cpp source

cuda_source = KERNEL_DIR / "vec_addition.cu"
cpp_source = "torch::Tensor vector_addition(torch::Tensor vec_a, torch::Tensor vec_b);"

In [10]:
# Compile extension
ext = compile_ext(cuda_source, cpp_source, "vector_ext", ["vector_addition"])

In [11]:
# Use extension
output = ext.vector_addition(v_a, v_b)

In [12]:
tensor_details(output, "Output")

**************************************************
Tensor Output
	 Shape: torch.Size([1000])
	 Dtype: torch.float32
	 Device: cuda:0
Sample:
 tensor([ 0.,  2.,  4.,  6.,  8., 10., 12., 14., 16., 18.], device='cuda:0')

