Skip to content

Commit

Permalink
Merge pull request #17 from DifferentiableUniverseInitiative/u/ASKaba…
Browse files Browse the repository at this point in the history
…lan/transpose_ops

Implement Transpose primitives
  • Loading branch information
EiffL authored Jun 11, 2024
2 parents d2fb678 + 1d1a779 commit 80172aa
Show file tree
Hide file tree
Showing 11 changed files with 743 additions and 311 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ pybind11_add_module(_jaxdecomp
src/jaxdecomp.cc
src/grid_descriptor_mgr.cc
src/fft.cu
src/transpose.cu
)

set_target_properties(_jaxdecomp PROPERTIES CUDA_ARCHITECTURES "${CUDECOMP_CUDA_CC_LIST}")
Expand Down
10 changes: 10 additions & 0 deletions include/grid_descriptor_mgr.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#include "fft.h"
#include "halo.h"
#include "logger.hpp"
#include "transpose.h"
#include <cstddef>
#include <cudecomp.h>
#include <memory>
Expand Down Expand Up @@ -34,6 +35,12 @@ class GridDescriptorManager {
HRESULT createHaloExecutor(haloDescriptor_t& descriptor, size_t& work_size,
std::shared_ptr<HaloExchange<double>>& executor);

HRESULT createTransposeExecutor(transposeDescriptor& descriptor, size_t& work_size,
std::shared_ptr<Transpose<float>>& executor);

HRESULT createTransposeExecutor(transposeDescriptor& descriptor, size_t& work_size,
std::shared_ptr<Transpose<double>>& executor);

inline cudecompHandle_t getHandle() const { return m_Handle; }

void finalize();
Expand All @@ -56,6 +63,9 @@ class GridDescriptorManager {
std::unordered_map<haloDescriptor_t, std::shared_ptr<HaloExchange<double>>> m_HaloDescriptors64;
std::unordered_map<haloDescriptor_t, std::shared_ptr<HaloExchange<float>>> m_HaloDescriptors32;

std::unordered_map<transposeDescriptor, std::shared_ptr<Transpose<double>>> m_TransposeDescriptors64;
std::unordered_map<transposeDescriptor, std::shared_ptr<Transpose<float>>> m_TransposeDescriptors32;

public:
GridDescriptorManager(GridDescriptorManager const&) = delete;
void operator=(GridDescriptorManager const&) = delete;
Expand Down
73 changes: 73 additions & 0 deletions include/transpose.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
#ifndef _JAX_DECOMP_TRANSPOSE_H_
#define _JAX_DECOMP_TRANSPOSE_H_

#include "checks.h"
#include <array>
#include <cstdint>
#include <cudecomp.h>
#include <pthread.h>

namespace jaxdecomp {

enum class TransposeType { TRANSPOSE_XY, TRANSPOSE_YZ, TRANSPOSE_ZY, TRANSPOSE_YX, UNKNOWN_TRANSPOSE };

class transposeDescriptor {
public:
TransposeType transpose_type = TransposeType::UNKNOWN_TRANSPOSE;
cudecompGridDescConfig_t config;
bool double_precision = false;

transposeDescriptor() = default;
transposeDescriptor(const transposeDescriptor& other) = default;
~transposeDescriptor() = default;

transposeDescriptor(cudecompGridDescConfig_t& config, const TransposeType& type, const bool& double_precision)
: config(config), transpose_type(type), double_precision(double_precision) {}

bool operator==(const transposeDescriptor& other) const {
return (config.gdims[0] == other.config.gdims[0] && config.gdims[1] == other.config.gdims[1] &&
config.gdims[2] == other.config.gdims[2] && config.pdims[0] == other.config.pdims[0] &&
config.pdims[1] == other.config.pdims[1] && double_precision == other.double_precision &&
config.transpose_comm_backend == other.config.transpose_comm_backend &&
config.halo_comm_backend == other.config.halo_comm_backend);
}
};

template <typename real_t> class Transpose {
friend class GridDescriptorManager;

public:
Transpose() = default;
~Transpose() = default;

HRESULT get_transpose_descriptor(cudecompHandle_t handle, size_t& work_size, transposeDescriptor& transpose_desc);
HRESULT transpose(cudecompHandle_t handle, transposeDescriptor desc, cudaStream_t stream, void** buffers);
HRESULT Release(cudecompHandle_t handle);

private:
cudecompGridDesc_t m_GridConfig;
cudecompGridDescConfig_t m_GridDescConfig;
int64_t m_WorkSize;
// DEBUG ONLY ... I WARN YOU
void inspect_device_array(void* data, bool transposed, cudaStream_t stream);
};

} // namespace jaxdecomp
namespace std {
template <> struct hash<jaxdecomp::transposeDescriptor> {
size_t operator()(const jaxdecomp::transposeDescriptor& desc) const {
size_t hash = 0;
hash = hash ^ std::hash<int32_t>()(desc.config.gdims[0]);
hash = hash ^ std::hash<int32_t>()(desc.config.gdims[1]);
hash = hash ^ std::hash<int32_t>()(desc.config.gdims[2]);
hash = hash ^ std::hash<int32_t>()(desc.config.pdims[0]);
hash = hash ^ std::hash<int32_t>()(desc.config.pdims[1]);
hash = hash ^ std::hash<bool>()(desc.double_precision);
hash = hash ^ std::hash<int>()(desc.config.transpose_comm_backend);
hash = hash ^ std::hash<int>()(desc.config.halo_comm_backend);
return hash;
}
};
} // namespace std

#endif // _JAX_DECOMP_TRANSPOSE_H_
17 changes: 9 additions & 8 deletions jaxdecomp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,15 @@
import jaxdecomp.fft as fft
from jaxdecomp.fft import pfft3d, pifft3d

from ._src import ( # transposeXtoY, transposeYtoX, transposeYtoZ, transposeZtoY
from ._src import ( # transposeXtoY, transposeYtoX, transposeYtoZ, transposeZtoY,
HALO_COMM_MPI, HALO_COMM_MPI_BLOCKING, HALO_COMM_NCCL, HALO_COMM_NVSHMEM,
HALO_COMM_NVSHMEM_BLOCKING, TRANSPOSE_COMM_MPI_A2A, TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL, TRANSPOSE_COMM_NCCL, TRANSPOSE_COMM_NCCL_PL,
TRANSPOSE_COMM_NVSHMEM, TRANSPOSE_COMM_NVSHMEM_PL, HaloCommBackend,
TRANSPOSE_COMM_NVSHMEM, TRANSPOSE_COMM_NVSHMEM_PL, TRANSPOSE_XY,
TRANSPOSE_YX, TRANSPOSE_YZ, TRANSPOSE_ZY, HaloCommBackend,
TransposeCommBackend, finalize, get_autotuned_config, get_pencil_info,
halo_exchange, init, make_config, slice_pad, slice_unpad)
halo_exchange, init, make_config, slice_pad, slice_unpad, transposeXtoY,
transposeYtoX, transposeYtoZ, transposeZtoY)

try:
__version__ = version("jaxDecomp")
Expand All @@ -30,11 +32,10 @@
"slice_unpad",
"pfft3d",
"pifft3d",
# Transpose functions are still in development
# "transposeXtoY",
# "transposeYtoZ",
# "transposeZtoY",
# "transposeYtoX",
"transposeXtoY",
"transposeYtoZ",
"transposeZtoY",
"transposeYtoX",
]


Expand Down
3 changes: 2 additions & 1 deletion jaxdecomp/_src/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@
TRANSPOSE_COMM_MPI_A2A, TRANSPOSE_COMM_MPI_P2P,
TRANSPOSE_COMM_MPI_P2P_PL, TRANSPOSE_COMM_NCCL,
TRANSPOSE_COMM_NCCL_PL, TRANSPOSE_COMM_NVSHMEM,
TRANSPOSE_COMM_NVSHMEM_PL, HaloCommBackend,
TRANSPOSE_COMM_NVSHMEM_PL, TRANSPOSE_XY, TRANSPOSE_YX,
TRANSPOSE_YZ, TRANSPOSE_ZY, HaloCommBackend,
TransposeCommBackend)

