Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Small simplifications #377

Merged
merged 7 commits into from
Nov 22, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 10 additions & 18 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def __init__(
self.config = config
self.eps = eps
self.normalized_shape = (size or config.d_model,)
if elementwise_affine is None:
elementwise_affine = self.config.layer_norm_with_affine
if elementwise_affine:
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
use_bias = self.config.bias_for_layer_norm
if use_bias is None:
Expand Down Expand Up @@ -151,7 +149,7 @@ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> Lay
elif config.layer_norm_type == LayerNormType.amd_compatible:
return AMDLayerNorm(config, size=size, **kwargs)
else:
raise NotImplementedError(f"Not sure how to handle '{config.layer_norm_type}' LayerNorm type")
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")

def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
Expand Down Expand Up @@ -311,8 +309,7 @@ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
out = (t * pos_cos) + (self.rotate_half(t) * pos_sin)
return out.to(t.dtype)
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)

def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.config.rope_full_precision:
Expand Down Expand Up @@ -357,7 +354,7 @@ def build(cls, config: ModelConfig) -> Activation:
elif config.activation_type == ActivationType.swiglu:
return SwiGLU(config)
else:
raise NotImplementedError(f"not sure how to handle activation type '{config.activation_type}'")
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")


class GELU(nn.GELU):
Expand Down Expand Up @@ -572,11 +569,7 @@ def attention(
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)

if use_cache:
present = (k, v)
else:
present = None

present = (k, v) if use_cache else None
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None

if self.config.rope:
Expand Down Expand Up @@ -629,7 +622,7 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OlmoBl
elif config.block_type == BlockType.llama:
return OlmoLlamaBlock(layer_id, config, cache)
else:
raise NotImplementedError(f"not sure how to handle block type '{config.block_type}'")
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")


class OlmoSequentialBlock(OlmoBlock):
Expand Down Expand Up @@ -719,10 +712,9 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
super().__init__(layer_id, config, cache)
self.norm = LayerNorm.build(config)
# Fused attention and feed-forward projection.
# NOTE: we could also fuse the attention and feed-forward output projections
# but we found that didn't help, possibly because of the overhead of joining the `att`
# and `ff` activations together.
# See https://github.com/allenai/LLM/pull/79 for details.
# NOTE: we could also fuse the attention and feed-forward output projections but we
# found that didn't help, possibly because of the overhead of joining the `att` and
# `ff` activations together. See https://github.com/allenai/LLM/pull/79 for details.
if config.multi_query_attention:
self.fused_dims = (
config.d_model,
Expand Down Expand Up @@ -1420,7 +1412,7 @@ def generate(
`(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`,
the same as for the forward method except only one shape is excepted here.

For an explanation of the other arguments, see the :class:`BeamSearch` class.
For an explanation of the other arguments, see :class:`BeamSearch`.
"""
beam_search = BeamSearch(
self.config.eos_token_id,
Expand Down
Loading