Skip to content

Commit d342565

Browse files
authored
Merge pull request #2517 from huggingface/legacy_adadamw_update
Update legacy AdamW impl so it has a multi-tensor impl like NAdamW (n…
2 parents 03f4f4d + 8343358 commit d342565

File tree

3 files changed

+340
-42
lines changed

3 files changed

+340
-42
lines changed

tests/test_optim.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,7 @@ def test_optim_factory(optimizer):
298298
assert isinstance(opt_info, OptimInfo)
299299

300300
lr = (1e-2,) * 4
301-
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'clamb'):
301+
if optimizer in ('mars', 'nadam', 'claprop', 'crmsproptf', 'cadafactorbv', 'csgdw', 'csgdc', 'clamb'):
302302
lr = (1e-3,) * 4
303303
elif optimizer in ('cmars',):
304304
lr = (1e-4,) * 4
@@ -378,7 +378,7 @@ def test_sgd(optimizer):
378378
_test_model(optimizer, dict(lr=1e-3))
379379

380380

381-
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw'])
381+
@pytest.mark.parametrize('optimizer', ['adamw', 'adam', 'nadam', 'adamax', 'nadamw', 'adamwlegacy', 'adamc'])
382382
def test_adam(optimizer):
383383
_test_rosenbrock(
384384
lambda params: create_optimizer_v2(params, optimizer, lr=5e-2)

0 commit comments

Comments
 (0)