Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add cuda::ptx:tensormap_{replace,cp_fenceproxy} #1441

Merged
merged 2 commits into from
Feb 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
203 changes: 199 additions & 4 deletions libcudacxx/docs/ptx.md
Original file line number Diff line number Diff line change
Expand Up @@ -461,7 +461,7 @@ __device__ static inline uint32_t getctarank(
| [`cp.async.bulk.prefetch.tensor`] | No |
| [`cp.async.bulk.commit_group`] | CTK-FUTURE, CCCL v2.4.0 |
| [`cp.async.bulk.wait_group`] | CTK-FUTURE, CCCL v2.4.0 |
| [`tensormap.replace`] | No |
| [`tensormap.replace`] | CTK-FUTURE, CCCL v2.4.0 |

[`cp.async`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
[`cp.async.commit_group`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
Expand All @@ -474,7 +474,7 @@ __device__ static inline uint32_t getctarank(
[`cp.async.bulk.prefetch.tensor`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-prefetch-tensor
[`cp.async.bulk.commit_group`]: #cpasyncbulkcommit_group
[`cp.async.bulk.wait_group`]: #cpasyncbulkwait_group
[`tensormap.replace`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-tensormap-replace
[`tensormap.replace`]: #tensormapreplace


#### `cp.async.bulk`
Expand Down Expand Up @@ -822,6 +822,182 @@ template <int N32>
__device__ static inline void cp_async_bulk_wait_group_read(
cuda::ptx::n32_t<N32> N);
```

#### `tensormap.replace`

- PTX ISA: [`tensormap.replace`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-tensormap-replace)

**tensormap_replace**:
```cuda
// tensormap.replace.tile.global_address.space.b1024.b64 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <typename B64>
__device__ static inline void tensormap_replace_global_address(
cuda::ptx::space_global_t,
void* tm_addr,
B64 new_val);

// tensormap.replace.tile.global_address.space.b1024.b64 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <typename B64>
__device__ static inline void tensormap_replace_global_address(
cuda::ptx::space_shared_t,
void* tm_addr,
B64 new_val);

// tensormap.replace.tile.rank.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <typename B32>
__device__ static inline void tensormap_replace_rank(
cuda::ptx::space_global_t,
void* tm_addr,
B32 new_val);

// tensormap.replace.tile.rank.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <typename B32>
__device__ static inline void tensormap_replace_rank(
cuda::ptx::space_shared_t,
void* tm_addr,
B32 new_val);

// tensormap.replace.tile.box_dim.space.b1024.b32 [tm_addr], ord, new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <int N32, typename B32>
__device__ static inline void tensormap_replace_box_dim(
cuda::ptx::space_global_t,
void* tm_addr,
cuda::ptx::n32_t<N32> ord,
B32 new_val);

// tensormap.replace.tile.box_dim.space.b1024.b32 [tm_addr], ord, new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <int N32, typename B32>
__device__ static inline void tensormap_replace_box_dim(
cuda::ptx::space_shared_t,
void* tm_addr,
cuda::ptx::n32_t<N32> ord,
B32 new_val);

// tensormap.replace.tile.global_dim.space.b1024.b32 [tm_addr], ord, new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <int N32, typename B32>
__device__ static inline void tensormap_replace_global_dim(
cuda::ptx::space_global_t,
void* tm_addr,
cuda::ptx::n32_t<N32> ord,
B32 new_val);

// tensormap.replace.tile.global_dim.space.b1024.b32 [tm_addr], ord, new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <int N32, typename B32>
__device__ static inline void tensormap_replace_global_dim(
cuda::ptx::space_shared_t,
void* tm_addr,
cuda::ptx::n32_t<N32> ord,
B32 new_val);

// tensormap.replace.tile.global_stride.space.b1024.b64 [tm_addr], ord, new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <int N32, typename B64>
__device__ static inline void tensormap_replace_global_stride(
cuda::ptx::space_global_t,
void* tm_addr,
cuda::ptx::n32_t<N32> ord,
B64 new_val);

// tensormap.replace.tile.global_stride.space.b1024.b64 [tm_addr], ord, new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <int N32, typename B64>
__device__ static inline void tensormap_replace_global_stride(
cuda::ptx::space_shared_t,
void* tm_addr,
cuda::ptx::n32_t<N32> ord,
B64 new_val);

// tensormap.replace.tile.element_stride.space.b1024.b32 [tm_addr], ord, new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <int N32, typename B32>
__device__ static inline void tensormap_replace_element_size(
cuda::ptx::space_global_t,
void* tm_addr,
cuda::ptx::n32_t<N32> ord,
B32 new_val);

// tensormap.replace.tile.element_stride.space.b1024.b32 [tm_addr], ord, new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <int N32, typename B32>
__device__ static inline void tensormap_replace_element_size(
cuda::ptx::space_shared_t,
void* tm_addr,
cuda::ptx::n32_t<N32> ord,
B32 new_val);

// tensormap.replace.tile.elemtype.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <int N32>
__device__ static inline void tensormap_replace_elemtype(
cuda::ptx::space_global_t,
void* tm_addr,
cuda::ptx::n32_t<N32> new_val);

// tensormap.replace.tile.elemtype.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <int N32>
__device__ static inline void tensormap_replace_elemtype(
cuda::ptx::space_shared_t,
void* tm_addr,
cuda::ptx::n32_t<N32> new_val);

// tensormap.replace.tile.interleave_layout.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <int N32>
__device__ static inline void tensormap_replace_interleave_layout(
cuda::ptx::space_global_t,
void* tm_addr,
cuda::ptx::n32_t<N32> new_val);

// tensormap.replace.tile.interleave_layout.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <int N32>
__device__ static inline void tensormap_replace_interleave_layout(
cuda::ptx::space_shared_t,
void* tm_addr,
cuda::ptx::n32_t<N32> new_val);

// tensormap.replace.tile.swizzle_mode.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <int N32>
__device__ static inline void tensormap_replace_swizzle_mode(
cuda::ptx::space_global_t,
void* tm_addr,
cuda::ptx::n32_t<N32> new_val);

// tensormap.replace.tile.swizzle_mode.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <int N32>
__device__ static inline void tensormap_replace_swizzle_mode(
cuda::ptx::space_shared_t,
void* tm_addr,
cuda::ptx::n32_t<N32> new_val);

// tensormap.replace.tile.fill_mode.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .global }
template <int N32>
__device__ static inline void tensormap_replace_fill_mode(
cuda::ptx::space_global_t,
void* tm_addr,
cuda::ptx::n32_t<N32> new_val);

// tensormap.replace.tile.fill_mode.space.b1024.b32 [tm_addr], new_val; // PTX ISA 83, SM_90a
// .space = { .shared::cta }
template <int N32>
__device__ static inline void tensormap_replace_fill_mode(
cuda::ptx::space_shared_t,
void* tm_addr,
cuda::ptx::n32_t<N32> new_val);
```

### [9.7.9. Texture Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#texture-instructions)

| Instruction | Available in libcu++ |
Expand Down Expand Up @@ -1139,7 +1315,7 @@ __device__ static inline void red_async(
| [`cp.async.mbarrier.arrive`] | No |
| [`mbarrier.test_wait/mbarrier.try_wait`] | CTK-FUTURE, CCCL v2.3.0 |
| [`mbarrier.pending_count`] | No |
| [`tensormap.cp_fenceproxy`] | No |
| [`tensormap.cp_fenceproxy`] | CTK-FUTURE, CCCL v2.4.0 |

[`mbarrier.init`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-init
[`mbarrier.inval`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-inval
Expand All @@ -1150,7 +1326,7 @@ __device__ static inline void red_async(
[`cp.async.mbarrier.arrive`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-cp-async-mbarrier-arrive
[`mbarrier.test_wait/mbarrier.try_wait`]: #mbarriertest_waitmbarriertry_wait
[`mbarrier.pending_count`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-mbarrier-pending-count
[`tensormap.cp_fenceproxy`]: https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-tensormap-cp-fenceproxy
[`tensormap.cp_fenceproxy`]: #tensormapcpfenceproxy



Expand Down Expand Up @@ -1410,6 +1586,25 @@ __device__ static inline bool mbarrier_try_wait_parity(
const uint32_t& suspendTimeHint);
```

#### `tensormap.cp_fenceproxy`

- PTX ISA: [`tensormap.cp_fenceproxy`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-tensormap-cp-fenceproxy)

**tensormap_cp_fenceproxy**:
```cuda
// tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.sem.scope.sync.aligned [dst], [src], size; // PTX ISA 83, SM_90
// .sem = { .release }
// .scope = { .cta, .cluster, .gpu, .sys }
template <int N32, cuda::ptx::dot_scope Scope>
__device__ static inline void tensormap_cp_fenceproxy(
cuda::ptx::sem_release_t,
cuda::ptx::scope_t<Scope> scope,
void* dst,
const void* src,
cuda::ptx::n32_t<N32> size);
```


### [9.7.13. Warp Level Matrix Multiply-Accumulate Instructions](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-multiply-accumulate-instructions)

| Instruction | Available in libcu++ |
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,4 +105,13 @@
#endif
#endif // __cccl_ptx_isa >= 800

// NVRTC uses its own <nv/target> header, so we need to manually tell it when we expect SM90a to be available
#if defined(_CCCL_COMPILER_NVRTC) && !defined(NV_HAS_FEATURE_SM_90a)
#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL))
#define NV_HAS_FEATURE_SM_90a NV_PROVIDES_SM_90
#else // ^^^ SM90a ^^^ / vvv !SM90a vvv
#define NV_HAS_FEATURE_SM_90a NV_NO_TARGET
#endif //
#endif // _CCCL_COMPILER_NVRTC && !NV_HAS_FEATURE_SM_90a

ahendriksen marked this conversation as resolved.
Show resolved Hide resolved
#endif // __CCCL_PTX_ISA_H_
Loading
Loading