1919
2020namespace gpu {
2121
22- static Logger kGpuLog = {stdout, " " , kInfo };
23-
2422#ifndef NDEBUG
2523static constexpr bool kDebug = true ;
2624#else
@@ -37,7 +35,15 @@ struct Array {
3735};
3836
3937/* *
40- * @brief Represents the shape of a tensor.
38+ * @brief Represents the shape of a tensor.
39+ *
40+ * The rank of the tensor is the
41+ * number of dimensions in the shape. The data array stores the size of each
42+ * dimension. For now, we limit the rank to 8 to avoid dynamic allocation.
43+ *
44+ * @code
45+ * Shape shape = {256, 256};
46+ * @endcode
4147 */
4248struct Shape {
4349 static constexpr size_t kMaxRank = 8 ; // Maximum rank of a tensor, avoids
@@ -60,6 +66,16 @@ struct Shape {
6066 }
6167};
6268
69+ /* *
70+ * @brief Returns the number of elements in a tensor with the given shape,
71+ * which is equal to the product of the dimensions.
72+ * @param[in] shape Shape of the tensor
73+ * @return Number of elements in the tensor
74+ *
75+ * @code
76+ * size({256, 256}) -> 65536
77+ * @endcode
78+ */
6379inline size_t size (const Shape &shape) {
6480 size_t numels = 1 ;
6581 for (size_t i = 0 ; i < shape.rank ; i++) {
@@ -68,27 +84,41 @@ inline size_t size(const Shape &shape) {
6884 return numels;
6985}
7086
87+
7188/* *
7289 * @brief Represents a tensor on the GPU, which is a buffer of values with a
7390 * shape.
91+ *
92+ * @code
93+ * Tensor tensor = createTensor(ctx, {256, 256}, kf32);
94+ * @endcode
7495 */
7596struct Tensor {
7697 Array data;
7798 Shape shape;
7899};
79100
101+ /* *
102+ * @brief Represents a non-owning view into a tensor specifying an offset and a
103+ * subspan. This is useful for specifying a slice of a tensor on the GPU
104+ * without copying the data.
105+ *
106+ * @code
107+ * TensorView view = {tensor, 0, 256};
108+ * @endcode
109+ */
80110struct TensorView {
81111 Tensor data; // non-owning view
82112 size_t offset = 0 ;
83113 size_t span = 0 ;
84114};
85115
86116/* *
87- * @brief Represents a collection of non-overlapping views into tensors.
117+ * @brief Represents an ordered collection of WGPUBuffers (wrapped as tensors,
118+ * non-overlapping views, or arrays) for the purpose of binding them to a
119+ * kernel operation to make them accessible to the shader code.
88120 *
89- * Since Tensor wraps a WGPUBuffer and WGPUBuffer is effectively a reference to
90- * a GPU buffer, performing operations on Bindings elements (writing /
91- * copying buffers) is tantamount to working with pointers to GPU buffers.
121+ * The ordering of the bindings should match the binding indices in the shader.
92122 */
93123template <std::size_t N> struct Bindings {
94124 std::array<Tensor, N> data;
@@ -112,6 +142,14 @@ template <std::size_t N> struct Bindings {
112142 }
113143 }
114144
145+ Bindings (const std::initializer_list<Array> &init) {
146+ std::copy (begin (init), end (init), begin (data));
147+ std::fill (begin (viewOffsets), end (viewOffsets), 0 );
148+ for (size_t i = 0 ; i < N; ++i) {
149+ viewSpans[i] = data[i].size ;
150+ }
151+ }
152+
115153 Tensor &operator [](std::size_t index) { return data[index]; }
116154 const Tensor &operator [](std::size_t index) const { return data[index]; }
117155};
@@ -125,6 +163,14 @@ template <typename... Args> Bindings(Args...) -> Bindings<sizeof...(Args)>;
125163struct Context ; // Forward declaration so that TensorPool can have a pointer to
126164 // Context
127165
166+ /* *
167+ * @brief Represents a pool of tensors to manage GPU resources. The pool is
168+ * responsible for managing the lifetime of the tensors and freeing them when
169+ * the pool is destroyed.
170+ *
171+ * Most users do not need to interact with the TensorPool type, as there is a member instance in the Context
172+ * struct to simplify lifetime management of GPU resources.
173+ */
128174struct TensorPool {
129175 inline TensorPool (Context *ctx) : ctx(ctx), data() {};
130176 Context *ctx;
@@ -306,7 +352,10 @@ struct Context {
306352 * @param[in] dtype Data type of the tensor (e.g. kf32)
307353 * @param[in] usage Usage flags for the tensor buffer
308354 * @return Tensor instance representing the created tensor
309- * @example Tensor tensor = createTensor(pool, device, {256, 256}, kf32);
355+ *
356+ * @code
357+ * Tensor tensor = createTensor(pool, device, {256, 256}, kf32);
358+ * @endcode
310359 */
311360inline Tensor
312361createTensor (TensorPool &pool, WGPUDevice &device, const Shape &shape,
@@ -347,7 +396,10 @@ createTensor(TensorPool &pool, WGPUDevice &device, const Shape &shape,
347396 * @param[in] shape Shape of the tensor
348397 * @param[in] dtype Data type of the tensor (e.g. kf32)
349398 * @return Tensor instance representing the created tensor
350- * @example Tensor tensor = createTensor(ctx, {256, 256}, kf32);
399+ *
400+ * @code
401+ * Tensor tensor = createTensor(ctx, {256, 256}, kf32);
402+ * @endcode
351403 */
352404inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype) {
353405 return createTensor (ctx.pool , ctx.device , shape, dtype);
@@ -366,7 +418,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype) {
366418 * @param[in] dtype Data type of the tensor (e.g. kf32)
367419 * @param[in] data Initial data to populate the tensor with
368420 * @return Tensor instance representing the created tensor
369- * @example Tensor tensor = createTensor(ctx, {256, 256}, kf32, data);
421+ *
422+ * @code
423+ * Tensor tensor = createTensor(ctx, {256, 256}, kf32, data);
424+ * @endcode
370425 */
371426inline Tensor createTensor (Context &ctx, const Shape &shape, NumType dtype,
372427 float *data) {
@@ -388,7 +443,10 @@ inline Tensor createTensor(Context &ctx, const Shape &shape, NumType dtype,
388443 *
389444 * @param[in] pool TensorPool instance to manage the tensor
390445 * @param[in] tensor Tensor instance to free
391- * @example FreeTensor(pool, tensor);
446+ *
447+ * @code
448+ * FreeTensor(pool, tensor);
449+ * @endcode
392450 */
393451inline void FreeTensor (TensorPool &pool, Tensor tensor) {
394452 if (tensor.data .buffer ) {
@@ -430,7 +488,10 @@ inline TensorPool::~TensorPool() {
430488 * @param[in] str String to mutate with substitution replacements.
431489 * @param[in] from Substring to replace
432490 * @param[in] to Substring to replace with
433- * @example replaceAll(str, "{{workgroupSize}}", "256");
491+ *
492+ * @code
493+ * replaceAll(str, "{{workgroupSize}}", "256");
494+ * @endcode
434495 */
435496inline void replaceAll (std::string &str, const std::string &from,
436497 const std::string &to) {
@@ -448,7 +509,10 @@ inline void replaceAll(std::string &str, const std::string &from,
448509 * @param[in] str String to mutate with substitution replacements.
449510 * @param[in] reps Vector of pairs of substrings to replace and their
450511 * replacements.
451- * @example replaceAll(str, {{"{{workgroupSize}}", "256"}, {"{{precision}}",
512+ *
513+ * @code
514+ * replaceAll(str, {{"{{workgroupSize}}", "256"}, {"{{precision}}",
515+ * @endcode
452516 * "f32"}});
453517 */
454518inline void replaceAll (std::string &str,
@@ -473,7 +537,10 @@ inline void replaceAll(std::string &str,
473537 * @param[in] precision Data type precision for the shader. As with
474538 * workgroupSize, precision is stored as a field in the ShaderCode instance
475539 * that is returned by createShader().
476- * @example ShaderCode code = createShader(kPuzzle1, {256, 1, 1}, kf32);
540+ *
541+ * @code
542+ * ShaderCode code = createShader(kPuzzle1, {256, 1, 1}, kf32);
543+ * @endcode
477544 */
478545inline ShaderCode createShader (const char *shaderTemplate,
479546 const Shape &workgroupSize = {256 , 1 , 1 },
@@ -495,13 +562,24 @@ inline ShaderCode createShader(const char *shaderTemplate,
495562 * @param[in] shaderTemplate Shader template string with placeholders
496563 * @param[in] workgroupSize Workgroup size in the x dimension
497564 * @param[in] precision Data type precision for the shader
498- * @example ShaderCode code = createShader(kPuzzle1, 256, kf32);
565+ *
566+ * @code
567+ * ShaderCode code = createShader(kPuzzle1, 256, kf32);
568+ * @endcode
499569 */
500570inline ShaderCode createShader (const char *shaderTemplate, size_t workgroupSize,
501571 NumType precision = kf32) {
502572 return createShader (shaderTemplate, Shape{workgroupSize, 1 , 1 }, precision);
503573}
504574
575+ /* *
576+ * @brief Checks a condition and logs an error message if the condition is false.
577+ * In debug mode, it will also exit the program with an error code.
578+ * @param[in] condition The condition to check.
579+ * @param[in] message The error message to log if the condition is false.
580+ * @param[in] file The source file where the check is performed.
581+ * @param[in] line The line number in the source file where the check is performed.
582+ */
505583inline void check (bool condition, const char *message,
506584 const char *file = " unkown" , int line = -1 ) {
507585 if constexpr (kDebug ) {
@@ -532,7 +610,10 @@ inline void check(bool condition, const char *message,
532610 * (optional)
533611 * @param[in] devDescriptor Device descriptor for the WebGPU device (optional)
534612 * @return Context instance representing the created GPU context
535- * @example Context ctx = createContext();
613+ *
614+ * @code
615+ * Context ctx = createContext();
616+ * @endcode
536617 */
537618inline Context createContext (const WGPUInstanceDescriptor &desc = {},
538619 const WGPURequestAdapterOptions &adapterOpts = {},
@@ -626,7 +707,10 @@ inline void wait(Context &ctx, std::future<void> &future) {
626707 * @param[in] tensor Tensor instance representing the GPU buffer to copy from
627708 * @param[out] data Pointer to the CPU memory to copy the data to
628709 * @param[in] bufferSize Size of the data buffer in bytes
629- * @example toCPU(ctx, tensor, data, bufferSize);
710+ *
711+ * @code
712+ * toCPU(ctx, tensor, data, bufferSize);
713+ * @endcode
630714 */
631715inline void toCPU (Context &ctx, Tensor &tensor, float *data,
632716 size_t bufferSize) {
@@ -691,7 +775,10 @@ inline void toCPU(Context &ctx, Tensor &tensor, float *data,
691775 * @param[in] ctx Context instance to manage the operation
692776 * @param[in] tensor Tensor instance representing the GPU buffer to copy from
693777 * @param[out] data Array of floats to copy the data to
694- * @example toCPU(ctx, tensor, data);
778+ *
779+ * @code
780+ * toCPU(ctx, tensor, data);
781+ * @endcode
695782 */
696783template <size_t N>
697784void toCPU (Context &ctx, Tensor &tensor, std::array<float , N>& data) {
@@ -707,7 +794,10 @@ void toCPU(Context &ctx, Tensor &tensor, std::array<float, N>& data) {
707794 * @param[in] data Pointer to the CPU memory to copy from
708795 * @param[in] buffer WGPUBuffer instance representing the GPU buffer to copy to
709796 * @param[in] size Size of the data buffer in bytes
710- * @example toGPU(ctx, data, buffer, size);
797+ *
798+ * @code
799+ * toGPU(ctx, data, buffer, size);
800+ * @endcode
711801 */
712802inline void toGPU (Context &ctx, const void *data, WGPUBuffer buffer,
713803 size_t size) {
@@ -720,7 +810,10 @@ inline void toGPU(Context &ctx, const void *data, WGPUBuffer buffer,
720810 * @param[in] ctx Context instance to manage the operation
721811 * @param[in] data Pointer to the CPU memory to copy from
722812 * @param[in] tensor Tensor instance representing the GPU buffer to copy to
723- * @example toGPU(ctx, data, tensor);
813+ *
814+ * @code
815+ * toGPU(ctx, data, tensor);
816+ * @endcode
724817 */
725818inline void toGPU (Context &ctx, const float *data, Tensor &tensor) {
726819 wgpuQueueWriteBuffer (ctx.queue , tensor.data .buffer , 0 , data,
@@ -747,7 +840,10 @@ inline void toGPU(Context &ctx, Params ¶ms, Kernel &op) {
747840 * reused for a dispatch.
748841 * @param[in] device WGPUDevice instance to manage the operation
749842 * @param[in] op Kernel instance representing the kernel to reset
750- * @example resetCommandBuffer(device, op);
843+ *
844+ * @code
845+ * resetCommandBuffer(device, op);
846+ * @endcode
751847 */
752848inline void resetCommandBuffer (WGPUDevice &device, Kernel &op) {
753849 {
@@ -813,7 +909,10 @@ inline Shape cdiv(Shape total, Shape group) {
813909 * arbitrary types to be passed as parameters.
814910 * @param[in] paramsSize Size of the parameters buffer in bytes.
815911 * @return Kernel instance representing the created kernel
816- * @example Kernel kernel = createKernel(ctx, shader, dataBindings, numInputs,
912+ *
913+ * @code
914+ * Kernel kernel = createKernel(ctx, shader, dataBindings, numInputs,
915+ * @endcode
817916 * output, nThreads, params, paramsSize);
818917 */
819918inline Kernel createKernel (Context &ctx, const ShaderCode &shader,
@@ -965,7 +1064,10 @@ inline Kernel createKernel(Context &ctx, const ShaderCode &shader,
9651064 * @param[in] params Optional parameters for the kernel. If the kernel does not
9661065 * have any parameters, use NoParam.
9671066 * @return Kernel instance representing the created kernel
968- * @example Kernel kernel = createKernel(ctx, shader, tensorData, output,
1067+ *
1068+ * @code
1069+ * Kernel kernel = createKernel(ctx, shader, tensorData, output,
1070+ * @endcode
9691071 * nWorkgroups, params);
9701072 */
9711073template <typename ParamsType = NoParam, size_t numInputs>
@@ -1000,7 +1102,10 @@ Kernel createKernel(Context &ctx, const ShaderCode &shader,
10001102 * @param[in] ctx Context instance to manage the kernel, from which the queue
10011103 * for the GPU is obtained
10021104 * @param[in] kernel Kernel instance to dispatch
1003- * @example dispatchKernel(ctx, kernel);
1105+ *
1106+ * @code
1107+ * dispatchKernel(ctx, kernel);
1108+ * @endcode
10041109 */
10051110inline void dispatchKernel (Context &ctx, Kernel &kernel,
10061111 std::promise<void > &promise) {
0 commit comments