Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add triton implementation of layer norm #260

Draft
wants to merge 34 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
25086b6
add triton implementation of layer norm
epwalsh Sep 6, 2023
495cd73
add note about source
epwalsh Sep 6, 2023
f8854ac
reorganize
epwalsh Sep 6, 2023
31b4ed9
decorate tests
epwalsh Sep 6, 2023
aeeb623
add missing params to test
epwalsh Sep 6, 2023
6dddf66
add bfloat16 test
epwalsh Sep 6, 2023
978d29a
clean up tests
epwalsh Sep 6, 2023
eb8b1d8
more dtypes
epwalsh Sep 6, 2023
4f1e66c
adjust test tolerance
epwalsh Sep 7, 2023
d3c565d
fix
epwalsh Sep 7, 2023
f5cf624
Merge branch 'main' into petew/triton
epwalsh Sep 7, 2023
7718032
add triton build script
epwalsh Sep 7, 2023
cafa951
upload to s3
epwalsh Sep 7, 2023
ea2ba9c
clean up tests
epwalsh Sep 7, 2023
54d0a14
increase tolerance
epwalsh Sep 7, 2023
ace4266
add lumi test script
epwalsh Sep 7, 2023
add5f7a
change names
epwalsh Sep 7, 2023
a9066e1
fix lints
epwalsh Sep 7, 2023
8344dc5
shorten names
epwalsh Sep 7, 2023
e1f54fc
Merge branch 'main' into petew/triton
epwalsh Sep 20, 2023
4c5c6cd
Merge branch 'petew/layer-norm' into petew/triton
epwalsh Sep 20, 2023
0a06206
add TritonLayerNorm class
epwalsh Sep 20, 2023
84df0f2
fix merge conflicts
epwalsh Sep 20, 2023
6ce3ed3
clean up
epwalsh Sep 20, 2023
dd5628d
update build script
epwalsh Sep 20, 2023
7415a72
Merge branch 'main' into petew/triton
epwalsh Sep 21, 2023
95547f1
add note about clearing out the cache
epwalsh Sep 21, 2023
570dd77
try no affines anywhere
epwalsh Sep 22, 2023
ffb697a
add warning
epwalsh Sep 22, 2023
17d27db
auto remove triton cache
epwalsh Sep 22, 2023
9b7519c
fix cleaning triton cache
epwalsh Sep 22, 2023
63bb3e5
add triton LN with linear
epwalsh Sep 22, 2023
940ea7b
fix
epwalsh Sep 22, 2023
2ce967d
add back elementwise affine to QK norm
epwalsh Sep 22, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@ ignore =
W503
# line too long, who cares?
E501
# don't assign a lambda expression
E731

exclude =
.venv
Expand Down
5 changes: 5 additions & 0 deletions olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,11 @@ class LayerNormType(StrEnum):
LayerNorm implemented manually to work around an issue with ROCm.
"""

triton = "triton"
"""
A triton implementation of layer norm.
"""


class ActivationType(StrEnum):
gelu = "gelu"
Expand Down
29 changes: 29 additions & 0 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> Lay
return RMSLayerNorm(config, size=size, low_precision=True, **kwargs)
elif config.layer_norm_type == LayerNormType.amd_compatible:
return AMDLayerNorm(config, size=size, **kwargs)
elif config.layer_norm_type == LayerNormType.triton:
return TritonLayerNorm(config, size=size, **kwargs)
else:
raise NotImplementedError(f"Not sure how to handle '{config.layer_norm_type}' LayerNorm type")

Expand Down Expand Up @@ -184,6 +186,33 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x.to(og_dtype)


class TritonLayerNorm(LayerNormBase):
def __init__(
self,
config: ModelConfig,
size: Optional[int] = None,
elementwise_affine: Optional[bool] = None,
eps: float = 1e-05,
):
super().__init__(config, size=size, elementwise_affine=elementwise_affine, eps=eps)
try:
from .triton import layer_norm as triton_layer_norm # type: ignore

self._layer_norm = triton_layer_norm
except ModuleNotFoundError:
raise OlmoConfigurationError(
f"{self.__class__.__name__} is not available. Please check if you have triton installed"
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
og_dtype = x.dtype
x = self._cast_if_autocast_enabled(x, dtype=torch.float32)
with torch.autocast(enabled=False, device_type=x.device.type):
return self._layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps).to(
og_dtype
)


class RMSLayerNorm(LayerNormBase):
"""
RMS layer norm, a simplified :class:`LayerNorm` implementation that can optionally run
Expand Down
3 changes: 3 additions & 0 deletions olmo/triton/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
from .layer_norm import layer_norm

__all__ = ["layer_norm"]
Loading
Loading