Skip to content

Commit e334f43

Browse files
committed
Avoid long file names fro sgd updates
1 parent 632475b commit e334f43

File tree

1 file changed

+3
-3
lines changed

1 file changed

+3
-3
lines changed

lib/train.ml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -111,9 +111,9 @@ let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) ?(nesterov =
111111
p =- learning_rate * sgd_delta ~logic:".")]
112112

113113
let sgd_update ~learning_rate ?momentum ?weight_decay ?nesterov loss =
114-
loss.Tensor.params |> Set.to_list
115-
|> List.map ~f:(sgd_one ~learning_rate ?momentum ?weight_decay ?nesterov)
116-
|> Asgns.sequence
114+
let f = sgd_one ~learning_rate ?momentum ?weight_decay ?nesterov in
115+
let comp = Set.to_list loss.Tensor.params |> List.map ~f |> Asgns.sequence in
116+
{comp with asgns = Asgns.Block_comment ("sgd_update", comp.asgns)}
117117

118118
(** All and only bindings with associated ranges are iterated, with the binding's initial value
119119
lost. Bindings without ranges remain at their initial values. *)

0 commit comments

Comments
 (0)