Skip to content

Commit

Permalink
Fix learning rate calculation in pretrain (#1435)
Browse files Browse the repository at this point in the history
  • Loading branch information
rasbt committed May 23, 2024
1 parent daffef0 commit 66a797a
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 2 deletions.
2 changes: 1 addition & 1 deletion extensions/thunder/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,7 +284,7 @@ def fit(
break

# determine and set the learning rate for this iteration
lr = get_lr(optimizer.param_groups[0]["lr"], state["iter_num"], warmup_iters, max_iters, train.min_lr)
lr = get_lr(optimizer.defaults["lr"], state["iter_num"], warmup_iters, max_iters, train.min_lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr

Expand Down
2 changes: 1 addition & 1 deletion litgpt/pretrain.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ def fit(
break

# determine and set the learning rate for this iteration
lr = get_lr(optimizer.param_groups[0]["lr"], state["iter_num"], warmup_iters, max_iters, train.min_lr)
lr = get_lr(optimizer.defaults["lr"], state["iter_num"], warmup_iters, max_iters, train.min_lr)
for param_group in optimizer.param_groups:
param_group["lr"] = lr

Expand Down

0 comments on commit 66a797a

Please sign in to comment.