Skip to content

Commit

Permalink
Merge pull request FlagAI-Open#266 from shunxing1234/master
Browse files Browse the repository at this point in the history
add optimizers
  • Loading branch information
ftgreat committed Mar 16, 2023
2 parents 1f72cf5 + 83a9bf4 commit 2e0347e
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 1 deletion.
28 changes: 28 additions & 0 deletions flagai/optimizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,34 @@ def get_optimizer(param_groups,
lr=lr,
relative_step=False,
warmup_init=False)
elif optimizer == 'adamw':
from torch.optim import AdamW
optimizer = AdamW(param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
eps=adam_eps)
elif optimizer == 'lion':
from lion_pytorch import Lion
optimizer = Lion(param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2)
)
elif optimizer == 'adan':
from adan import Adan
optimizer = Adan(param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2, 0.99),
eps=adam_eps)
elif optimizer == 'lamb':
from torch_optimizer import Lamb
optimizer = Lamb(param_groups,
lr=lr,
weight_decay=weight_decay,
betas=(adam_beta1, adam_beta2),
eps=adam_eps)
else:
raise NotImplementedError

Expand Down
5 changes: 4 additions & 1 deletion flagai/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,7 @@ def __init__(
deepspeed_config=None,
model_parallel_size=1,
training_script="train.py",
optimizer_type='adam',
):

if timers is not None:
Expand All @@ -185,6 +186,8 @@ def __init__(
self.eval_interval = eval_interval
self.tokenizer = tokenizer

self.optimizer_type = optimizer_type

# model checkpointing
self.save_dir = save_dir
self.save_interval = save_interval
Expand Down Expand Up @@ -471,7 +474,7 @@ def train(self,
cpu_optimizer=False,
cpu_torch_adam=False,
fp16=self.fp16,
optimizer='adam') # if not self.fp16 else 'adafactor')
optimizer=self.optimizer_type) # if not self.fp16 else 'adafactor')

if lr_scheduler == None and optimizer != None and self.warm_up > 0 and 'deepspeed' not in self.env_type and self.epochs > 0:
if self.env_type == 'bmtrain':
Expand Down

0 comments on commit 2e0347e

Please sign in to comment.