Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/art/megatron/jobs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Literal

from pydantic import BaseModel
from pydantic import BaseModel, Field

from .. import types
from ..preprocessing.pack import DiskPackedTensors
Expand Down Expand Up @@ -31,6 +31,7 @@ class MegatronSFTTrainingJob(BaseModel):
grad_accumulation_sequences: int | None = None
weight_decay: float = 0.0
max_grad_norm: float = 1.0
internal_checkpoint_interval: int | None = Field(default=None, ge=1)
log_path: str = DEFAULT_TRAINING_LOG_PATH


Expand Down
35 changes: 32 additions & 3 deletions src/art/megatron/merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,16 @@
save_file = safetensors_torch.save_file


def merge_lora_adapter(lora_path: str) -> None:
base_dir = Path(lora_path)
def _load_adapter_shards(
base_dir: Path,
) -> tuple[
dict[str, torch.Tensor],
list[Path],
list[Path],
]:
shard_filenames = sorted(base_dir.glob("adapter_model-*-of-*.safetensors"))
if not shard_filenames:
return
raise FileNotFoundError(f"No adapter shards found in {base_dir}")

shard_files_by_suffix = {
path.name.removeprefix("adapter_model-").removesuffix(".safetensors"): path
Expand Down Expand Up @@ -93,6 +98,30 @@ def merge_lora_adapter(lora_path: str) -> None:
concat_dim = 1 if "lora_A" in key else 0
tensor = torch.cat(ordered_shards, dim=concat_dim)
adapter_model[key] = tensor
return adapter_model, shard_filenames, manifest_filenames


def load_lora_adapter_state_dict(lora_path: str) -> dict[str, torch.Tensor]:
base_dir = Path(lora_path)
adapter_model_path = base_dir / "adapter_model.safetensors"
if adapter_model_path.exists():
with safe_open(adapter_model_path, framework="pt") as file:
return {key: file.get_tensor(key) for key in file.keys()}

adapter_model, _shard_filenames, _manifest_filenames = _load_adapter_shards(
base_dir
)
return adapter_model


def merge_lora_adapter(lora_path: str) -> None:
base_dir = Path(lora_path)
try:
adapter_model, shard_filenames, manifest_filenames = _load_adapter_shards(
base_dir
)
except FileNotFoundError:
return

adapter_model_path = base_dir / "adapter_model.safetensors"
save_file(adapter_model, adapter_model_path)
Expand Down
25 changes: 18 additions & 7 deletions src/art/megatron/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
MegatronTrainingJob,
)
from art.megatron.lora import apply_lora_adapters
from art.megatron.merge import merge_lora_adapter
from art.megatron.merge import load_lora_adapter_state_dict, merge_lora_adapter
from art.megatron.offload import (
OffloadState,
clear_optimizer_state,
Expand All @@ -66,7 +66,6 @@
safetensors = importlib.import_module("safetensors")
safetensors_torch = importlib.import_module("safetensors.torch")
safe_open = safetensors.safe_open
load_file = safetensors_torch.load_file
save_file = safetensors_torch.save_file

DEFAULT_MODEL_IDENTIFIER = "Qwen/Qwen3-30B-A3B-Instruct-2507"
Expand Down Expand Up @@ -496,6 +495,7 @@ def run_megatron_sft_job(
grad_accumulation_sequences = resolve_global_grad_accumulation_sequences(
job.grad_accumulation_sequences
)
checkpoint_interval = job.internal_checkpoint_interval

for batch_idx in range(job.num_batches):
batch_start_time = time.perf_counter()
Expand Down Expand Up @@ -550,6 +550,20 @@ def run_megatron_sft_job(
)
batch_time = time.perf_counter() - batch_start_time
tokens_per_second = global_tokens / batch_time if batch_time > 0 else 0.0
completed_batches = batch_idx + 1

if (
checkpoint_interval is not None
and completed_batches < job.num_batches
and completed_batches % checkpoint_interval == 0
):
_save_lora_and_optimizer(
runtime,
adapter_model=adapter_model,
lora_path=job.lora_path,
optimizer_state_path=job.optimizer_state_path,
)
torch.distributed.barrier() # type: ignore[possibly-missing-attribute]

if runtime.rank == 0:
with open(job.log_path, "a+", encoding="utf-8") as log_file:
Expand Down Expand Up @@ -609,11 +623,8 @@ def _load_lora_and_optimizer(
lora_path: str,
optimizer_state_path: str,
) -> dict[str, torch.Tensor]:
adapter_model_path = os.path.join(lora_path, "adapter_model.safetensors")
if not os.path.exists(adapter_model_path):
raise FileNotFoundError(f"No adapter model found at {adapter_model_path}")
print0(runtime.rank, "Loading adapter model from", adapter_model_path)
adapter_model = load_file(adapter_model_path)
print0(runtime.rank, "Loading adapter model from", lora_path)
adapter_model = load_lora_adapter_state_dict(lora_path)
load_adapter_into_model(runtime.model, adapter_model, runtime.optimizer)

optimizer_shard_path = os.path.join(
Expand Down
Loading