Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable both Pytorch native AMP and Nvidia APEX AMP for SRU #158

Conversation

visionscaper
Copy link

Hi!

I was happily using SRUs with Pytorch native AMP, however I started experimenting with training using Microsoft DeepSpeed and bumped in to an issue.

Basically the issues is that I observed that FP16 training using DeepSpeed doesn't work for both GRUs and SRUs. However when using Nvidia APEX AMP, DeepSpeed training using GRUs does work.

So, based on the tips in one of your issues, I started looking in to how I could enable Pytorch native AMP and Nvidia APEX AMP for SRUs, so I could train models based on SRUs using DeepSpeed.

That is why I created this pull request. Basically, I found that by making the code simpler, I can make SRUs work with both methods of AMP.

Now amp_recurrence_fp16 can be used for both types of AMP. When amp_recurrence_fp16=True, the tensor's are cast to float16, otherwise nothing special happens. So, I also removed the torch.cuda.amp.autocast(enabled=False) region; I might be wrong, but it seems that we don't need it.

I did some tests with my own code and it works in the different scenarios of interest:

  • Using PyTorch native AMP, not using DeepSpeed
  • Not using PyTorch native AMP, not using DeepSpeed
  • Using Nvidia APEX AMP, using DeepSpeed
  • Not using Nvidia APEX AMP, using DeepSpeed

It would be beneficial if we can test this with an official SRU repo test, maybe repurposing the language_model/train_lm.py?

@taolei87
Copy link
Contributor

cc @hpasapp

Hi @visionscaper , thank you so much for contributing to the repo.

We set the autocast(enabled=False) block for keeping the recurrence kernel running in fp32 even if AMP is used. In other words, if amp_recurrence_f16=False, the recurrence kernel will always use fp32 no matter AMP is used or not. if amp_recurrrence_fp16=True, the recurrence kernel will use fp16 when AMP is enabled. We keep the recurrence kernel for fp32 for better precision, although we didn't compare if this makes a difference in real cases though.

Does the autocast() block make SRU incompatible with APEX AMP?

@hpasapp
Copy link
Collaborator

hpasapp commented Jan 29, 2021

@visionscaper

  • Your new if block is interesting. But I believe that the conditional is reversed? i.e., when amp is enabled, previous operations will I think have already converted all tensors to float16, and what we want is to convert the tensors to float32, during the recurrence, in case the extra precision improves performance somewhat.
    • You may wonder why we convert to float16 when amp_recurrence_f16=True, rather than just leaving them as float16, and the answer is, because if they're already in float16, then the casts become nops, which doesn't affect performance, whereas if they aren't then not having the cast would cause a runtime error. Always doing the cast decreases coupling with the operations creating the incoming tensors, so that if some upstream operation does produce a float32 output, now or in the future, this won't cause a runtime error. It also means the recurrence is easy to reason about, because we know for sure what are the tensor types going into it, without having to check empirically, or to look up the amp types of all upstream operations.
  • I'm surprised that the if getattr(torch, 'is_autocast_enabled', lambda: False)() is evaluating to True. Am I right in guessing that this attribute exists in APEX, not just in AMP, and it's just the more specific torch.cuda.amp.autocast(...) that doesn't exist in APEX. Is there somethign in APEX that is equivalent to torch.cuda.amp.autocast(...) ?

@visionscaper
Copy link
Author

Hi @taolei87, @hpasapp,

Thanks for your replies! I'm still trying to figure out what precisely is going on and what the reasoning is behind some choices.

  • If you want the recurrence kernel to use FP32, no matter what, why at all allow FP16 inputs? Do you mean you want to keep the weights FP32?

  • I did some debugging tonight and found the following (applies to both using native PyTorch AMP and APEX AMP):

Inside the elementwise_recurrence_gpu before any casting, the types of the input arguments are as follows:

Type of U : torch.float16
Type of x : torch.float32
Type of weight_c : torch.float32
Type of bias : torch.float32
Type of c_init : torch.float32
Type of scale_x : None
Type of dropout_mask_c : torch.float32

In this case U is torch.float16 because it is calculated as U = x.mm(self.weight) in sru.modules.SRUCell.compute_U, where x = input. Input is torch.float32 in my case because it is a token embedding, which is always computed in FP32. Subsequently, since we are using AMP (for both versions), U will be torch.float16.