# Registering ops for XLA
Expand Down
25 changes: 15 additions & 10 deletions jaxdecomp/_src/spmd_ops.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
from abc import ABCMeta, abstractmethod
from dataclasses import dataclass
from functools import partial
from typing import Tuple

import jax.numpy as jnp
from jax import jit, lax
from jax._src.api import ShapeDtypeStruct
from jax._src.core import ShapedArray
from jax._src.typing import Array, ArrayLike
from jax import core
from jax._src import dispatch
from jax._src.interpreters import batching
from jax.experimental.custom_partitioning import custom_partitioning
from jax.lax import dynamic_slice
from jax.sharding import Mesh, NamedSharding
from jax.sharding import PartitionSpec as P
from jax.interpreters import mlir, xla

# Inspired by https://github.com/NVIDIA/TransformerEngine/blob/main/transformer_engine/jax/cpp_extensions.py

Expand Down Expand Up @@ -169,3 +163,14 @@ def name_of_wrapper_p():
else:
raise ValueError(
"register_primitive only accepts BasePrimitive or CustomParPrimitive")


# helper functions


def get_axis_size(sharding, index):
axis_name = sharding.spec[index]
if axis_name == None:
return 1
else:
return sharding.mesh.shape[sharding.spec[index]]
Loading

0 comments on commit 80172aa

Please sign in to comment.