@@ -657,10 +657,12 @@ let grad_update ?(setup_for_parallel = false) loss =
657657The OCANNL code for a single parameter update step, including options for weight decay and momentum:
658658
659659``` ocaml
660- let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0) p =
660+ let sgd_one ~learning_rate ?(momentum = 0.0) ?(weight_decay = 0.0)
661+ p =
661662 [%cd
662663 ~~(p "param sgd step";
663- (* Instead of adding a regularizer to the loss tensor, regularize here. *)
664+ (* Instead of adding a regularizer to the loss tensor,
665+ regularize here. *)
664666 {sgd_delta} =: p.grad + (!.weight_decay *. p);
665667 if Float.(momentum > 0.0) then
666668 (* Inline declarations of (non-differentiable) tensors. *)
@@ -881,35 +883,39 @@ Update this before release
881883> * Data can arrive in this buffer via copying, direct pointing (for CPUs or devices on the same GPU), or potentially streaming in the future.
882884> * Unlike a regular device-to-device transfer that writes to a tensor's destination buffer, a transfer into the merge buffer does not.
883885
884- {pause up .example title="Data parallel training: merging gradients in OCANNL"}
885-
886- ``` ocaml
887- (* Define the merge operation: p.grad =+ p.grad.merge *)
888- let grad_merges : Asgns.t array =
889- Array.map all_params ~f:(fun p -> [%cd p.grad =+ p.grad.merge])
890- in
891-
892- (* Compile the merge operation for all necessary device pairs *)
893- let grad_merges_to : Backend.routine option array array =
894- Array.mapi ctxs ~f:(fun dst_n ctx ->
895- if occupancy_dst ~dst_n then
896- snd @@ Backend.link_batch ctx
897- @@ Backend.compile_batch ~shared:true ~occupancy:Idx.Empty grad_merges
898- else [||]
899- )
900- in
886+ {pause up .example #data-parallel-code title="Data parallel training: merging gradients in OCANNL"}
887+ > ``` ocaml
888+ > (* Define the merge operation: p.grad =+ p.grad.merge *)
889+ > let grad_merges : Asgns.t array =
890+ > Array.map all_params ~f:(fun p ->
891+ > [%cd p.grad =+ p.grad.merge])
892+ > in
893+ >
894+ > (* Compile the merge operation for all needed device pairs *)
895+ > let grad_merges_to : Backend.routine option array array =
896+ > Array.mapi ctxs ~f:(fun dst_n ctx ->
897+ > if occupancy_dst ~dst_n then
898+ > snd @@ Backend.link_batch ctx
899+ > @@ Backend.compile_batch ~shared:true
900+ > ~occupancy:Idx.Empty grad_merges
901+ > else [||]
902+ > )
903+ > in
904+ >
905+ > let merge_grads ~(from: int) ~(to_: int) : unit =
906+ > Array.iteri all_params ~f:(fun i p ->
907+ > let grad_merge =
908+ > Option.value_exn grad_merges_to.(to_).(i) in
909+ > (* Fill the merge buffer before running merging. *)
910+ > assert (
911+ > Backend.device_to_device (Option.value_exn p.diff).grad
912+ > ~into_merge_buffer:BT.Copy
913+ > ~dst:grad_merge.context ~src:ctxs.(from));
914+ > Task.run grad_merge.schedule )
915+ > in
916+ > ```
901917
902- let merge_grads ~(from: int) ~(to_: int) : unit =
903- Array.iteri all_params ~f:(fun i p ->
904- let grad_merge = Option.value_exn grad_merges_to.(to_).(i) in
905- (* Fill the merge buffer before running merging. *)
906- assert (
907- Backend.device_to_device (Option.value_exn p.diff).grad ~into_merge_buffer:BT.Copy
908- ~dst:grad_merge.context ~src:ctxs.(from));
909- (* Synchronization now happens automatically. *)
910- Task.run grad_merge.schedule )
911- in
912- ```
918+ {pause down=data-parallel-code}
913919
914920{pause up .block title="OCANNL Features"}
915921> * **Declarative** differentiable tensors.
0 commit comments