Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
*~
*.o
build/
*.pyc
34 changes: 23 additions & 11 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -39,6 +39,10 @@ list(APPEND CMAKE_MODULE_PATH ${CUDA_PATH}/lib64)
find_package(CUDA REQUIRED)

# setting compiler flags
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall")

if (SM STREQUAL 70 OR
SM STREQUAL 75 OR
SM STREQUAL 61 OR
Expand All @@ -49,15 +53,21 @@ set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_${SM},code=\\\"s
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA")
endif()

set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -Wall")
message("-- Assign GPU architecture (sm=${SM})")

else()
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -gencode=arch=compute_60,code=\\\"sm_60,compute_60\\\" -rdc=true")
message("-- Unknown or unsupported GPU architecture (set sm=60)")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} \
-gencode=arch=compute_60,code=\\\"sm_60,compute_60\\\" \
-gencode=arch=compute_61,code=\\\"sm_61,compute_61\\\" \
-gencode=arch=compute_70,code=\\\"sm_70,compute_70\\\" \
-gencode=arch=compute_75,code=\\\"sm_75,compute_75\\\" \
-rdc=true")
set(CMAKE_C_FLAGS "${CMAKE_C_FLAGS} -DWMMA")
set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DWMMA")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DWMMA")
message("-- Assign GPU architecture (sm=60,61,70,75)")
endif()

set(CMAKE_C_FLAGS_DEBUG "${CMAKE_C_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CXX_FLAGS_DEBUG "${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0")
set(CMAKE_CUDA_FLAGS_DEBUG "${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall")
Expand All @@ -72,8 +82,8 @@ if(CMAKE_CXX_STANDARD STREQUAL "11")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} --std=c++11")
endif()

set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -O3")
set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler -O3")
set(CMAKE_CXX_FLAGS_RELEASE "${CMAKE_CXX_FLAGS_RELEASE} -O3")
set(CMAKE_CUDA_FLAGS_RELEASE "${CMAKE_CUDA_FLAGS_RELEASE} -Xcompiler -O3")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib)
Expand Down Expand Up @@ -110,12 +120,14 @@ add_subdirectory(tools)
add_subdirectory(fastertransformer)
add_subdirectory(sample)


