@@ -10,12 +10,9 @@ module Tn = Ir.Tnode
1010
1111module type Backend = Ir.Backend_intf. Backend
1212
13- let () =
13+ let () =
1414 Tensor. unsafe_reinitialize () ;
1515 Rand. init 0 ;
16- Utils. set_log_level 2 ;
17- Utils. settings.output_debug_files_in_build_directory < - true ;
18- Utils. settings.debug_log_from_routines < - true ;
1916 let module Backend = (val Backends. fresh_backend () ) in
2017 let stream = Backend. (new_stream @@ get_device ~ordinal: 0 ) in
2118 let ctx = Backend. make_context stream in
@@ -35,9 +32,37 @@ let () =
3532 let ctx = init.context in
3633 let update = Train. grad_update g in
3734 let step = Train. to_routine (module Backend ) ctx IDX. empty update in
35+ Tn. print_accessible_headers () ;
3836 Utils. capture_stdout_logs @@ fun () ->
3937 Train. run init;
4038 Train. run step;
4139 Tensor. print ~here: [% here] ~with_code: false ~with_grad: false `Default g;
4240 Tensor. print ~here: [% here] ~with_code: false ~with_grad: true `Default a;
4341 Tensor. print ~here: [% here] ~with_code: false ~with_grad: true `Default b
42+
43+ let _suspended () =
44+ Tensor. unsafe_reinitialize () ;
45+ Rand. init 0 ;
46+ let module Backend = (val Backends. fresh_backend () ) in
47+ let stream = Backend. (new_stream @@ get_device ~ordinal: 0 ) in
48+ let ctx = Backend. make_context stream in
49+ let % op c = " a" [ - 4 ] + " b" [ 2 ] in
50+ let % op d = (a *. b) + (b **. 3 ) in
51+ let % op c = c + c + 1 in
52+ let % op c = c + 1 + c + ~- a in
53+ let % op d = d + (d *. 2 ) + relu (b + a) in
54+ let % op d = d + (3 *. d) + relu (b - a) in
55+ let % op e = c - d in
56+ let % op f = e **. 2 in
57+ let % op g = f /. 2 in
58+ let % op g = g + (10. /. f) in
59+ List. iter ~f: (Option. iter ~f: (fun diff -> Train. set_hosted diff.Tensor. grad)) [ a.diff; b.diff ];
60+ let init_params = Tensor. init_params g in
61+ let update = Train. grad_update g in
62+ let step = Train. to_routine (module Backend ) ctx IDX. empty @@ Asgns. sequence [init_params; update] in
63+ Tn. print_accessible_headers () ;
64+ Utils. capture_stdout_logs @@ fun () ->
65+ Train. run step;
66+ Tensor. print ~here: [% here] ~with_code: false ~with_grad: false `Default g;
67+ Tensor. print ~here: [% here] ~with_code: false ~with_grad: true `Default a;
68+ Tensor. print ~here: [% here] ~with_code: false ~with_grad: true `Default b
0 commit comments