Skip to content

Commit

Permalink
Merge pull request #377 from allenai/Muennighoff/split-model-comps
Browse files Browse the repository at this point in the history
Small simplifications
  • Loading branch information
Muennighoff committed Nov 22, 2023
2 parents 80b081b + bb78df5 commit dcdadc5
Showing 1 changed file with 10 additions and 18 deletions.
28 changes: 10 additions & 18 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def __init__(
self.config = config
self.eps = eps
self.normalized_shape = (size or config.d_model,)
if elementwise_affine is None:
elementwise_affine = self.config.layer_norm_with_affine
if elementwise_affine:
if elementwise_affine or (elementwise_affine is None and self.config.layer_norm_with_affine):
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
use_bias = self.config.bias_for_layer_norm
if use_bias is None:
Expand Down Expand Up @@ -151,7 +149,7 @@ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> Lay
elif config.layer_norm_type == LayerNormType.amd_compatible:
return AMDLayerNorm(config, size=size, **kwargs)
else:
raise NotImplementedError(f"Not sure how to handle '{config.layer_norm_type}' LayerNorm type")
raise NotImplementedError(f"Unknown LayerNorm type: '{config.layer_norm_type}'")

def _cast_if_autocast_enabled(self, tensor: torch.Tensor, dtype: Optional[torch.dtype] = None) -> torch.Tensor:
# NOTE: `is_autocast_enabled()` only checks for CUDA autocast, so we use the separate function
Expand Down Expand Up @@ -311,8 +309,7 @@ def rotate_half(self, x: torch.Tensor) -> torch.Tensor:
return torch.cat((-x2, x1), dim=-1)

def apply_rotary_pos_emb(self, pos_sin: torch.Tensor, pos_cos: torch.Tensor, t: torch.Tensor) -> torch.Tensor:
out = (t * pos_cos) + (self.rotate_half(t) * pos_sin)
return out.to(t.dtype)
return ((t * pos_cos) + (self.rotate_half(t) * pos_sin)).to(t.dtype)

def forward(self, q: torch.Tensor, k: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
if self.config.rope_full_precision:
Expand Down Expand Up @@ -357,7 +354,7 @@ def build(cls, config: ModelConfig) -> Activation:
elif config.activation_type == ActivationType.swiglu:
return SwiGLU(config)
else:
raise NotImplementedError(f"not sure how to handle activation type '{config.activation_type}'")
raise NotImplementedError(f"Unknown activation: '{config.activation_type}'")


class GELU(nn.GELU):
Expand Down Expand Up @@ -572,11 +569,7 @@ def attention(
k = torch.cat((past_key, k), dim=-2)
v = torch.cat((past_value, v), dim=-2)

if use_cache:
present = (k, v)
else:
present = None

present = (k, v) if use_cache else None
query_len, key_len = q.shape[-2], k.shape[-2] # could be different if layer_past not None

if self.config.rope:
Expand Down Expand Up @@ -629,7 +622,7 @@ def build(cls, layer_id: int, config: ModelConfig, cache: BufferCache) -> OlmoBl
elif config.block_type == BlockType.llama:
return OlmoLlamaBlock(layer_id, config, cache)
else:
raise NotImplementedError(f"not sure how to handle block type '{config.block_type}'")
raise NotImplementedError(f"Unknown block type: '{config.block_type}'")


class OlmoSequentialBlock(OlmoBlock):
Expand Down Expand Up @@ -719,10 +712,9 @@ def __init__(self, layer_id: int, config: ModelConfig, cache: BufferCache):
super().__init__(layer_id, config, cache)
self.norm = LayerNorm.build(config)
# Fused attention and feed-forward projection.
# NOTE: we could also fuse the attention and feed-forward output projections
# but we found that didn't help, possibly because of the overhead of joining the `att`
# and `ff` activations together.
# See https://github.com/allenai/LLM/pull/79 for details.
# NOTE: we could also fuse the attention and feed-forward output projections but we
# found that didn't help, possibly because of the overhead of joining the `att` and
# `ff` activations together. See https://github.com/allenai/LLM/pull/79 for details.
if config.multi_query_attention:
self.fused_dims = (
config.d_model,
Expand Down Expand Up @@ -1420,7 +1412,7 @@ def generate(
`(batch_size, 1, seq_len + tokens_to_generate, seq_len + tokens_to_generate)`,
the same as for the forward method except only one shape is excepted here.
For an explanation of the other arguments, see the :class:`BeamSearch` class.
For an explanation of the other arguments, see :class:`BeamSearch`.
"""
beam_search = BeamSearch(
self.config.eos_token_id,
Expand Down

0 comments on commit dcdadc5

Please sign in to comment.