diff --git a/configs/v1-mix-medium-mcli.yaml b/configs/v1-mix-medium-mcli.yaml index ef75d15e6..6ad2eafe0 100644 --- a/configs/v1-mix-medium-mcli.yaml +++ b/configs/v1-mix-medium-mcli.yaml @@ -10,7 +10,7 @@ wandb: model: d_model: 4096 n_heads: 16 - n_layers: 30 + n_layers: 29 mlp_ratio: 8 alibi: true alibi_bias_max: 8.0 @@ -20,7 +20,8 @@ model: include_bias: false block_type: sequential layer_norm_type: low_precision - layer_norm_with_affine: false + layer_norm_with_affine: true # workaround for the layer norm bug + bias_for_layer_norm: true # workaround for the layer norm bug activation_type: swiglu residual_dropout: 0.0 embedding_dropout: 0.0 @@ -30,7 +31,7 @@ model: eos_token_id: 0 pad_token_id: 1 init_device: meta - init_fn: mitchell + init_fn: normal compile: null # causes instability on AMD GPUs diff --git a/configs/v1-mix-medium.yaml b/configs/v1-mix-medium.yaml index 84e958ac3..258cf13c8 100644 --- a/configs/v1-mix-medium.yaml +++ b/configs/v1-mix-medium.yaml @@ -10,7 +10,7 @@ wandb: model: d_model: 4096 n_heads: 16 - n_layers: 30 + n_layers: 29 mlp_ratio: 8 alibi: true alibi_bias_max: 8.0 @@ -20,7 +20,8 @@ model: include_bias: false block_type: sequential layer_norm_type: low_precision - layer_norm_with_affine: false + layer_norm_with_affine: true # workaround for the layer norm bug + bias_for_layer_norm: true # workaround for the layer norm bug activation_type: swiglu residual_dropout: 0.0 embedding_dropout: 0.0 @@ -51,7 +52,7 @@ scheduler: data: paths: ${path.glob:${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/books/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/c4/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/common-crawl/*/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/s2/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/stack/*.npy,${oc.env:FLASH_DIR,no_exist}/preprocessed/olmo-mix/v1-sample/gpt-neox-20b-pii-special/wiki/*.npy} pad_direction: right - num_workers: 1 + num_workers: 0 drop_last: true pin_memory: true prefetch_factor: 16 @@ -65,7 +66,7 @@ tokenizer: save_folder: ${path.choose:${oc.env:FLASH_DIR,no_exist}/checkpoints,/results}/${oc.env:SLURM_JOB_ID,${run_name}} save_overwrite: false # Sharded checkpoints (best for restarts) -save_interval: 5000 +save_interval: 1000 save_num_checkpoints_to_keep: -1 # Unsharded checkpoints (for final storage) save_interval_unsharded: null # getting errors on LUMI right now @@ -88,30 +89,6 @@ eval_interval: ${save_interval} eval_subset_num_batches: -1 device_eval_batch_size: ${device_train_microbatch_size} evaluators: - ########################## - # Perplexity evaluations # - ########################## - # TODO: do we care about c4 and RP validation? We don't have these tokenized at the moment. - # - label: c4-validation - # subset_num_batches: 10 - # data: - # paths: ${path.glob:${path.choose:${oc.env:SCRATCH_DIR,no_exist}/pretraining_data/preprocessed,/net/nfs.cirrascale/allennlp/llm-data}/c4/en/c4-validation.*.npy} - # num_workers: 2 - # drop_last: true - # pin_memory: true - # persistent_workers: true - # prefetch_factor: 4 - - # - label: rp-validation - # subset_num_batches: 10 - # data: - # paths: ${path.glob:${path.choose:${oc.env:SCRATCH_DIR,no_exist}/pretraining_data/preprocessed,/net/nfs.cirrascale/allennlp/llm-data}/redpajama/redpajama-validation.npy} - # num_workers: 2 - # drop_last: true - # pin_memory: true - # persistent_workers: true - # prefetch_factor: 4 - # lump all the small datasets together (we still get separate metrics). - label: all-small-ppl-validation data: diff --git a/configs/v1-mix-small-mcli.yaml b/configs/v1-mix-small-mcli.yaml index f2d68e93b..d700d6f1a 100644 --- a/configs/v1-mix-small-mcli.yaml +++ b/configs/v1-mix-small-mcli.yaml @@ -20,7 +20,8 @@ model: include_bias: false block_type: sequential layer_norm_type: low_precision - layer_norm_with_affine: false + layer_norm_with_affine: true # workaround for the layer norm bug + bias_for_layer_norm: true # workaround for the layer norm bug activation_type: swiglu residual_dropout: 0.0 embedding_dropout: 0.0 @@ -30,7 +31,7 @@ model: eos_token_id: 0 pad_token_id: 1 init_device: meta - init_fn: mitchell + init_fn: normal compile: null # causes instability on AMD GPUs @@ -44,7 +45,7 @@ optimizer: scheduler: name: cosine_with_warmup - t_warmup: 2000 + t_warmup: 5000 alpha_f: 0.1 tokenizer: diff --git a/configs/v1-mix-small.yaml b/configs/v1-mix-small.yaml index b376b78a7..f4ffe4ecc 100644 --- a/configs/v1-mix-small.yaml +++ b/configs/v1-mix-small.yaml @@ -20,7 +20,8 @@ model: include_bias: false block_type: sequential layer_norm_type: low_precision - layer_norm_with_affine: false + layer_norm_with_affine: true # workaround for the layer norm bug + bias_for_layer_norm: true # workaround for the layer norm bug activation_type: swiglu residual_dropout: 0.0 embedding_dropout: 0.0 @@ -30,7 +31,7 @@ model: eos_token_id: 0 pad_token_id: 1 init_device: meta - init_fn: mitchell + init_fn: normal compile: null # causes instability on AMD GPUs @@ -44,7 +45,7 @@ optimizer: scheduler: name: cosine_with_warmup - t_warmup: 2000 + t_warmup: 5000 alpha_f: 0.1 data: diff --git a/olmo/config.py b/olmo/config.py index b015e25b2..90d92c0b6 100644 --- a/olmo/config.py +++ b/olmo/config.py @@ -169,6 +169,11 @@ class LayerNormType(StrEnum): A low-precision version of RMSNorm. """ + amd_compatible = "amd_compatible" + """ + LayerNorm implemented manually to work around an issue with ROCm. + """ + class ActivationType(StrEnum): gelu = "gelu" @@ -301,7 +306,8 @@ class ModelConfig(BaseConfig): """ Whether to include bias and weight parameters for the layer norms. This only affects layer norms that are immediately followed by a linear layer in the forward pass. - Other layer norms, such as those applied to attention keys and queries, will always include an elementwise affine transform. + Other layer norms, such as those applied to attention keys and queries, will always include an elementwise + affine transform. """ max_sequence_length: int = 1024 @@ -316,6 +322,14 @@ class ModelConfig(BaseConfig): models tend to have near 0 bias terms anyway. """ + bias_for_layer_norm: Optional[bool] = None + """ + Whether or not to include bias parameters in layer norm. + This is separate from the include_bias parameter, because of a ROCm crash when biases are disabled in + layer norm. + When this is None (the default), it inherits the setting from include_bias. + """ + scale_logits: bool = False """ If ``True``, scale the output logits by ``1 / sqrt(d_model)``. diff --git a/olmo/model.py b/olmo/model.py index 3dc6e7699..bd8b11881 100644 --- a/olmo/model.py +++ b/olmo/model.py @@ -72,6 +72,8 @@ def build(cls, config: ModelConfig, size: Optional[int] = None, **kwargs) -> Lay return RMSLayerNorm(config, size=size, low_precision=False, **kwargs) elif config.layer_norm_type == LayerNormType.low_precision_rms: return RMSLayerNorm(config, size=size, low_precision=True, **kwargs) + elif config.layer_norm_type == LayerNormType.amd_compatible: + return AMDLayerNorm(config, size=size, **kwargs) else: raise NotImplementedError(f"Not sure how to handle '{config.layer_norm_type}' LayerNorm type") @@ -108,27 +110,22 @@ def __init__( super().__init__(config) self.normalized_shape = (size or config.d_model,) self.eps = 1e-05 - self.low_precision = low_precision - - # We always have weight and bias even if they are turned off/set to 1 and 0, because ROCm has a - # bug where F.layer_norm() crashes during the backwards pass when no bias was given. - # When they are turned off, they need to be buffers, because FSDP can't handle the situation - # where some parameters don't require gradients. if elementwise_affine is None: elementwise_affine = self.config.layer_norm_with_affine - weight = torch.ones(self.normalized_shape, device=config.init_device) if elementwise_affine: - self.register_parameter("weight", nn.Parameter(weight)) - else: - self.register_buffer("weight", weight, persistent=False) - - needs_bias = elementwise_affine and self.config.include_bias - bias = torch.zeros(self.normalized_shape, device=config.init_device) - if needs_bias: - self.register_parameter("bias", nn.Parameter(bias)) + self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device)) + use_bias = self.config.bias_for_layer_norm + if use_bias is None: + use_bias = self.config.include_bias + if use_bias: + self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device)) + else: + self.register_parameter("bias", None) else: - self.register_buffer("bias", bias, persistent=False) + self.register_parameter("bias", None) + self.register_parameter("weight", None) + self.low_precision = low_precision def forward(self, x: torch.Tensor) -> torch.Tensor: if self.low_precision: @@ -146,6 +143,46 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return F.layer_norm(x, self.normalized_shape, weight=self.weight, bias=self.bias, eps=self.eps) +class AMDLayerNorm(LayerNormBase): + """ + LayerNorm implemented using PyTorch primitives. + + We do this to work around a bug in the PyTorch/ROCm implementation of layer norm that fails with a + segfault when the bias is not present. + """ + + def __init__(self, config: ModelConfig, size: Optional[int] = None, elementwise_affine: Optional[bool] = None): + super().__init__(config) + self.normalized_shape = (size or config.d_model,) + self.eps = 1e-05 + + if elementwise_affine is None: + elementwise_affine = self.config.layer_norm_with_affine + if elementwise_affine: + self.weight = nn.Parameter(torch.ones(self.normalized_shape, device=config.init_device)) + use_bias = self.config.bias_for_layer_norm + if use_bias is None: + use_bias = self.config.include_bias + if use_bias: + self.bias = nn.Parameter(torch.zeros(self.normalized_shape, device=config.init_device)) + else: + self.register_parameter("bias", None) + else: + self.register_parameter("bias", None) + self.register_parameter("weight", None) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + var, mean = torch.var_mean(x, dim=-1, correction=0, keepdim=True) + var.add_(self.eps) + var.sqrt_() + x = (x - mean) / var + if self.weight is not None: + x.mul_(self.weight) + if self.bias is not None: + x.add_(self.bias) + return x + + class RMSLayerNorm(LayerNorm): """ RMS layer norm, a simplified :class:`LayerNorm` implementation that can optionally run @@ -167,7 +204,10 @@ def __init__( elementwise_affine = self.config.layer_norm_with_affine if elementwise_affine: self.weight = nn.Parameter(torch.ones(self.config.d_model)) - if self.config.include_bias: + use_bias = self.config.bias_for_layer_norm + if use_bias is None: + use_bias = self.config.include_bias + if use_bias: self.bias = nn.Parameter(torch.zeros(self.config.d_model)) else: self.register_parameter("bias", None) diff --git a/scripts/v1-mix-medium-on-lumi-no-flash.sh b/scripts/v1-mix-medium-on-lumi-no-flash.sh deleted file mode 100644 index 68fbb63e3..000000000 --- a/scripts/v1-mix-medium-on-lumi-no-flash.sh +++ /dev/null @@ -1,58 +0,0 @@ -#!/bin/bash -#SBATCH --job-name=v1-mix-medium -#SBATCH --account=project_462000229 -#SBATCH --output=/scratch/project_462000229/logs/%j.log -#SBATCH --nodes=128 # Total number of nodes -#SBATCH --ntasks-per-node=8 -#SBATCH --gpus-per-node=8 # Allocate one gpu per MPI rank -#SBATCH --cpus-per-task=6 -#SBATCH --time=48:00:00 -#SBATCH --time-min=24:00:00 -#SBATCH --mem=0 # All memory on the node -#SBATCH --partition=standard-g - -module load LUMI/22.08 partition/G - -export OLMO_CONTAINER=llm-lumi_latest.sif - -export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK -export MPICH_GPU_SUPPORT_ENABLED=1 -export NCCL_SOCKET_IFNAME=hsn -export NCCL_NET_GDR_LEVEL=3 -export MIOPEN_USER_DB_PATH=/tmp/${USER}-miopen-cache-${SLURM_JOB_ID} -export MIOPEN_CUSTOM_CACHE_DIR=${MIOPEN_USER_DB_PATH} -export CXI_FORK_SAFE=1 -export CXI_FORK_SAFE_HP=1 -export FI_CXI_DISABLE_CQ_HUGETLB=1 - -# We need to set this to avoid "Cassini Event Queue overflow detected." errors. -export FI_CXI_DEFAULT_CQ_SIZE=131072 - -#export NCCL_DEBUG=INFO -export PYTHONPATH=.:${PYTHONPATH} -export ROCM_PATH=/opt/rocm -export SINGULARITYENV_LD_LIBRARY_PATH=/usr/local/lib:/opt/cray/libfabric/1.15.2.0/lib64 - -# Try playing with max_split_size_mb if you run into OOM errors. -export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:128 - -run_name=adamw-normal-init-long-warmup -srun \ - --cpus-per-task=$SLURM_CPUS_PER_TASK \ - --distribution=block:block \ - --kill-on-bad-exit \ - scripts/run_with_environment.sh \ - singularity exec \ - -B"$PROJECT_DIR:$PROJECT_DIR" \ - -B"$SCRATCH_DIR:$SCRATCH_DIR" \ - -B /opt/cray:/opt/cray \ - -B /usr/lib64/libcxi.so.1:/usr/lib64/libcxi.so.1 \ - -B /usr/lib64/libjson-c.so.3:/usr/lib64/libjson-c.so.3 \ - $PROJECT_DIR/containers/$OLMO_CONTAINER \ - python scripts/train.py configs/v1-mix-medium-mcli.yaml \ - --save_folder=$SCRATCH_DIR/checkpoints/$run_name - --run_name=$run_name \ - --model.init_fn=normal \ - --scheduler.t_warmup=5000 ${@} - - # -B"$FLASH_DIR:$FLASH_DIR" \ diff --git a/scripts/v1-mix-medium-on-lumi.sh b/scripts/v1-mix-medium-on-lumi.sh index 66ac2ce96..776808e2a 100644 --- a/scripts/v1-mix-medium-on-lumi.sh +++ b/scripts/v1-mix-medium-on-lumi.sh @@ -2,12 +2,12 @@ #SBATCH --job-name=v1-mix-medium #SBATCH --account=project_462000229 #SBATCH --output=/pfs/lustref1/flash/project_462000229/logs/%j.log -#SBATCH --nodes=128 # Total number of nodes +#SBATCH --nodes=32 # Total number of nodes #SBATCH --ntasks-per-node=8 #SBATCH --gpus-per-node=8 # Allocate one gpu per MPI rank #SBATCH --cpus-per-task=6 #SBATCH --time=48:00:00 -#SBATCH --time-min=24:00:00 +#SBATCH --time-min=8:00:00 #SBATCH --mem=0 # All memory on the node #SBATCH --partition=standard-g @@ -36,7 +36,6 @@ export SINGULARITYENV_LD_LIBRARY_PATH=/usr/local/lib:/opt/cray/libfabric/1.15.2. # Try playing with max_split_size_mb if you run into OOM errors. export PYTORCH_HIP_ALLOC_CONF=max_split_size_mb:128 -run_name=adamw-normal-init-long-warmup srun \ --cpus-per-task=$SLURM_CPUS_PER_TASK \ --distribution=block:block \ @@ -50,7 +49,4 @@ srun \ -B /usr/lib64/libcxi.so.1:/usr/lib64/libcxi.so.1 \ -B /usr/lib64/libjson-c.so.3:/usr/lib64/libjson-c.so.3 \ $PROJECT_DIR/containers/$OLMO_CONTAINER \ - python scripts/train.py configs/v1-mix-medium.yaml \ - --run_name=$run_name \ - --model.init_fn=normal \ - --scheduler.t_warmup=5000 ${@} + python scripts/train.py configs/v1-mix-medium.yaml --run_name=${SLURM_JOB_ID} ${@} diff --git a/tests/model_test.py b/tests/model_test.py index 7770c350a..e0d46ea31 100644 --- a/tests/model_test.py +++ b/tests/model_test.py @@ -6,6 +6,7 @@ from olmo import BlockType, LayerNorm, Olmo, Tokenizer, TrainConfig from olmo.config import PaddingDirection from olmo.data import DataCollator +from olmo.model import AMDLayerNorm @pytest.mark.parametrize( @@ -395,6 +396,7 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include train_config.model.layer_norm_with_affine = elementwise_affine train_config.model.include_bias = include_bias ln = LayerNorm.build(train_config.model) + amd_ln = AMDLayerNorm(train_config.model) needs_weight = elementwise_affine needs_bias = elementwise_affine and include_bias @@ -402,21 +404,28 @@ def test_layer_norm(train_config: TrainConfig, elementwise_affine: bool, include if needs_weight: weight = torch.randn(train_config.model.d_model) ln.weight.copy_(weight) + amd_ln.weight.copy_(weight) else: weight = None if needs_bias: bias = torch.randn(train_config.model.d_model) ln.bias.copy_(bias) + amd_ln.bias.copy_(bias) else: bias = None assert ln.bias is None or ln.bias.requires_grad == needs_bias assert ln.weight is None or ln.weight.requires_grad == needs_weight + assert amd_ln.bias is None or amd_ln.bias.requires_grad == needs_bias + assert amd_ln.weight is None or amd_ln.weight.requires_grad == needs_weight x = torch.randn(16, 1024, train_config.model.d_model) x.requires_grad = False + y_expected = F.layer_norm(x, [train_config.model.d_model], weight, bias) y_actual = ln(x) - y_expected = F.layer_norm(x, [train_config.model.d_model], weight, bias) + torch.testing.assert_close(y_actual, y_expected) + + y_actual = amd_ln(x) torch.testing.assert_close(y_actual, y_expected)