Skip to content

Commit 309a89a

Browse files
committed
Auto-set hosted for Train.forward_and_ctx / forward_and_forget
1 parent 2285eaf commit 309a89a

File tree

2 files changed

+60
-618
lines changed

2 files changed

+60
-618
lines changed

lib/train.ml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -488,21 +488,22 @@ let example_train_loop ?(disable_rootness_check = false) ~seed ~batch_size ~init
488488
}
489489

490490
(* Note: this will get nicer with modular explicits. *)
491-
let%track3_sexp forward_and_ctx ?(disable_rootness_check = false)
491+
let%track3_sexp forward_and_ctx ?(hosted=true) ?(disable_rootness_check = false)
492492
(type buffer_ptr dev runner event optimize_ctx)
493493
(module Backend : Backend
494494
with type buffer_ptr = buffer_ptr
495495
and type dev = dev
496496
and type runner = runner
497497
and type optimize_ctx = optimize_ctx
498498
and type event = event) ctx ?(bindings = IDX.empty) t =
499+
if hosted then set_hosted t.Tensor.value;
499500
let routine =
500501
Backend.(link ctx @@ compile ctx.optimize_ctx bindings @@ forward ~disable_rootness_check t)
501502
in
502503
if not disable_rootness_check then Tensor.remove_bprop_root t;
503504
Task.run routine.schedule;
504505
routine.context
505506

506-
let forward_and_forget ?disable_rootness_check backend ctx ?bindings t =
507+
let forward_and_forget ?hosted ?disable_rootness_check backend ctx ?bindings t =
507508
(* FIXME: to properly forget we need to free the incrementally-allocated memory! *)
508-
ignore @@ forward_and_ctx ?disable_rootness_check backend ctx ?bindings t
509+
ignore @@ forward_and_ctx ?hosted ?disable_rootness_check backend ctx ?bindings t

0 commit comments

Comments
 (0)