diff --git a/configs/v1-mix-medium.yaml b/configs/v1-mix-medium.yaml index b9023c218..b3ee1a1e8 100644 --- a/configs/v1-mix-medium.yaml +++ b/configs/v1-mix-medium.yaml @@ -10,7 +10,7 @@ wandb: model: d_model: 4096 n_heads: 16 - n_layers: 29 + n_layers: 30 mlp_ratio: 8 alibi: true alibi_bias_max: 8.0 diff --git a/olmo/train.py b/olmo/train.py index 65411ac2a..0ea662cb9 100644 --- a/olmo/train.py +++ b/olmo/train.py @@ -11,6 +11,7 @@ from dataclasses import dataclass, field from itertools import islice from pathlib import Path +from pstats import SortKey from typing import Any, Deque, Dict, List, Optional, TextIO, Tuple import numpy as np @@ -980,7 +981,7 @@ def fit(self): profiler.enable() elif self.global_step == 8: profiler.disable() - profiler.print_stats() + profiler.print_stats(sort=SortKey.CUMULATIVE) profiler = None else: log.info("Training loop complete") diff --git a/scripts/run_with_environment.sh b/scripts/run_with_environment.sh index facf88c7b..d69027d88 100755 --- a/scripts/run_with_environment.sh +++ b/scripts/run_with_environment.sh @@ -4,11 +4,7 @@ set -euo pipefail -# Redirect stdout and stderr so that we get a prefix with the node name export NODENAME=$(hostname -s) -exec > >(trap "" INT TERM; sed -u "s/^/$NODENAME out: /") -exec 2> >(trap "" INT TERM; sed -u "s/^/$NODENAME err: /" >&2) - export MASTER_ADDR=$(scontrol show hostnames | head -n 1) export MASTER_PORT=39591 export WORLD_SIZE=$SLURM_NTASKS @@ -18,6 +14,10 @@ export LOCAL_WORLD_SIZE=$SLURM_NTASKS_PER_NODE export LOCAL_RANK=$SLURM_LOCALID export NODE_RANK=$((($RANK - $LOCAL_RANK) / $LOCAL_WORLD_SIZE)) +# Redirect stdout and stderr so that we get a prefix with the node name +exec > >(trap "" INT TERM; sed -u "s/^/$NODENAME:$LOCAL_RANK out: /") +exec 2> >(trap "" INT TERM; sed -u "s/^/$NODENAME:$LOCAL_RANK err: /" >&2) + if [ $SLURM_LOCALID -eq 0 ] ; then rm -rf /dev/shm/* || true rocm-smi || true # rocm-smi returns exit code 2 even when it succeeds