So now I also better understand why originally I had the issue with Nvidia APEX AMP. Looking at the original code ...

    in_autocast = getattr(torch, 'is_autocast_enabled', lambda: False)()
    if in_autocast:
        with torch.cuda.amp.autocast(enabled=False):
            cast = torch.Tensor.half if amp_recurrence_fp16 else torch.Tensor.float

            U = cast(U)
            x = cast(x)
            weight_c = cast(weight_c)
            bias = cast(bias)
            c_init = cast(c_init)
            scale_x = cast(scale_x) if scale_x is not None else scale_x
            dropout_mask_c = cast(dropout_mask_c) if dropout_mask_c is not None else dropout_mask_c

            return SRU_Compute_GPU.apply(
                U,
                x,
                weight_c,
                bias,
                c_init,
                activation_type,
                hidden_size,
                bidirectional,
                has_skip_term,
                scale_x,
                dropout_mask_c,
                mask_pad
            )
    else:
        return SRU_Compute_GPU.apply(
            U,
            x,
            weight_c,
            bias,
            c_init,
            activation_type,
            hidden_size,
            bidirectional,
            has_skip_term,
            scale_x,
            dropout_mask_c,
            mask_pad
        )

... in_autocast evaluates to False (I checked), simply because we are using Nvidia APEX AMP instead of Pytorch AMP.
This implies that we are directly calling SRU_Compute_GPU.apply() with mixed precision arguments, in my opinion this leads to the error I reported. The same issue will occur with PyTorch AMP but keeping amp_recurrence_fp16=False (I checked this too).

I'm wondering, could this be because you use a custom C/C++ implementation that AMP stops working in the recurrence kernel? Do you agree that this causes the error?

In the code in my pull request this is remedied by simply ensuring we cast to torch.float16 when amp_recurrence_fp16=True.

Now the last point is, do we need the torch.cuda.amp.autocast(enabled=False) region?
I'm even in doubt if it makes any difference, not even from a numerical precision point of view, but if the block actually affects the C/C++ code at all? What's your view on this?

In conclusion, cases can occur where arguments with mixed precision are used to call the recurrence kernel.
It seems that autocasting doesn't seems to apply in the recurrence kernel (because it is implemented in C/C++?), this seems to lead to the RuntimeError: expected scalar type Half but found Float in the case of mixed precision input arguments.

PS 1 : @hpasapp I understand that casting is just an identity function if the input tensor is already in the right type. And yes, it creates a decoupling, but this decoupling should also be applied for any AMP method, not just PyTorch native AMP, right?

PS 2 : I would be happy to place the torch.cuda.amp.autocast(enabled=False) back in to the code, but I have doubts if it makes any sense. Please let me know what you think.

@hpasapp
Copy link
Collaborator

hpasapp commented Jan 30, 2021

Ok, so:

  1. as you show, the tensor types arriving at this point of the code are a mixture of float16 and float32. The float16 is because the previous matrix multiplication, for U was autocast, by amp, into float16. The other tensors have arrived unchanged. I was wrong to say they all arrive as float16: in fact many/most arrive as float32.
  2. for many geometries, the execution time for SRU is dominated by the matrix multiplication for U. Thus, as long as the matrix multiplication is carried out in float16, amp will show good speedup.
  3. The recurrence is a custom kernel. The custom kernel is templated, so runs either all in float16 or all in float32. Generally, the recurrence is I/O bound: there is little computation, and most of the kernel time plausibly goes into moving data to and from the kernel. In this scenario, if the tensors are float16, then the amount of data to move is half that for float32, and there is a speedup.
  • note that to convert between float16 and float32, a new kernel launch is required, for each tensor, which adds its own latency
  • when using float32 recurrence, with AMP turned on, running nvprof showed some kernels with signatures like void at::native::unrolled_elementwise_kernel<at::native::copy_device_to_device(...) adding significantly to the latency
  • I guess these might be converting from U from float16 to float32? At least, this is my current hypothesis/belief.

…CUDA kernel either needs all Float32 or all Float16 tensors. Disabled AMP block is back because AMO is moot for the recurrence kernel
@visionscaper
Copy link
Author

