Skip to content

Commit d453b00

Browse files
lukstaficlaude
andcommitted
Remove redundant With_context module from train.ml
The With_context module was just deprecated aliases after the refactoring. Since we've made those functions the default API, the module is no longer needed. Also updated remaining references in hello_world_op.ml to use forward_once directly. 🤖 Generated with Claude Code Co-Authored-By: Claude <noreply@anthropic.com>
1 parent 05159a4 commit d453b00

File tree

2 files changed

+15
-25
lines changed

2 files changed

+15
-25
lines changed

lib/train.ml

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -270,16 +270,6 @@ let sgd_step ~learning_rate ?momentum ?weight_decay ?nesterov ?(bindings = IDX.e
270270
let ctx, sgd_routine = Context.compile ctx sgd_comp bindings in
271271
Context.run ctx sgd_routine
272272

273-
(** Deprecated: Use the module-level functions directly *)
274-
module With_context = struct
275-
let init_params ?(reinit_all = false) ctx t =
276-
init_params ~reinit_all ctx IDX.empty t
277-
let forward ?(bindings = IDX.empty) ctx t =
278-
forward_once ~bindings ctx t
279-
let grad_update ?(bindings = IDX.empty) ctx t =
280-
update_once ~bindings ctx t
281-
let sgd_step = sgd_step
282-
end
283273

284274
(** [printf] is a wrapper around {!Tensor.print} that assumes [~force:true], and by default sets
285275
[~with_code:false], [~with_grad:true], and [~style:`Default]. *)

test/operations/hello_world_op.ml

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ let%expect_test "Pointwise multiplication dims 1" =
1818

1919
(* "Hey" is inferred to be a scalar. *)
2020
let%op y = 2 *. { hey = 7.0 } in
21-
let _ctx = Train.With_context.forward ctx y in
21+
let _ctx = Train.forward_once ctx y in
2222

2323
Train.printf ~here:[%here] ~with_code:false ~with_grad:false y;
2424
[%expect
@@ -40,7 +40,7 @@ let%expect_test "Matrix multiplication dims 1x1" =
4040

4141
(* Hey is inferred to be a matrix because of matrix multiplication [*]. *)
4242
let%op y = ({ hey = 7.0 } * 'q' 2.0) + 'p' 1.0 in
43-
let _ctx = Train.With_context.forward ctx y in
43+
let _ctx = Train.forward_once ctx y in
4444
(* Punning for ["hey"] above introduced the [hey] identifier. *)
4545
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
4646
[%expect
@@ -88,7 +88,7 @@ let%expect_test "Print tensor too early" =
8888
|}];
8989
let%op c = a *. b in
9090

91-
let _ctx = Train.With_context.forward ctx c in
91+
let _ctx = Train.forward_once ctx c in
9292
Train.printf ~here:[%here] c;
9393
[%expect
9494
{|
@@ -108,7 +108,7 @@ let%expect_test "Print constant tensor" =
108108
let ctx = Context.auto () in
109109

110110
let%op hey = [ (1, 2, 3); (4, 5, 6) ] in
111-
let ctx = Train.With_context.forward ctx hey in
111+
let ctx = Train.forward_once ctx hey in
112112
(* ignore (failwith @@ Tn.debug_memory_mode hey.value.memory_mode); *)
113113
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline hey;
114114
[%expect
@@ -134,7 +134,7 @@ let%expect_test "Print constant tensor" =
134134
└─────────────────────────────────────┘
135135
|}];
136136
let%op hoo = [| [ 1; 2; 3 ]; [ 4; 5; 6 ] |] in
137-
let ctx = Train.With_context.forward ctx hoo in
137+
let ctx = Train.forward_once ctx hoo in
138138
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline hoo;
139139
[%expect
140140
{|
@@ -166,7 +166,7 @@ let%expect_test "Print constant tensor" =
166166
((19, 20, 21), (22, 23, 24));
167167
]
168168
in
169-
let ctx = Train.With_context.forward ctx hey2 in
169+
let ctx = Train.forward_once ctx hey2 in
170170
Train.printf ~here:[%here] ~with_code:false ~with_grad:false @@ hey2;
171171
[%expect
172172
{|
@@ -209,7 +209,7 @@ let%expect_test "Print constant tensor" =
209209
[ [ 19; 20; 21 ]; [ 22; 23; 24 ] ];
210210
|]
211211
in
212-
let ctx = Train.With_context.forward ctx hoo2 in
212+
let ctx = Train.forward_once ctx hoo2 in
213213
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline hoo2;
214214
[%expect
215215
{|
@@ -244,7 +244,7 @@ let%expect_test "Print constant tensor" =
244244
[| [ 19; 20; 21 ]; [ 22; 23; 24 ] |];
245245
|]
246246
in
247-
let ctx = Train.With_context.forward ctx heyhoo in
247+
let ctx = Train.forward_once ctx heyhoo in
248248
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline heyhoo;
249249
[%expect
250250
{|
@@ -279,7 +279,7 @@ let%expect_test "Print constant tensor" =
279279
[| [ [ 19; 49 ]; [ 20; 50 ]; [ 21; 51 ] ]; [ [ 22; 52 ]; [ 23; 53 ]; [ 24; 54 ] ] |];
280280
|]
281281
in
282-
let ctx = Train.With_context.forward ctx heyhoo2 in
282+
let ctx = Train.forward_once ctx heyhoo2 in
283283
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline heyhoo2;
284284
[%expect
285285
{|
@@ -343,7 +343,7 @@ let%expect_test "Print constant tensor" =
343343
|];
344344
|]
345345
in
346-
let ctx = Train.With_context.forward ctx heyhoo3 in
346+
let ctx = Train.forward_once ctx heyhoo3 in
347347
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline heyhoo3;
348348
[%expect
349349
{|
@@ -416,7 +416,7 @@ let%expect_test "Print constant tensor" =
416416
];
417417
|]
418418
in
419-
let _ctx = Train.With_context.forward ctx heyhoo4 in
419+
let _ctx = Train.forward_once ctx heyhoo4 in
420420
Train.printf ~here:[%here] ~with_code:false ~with_grad:false ~style:`Inline heyhoo4;
421421
[%expect
422422
{|
@@ -484,7 +484,7 @@ let%expect_test "Matrix multiplication dims 2x3" =
484484

485485
(* Hey is inferred to be a matrix. *)
486486
let%op y = ({ hey = 7.0 } * [ 2; 3 ]) + [ 4; 5; 6 ] in
487-
let _ctx = Train.With_context.forward ctx y in
487+
let _ctx = Train.forward_once ctx y in
488488
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
489489
[%expect
490490
{|
@@ -522,7 +522,7 @@ let%expect_test "Big matrix" =
522522
let hey = TDSL.param ~value:0.5 "hey" () in
523523
let zero_to_twenty = TDSL.range 20 in
524524
let y = TDSL.O.((hey * zero_to_twenty) + zero_to_twenty) in
525-
let _ctx = Train.With_context.forward ctx y in
525+
let _ctx = Train.forward_once ctx y in
526526
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
527527
[%expect
528528
{|
@@ -563,7 +563,7 @@ let%expect_test "Very big tensor" =
563563
in
564564
let%op hoo = (hey * (1 + 1)) - 10 in
565565
Train.set_hosted hey.value;
566-
let _ctx = Train.With_context.forward ctx hoo in
566+
let _ctx = Train.forward_once ctx hoo in
567567
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
568568
[%expect
569569
{|
@@ -730,7 +730,7 @@ let%expect_test "Embed self id" =
730730
Train.set_hosted hey.value;
731731
Train.set_hosted hoo.value;
732732
Train.set_hosted bar.value;
733-
let _ctx = Train.With_context.forward ctx bar in
733+
let _ctx = Train.forward_once ctx bar in
734734
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hey;
735735
Train.printf ~here:[%here] ~with_code:false ~with_grad:false hoo;
736736
Train.printf ~here:[%here] ~with_code:false ~with_grad:false bar;

0 commit comments

Comments
 (0)