Skip to content

Commit f5fb81d

Browse files
committed
Fix: Always reinitialize global state at the beginning of let%expect_test
1 parent 466acd7 commit f5fb81d

File tree

10 files changed

+740
-619
lines changed

10 files changed

+740
-619
lines changed

CHANGES.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929
- Fixed loss of significant digits for small numbers when outputting files.
3030
- Added missing mixed-precision conversions in the `C_syntax` backend builder.
3131
- Restored the functionality of debug logging from the cuda backend.
32+
- Always reinitialize global state at the beginning of `let%expect_test`, to make them more deterministic.
3233

3334
## [0.4.0] -- 2024-09-04
3435

lib/shape.ml

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,11 +238,11 @@ end
238238

239239
type update_id = Update_id.t [@@deriving equal, compare, hash, sexp]
240240

241-
let get_update_id =
242-
let uid = ref 0 in
243-
fun () ->
244-
Int.incr uid;
245-
Update_id.Update_id !uid
241+
let update_uid = ref 0
242+
243+
let get_update_id () =
244+
Int.incr update_uid;
245+
Update_id.Update_id !update_uid
246246

247247
type update_step = { shape : t; logic : logic; id : update_id } [@@deriving sexp]
248248
(** Data required for a shape inference update step. Ideally, an update should be performed at least
@@ -538,6 +538,12 @@ let state = ref Row.empty_env
538538
let active_update_steps = ref []
539539
let active_constraints = ref []
540540

541+
let unsafe_reinitialize () =
542+
update_uid := 0;
543+
state := Row.empty_env;
544+
active_update_steps := [];
545+
active_constraints := []
546+
541547
let iter_shapes update_step ~f =
542548
f update_step.shape;
543549
match update_step.logic with

lib/shape.mli

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -101,6 +101,9 @@ val to_string_hum :
101101
t ->
102102
string
103103

104+
val unsafe_reinitialize : unit -> unit
105+
(** Bring global state to its initialization values. This invalidates any unfinished inference. *)
106+
104107
(** {2 Internal-ish API.} *)
105108

106109
(** How to propagate shape updates and do the last update of [Tensor.t.shape] when finalizing the

lib/tensor.ml

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,13 @@ type session_state = {
7070
let session_state =
7171
{ next_id = 0; forward_roots = Map.empty (module Int); backprop_roots = Map.empty (module Int) }
7272

73+
let unsafe_reinitialize () =
74+
session_state.next_id <- 0;
75+
session_state.forward_roots <- Map.empty (module Int);
76+
session_state.backprop_roots <- Map.empty (module Int);
77+
Tn.Registry.clear Tn.registry;
78+
Shape.unsafe_reinitialize ()
79+
7380
let is_fwd_root t = Map.mem session_state.forward_roots t.id
7481
let remove_fwd_root t = session_state.forward_roots <- Map.remove session_state.forward_roots t.id
7582
let is_bprop_root t = Map.mem session_state.backprop_roots t.id

lib/tensor.mli

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,10 @@ val consume_backprop_code : t -> asgns * asgns
188188
[consume_backprop_code t] ensures [t] is a backprop root, removes it from backprop roots, and
189189
checks that there are no other backprop roots for tensors with children. *)
190190

191+
val unsafe_reinitialize : unit -> unit
192+
(** Bring global state to its initialization values. This invalidates any previously defined tensors
193+
and tensor nodes. Also invokes {!Shape.unsafe_reinitialize}. *)
194+
191195
(** {2 Printing.} *)
192196

193197
val header : t -> string

test/einsum_trivia.ml

Lines changed: 236 additions & 173 deletions
Large diffs are not rendered by default.

test/hello_world_op.ml

Lines changed: 293 additions & 269 deletions
Large diffs are not rendered by default.

test/micrograd_demo.ml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ module CDSL = Train.CDSL
77
module Rand = Arrayjit.Rand.Lib
88

99
let%expect_test "Micrograd README basic example" =
10+
Tensor.unsafe_reinitialize ();
1011
Rand.init 0;
1112
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
1213
let backend = (module Backend : Train.Backend_type with type context = Backend.context) in
@@ -79,6 +80,7 @@ let%expect_test "Micrograd README basic example" =
7980
|}]
8081

8182
let%expect_test "Micrograd half-moons example" =
83+
Tensor.unsafe_reinitialize ();
8284
Rand.init 5;
8385
let module Backend = (val Arrayjit.Backends.fresh_backend ()) in
8486
let backend = (module Backend : Train.Backend_type with type context = Backend.context) in

test/moons_demo_parallel.ml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ let main () =
7878
PrintBox_text.output Stdio.stdout plot_moons
7979

8080
let%expect_test "Half-moons data parallel" =
81+
Tensor.unsafe_reinitialize ();
8182
main ();
8283
(* NOTE: as of OCANNL 0.4, moons_demo_parallel, while deterministic on a single machine, gives
8384
slightly different results on machines with a different hardware, e.g. arm64, ppc. Here we list

test/zero2hero_1of7.ml

Lines changed: 182 additions & 172 deletions
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)