Skip to content

Commit

Permalink
[CI] Add clang-format script/action (hidet-org#120)
Browse files Browse the repository at this point in the history
With https://github.com/CentML/hidet/pull/69 there will be a lot more
C++ code introduced into the runtime, I think it's a good idea to have
some standardization. For now this is just doing formatting (no linting,
which takes more work to set up + has more opinions about right vs.
wrong).

Summary of changes:
- Update `format.sh` to support formatting just Python, C++, or both
- Add `clang-format` to the existing lint/format workflow
- Apply `clang-format` changes to existing code; I've set up the
configuration to try to minimize the number of changes and have excluded
the float16/bfloat16 code

Example workflow failure @ 4cc430c:
<img width="1155" alt="image"
src="https://github.com/CentML/hidet/assets/43303581/9566e9dd-bd01-4638-b556-11afaf7e6e52">
  • Loading branch information
jacklee1792 committed Apr 5, 2024
1 parent 52cb6b7 commit c45ef79
Show file tree
Hide file tree
Showing 26 changed files with 330 additions and 348 deletions.
7 changes: 7 additions & 0 deletions .github/workflows/lint.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ jobs:
pip install torch torchvision torchaudio
pip install -r requirements.txt
pip install -r requirements-dev.txt
sudo apt-get update
sudo apt-get install clang-format
- name: Format with black
run: |
# stop the build if format is not correct
Expand All @@ -36,3 +38,8 @@ jobs:
run: |
echo "Running with" $(pip freeze | grep "pylint")
python -m pylint --rcfile ./scripts/lint/pylintrc -j $(nproc) ./python/hidet
- name: Format with clang-format
run: |
echo "Running with" $(clang-format --version)
find ./src ./include -iname '*.h' -o -iname '*.cpp' \
| xargs clang-format -style=file:scripts/lint/.clang-format --dry-run -Werror
4 changes: 2 additions & 2 deletions include/hidet/runtime.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,8 @@
// limitations under the License.
#pragma once

#include <hidet/runtime/common.h>
#include <hidet/runtime/callbacks.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/common.h>
#include <hidet/runtime/cpu/context.h>
#include <hidet/runtime/cuda/context.h>
#include <hidet/runtime/logging.h>
2 changes: 1 addition & 1 deletion include/hidet/runtime/callbacks.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
#include <cstdint>
#include <hidet/runtime/common.h>

DLL void register_callback(const char* name, void *func_ptr);
DLL void register_callback(const char *name, void *func_ptr);

DLL uint64_t allocate_cuda_storage(uint64_t nbytes);

Expand Down
3 changes: 1 addition & 2 deletions include/hidet/runtime/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <iostream>
#include <cassert>
#include <iostream>

#ifndef DLL
#define DLL extern "C" __attribute__((visibility("default")))
#endif

4 changes: 2 additions & 2 deletions include/hidet/runtime/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,11 @@
// limitations under the License.
#pragma once

#include <hidet/runtime/common.h>
#include <hidet/runtime/callbacks.h>
#include <hidet/runtime/common.h>

struct Workspace {
void* base;
void *base;
size_t allocated_nbytes;
Workspace() {
base = nullptr;
Expand Down
3 changes: 3 additions & 0 deletions include/hidet/runtime/cpu/bfloat16.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// clang-format off

/**
* From PyTorch:
Expand Down
6 changes: 2 additions & 4 deletions include/hidet/runtime/cpu/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,10 @@
#include <hidet/runtime/context.h>

struct CpuContext: BaseContext {
static CpuContext* global();
static CpuContext *global();
};


/**
* Request a workspace.
*/
DLL void* request_cpu_workspace(size_t nbytes, bool require_clean);

DLL void *request_cpu_workspace(size_t nbytes, bool require_clean);
3 changes: 3 additions & 0 deletions include/hidet/runtime/cpu/float16.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

// clang-format off

/**
* From PyTorch:
Expand Down
4 changes: 1 addition & 3 deletions include/hidet/runtime/cpu/float32.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,6 @@
// limitations under the License.
#include <math.h>

static inline float rsqrtf(float x)
{
static inline float rsqrtf(float x) {
return 1.0f / sqrtf(x);
}

11 changes: 4 additions & 7 deletions include/hidet/runtime/cuda/complex.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
#include <cuComplex.h>
#define HIDET_HOST_DEVICE __host__ __device__ __forceinline__

template <typename T>
template<typename T>
struct Complex {
T real, imag;
Complex() = default;
Expand All @@ -22,14 +22,12 @@ struct Complex {
};

template<typename T>
HIDET_HOST_DEVICE
Complex<T> operator-(Complex<T> a) {
HIDET_HOST_DEVICE Complex<T> operator-(Complex<T> a) {
return {-a.real, -a.imag};
}

template<typename T>
HIDET_HOST_DEVICE
Complex<T> operator+(Complex<T> a, Complex<T> b) {
HIDET_HOST_DEVICE Complex<T> operator+(Complex<T> a, Complex<T> b) {
return {a.real + b.real, a.imag + b.imag};
}

Expand All @@ -40,8 +38,7 @@ HIDET_HOST_DEVICE Complex<T> operator-(Complex<T> a, Complex<T> b) {

template<typename T>
HIDET_HOST_DEVICE Complex<T> operator*(Complex<T> a, Complex<T> b) {
return {a.real * b.real - a.imag * b.imag,
a.real * b.imag + a.imag * b.real};
return {a.real * b.real - a.imag * b.imag, a.real * b.imag + a.imag * b.real};
}

template<typename T>
Expand Down
18 changes: 9 additions & 9 deletions include/hidet/runtime/cuda/context.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,47 +10,47 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <hidet/runtime/common.h>
#include <hidet/runtime/callbacks.h>
#include <hidet/runtime/common.h>
#include <hidet/runtime/context.h>
// #include <cuda_runtime.h>

struct CudaContext: BaseContext {
/* The cuda stream the kernels will be launched on. */
void* stream = nullptr;
void *stream = nullptr;

/* NCCL Comunicators*/
void ** nccl_comms = nullptr;
void **nccl_comms = nullptr;

int num_comms = 0;

/**
* Get the instance of cuda context.
*/
static CudaContext* global();
static CudaContext *global();
};

/**
* Set the cuda stream of cuda context.
*/
DLL void set_cuda_stream(void* stream);
DLL void set_cuda_stream(void *stream);

/**
* Get the cuda stream of cuda context.
*/
DLL void* get_cuda_stream();
DLL void *get_cuda_stream();

/**
* Request a workspace.
*/
DLL void* request_cuda_workspace(size_t nbytes, bool require_clean);
DLL void *request_cuda_workspace(size_t nbytes, bool require_clean);

/**
* Set required NCCL communicators of the context.
*/
DLL void set_nccl_comms(int num_comms, void** comm);
DLL void set_nccl_comms(int num_comms, void **comm);

/**
* Get the NCCL communicator by the index
*/
DLL void* get_nccl_comm(int idx);
DLL void *get_nccl_comm(int idx);
28 changes: 11 additions & 17 deletions include/hidet/runtime/cuda/cublas.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,29 +15,23 @@

#define HIDET_CUBLAS_MAX_GPUS 32

typedef void* cublasHandle_t;
typedef void *cublasHandle_t;

struct CublasContext {
cublasHandle_t handles[HIDET_CUBLAS_MAX_GPUS]; // cublas handle for each gpu on this node
static CublasContext* global();
cublasHandle_t handles[HIDET_CUBLAS_MAX_GPUS]; // cublas handle for each gpu on this node
static CublasContext *global();
static cublasHandle_t current_handle();
};

DLL void hidet_cublas_set_library_path(const char* path);
DLL void hidet_cublas_set_library_path(const char *path);

// kernel functions
DLL void hidet_cublas_gemm(
int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void* ptr_b, void* ptr_c, bool trans_a, bool trans_b,
int compute_type
);
DLL void hidet_cublas_gemm(int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void *ptr_b, void *ptr_c,
bool trans_a, bool trans_b, int compute_type);

DLL void hidet_cublas_strided_gemm(
int b, int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void* ptr_b, void* ptr_c,
int64_t sa, int64_t sb, int64_t sc,
bool trans_a, bool trans_b, int compute_type
);
DLL void hidet_cublas_strided_gemm(int b, int m, int n, int k, int ta, int tb, int tc, void *ptr_a, void *ptr_b,
void *ptr_c, int64_t sa, int64_t sb, int64_t sc, bool trans_a, bool trans_b,
int compute_type);

DLL void hidet_cublas_batched_gemm(
int b, int m, int n, int k, int ta, int tb, int tc, void **ptr_a, void **ptr_b, void **ptr_c,
bool trans_a, bool trans_b, int compute_type
);
DLL void hidet_cublas_batched_gemm(int b, int m, int n, int k, int ta, int tb, int tc, void **ptr_a, void **ptr_b,
void **ptr_c, bool trans_a, bool trans_b, int compute_type);
11 changes: 5 additions & 6 deletions include/hidet/runtime/cuda/cuda.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,14 @@ typedef enum {
cudaMemcpyDefault = 4
} cudaMemcpyKind;

typedef void* cudaStream_t;
typedef void *cudaStream_t;

DLL int hidet_cuda_device_count();
DLL int hidet_cuda_get_device();
DLL void hidet_cuda_set_device(int device);
DLL void* hidet_cuda_malloc(size_t size);
DLL void* hidet_cuda_malloc_async(size_t size, cudaStream_t stream);
DLL void *hidet_cuda_malloc(size_t size);
DLL void *hidet_cuda_malloc_async(size_t size, cudaStream_t stream);
DLL void hidet_cuda_free(void *devPtr);
DLL void hidet_cuda_free_async(void *devPtr, cudaStream_t stream);
DLL void hidet_cuda_memcpy(void* dst, const void* src, size_t count, cudaMemcpyKind kind);
DLL void hidet_cuda_memcpy_async(void* dst, const void* src, size_t count, cudaMemcpyKind kind, cudaStream_t stream);

DLL void hidet_cuda_memcpy(void *dst, const void *src, size_t count, cudaMemcpyKind kind);
DLL void hidet_cuda_memcpy_async(void *dst, const void *src, size_t count, cudaMemcpyKind kind, cudaStream_t stream);
34 changes: 12 additions & 22 deletions include/hidet/runtime/logging.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,26 +26,22 @@ struct ErrorState {
struct HidetException: std::exception {
std::string msg;

HidetException(std::string msg): msg(msg){}
HidetException(std::string msg) : msg(msg) {}

const char * what() const noexcept override {
const char *what() const noexcept override {
static std::string what_msg;
what_msg = this->msg;
return what_msg.c_str();
}
};


class FATALMessage {
std::ostringstream stream_;
public:
FATALMessage(const char* file, int line) {
this->stream_ << file << ":" << line << ": ";
}

std::ostringstream &stream() {
return this->stream_;
}
public:
FATALMessage(const char *file, int line) { this->stream_ << file << ":" << line << ": "; }

std::ostringstream &stream() { return this->stream_; }

[[noreturn]] ~FATALMessage() {
std::cerr << this->stream_.str() << std::endl;
Expand All @@ -55,21 +51,15 @@ class FATALMessage {

DLL void hidet_set_last_error(const char *msg);

DLL const char * hidet_get_last_error();
DLL const char *hidet_get_last_error();

class ERRORMessage {
std::ostringstream stream_;
public:
ERRORMessage(const char* file, int line) {
this->stream_ << file << ":" << line << ": ";
}

std::ostringstream &stream() {
return this->stream_;
}
public:
ERRORMessage(const char *file, int line) { this->stream_ << file << ":" << line << ": "; }

~ERRORMessage() noexcept(false) {
throw HidetException(this->stream_.str().c_str());
}
};
std::ostringstream &stream() { return this->stream_; }

~ERRORMessage() noexcept(false) { throw HidetException(this->stream_.str().c_str()); }
};

0 comments on commit c45ef79

Please sign in to comment.