In [None]:
!pip install ninja

Collecting ninja
  Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl.metadata (5.3 kB)
Downloading ninja-1.11.1.3-py3-none-manylinux_2_12_x86_64.manylinux2010_x86_64.whl (422 kB)
[?25l   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/422.9 kB[0m [31m?[0m eta [36m-:--:--[0m[2K   [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m422.9/422.9 kB[0m [31m24.8 MB/s[0m eta [36m0:00:00[0m
[?25hInstalling collected packages: ninja
Successfully installed ninja-1.11.1.3


In [None]:
!nvidia-smi

Thu Feb 20 07:19:20 2025       
+-----------------------------------------------------------------------------------------+
| NVIDIA-SMI 550.54.15              Driver Version: 550.54.15      CUDA Version: 12.4     |
|-----------------------------------------+------------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id          Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |           Memory-Usage | GPU-Util  Compute M. |
|                                         |                        |               MIG M. |
|   0  Tesla T4                       Off |   00000000:00:04.0 Off |                    0 |
| N/A   42C    P8              9W /   70W |       0MiB /  15360MiB |      0%      Default |
|                                         |                        |                  N/A |
+-----------------------------------------+------------------------+----------------------+
                                                

In [None]:
cu_code = '''
# include <torch/extension.h>
# include <cuda.h>
# include <cuda_runtime.h>

__global__ void triangular_mm_kernel(const float* __restrict__ A,
                                      const float* __restrict__ B,
                                      float* __restrict__ C, const int N) {
  // Use 2D block configuration for better occupancy
  const int row = blockIdx.y * blockDim.y + threadIdx.y;
  const int col = blockIdx.x * blockDim.x + threadIdx.x;

  if (row < N && col < N) {
    if (col <= row) {
      // Lower triangle computation
      float sum = 0.0f;
      // Process elements in chunks to improve cache utilization
# pragma unroll 8
      for (int k = col; k <= row; k++) {
        sum += A[row * N + k] * B[k * N + col];
      }
      C[row * N + col] = sum;
    } else {
      // Upper triangle (set to zero)
      C[row * N + col] = 0.0f;
    }
  }
}

at::Tensor forward(at::Tensor A, at::Tensor B) {
  TORCH_CHECK(A.is_cuda(), "A must be a CUDA tensor");
  TORCH_CHECK(B.is_cuda(), "B must be a CUDA tensor");
  TORCH_CHECK(A.dim() == 2, "A must be a 2D tensor");
  TORCH_CHECK(B.dim() == 2, "B must be a 2D tensor");
  TORCH_CHECK(A.size(0) == A.size(1), "A must be square");
  TORCH_CHECK(B.size(0) == B.size(1), "B must be square");
  TORCH_CHECK(A.size(0) == B.size(0), "A and B must be the same size");

  int N = A.size(0);
  auto C = torch::empty_like(A);

  // Optimize thread count based on matrix size
  const int threadsPerBlock = 256;  // Increased thread count per block
  const int numBlocks = N;

  triangular_mm_kernel<<<numBlocks, threadsPerBlock>>>(
      A.data_ptr<float>(), B.data_ptr<float>(), C.data_ptr<float>(), N);

  cudaError_t err = cudaGetLastError();
  TORCH_CHECK(err == cudaSuccess, "CUDA kernel failed: ", cudaGetErrorString(err));
  return C;
}

PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
  m.def("forward", &forward,
        "Strided efficient triangular matrix multiplication (CUDA)");
}
'''

with open("tmp.cu", "w") as f:
  f.write(cu_code)

In [None]:
import torch
from torch.utils.cpp_extension import load
from triton.testing import do_bench

# make sure you have nvcc
cuda_fn = load(
    name="triangular_mm",
    sources=["tmp.cu"],
    extra_cuda_cflags=["-O3", "--use_fast_math"],
    with_cuda=True,
    verbose=True,
).forward

N = 4096

def trilmm(a, b): return torch.matmul(a, b).tril()

a = torch.randn(N, N, device="cuda")
b = torch.randn(N, N, device="cuda")

a = torch.tril(a)
b = torch.tril(b)

do_bench(lambda: cuda_fn(a, b).mean()) # do this once jic we need more warmup

# Normal testing
time_new = do_bench(lambda: cuda_fn(a, b))
print(f"Time taken: {time_new} ms")

time_old = do_bench(lambda: trilmm(a, b))
print(f"Time taken: {time_old} ms")

print(f"Speedup: {time_old / time_new}")

# Incease rep and do .mean() in case ^ is only capturing dispatches
time_new = do_bench(lambda: cuda_fn(a, b).mean(), rep=10000)
print(f"Time taken: {time_new} ms")

time_old = do_bench(lambda: trilmm(a, b).mean(), rep=10000)
print(f"Time taken: {time_old} ms")

print(f"Speedup: {time_old / time_new}") # should still see a drastic speedup

print(torch.allclose(cuda_fn(a, b), trilmm(a, b)))



Using /root/.cache/torch_extensions/py311_cu124 as PyTorch extensions root...
Creating extension directory /root/.cache/torch_extensions/py311_cu124/triangular_mm...
Detected CUDA files, patching ldflags
Emitting ninja build file /root/.cache/torch_extensions/py311_cu124/triangular_mm/build.ninja...
If this is not desired, please set os.environ['TORCH_CUDA_ARCH_LIST'].
Building extension module triangular_mm...
Allowing ninja to set a default number of workers... (overridable by setting the environment variable MAX_JOBS=N)
Loading extension module triangular_mm...


Time taken: 0.017692044377326965 ms
Time taken: 27.793136596679688 ms
Speedup: 1570.9397966635026
Time taken: 0.26734623312950134 ms
Time taken: 29.115659713745117 ms
Speedup: 108.90619019734466
True
