This project is developed on a Turing architecture GPU. Modifications are required to run on other architectures.
I wanted to know how much overhead is added by Python-based ML frameworks like PyTorch, and how much faster we can possibly get by going lower-level (CUDA). It is also a good exercise for me to learn CUDA.
Coming from a Python/PyTorch/JAX-heavy background, this is my first time doing a large-ish CUDA project. Compared to many C++/CUDA repositories out there, I consider the code to be pretty clean and "pythonic", which hopefully makes it also approachable for other CUDA beginners.
I also wrote a blog post on How Matrix Multiplication Works on the GPU, you can read it on here on HackMD or here on Medium.
It's... pretty slow, at least for small networks. Even using PyTorch 2.0's torch.compile
functionality (with mode="max-autotune"
and fullgraph=True
, which is supposed to remove all Python overhead), it can still be up to
This overhead goes down as the network gets larger, though it never completely goes away. It asymptotically approaches
There are a few reasons why PyTorch is (asymptotically) slower than CUDA:
- The main contributor is my implementation usees fp16 accumulation for matrix multiplication while PyTorch uses fp32. NVIDIA reports double(!) the amount of theoretical throughput using fp16. PyTorch defaults to fp32 for stability reasons, but I haven't encountered those issues in my runs.
- I tuned the hyperparameters for the CUDA implementation specifically for my hardware. I'm not sure if
max-autotune
does the same for PyTorch.
Note
I applied a few optimisations to both implementations.
- I preloaded all data into memory in order to minimise the host-device data transfer overhead.
- I allowed the PyTorch implementation to have a few warm-up steps before timing, to allow the JIT compiler to compile the graph.
My implementation is also not perfect! For example, I didn't use vectorized loads for my element-wise kernels (e.g., relu
). I expect this to give a pretty substantial speedup too.
Comparing the loss curves of the PyTorch and CUDA implementations, we can see that they are pretty much identical.