if(BUILD_TF)
add_custom_target(copy ALL COMMENT "Copying tensorflow test scripts")
add_custom_command(TARGET copy
POST_BUILD
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/*.py ${PROJECT_SOURCE_DIR}/build/
)
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/utils ${PROJECT_SOURCE_DIR}/build/ -r
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow/scripts ${PROJECT_SOURCE_DIR}/build/ -r
COMMAND cp ${PROJECT_SOURCE_DIR}/sample/tensorflow_bert ${PROJECT_SOURCE_DIR}/build/ -r
)
endif()

943 changes: 867 additions & 76 deletions README.md

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion fastertransformer/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
# Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand Down
76 changes: 49 additions & 27 deletions fastertransformer/allocator.h
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019, NVIDIA CORPORATION. All rights reserved.
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -35,34 +35,39 @@
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#endif

namespace fastertransformer{
namespace fastertransformer
{


class IAllocator{
public:
virtual void* malloc(size_t size) const = 0;
virtual void free(void* ptr) const = 0;
class IAllocator
{
public:
virtual void *malloc(size_t size) const = 0;
virtual void free(void *ptr) const = 0;
};

template<AllocatorType AllocType_>
template <AllocatorType AllocType_>
class Allocator;

template<>
class Allocator<AllocatorType::CUDA> : public IAllocator{
template <>
class Allocator<AllocatorType::CUDA> : public IAllocator
{
const int device_id_;
public:
Allocator(int device_id): device_id_(device_id){}

void* malloc(size_t size) const {
void* ptr = nullptr;
public:
Allocator(int device_id) : device_id_(device_id) {}

void *malloc(size_t size) const
{
void *ptr = nullptr;
int o_device = 0;
check_cuda_error(get_set_device(device_id_, &o_device));
check_cuda_error(cudaMalloc(&ptr, size));
check_cuda_error(get_set_device(o_device));
return ptr;
}

void free(void* ptr) const {

void free(void *ptr) const
{
int o_device = 0;
check_cuda_error(get_set_device(device_id_, &o_device));
check_cuda_error(cudaFree(ptr));
Expand All @@ -71,34 +76,51 @@ class Allocator<AllocatorType::CUDA> : public IAllocator{
}
};


//TODO: allocator of TensorFlow
// You can add context to constructor
#ifdef GOOGLE_CUDA
using namespace tensorflow;
template<>
class Allocator<AllocatorType::TF> : public IAllocator{
template <>
class Allocator<AllocatorType::TF> : public IAllocator
{
OpKernelContext *context_;
public:
Allocator(OpKernelContext *context): context_(context){}
std::vector<Tensor> *allocated_tensor_vector;

public:
Allocator(OpKernelContext *context) : context_(context)
{
allocated_tensor_vector = new std::vector<Tensor>;
}

void* malloc(size_t size) const {
void *malloc(size_t size) const
{
Tensor buf;
long long int buf_size = (long long int)size;
tensorflow::Status status = context_->allocate_temp(DT_UINT8, TensorShape{buf_size}, &buf);
allocated_tensor_vector->push_back(buf);

if(status != tensorflow::Status::OK())
if (status != tensorflow::Status::OK())
throw std::runtime_error("TF error: context->allocate_temp failed");

auto flat = buf.flat<uint8>();
void* ptr = (void*)flat.data();
void *ptr = (void *)flat.data();
cudaMemset(ptr, 0, buf_size);
return ptr;
}

void free(void* ptr) const {

void free(void *ptr) const
{
#ifndef NDEBUG
printf("call from allocator free\n");
#endif
return;
}

~Allocator()
{
allocated_tensor_vector->clear();
delete allocated_tensor_vector;
}
};
#endif
}//namespace fastertransformer
} //namespace fastertransformer
121 changes: 121 additions & 0 deletions fastertransformer/beamsearch_opennmt.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
/*
* Copyright (c) 2020, NVIDIA CORPORATION. All rights reserved.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

/**
* BeamSearch OpenNMT
**/

#pragma once

#include <cuda_runtime.h>
#include "fastertransformer/allocator.h"
#include "fastertransformer/cuda/cuda_kernels.h"
#include "fastertransformer/cuda/open_attention.h"
#include "fastertransformer/cuda/decoding_kernel_check.h"

namespace fastertransformer
{

template <typename T>
void BeamSearch_OpenNMT(
float *log_probs, float *cum_log_probs, bool *finished,
T **key_cache, T **value_cache,
int *parent_ids,
int *sequence_length,
int *word_ids,
int *ids,
int *output_ids,
const int batch_size, const int beam_width,
const int vocab_size, const int hidden_dim, const int step,
const int cache_size, const int decoder_layers, cudaStream_t stream,
const int end_id,
int *finished_count)
{
#ifdef NDEBUG
/* adding cum_log_probs to log_probs */
broadcast_kernelLauncher(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream);
#else
broadcast_kernelLauncher(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream);
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());

/*
User can check the broadcast_kernel by broadcast_kernel_check.
broadcast_kernel_check will compare the results of GPU and CPU.
Note that broadcast_kernel_check contains broadcast_kernelLauncher and uses do not need to call it again.
*/
// broadcast_kernel_check(log_probs, cum_log_probs, batch_size, beam_width, vocab_size, stream);
#endif

#ifdef NDEBUG
/*Use two round kernels to pick the topK values for each batch */
topK(log_probs, ids, batch_size, beam_width, vocab_size, stream);
#else
topK(log_probs, ids, batch_size, beam_width, vocab_size, stream);
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());

/*
User can check the topK by topK_check.
topK_check will compare the results of GPU and CPU.
Note that topK_check contains topK and uses do not need to call it again.
*/
// topK_kernel_check(log_probs, ids, batch_size, beam_width, vocab_size, stream);
#endif

#ifdef NDEBUG
update(log_probs, cum_log_probs, ids, finished,
parent_ids, sequence_length, word_ids, output_ids,
batch_size, beam_width, vocab_size, stream,
end_id, finished_count);
#else
update(log_probs, cum_log_probs, ids, finished,
parent_ids, sequence_length, word_ids, output_ids,
batch_size, beam_width, vocab_size, stream,
end_id, finished_count);
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());

/*
User can check the update by update_kernel_check.
update_kernel_check will compare the results of GPU and CPU.
Note that update_kernel_check contains update and uses do not need to call it again.
*/
// update_kernel_check(log_probs, cum_log_probs, ids, finished, parent_ids, sequence_length, word_ids, output_ids,
// batch_size, beam_width, vocab_size, stream, end_id, finished_count);
#endif

#ifdef NDEBUG
update_KV_cache<T>(key_cache, value_cache, parent_ids, batch_size,
beam_width, hidden_dim, step, cache_size,
decoder_layers, stream);
#else
update_KV_cache<T>(key_cache, value_cache, parent_ids, batch_size,
beam_width, hidden_dim, step, cache_size,
decoder_layers, stream);
cudaDeviceSynchronize();
check_cuda_error(cudaGetLastError());

/*
User can check the update_KV_cache by update_KV_cache_kernel_check.
update_KV_cache_kernel_check will compare the results of GPU and CPU.
Note that update_KV_cache_kernel_check contains update_KV_cache and uses do not need to call it again.
*/
// update_KV_cache_kernel_check(key_cache, value_cache, parent_ids, batch_size, beam_width, hidden_dim, step, cache_size, decoder_layers, stream);
#endif
}

} // namespace fastertransformer
Loading