Skip to content

Commit 88f6de2

Browse files
committed
session state Tensor id validation in op
Added checks for tensor IDs in `op` to prevent invalid state usage. Updated `unsafe_reinitialize` documentation to clarify its purpose in preventing session state pollution.
1 parent 6390ff3 commit 88f6de2

File tree

2 files changed

+19
-6
lines changed

2 files changed

+19
-6
lines changed

lib/tensor.ml

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -203,10 +203,20 @@ type grad_spec = Require_grad | Prohibit_grad | If_needed [@@deriving sexp, equa
203203
let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
204204
?(compose_op = Shape.Pointwise_bin) ?(transpose_op = Shape.Pointwise_un) ?terminal_op ~op_asn
205205
~grad_asn ?(grad_spec = If_needed) make_shape (orig_ts : t list) : t =
206+
List.iter orig_ts ~f:(fun t ->
207+
if t.id >= session_state.next_id then
208+
raise
209+
@@ Session_error
210+
( [%string
211+
"Tensor #%{t.id#Int} %{Tn.debug_name t.value} has an id greater than the last id \
212+
#%{session_state.next_id - 1#Int} -- check your uses of \
213+
Tensor.unsafe_reinitialize, if all your uses are valid, report this as a bug."],
214+
Some t ));
206215
(* The code needs to be included in the order it was computed due to potential non-tree DAGs. *)
207216
let ordered_ts = List.dedup_and_sort orig_ts ~compare:(fun t1 t2 -> Int.ascending t1.id t2.id) in
208217
let id : int = session_state.next_id in
209218
session_state.next_id <- session_state.next_id + 1;
219+
let _session_state_next_id : int = session_state.next_id in
210220
let shape = make_shape ~debug_name:(Tn.get_debug_name ~id ~label ()) ~id in
211221
let default_prec =
212222
let lazy_v_precs = List.map orig_ts ~f:(fun ti -> ti.value.prec) in
@@ -247,8 +257,8 @@ let%track7_sexp op ~(label : string list) ?(ternary_op = Shape.Pointwise_tern)
247257
| [ t1; t2 ] -> [ Shape.Broadcast (compose_op, t1.shape, t2.shape) ]
248258
| [ t1; t2; t3 ] -> [ Shape.Broadcast_tern (ternary_op, t1.shape, t2.shape, t3.shape) ]
249259
| _ ->
250-
(* Let's implement what we need when we need it. *)
251-
assert false
260+
(* Let's implement what we need when we need it. *)
261+
assert false
252262
in
253263
let local_shape_updates =
254264
List.map ~f:(fun logic -> Shape.{ shape; logic; id = get_update_id () }) @@ shape_logics orig_ts
@@ -381,8 +391,8 @@ let%track7_sexp unop ?transpose_op ~op_asn ~grad_asn ?grad_spec t1 ?(label = [])
381391
())
382392
[ t1 ]
383393

384-
let%track7_sexp term ?init_data ?fetch_op ?grad_spec ?(label = []) ?batch_dims ?batch_axes ?input_dims
385-
?output_dims ?input_axes ?output_axes ?deduced () : t =
394+
let%track7_sexp term ?init_data ?fetch_op ?grad_spec ?(label = []) ?batch_dims ?batch_axes
395+
?input_dims ?output_dims ?input_axes ?output_axes ?deduced () : t =
386396
let terminal_op =
387397
match (init_data, fetch_op) with
388398
| Some _, Some _ -> invalid_arg "Tensor.term: both init_data and fetch_op are provided"
@@ -566,7 +576,7 @@ let rec get_random_seed () =
566576
set_random_seed ();
567577
get_random_seed ()
568578

569-
let%track5_sexp unsafe_reinitialize () =
579+
let%track5_sexp unsafe_reinitialize () : unit =
570580
session_state.next_id <- 0;
571581
session_state.forward_roots <- Map.empty (module Int);
572582
session_state.backprop_roots <- Map.empty (module Int);

lib/tensor.mli

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,10 @@ val iter_embedded : f:(tn -> unit) -> t -> unit
226226
val unsafe_reinitialize : unit -> unit
227227
(** Bring global state to its initialization values. This invalidates any previously defined tensors
228228
and tensor nodes. Also reinitializes the modules: {!Shape}, {!Ir.Tnode},
229-
{!Ir.Rand.Random_for_tests}. *)
229+
{!Ir.Rand.Random_for_tests}.
230+
231+
While this function is intended for testing, using it can prevent unintentional session state
232+
pollution errors. *)
230233

231234
val set_random_seed : ?seed:int -> unit -> unit
232235
(** Creates the random seed tensor. If [seed] is provided, it is used to set the random seed.

0 commit comments

Comments
 (0)