Skip to content

Commit

Permalink
Revert "[mta] APEX style Fused Adam (pytorch#81705)"
Browse files Browse the repository at this point in the history
This reverts commit 7a6c4d0.

Reverted pytorch#81705 on behalf of https://github.com/dagitses due to broke internal builds, details to come
  • Loading branch information
pytorchmergebot authored and alvgaona committed Oct 11, 2022
1 parent b810784 commit 3288f03
Show file tree
Hide file tree
Showing 13 changed files with 76 additions and 934 deletions.
19 changes: 0 additions & 19 deletions aten/src/ATen/native/cuda/ForeachFunctors.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,25 +47,6 @@ __device__ bool init_args(
return all_aligned;
}

template<int depth, typename T>
__device__ bool init_args(
T** args,
FusedOptimizerTensorListMetadata<depth>& tl,
int chunk_idx,
int chunk_size,
int tensor_loc) {
bool all_aligned = true;
for (int i = 0; i < depth; i++) {
args[i] = (T*)tl.addresses[i][tensor_loc];
args[i] += chunk_idx * chunk_size;

if (!is_aligned(args[i])) {
all_aligned = false;
}
}
return all_aligned;
}

template<int depth, typename T>
__device__ void load_args(T r_args[][kILP], T** args, int i_start, int chunk_size, int n) {
#pragma unroll
Expand Down
37 changes: 0 additions & 37 deletions aten/src/ATen/native/cuda/FusedAdamKernel.cu

This file was deleted.

68 changes: 0 additions & 68 deletions aten/src/ATen/native/cuda/MultiTensorApply.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ __device__ __forceinline__ void load_store(T* dst, T* src, int dst_offset, int s
((LT*)dst)[dst_offset] = ((LT*)src)[src_offset];
}

// TODO(crcrpar): Add `n>5` for `low prec params & their higher prec copy`
// TensorListMetadata has to be < 4KB - the limit for kernel launch argument
static constexpr int depth_to_max_tensors[5] = {110, 64, 48, 36, 30};
static constexpr int depth_to_max_blocks[5] = {320, 320, 320, 320, 320};
Expand All @@ -39,18 +38,6 @@ template<int n> struct TensorListMetadata
int start_tensor_this_launch;
};

// NOTE(crcrpar): This is a conservative resolution to handle `state_steps`
// whose each element is `at::Tensor` of 1 element representing the number of `step`s called so far.
template<int n> struct FusedOptimizerTensorListMetadata
{
void* addresses[n][depth_to_max_tensors[n-1]];
int numel_for_tensor[depth_to_max_tensors[n-1]];
void* state_steps_addresses[depth_to_max_tensors_scalarlist[n-1]];
unsigned char block_to_tensor[depth_to_max_blocks[n-1]];
int block_to_chunk[depth_to_max_blocks[n-1]];
int start_tensor_this_launch;
};

template<typename scalar_vals_t, int n> struct TensorListScalarListMetadata
{
void* addresses[n][depth_to_max_tensors_scalarlist[n-1]];
Expand Down Expand Up @@ -197,61 +184,6 @@ void multi_tensor_apply(
}
}
}
}

template<int depth, typename T, typename... ArgTypes>
void multi_tensor_apply_for_fused_optimizer(
std::vector<std::vector<at::Tensor>>& tensor_lists,
at::TensorList state_steps,
T callable,
ArgTypes... args) {
TORCH_CHECK(tensor_lists.size() == depth, "Number of tensor lists has to match the depth");
const auto num_tensors = tensor_lists[0].size();
FusedOptimizerTensorListMetadata<depth> tensorListMeta;

int loc_block_info = 0;
int loc_tensor_info = 0;
for (const auto & tensor_index : c10::irange(num_tensors)) {
tensorListMeta.state_steps_addresses[loc_tensor_info] = state_steps[tensor_index].data_ptr();
tensorListMeta.numel_for_tensor[loc_tensor_info] = tensor_lists[0][tensor_index].numel();
for (const auto & d : c10::irange(depth)) {
tensorListMeta.addresses[d][loc_tensor_info] = tensor_lists[d][tensor_index].data_ptr();
}
loc_tensor_info++;

const auto chunks = (tensor_lists[0][tensor_index].numel() + kChunkSize - 1) / kChunkSize;
for (const auto & chunk : c10::irange(chunks)) {
tensorListMeta.block_to_tensor[loc_block_info] = loc_tensor_info - 1;
tensorListMeta.block_to_chunk[loc_block_info] = chunk;
loc_block_info++;

const auto tensor_full = (loc_tensor_info == depth_to_max_tensors[depth - 1] && chunk == chunks - 1);
const auto blocks_full = loc_block_info == depth_to_max_blocks[depth - 1];
const auto last_chunk = (tensor_index == num_tensors - 1 && chunk == chunks - 1);

if (tensor_full || blocks_full || last_chunk) {
multi_tensor_apply_kernel<<<loc_block_info, kBlockSize, 0, at::cuda::getCurrentCUDAStream()>>>(
tensorListMeta,
callable,
args...);
C10_CUDA_KERNEL_LAUNCH_CHECK();

// Reset.
loc_block_info = 0;
if (chunk == chunks - 1) {
loc_tensor_info = 0;
} else {
tensorListMeta.numel_for_tensor[0] = tensorListMeta.numel_for_tensor[loc_tensor_info - 1];
tensorListMeta.state_steps_addresses[0] = tensorListMeta.state_steps_addresses[loc_tensor_info - 1];
for (const auto & d : c10::irange(depth)) {
tensorListMeta.addresses[d][0] = tensorListMeta.addresses[d][loc_tensor_info - 1];
}
loc_tensor_info = 1;
}
}
}
}
}

} // namespace
}} // at::native
52 changes: 0 additions & 52 deletions aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cu

This file was deleted.

24 changes: 0 additions & 24 deletions aten/src/ATen/native/cuda/fused_adam_amsgrad_impl.cuh

This file was deleted.

51 changes: 0 additions & 51 deletions aten/src/ATen/native/cuda/fused_adam_impl.cu

This file was deleted.

23 changes: 0 additions & 23 deletions aten/src/ATen/native/cuda/fused_adam_impl.cuh

This file was deleted.

0 comments on commit 3288f03

Please sign in to comment.