The recurrence is a custom kernel. The custom kernel is templated, so runs either all in float16 or all in float32.

Aha! This is new information to me (I haven't written any custom CUDA kernels myself yet). This is also why we need to cast all the Tensors either to Float32 or Float16.

So do I now understand correctly that we need the autocast(enabled=False) because autocasting is moot for the custom kernel?

Bringing this all together, we could update the code as follows:

@torch.jit.unused
def elementwise_recurrence_gpu(U: Tensor,
                               x: Tensor,
                               weight_c: Tensor,
                               bias: Tensor,
                               c_init: Tensor,
                               activation_type: int,
                               hidden_size: int,
                               bidirectional: bool,
                               has_skip_term: bool,
                               scale_x: Optional[Tensor] = None,
                               dropout_mask_c: Optional[Tensor] = None,
                               mask_pad: Optional[Tensor] = None,
                               amp_recurrence_fp16: bool = False) -> List[Tensor]:
    """Elementwise forward operation of SRU on GPU.

    """
    from .cuda_functional import SRU_Compute_GPU

    cast = torch.Tensor.half if amp_recurrence_fp16 else torch.Tensor.float

    U = cast(U)
    x = cast(x)
    weight_c = cast(weight_c)
    bias = cast(bias)
    c_init = cast(c_init)
    scale_x = cast(scale_x) if scale_x is not None else scale_x
    dropout_mask_c = cast(dropout_mask_c) if dropout_mask_c is not None else dropout_mask_c

    in_autocast = getattr(torch, 'is_autocast_enabled', lambda: False)()
    if in_autocast:
        with torch.cuda.amp.autocast(enabled=False):
            return SRU_Compute_GPU.apply(
                U,
                x,
                weight_c,
                bias,
                c_init,
                activation_type,
                hidden_size,
                bidirectional,
                has_skip_term,
                scale_x,
                dropout_mask_c,
                mask_pad
            )
    else:
        return SRU_Compute_GPU.apply(
            U,
            x,
            weight_c,
            bias,
            c_init,
            activation_type,
            hidden_size,
            bidirectional,
            has_skip_term,
            scale_x,
            dropout_mask_c,
            mask_pad
        )

The cast = torch.Tensor.half if amp_recurrence_fp16 else torch.Tensor.float ensures that all tensors have the same precision. In this way it will work for both Python AMP and Nvidia APEX AMP. Further, as you mentioned before, the casting is a no-op when the tensor already has the right precision.

Further, we (again), disable Pytorch native AMP when it is used by the user, because it is moot when calling te recurrence kernel.

Because I think this fix update makes a lot of sense, I committed it to my fork. I also checked if it still works with PyTorch AMP and Nvidia APEX AMP (with Deepspeed).

@hpasapp @taolei87 please let me know what you think of this new fix and if we need to do more testing.

@hpasapp
Copy link
Collaborator

hpasapp commented Jan 30, 2021

Per my understanding from https://pytorch.org/docs/stable/notes/amp_examples.html#autocast-and-custom-autograd functions we should first disable autocast, and then do any casts we would like. I'm not sure what happens if we do the cast first, but I worry we might then be "out of spec", and things might fail under some circumstances.

Googling briefly for using custom cuda functions with apex, https://github.com/NVIDIA/apex/tree/master/apex/amp#annotating-user-functions appears to describe what to do. We can either add annotations to the function to say that it should be used with half, or should be used with float32. Or we can call a registration function directly. Given that we have a runtime parameter to decide whether the recurrence should be float32 or float16, it might not make sense to use the annotations, and so calling the registration function might be the way to go. What do you think?

@visionscaper
Copy link
Author

Hi @hpasapp,

I adapted the implementation as follows:

  • When Pytorch AMP is used : Casting of tensors is done again in the disabled autocast block.
  • When APEX AMP is available : A FP16 and a FP32 function are registered. Depending on amp_recurrence_fp16 the first or the latter is used.
  • Otherwise: no casting is done

In my opinion the code is more complex now, I liked my previous fix, and am not sure if the changes are really required. But if you want to be sure to do it the right way, I guess this is better.

For now I put my changes in a separate branch, and can be found here:
visionscaper@feedd77

Let me know what you think.

@visionscaper
Copy link
Author

visionscaper commented Jan 31, 2021

@hpasapp @taolei87

I did some more experiments, concerning the use of Deepspeed, APEX AMP and Pytorch AMP on a system with 3x 1080 Ti and on an AWS instance with 4x T4 (also scroll to the right). Here I'm training a seq2seq conversational model, with customizations to make it multi-turn. The model uses SRUs (with the latest fix). CUDA version is 11.00.

								3x 1080Ti					4x T4
								Memory (per GPU)	Samples/sec (per GPU)	Memory (per GPU)	Samples/sec (per GPU)
Deepspeed	APEX AMP config (FP16)				~4.3 GB			~8.5			~4.5			~13.3
Deepspeed	no APEX AMP Config/APEX installed (FP32)	~3.3 GB			~12			~3.6			~10
Deepspeed	no APEX installed (FP32)			~3.3 GB			~12			~3.6			~9.5

No Deepspeed	Pytorch AMP (FP16)				~7.1 GB			~3.5			~7.4			~4.7
No Deepspeed	No Pytorch AMP (FP32)				~6.5 GB			~4.4			~6.8			~4.0

Some conclusions:

  • Deepspeed speeds up training a lot (factor 2-4)
  • Deepspeed uses much less memory (about a factor 2)
  • Oddly enough, for both Pytorch AMP or APEX AMP memory usage increases
  • On older/consumer GPUs (1080Ti) FP16 training is actually slower than FP32 training. On enterprise grade HW (T4) this is the other way around.

Any idea why the memory usage increases?
What is your experience with FP16 training of models that use SRUs? Does it also have poor memory performance?

@visionscaper
Copy link
Author

^ Could the higher memory usage be because of the casting from FP32 => FP16?

@hpasapp
Copy link
Collaborator

hpasapp commented Jan 31, 2021

  1. execution speed vs matrix dimensions

As far as execution speed, note that AMP is strongly is affected by whether all dimensions are a multiple of 8. I'm not saying this is the cause of some of your results, but it could be. Attached, screenshot of results of some experiments where we multiple two matrices A and B, of dimension n x n, and vary n:

Screen Shot 2021-01-31 at 10 29 15 AM

You can see that changing from n=512 to n=510 decreases execution time without AMP, but with AMP enabled, reducing n from 512 to 510 actually increases the execution time by nearly 3 :O

For SRU, the simplest way to ensure this is that one makes all dimensions of each tensor a multiple of 8: sequence length, batch size, embedding size.

In practice, the matrix multiplication is the one that generates U,

U = x.mm(self.weight)
So the following quantities should be multiples of 8:

  • seq_len * batch_size
  • input_size
  • output_size * num_matrices
    where num_matrices is often 3, and so this means output_size should most likely be a multiple of 8.
  1. 1080ti vs t4

Yes, older gpus didn't have the specialized tensor cores. v100 and t4 both work ok using amp.

  1. memory usage

That AMP increases memory usage is actually new information to me. I don't know if the casting operation itself increases memory usage, but perhaps some buffers are in memory twice: once as fp32 and once as fp16?

  • you might also consider experimenting with fp16 recurrence. This might decrease memory usage, since the result of the matrix multiplication wont need to be converted back into fp32. You'd probably want to do experiments on actual performance (ppl, accuracy, etc), to check/confirm/convince yourself that using an fp16 recurrence doesn't affect them.

@taoleicn
Copy link
Contributor

taoleicn commented Mar 3, 2021

Closing this PR. This is addressed in the dev branch and will be merged in later:
#167

@taoleicn taoleicn closed this Mar 3, 2021
@visionscaper
Copy link
Author

You're welcome.

@taoleicn
Copy link
Contributor

hi @visionscaper ,
I read your commit that you posted here earlier.

Seems you explicitly register the recurrence operators for apex. So it will enable autocasting of the tensors into the appropriate types? I noticed that in this PR you don't have the registrations. Would not registering the operators break the apex amp training?

@taoleicn taoleicn reopened this Apr 11, 2021
@taoleicn taoleicn closed this May 16, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants