Skip to content

Commit

Permalink
fdsp config dict fix, todo list, add torchdistx support
Browse files Browse the repository at this point in the history
  • Loading branch information
winglian committed Apr 30, 2023
1 parent 9190ada commit ad2b48c
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 3 deletions.
10 changes: 10 additions & 0 deletions TODO.md
@@ -0,0 +1,10 @@
# todo list

- [] Validation of parameters for combinations that won't work



## things that are known not to work

- FSDP offload and gradient_checkpointing - https://github.com/pytorch/pytorch/issues/82203
- adamw_bnb_8bit doesn't play well with FSDP offload
5 changes: 5 additions & 0 deletions src/axolotl/utils/models.py
Expand Up @@ -179,6 +179,11 @@ def load_model(
m.scales = m.scales.half()
m.bias = m.bias.half()

if torch.cuda.device_count() > 1 and int(os.getenv("WORLD_SIZE", "1")) > 1:
model.is_parallelizable = True
model.model_parallel = True


# TODO resume_from_checkpoint handling
return model, tokenizer, lora_config

Expand Down
12 changes: 9 additions & 3 deletions src/axolotl/utils/trainer.py
@@ -1,5 +1,7 @@
import importlib
import math
import os
import sys
from pathlib import Path

import bitsandbytes as bnb
Expand Down Expand Up @@ -35,9 +37,9 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):
else:
training_arguments_kwargs["gradient_checkpointing"] = cfg.gradient_checkpointing
if cfg.fsdp:
training_arguments_kwargs["fsdp"] = cfg.fsdp.split(" ")
if cfg.fsdp_transformer_layer_cls_to_wrap:
training_arguments_kwargs["fsdp_transformer_layer_cls_to_wrap"] = cfg.fsdp_transformer_layer_cls_to_wrap
training_arguments_kwargs["fsdp"] = cfg.fsdp
if cfg.fsdp_config:
training_arguments_kwargs["fsdp_config"] = dict(cfg.fsdp_config)


# deepspeed
Expand Down Expand Up @@ -73,6 +75,10 @@ def setup_trainer(cfg, train_dataset, eval_dataset, model, tokenizer):

trainer_kwargs = {}

if cfg.optimizer == "adamw_anyprecision":
if Path(cfg.torchdistx_path).exists():
sys.path.append(cfg.torchdistx_path)
torchdistx = importlib.import_module('torchdistx')
if cfg.optimizer == "adam8bit" and not cfg.load_4bit and not "deepspeed" in training_arguments_kwargs:
decay_parameters = get_parameter_names(model, [nn.LayerNorm])
decay_parameters = [name for name in decay_parameters if "bias" not in name]
Expand Down

0 comments on commit ad2b48c

Please sign in to comment.