In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F

note: mps and cuda devices are optimized for mixed precision (float16) to achieve higher performance without significantly sacrificing accuracy.

To balance efficiency and precision, you can use automatic mixed precision (AMP) for training. This technique allows the model to use float16 where possible and float32 when necessary for stability.

In [2]:
class LayerNormFp32(nn.LayerNorm):
    """
    Subclass of torch's LayerNorm that ensures the computation is done in float32
    even if the input tensor is in float16. Designed for compatibility with MPS
    and CUDA devices.
    """

    def forward(self, x: torch.Tensor):
        # Save the original data type of the input tensor
        orig_type = x.dtype
        
        # Ensure the computation is done on the same device as the input tensor
        device = x.device
        
        # Convert the input tensor to float32 for the layer normalization computation
        x = F.layer_norm(x.to(torch.float32), self.normalized_shape, self.weight, self.bias, self.eps)
        
        # Convert the output tensor back to the original data type and device
        return x.to(device).to(orig_type)
