Skip to content

Commit

Permalink
Merge pull request #264 from allenai/LayerNormAffine-ManualLayerNorm-…
Browse files Browse the repository at this point in the history
…TurnedOffForSafety

Layer norm affine manual layer norm turned off for safety
  • Loading branch information
dirkgr authored Sep 12, 2023
2 parents a49f4ec + 8d094b6 commit 26e17c3
Show file tree
Hide file tree
Showing 9 changed files with 102 additions and 121 deletions.
7 changes: 4 additions & 3 deletions configs/v1-mix-medium-mcli.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ wandb:
model:
d_model: 4096
n_heads: 16
n_layers: 30
n_layers: 29
mlp_ratio: 8
alibi: true
alibi_bias_max: 8.0
Expand All @@ -20,7 +20,8 @@ model:
include_bias: false
block_type: sequential
layer_norm_type: low_precision
layer_norm_with_affine: false
layer_norm_with_affine: true # workaround for the layer norm bug
bias_for_layer_norm: true # workaround for the layer norm bug
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
Expand All @@ -30,7 +31,7 @@ model:
eos_token_id: 0
pad_token_id: 1
init_device: meta
init_fn: mitchell
init_fn: normal

compile: null # causes instability on AMD GPUs

Expand Down
33 changes: 5 additions & 28 deletions configs/v1-mix-medium.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ wandb:
model:
d_model: 4096
n_heads: 16
n_layers: 30
n_layers: 29
mlp_ratio: 8
alibi: true
alibi_bias_max: 8.0
Expand All @@ -20,7 +20,8 @@ model:
include_bias: false
block_type: sequential
layer_norm_type: low_precision
layer_norm_with_affine: false
layer_norm_with_affine: true # workaround for the layer norm bug
bias_for_layer_norm: true # workaround for the layer norm bug
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
Expand Down Expand Up @@ -51,7 +52,7 @@ scheduler:
data:
paths: ${path.glob:${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/books/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/c4/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/common-crawl/*/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/s2/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/stack/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/wiki/*.npy}
pad_direction: right
num_workers: 1
num_workers: 0
drop_last: true
pin_memory: true
prefetch_factor: 16
Expand All @@ -65,7 +66,7 @@ tokenizer:
save_folder: ${path.choose:${oc.env:FLASH_DIR,no_exist}/checkpoints,/results}/${oc.env:SLURM_JOB_ID,${run_name}}
save_overwrite: false
# Sharded checkpoints (best for restarts)
save_interval: 5000
save_interval: 1000
save_num_checkpoints_to_keep: -1
# Unsharded checkpoints (for final storage)
save_interval_unsharded: null # getting errors on LUMI right now
Expand All @@ -88,30 +89,6 @@ eval_interval: ${save_interval}
eval_subset_num_batches: -1
device_eval_batch_size: ${device_train_microbatch_size}
evaluators:
##########################
# Perplexity evaluations #
##########################
# TODO: do we care about c4 and RP validation? We don't have these tokenized at the moment.
# - label: c4-validation
# subset_num_batches: 10
# data:
# paths: ${path.glob:${path.choose:${oc.env:SCRATCH_DIR,no_exist}/pretraining_data/preprocessed,/net/nfs.cirrascale/allennlp/llm-data}/c4/en/c4-validation.*.npy}
# num_workers: 2
# drop_last: true
# pin_memory: true
# persistent_workers: true
# prefetch_factor: 4

# - label: rp-validation
# subset_num_batches: 10
# data:
# paths: ${path.glob:${path.choose:${oc.env:SCRATCH_DIR,no_exist}/pretraining_data/preprocessed,/net/nfs.cirrascale/allennlp/llm-data}/redpajama/redpajama-validation.npy}
# num_workers: 2
# drop_last: true
# pin_memory: true
# persistent_workers: true
# prefetch_factor: 4

# lump all the small datasets together (we still get separate metrics).
- label: all-small-ppl-validation
data:
Expand Down
7 changes: 4 additions & 3 deletions configs/v1-mix-small-mcli.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ model:
include_bias: false
block_type: sequential
layer_norm_type: low_precision
layer_norm_with_affine: false
layer_norm_with_affine: true # workaround for the layer norm bug
bias_for_layer_norm: true # workaround for the layer norm bug
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
Expand All @@ -30,7 +31,7 @@ model:
eos_token_id: 0
pad_token_id: 1
init_device: meta
init_fn: mitchell
init_fn: normal

compile: null # causes instability on AMD GPUs

Expand All @@ -44,7 +45,7 @@ optimizer:

scheduler:
name: cosine_with_warmup
t_warmup: 2000
t_warmup: 5000
alpha_f: 0.1

tokenizer:
Expand Down
7 changes: 4 additions & 3 deletions configs/v1-mix-small.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ model:
include_bias: false
block_type: sequential
layer_norm_type: low_precision
layer_norm_with_affine: false
layer_norm_with_affine: true # workaround for the layer norm bug
bias_for_layer_norm: true # workaround for the layer norm bug
activation_type: swiglu
residual_dropout: 0.0
embedding_dropout: 0.0
Expand All @@ -30,7 +31,7 @@ model:
eos_token_id: 0
pad_token_id: 1
init_device: meta
init_fn: mitchell
init_fn: normal

compile: null # causes instability on AMD GPUs

Expand All @@ -44,7 +45,7 @@ optimizer:

scheduler:
name: cosine_with_warmup
t_warmup: 2000
t_warmup: 5000
alpha_f: 0.1

data:
Expand Down
16 changes: 15 additions & 1 deletion olmo/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,11 @@ class LayerNormType(StrEnum):
A low-precision version of RMSNorm.
"""

amd_compatible = "amd_compatible"
"""
LayerNorm implemented manually to work around an issue with ROCm.
"""


class ActivationType(StrEnum):
gelu = "gelu"
Expand Down Expand Up @@ -301,7 +306,8 @@ class ModelConfig(BaseConfig):
"""
Whether to include bias and weight parameters for the layer norms.
This only affects layer norms that are immediately followed by a linear layer in the forward pass.
Other layer norms, such as those applied to attention keys and queries, will always include an elementwise affine transform.
Other layer norms, such as those applied to attention keys and queries, will always include an elementwise
affine transform.
"""

max_sequence_length: int = 1024
Expand All @@ -316,6 +322,14 @@ class ModelConfig(BaseConfig):
models tend to have near 0 bias terms anyway.
"""

bias_for_layer_norm: Optional[bool] = None
"""
Whether or not to include bias parameters in layer norm.
This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in
layer norm.
When this is None (the default), it inherits the setting from include_bias.
"""

scale_logits: bool = False
"""
If ``True``, scale the output logits by ``1 / sqrt(d_model)``.
Expand Down
74 changes: 57 additions & 17 deletions olmo/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> Lay
return RMSLayerNorm(config, size=size, low_precision=False, **kwargs)
elif config.layer_norm_type == LayerNormType.low_precision_rms:
return RMSLayerNorm(config, size=size, low_precision=True, **kwargs)
elif config.layer_norm_type == LayerNormType.amd_compatible:
return AMDLayerNorm(config, size=size, **kwargs)
else:
raise NotImplementedError(f"Not sure how to handle '{config.layer_norm_type}' LayerNorm type")

Expand Down Expand Up @@ -108,27 +110,22 @@ def __init__(
super().__init__(config)
self.normalized_shape = (size or config.d_model,)
self.eps = 1e-05
self.low_precision = low_precision

# We always have weight and bias even if they are turned off/set to 1 and 0, because ROCm has a
# bug where F.layer_norm() crashes during the backwards pass when no bias was given.
# When they are turned off, they need to be buffers, because FSDP can't handle the situation
# where some parameters don't require gradients.

if elementwise_affine is None:
elementwise_affine = self.config.layer_norm_with_affine
weight = torch.ones(self.normalized_shape, device=config.init_device)
if elementwise_affine:
self.register_parameter("weight", nn.Parameter(weight))
else:
self.register_buffer("weight", weight, persistent=False)

needs_bias = elementwise_affine and self.config.include_bias
bias = torch.zeros(self.normalized_shape, device=config.init_device)
if needs_bias:
self.register_parameter("bias", nn.Parameter(bias))
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
use_bias = self.config.bias_for_layer_norm
if use_bias is None:
use_bias = self.config.include_bias
if use_bias:
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
else:
self.register_parameter("bias", None)
else:
self.register_buffer("bias", bias, persistent=False)
self.register_parameter("bias", None)
self.register_parameter("weight", None)
self.low_precision = low_precision

def forward(self, x: torch.Tensor) -> torch.Tensor:
if self.low_precision:
Expand All @@ -146,6 +143,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps)


class AMDLayerNorm(LayerNormBase):
"""
LayerNorm implemented using PyTorch primitives.
We do this to work around a bug in the PyTorch/ROCm implementation of layer norm that fails with a
segfault when the bias is not present.
"""

def __init__(self, config: ModelConfig, size: Optional[int] = None, elementwise_affine: Optional[bool] = None):
super().__init__(config)
self.normalized_shape = (size or config.d_model,)
self.eps = 1e-05

if elementwise_affine is None:
elementwise_affine = self.config.layer_norm_with_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device))
use_bias = self.config.bias_for_layer_norm
if use_bias is None:
use_bias = self.config.include_bias
if use_bias:
self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device))
else:
self.register_parameter("bias", None)
else:
self.register_parameter("bias", None)
self.register_parameter("weight", None)

def forward(self, x: torch.Tensor) -> torch.Tensor:
var, mean = torch.var_mean(x, dim=-1, correction=0, keepdim=True)
var.add_(self.eps)
var.sqrt_()
x = (x - mean) / var
if self.weight is not None:
x.mul_(self.weight)
if self.bias is not None:
x.add_(self.bias)
return x


class RMSLayerNorm(LayerNorm):
"""
RMS layer norm, a simplified :class:`LayerNorm` implementation that can optionally run
Expand All @@ -167,7 +204,10 @@ def __init__(
elementwise_affine = self.config.layer_norm_with_affine
if elementwise_affine:
self.weight = nn.Parameter(torch.ones(self.config.d_model))
if self.config.include_bias:
use_bias = self.config.bias_for_layer_norm
if use_bias is None:
use_bias = self.config.include_bias
if use_bias:
self.bias = nn.Parameter(torch.zeros(self.config.d_model))
else:
self.register_parameter("bias", None)
Expand Down
58 changes: 0 additions & 58 deletions scripts/v1-mix-medium-on-lumi-no-flash.sh

This file was deleted.

10 changes: 3 additions & 7 deletions scripts/v1-mix-medium-on-lumi.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
#SBATCH --job-name=v1-mix-medium
#SBATCH --account=project_462000229
#SBATCH --output=/pfs/lustref1/flash/project_462000229/logs/%j.log
#SBATCH --nodes=128 # Total number of nodes
#SBATCH --nodes=32 # Total number of nodes
#SBATCH --ntasks-per-node=8
#SBATCH --gpus-per-node=8 # Allocate one gpu per MPI rank
#SBATCH --cpus-per-task=6
#SBATCH --time=48:00:00
#SBATCH --time-min=24:00:00
#SBATCH --time-min=8:00:00
#SBATCH --mem=0 # All memory on the node
#SBATCH --partition=standard-g

Expand Down Expand Up @@ -36,7 +36,6 @@ export SINGULARITYENV_LD_LIBRARY_PATH=/usr/local/lib:/opt/cray/libfabric/1.15.2.
# Try playing with max_split_size_mb if you run into OOM errors.
export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:128

run_name=adamw-normal-init-long-warmup
srun \
--cpus-per-task=$SLURM_CPUS_PER_TASK \
--distribution=block:block \
Expand All @@ -50,7 +49,4 @@ srun \
-B /usr/lib64/libcxi.so.1:/usr/lib64/libcxi.so.1 \
-B /usr/lib64/libjson-c.so.3:/usr/lib64/libjson-c.so.3 \
$PROJECT_DIR/containers/$OLMO_CONTAINER \
python scripts/train.py configs/v1-mix-medium.yaml \
--run_name=$run_name \
--model.init_fn=normal \
--scheduler.t_warmup=5000 ${@}
python scripts/train.py configs/v1-mix-medium.yaml --run_name=${SLURM_JOB_ID} ${@}
Loading

0 comments on commit 26e17c3

Please sign in to comment.